kokkos-execution 0.0.1
Loading...
Searching...
No Matches
when_all.hpp
Go to the documentation of this file.
1#ifndef KOKKOS_EXECUTION_GRAPH_WHEN_ALL_HPP
2#define KOKKOS_EXECUTION_GRAPH_WHEN_ALL_HPP
3
5
6#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
7# include "plog/Log.h"
8#endif
9
10#include "Kokkos_Graph.hpp"
11
13
24
26
33template <Kokkos::ExecutionSpace Exec, stdexec::receiver Rcvr, stdexec::sender... Sndrs>
35 : public Impl::Immovable
36 , public OpStateBase<Exec, Rcvr> {
37 using operation_state_concept = stdexec::operation_state_tag;
38
40 using execution_space = Exec;
41
43 using root_t = typename state_t::graph_t::root_t;
44
46 struct WhenAllChildReceiver : public Impl::Receiver<WhenAllOpState, stdexec::env_of_t<Rcvr>> {
47 [[nodiscard]]
48 constexpr auto query(get_node_t) const & noexcept -> const root_t& {
49 return this->parent_op->root;
50 }
51 };
52
53 using children_opstates_t = stdexec::__tuple<stdexec::connect_result_t<Sndrs, WhenAllChildReceiver>...>;
54
55#if defined(KOKKOS_ENABLE_DEBUG)
57 static_assert(
58 stdexec::__mapply<
59 stdexec::__mall_of<stdexec::__q<Impl::queryable_for<get_node_t>::type>>,
61 >::value,
62 "Child senders of the 'when_all' must lead to 'get_node_t' queryable operation states.");
63#endif
64
67
68 using node_t = decltype(stdexec::__apply(
69 [](const auto&... ops) { return Kokkos::Experimental::when_all(ops.query(get_node)...); },
70 std::declval<const children_opstates_t&>()));
71
77 std::atomic<size_t> count = sizeof...(Sndrs);
78
80 WhenAllOpState(stdexec::__tuple<Sndrs...>&& sndrs_, Rcvr&& rcvr_)
81 : base_t(std::move(rcvr_))
88 , state{Kokkos::Experimental::get_device_handle(execution_space{})}
89 , root(state.graph.root_node())
91 stdexec::__apply(
92 [this]<typename... Children>(Children&&... children) -> children_opstates_t {
94 stdexec::connect(std::forward<Children>(children), WhenAllChildReceiver{this})...};
95 },
96 std::move(sndrs_)))
97 , node(
98 stdexec::__apply(
99 [](const auto&... child_op) {
100 auto agg = Kokkos::Experimental::when_all(child_op.query(get_node)...);
101 graph_add_aggregate_node_event(agg, child_op.query(get_node)...);
102 return agg;
103 },
104 children_opstates)) {
105 }
106
107 const auto& query(get_node_t) const & noexcept {
108 return node;
109 }
110
111 const auto& query(get_graph_t) const & noexcept {
112 return state.graph;
113 }
114
139 void submit() & noexcept requires(!as_one)
140 {
141 if (count.fetch_sub(1) == 1) {
142 this->submit_graph();
143 }
144 }
145
146 void submit() & noexcept requires(as_one)
147 {
148 this->submit_graph();
149 }
150
151 template <typename Tag, typename... Args>
152 requires(!std::same_as<Tag, stdexec::set_value_t>)
153 void complete(Tag, Args&&... args) & noexcept {
154 base_t::complete(Tag{}, std::forward<Args>(args)...);
155 }
156
157 void start() & noexcept requires(!as_one)
158 {
159#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
160 PLOG_INFO << "Starting all branches before submission.";
161#endif
162 stdexec::__apply([](auto&... ops) -> void { (stdexec::start(ops), ...); }, children_opstates);
163 }
164
166 void start() & noexcept requires(as_one)
167 {
168#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
169 PLOG_INFO << "Submit the graph directly without starting the branches.";
170#endif
171 this->submit();
172 }
173
174 void submit_graph() & noexcept {
175#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
176 PLOG_INFO << "Submitting graph " << get_graph_impl_ptr(state.graph.root_node()) << " on "
177 << Kokkos::Tools::Experimental::device_id(state.get_device_handle().m_exec) << '.';
178#endif
179 try {
180 Kokkos::Execution::GraphImpl::submit_graph(state.graph, state.get_device_handle().m_exec);
181 } catch (...) {
182 stdexec::set_error(std::move(this->completion_signal.rcvr), std::current_exception());
183 return;
184 }
185 this->completion_signal.propagate(state.get_device_handle().m_exec);
186 }
187
189};
190
192template <stdexec::operation_state OpState, Kokkos::ExecutionSpace Exec>
193requires(
194 stdexec::__is_instance_of<OpState, Kokkos::Execution::GraphImpl::WhenAllOpState>
195 && std::same_as<typename OpState::execution_space, Exec>)
196struct GraphOperationStateFor<OpState, Exec> : public std::true_type { };
197
199template <stdexec::operation_state OpState, Kokkos::ExecutionSpace Exec>
200requires(
202 && stdexec::__is_instance_of<OpState, Kokkos::Execution::GraphImpl::WhenAllOpState>)
204 template <stdexec::operation_state ChildOpState>
206
207 static constexpr bool value = stdexec::__mapply<
208 stdexec::__mall_of<stdexec::__q<RemainsOnGraphForChild>>,
209 typename OpState::children_opstates_t
210 >::value;
211};
212
214template <Kokkos::ExecutionSpace Exec, stdexec::sender... Sndrs>
216 using sender_concept = stdexec::sender_tag;
217
218 using sndrs_t = stdexec::__tuple<Sndrs...>;
219
220 struct attrs {
221 template <typename... Env>
222 [[nodiscard]]
223 constexpr auto
224 query(stdexec::get_completion_domain_t<stdexec::set_value_t>, const Env&...) const noexcept -> Domain {
225 return {};
226 }
227 };
228
230 template <typename Self, typename... Env>
231 static consteval auto get_completion_signatures() {
232 return stdexec::completion_signatures<stdexec::set_value_t(), stdexec::set_error_t(std::exception_ptr)>{};
233 }
234
235 template <stdexec::receiver Rcvr>
236 stdexec::operation_state auto connect(Rcvr rcvr) && noexcept(
237 std::is_nothrow_constructible_v<WhenAllOpState<Exec, Rcvr, Sndrs...>, sndrs_t&&, Rcvr&&>) {
238 return WhenAllOpState<Exec, Rcvr, Sndrs...>(std::move(sndrs), std::move(rcvr));
239 }
240
241 constexpr auto get_env() const noexcept -> attrs {
242 return {};
243 }
244
246};
247
248struct BECAUSE_THE_EXECUTION_SPACE_TYPE_IS_NOT_HOMOGENEOUS;
249
250template <size_t Index, typename Sndr>
252
253template <size_t Index, typename Sndr>
255
256template <>
257struct TransformSenderFor<stdexec::when_all_t> {
258 template <typename Env, typename... Sndrs>
259 using trnsfrmd_sndr_t = WhenAllSender<Impl::exec_of_t<stdexec::__m_at_c<0, Sndrs...>, Env>, Sndrs...>;
260
261 template <typename Env, typename... Sndrs>
262 auto operator()(const Env&, stdexec::when_all_t, stdexec::__ignore, Sndrs&&... sndrs) const
263 noexcept(std::is_nothrow_constructible_v<typename trnsfrmd_sndr_t<Env, Sndrs...>::sndrs_t, Sndrs&&...>) {
264 if constexpr ((graph_completing_sender<Sndrs, Env> && ...)) {
265 using execution_space = Impl::exec_of_t<stdexec::__m_at_c<0, Sndrs...>, Env>;
266
268 if constexpr ((std::same_as<Impl::exec_of_t<Sndrs, Env>, execution_space> && ...)) {
269 return trnsfrmd_sndr_t<Env, Sndrs...>{.sndrs = {std::forward<Sndrs>(sndrs)...}};
270 } else {
272 STDEXEC_CONSTEXPR_LOCAL bool map[] = {
273 !std::same_as<Impl::exec_of_t<stdexec::__m_at_c<0, Sndrs>, Env>, execution_space>...};
274 STDEXEC_CONSTEXPR_LOCAL std::size_t index = stdexec::__pos_of(map, map + sizeof...(Sndrs));
275 using invalid_sndr_t = stdexec::__m_at_c<index, Sndrs...>;
276 return stdexec::__not_a_sender<
277 stdexec::_WHAT_(CANNOT_DISPATCH_THIS_ALGORITHM_TO_THE_GRAPH_SCHEDULER),
278 stdexec::_WHY_(BECAUSE_THE_EXECUTION_SPACE_TYPE_IS_NOT_HOMOGENEOUS),
279 stdexec::_WHERE_(stdexec::_IN_ALGORITHM_, stdexec::when_all_t),
281 stdexec::_WITH_PRETTY_SENDERS_<Sndrs...>,
282 stdexec::_WITH_ENVIRONMENT_(Env)
283 >{};
284 }
285 } else {
287 STDEXEC_CONSTEXPR_LOCAL bool map[] = {!graph_completing_sender<Sndrs, Env>...};
288 STDEXEC_CONSTEXPR_LOCAL std::size_t index = stdexec::__pos_of(map, map + sizeof...(Sndrs));
289 using invalid_sndr_t = stdexec::__m_at_c<index, Sndrs...>;
291 }
292 }
293};
294
295} // namespace Kokkos::Execution::GraphImpl
296
297// NOLINTBEGIN(bugprone-reserved-identifier)
298namespace stdexec::__detail {
299template <typename... Sndrs>
300extern __mtype<Kokkos::Execution::GraphImpl::WhenAllSender<__demangle_t<Sndrs>...>>
301 __demangle_v<Kokkos::Execution::GraphImpl::WhenAllSender<Sndrs...>>;
302} // namespace stdexec::__detail
303// NOLINTEND(bugprone-reserved-identifier)
304
305#endif // KOKKOS_EXECUTION_GRAPH_WHEN_ALL_HPP
Concept for a sender whose completion scheduler is Kokkos::Execution::GraphImpl::Scheduler.
#define KOKKOS_EXECUTION_GET_ENV(_type_, _obj_)
Retrieve the environment of _obj_. // NOLINTNEXTLINE(cppcoreguidelines-macro-usage).
Definition env.hpp:14
void graph_add_aggregate_node_event(const NodeType &aggregate, const Predecessors &... predecessors)
Record an event for an aggregate node added after predecessors.
Definition events.hpp:147
WITH_SENDER_AT_INDEX< Index, stdexec::__demangle_t< Sndr > > WITH_PRETTY_SENDER_AT_INDEX
Definition when_all.hpp:254
constexpr get_node_t get_node
Definition get_node.hpp:15
auto * get_graph_impl_ptr(const NodeType &node) noexcept
Retrieve the raw graph pointer from a node.
Definition events.hpp:93
void submit_graph(const Kokkos::Experimental::Graph< Exec > &graph, const Exec &exec)
Submit a graph and record the associated event with graph_submit_event.
Definition events.hpp:178
auto no_graph_scheduler_in_env() noexcept
Show a better compile diagnostic when there is no Kokkos::Execution::GraphImpl::Scheduler found.
typename ExecOf< Args... >::type exec_of_t
Definition get_exec.hpp:37
void complete(stdexec::set_error_t, Error &&error) noexcept
constexpr OpStateBase(Rcvr rcvr) noexcept(std::is_nothrow_constructible_v< completion_signal_t, Rcvr && >)
RemainsOnGraphFor< ChildOpState, Exec > RemainsOnGraphForChild
Definition when_all.hpp:205
WhenAllSender< Impl::exec_of_t< stdexec::__m_at_c< 0, Sndrs... >, Env >, Sndrs... > trnsfrmd_sndr_t
Definition when_all.hpp:259
auto operator()(const Env &, stdexec::when_all_t, stdexec::__ignore, Sndrs &&... sndrs) const noexcept(std::is_nothrow_constructible_v< typename trnsfrmd_sndr_t< Env, Sndrs... >::sndrs_t, Sndrs &&... >)
Definition when_all.hpp:262
constexpr auto query(get_node_t) const &noexcept -> const root_t &
Definition when_all.hpp:48
Operation state for stdexec::when_all.
Definition when_all.hpp:36
const auto & query(get_node_t) const &noexcept
Definition when_all.hpp:107
WhenAllOpState(stdexec::__tuple< Sndrs... > &&sndrs_, Rcvr &&rcvr_)
Definition when_all.hpp:80
State< GraphComposition::Create, execution_space > state_t
Definition when_all.hpp:42
stdexec::operation_state_tag operation_state_concept
Definition when_all.hpp:37
static constexpr bool as_one
Determine if all branches remain fully on the graph, if connected to WhenAllChildReceiver.
Definition when_all.hpp:66
void start() &noexcept
If as_one is true, there is no need to start the branches.
Definition when_all.hpp:166
typename state_t::graph_t::root_t root_t
Definition when_all.hpp:43
void complete(Tag, Args &&... args) &noexcept
Definition when_all.hpp:153
decltype(stdexec::__apply([](const auto &... ops) { return Kokkos::Experimental::when_all(ops.query(get_node)...);}, std::declval< const children_opstates_t & >())) node_t
Definition when_all.hpp:68
const auto & query(get_graph_t) const &noexcept
Definition when_all.hpp:111
stdexec::__tuple< stdexec::connect_result_t< Sndrs, WhenAllChildReceiver >... > children_opstates_t
Definition when_all.hpp:53
constexpr auto query(stdexec::get_completion_domain_t< stdexec::set_value_t >, const Env &...) const noexcept -> Domain
Definition when_all.hpp:224
Sender for stdexec::when_all.
Definition when_all.hpp:215
constexpr auto get_env() const noexcept -> attrs
Definition when_all.hpp:241
static consteval auto get_completion_signatures()
Definition when_all.hpp:231
stdexec::__tuple< Sndrs... > sndrs_t
Definition when_all.hpp:218
stdexec::operation_state auto connect(Rcvr rcvr) &&noexcept(std::is_nothrow_constructible_v< WhenAllOpState< Exec, Rcvr, Sndrs... >, sndrs_t &&, Rcvr && >)
Definition when_all.hpp:236
Receiver for an object parent_op that implements complete.
Definition receiver.hpp:13
Kokkos::DefaultExecutionSpace execution_space