kokkos-execution 0.0.1
Loading...
Searching...
No Matches
test_fork_join.cpp
Go to the documentation of this file.
2PRAGMA_DIAGNOSTIC_PUSH
4#include "exec/fork_join.hpp"
5#include "exec/static_thread_pool.hpp"
6PRAGMA_DIAGNOSTIC_POP
7
10
18
30
32
33using namespace Kokkos::utils::callbacks;
34
44
46TEST_F(ForkJoinTest, diamond) {
47 const view_s_t data(Kokkos::view_alloc("data - shared space"));
48
49 experimental::execution::static_thread_pool pool{4};
50 const context_t esc{exec};
51
52 auto chain =
53 stdexec::schedule(pool.get_scheduler())
54 | stdexec::then(Tests::Utils::Functors::LoadCheckAdd<int, false>{.prev = 0, .value = 4, .data = data.data()})
55 | experimental::execution::fork_join(
56 stdexec::continues_on(esc.get_scheduler())
58 stdexec::continues_on(pool.get_scheduler()) | THEN_INCREMENT_ATOMIC(data))
59 | stdexec::continues_on(esc.get_scheduler())
60 | stdexec::then(
61 Tests::Utils::Functors::LoadCheckAdd<int, on_device>{.prev = 6, .value = 3, .data = data.data()})
62 | stdexec::continues_on(stdexec::inline_scheduler{})
63 | stdexec::then(Tests::Utils::Functors::LoadCheckAdd<int, false>{.prev = 9, .value = 5, .data = data.data()});
64
65 ASSERT_EQ(data(), 0) << "Eager execution is not allowed.";
66
67 ASSERT_THAT(
68 recorder_listener_t::record([chain = std::move(chain)]() mutable { stdexec::sync_wait(std::move(chain)); }),
69 testing::ElementsAre(
72 MATCHER_FOR_BEGIN_FENCE(exec, dispatch_label(exec, "schedule_from"))));
73
74 ASSERT_EQ(data(), 14);
75}
76
82TEST_F(ForkJoinTest, continues_on) {
83 const view_s_t data(Kokkos::view_alloc(exec, "data - shared space"));
84
85 const context_t esc{exec};
86
87 auto sndr =
88 stdexec::just() | stdexec::continues_on(esc.get_scheduler())
89 | experimental::execution::fork_join(
90 stdexec::continues_on(esc.get_scheduler())
91 | stdexec::then(
92 Tests::Utils::Functors::LoadCheckAdd<int, on_device>{.prev = 0, .value = 3, .data = data.data()}));
93
94 ASSERT_EQ(data(), 0) << "Eager execution is not allowed.";
95
96 ASSERT_THAT(
97 recorder_listener_t::record([sndr = std::move(sndr)]() mutable { // NOLINT(performance-move-const-arg)
98 stdexec::sync_wait(std::move(sndr)); // NOLINT(performance-move-const-arg)
99 }),
100 testing::ElementsAre(
103
104 ASSERT_EQ(data(), 3);
105}
106
112TEST_F(ForkJoinTest, continues_on_bulk) {
113 const view_s_t data(Kokkos::view_alloc(exec, "data - shared space"));
114
115 const context_t esc{exec};
116
117 auto sndr =
118 stdexec::just() | stdexec::continues_on(esc.get_scheduler()) | BULK_SUM_INDICES(2, data)
119 | experimental::execution::fork_join(
120 stdexec::continues_on(esc.get_scheduler())
121 | stdexec::then(
122 Tests::Utils::Functors::LoadCheckAdd<int, on_device>{.prev = 1, .value = 2, .data = data.data()}));
123
124 ASSERT_EQ(data(), 0) << "Eager execution is not allowed.";
125
126 ASSERT_THAT(
127 recorder_listener_t::record([sndr = std::move(sndr)]() mutable { // NOLINT(performance-move-const-arg)
128 stdexec::sync_wait(std::move(sndr)); // NOLINT(performance-move-const-arg)
129 }),
130 testing::ElementsAre(
134
135 ASSERT_EQ(data(), 3);
136}
137
138} // namespace Tests::ExecutionSpaceImpl
constexpr std::string dispatch_label(const Exec &, Label &&label)
Get the dispatch label from Exec and label.
#define MATCHER_FOR_BEGIN_PFOR(_exec_, _label_)
#define MATCHER_FOR_BEGIN_FENCE(_exec_, _label_)
RecorderListener< EventDiscardMatcher< TEST_EXECUTION_SPACE >, BeginFenceEvent, BeginParallelForEvent > recorder_listener_t
#define KOKKOS_EXECUTION_STDEXEC_PRAGMA_DIAGNOSTIC_IGNORED
Basic list of ignored diagnostics when including anything from stdexec.
#define THEN_INCREMENT_ATOMIC(_data_)
Same as THEN_INCREMENT, using Tests::Utils::atomic_add. // NOLINTNEXTLINE(cppcoreguidelines-macro-usa...
Definition increment.hpp:39
constexpr check_scheduler_type_t< Tag, Schd > check_scheduler_type
constexpr bool on_device()
Definition kokkos.hpp:22
auto get_scheduler() const noexcept -> ExecutionSpaceImpl::Scheduler< Exec >
Kokkos::Execution::ExecutionSpaceContext< Exec > context_t
Definition context.hpp:25
Load the value at data and check it is equal to prev. Then, add value to it.
#define BULK_SUM_INDICES(_size_, _data_)
Add a bulk using Tests::Utils::Functors::SumIndices. // NOLINTNEXTLINE(cppcoreguidelines-macro-usage)...