1#ifndef KOKKOS_EXECUTION_GRAPH_WHEN_ALL_HPP
2#define KOKKOS_EXECUTION_GRAPH_WHEN_ALL_HPP
6#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
10#include "Kokkos_Graph.hpp"
33template <Kokkos::ExecutionSpace Exec, stdexec::receiver Rcvr, stdexec::sender... Sndrs>
43 using root_t =
typename state_t::graph_t::root_t;
53 using children_opstates_t = stdexec::__tuple<stdexec::connect_result_t<Sndrs, WhenAllChildReceiver>...>;
55#if defined(KOKKOS_ENABLE_DEBUG)
59 stdexec::__mall_of<stdexec::__q<Impl::queryable_for<get_node_t>::type>>,
62 "Child senders of the 'when_all' must lead to 'get_node_t' queryable operation states.");
68 using node_t =
decltype(stdexec::__apply(
69 [](
const auto&... ops) {
return Kokkos::Experimental::when_all(ops.query(
get_node)...); },
70 std::declval<const children_opstates_t&>()));
77 std::atomic<size_t>
count =
sizeof...(Sndrs);
99 [](
const auto&... child_op) {
100 auto agg = Kokkos::Experimental::when_all(child_op.query(
get_node)...);
104 children_opstates)) {
141 if (
count.fetch_sub(1) == 1) {
146 void submit() &
noexcept requires(as_one)
151 template <
typename Tag,
typename... Args>
152 requires(!std::same_as<Tag, stdexec::set_value_t>)
159#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
160 PLOG_INFO <<
"Starting all branches before submission.";
162 stdexec::__apply([](
auto&... ops) ->
void { (stdexec::start(ops), ...); },
children_opstates);
168#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
169 PLOG_INFO <<
"Submit the graph directly without starting the branches.";
175#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
177 << Kokkos::Tools::Experimental::device_id(
state.get_device_handle().m_exec) <<
'.';
182 stdexec::set_error(std::move(this->
completion_signal.rcvr), std::current_exception());
192template <stdexec::operation_state OpState, Kokkos::ExecutionSpace Exec>
194 stdexec::__is_instance_of<OpState, Kokkos::Execution::GraphImpl::WhenAllOpState>
195 && std::same_as<typename OpState::execution_space, Exec>)
199template <stdexec::operation_state OpState, Kokkos::ExecutionSpace Exec>
202 && stdexec::__is_instance_of<OpState, Kokkos::Execution::GraphImpl::WhenAllOpState>)
204 template <stdexec::operation_state ChildOpState>
207 static constexpr bool value = stdexec::__mapply<
208 stdexec::__mall_of<stdexec::__q<RemainsOnGraphForChild>>,
209 typename OpState::children_opstates_t
214template <Kokkos::ExecutionSpace Exec, stdexec::sender... Sndrs>
221 template <
typename... Env>
224 query(stdexec::get_completion_domain_t<stdexec::set_value_t>,
const Env&...)
const noexcept ->
Domain {
230 template <
typename Self,
typename... Env>
232 return stdexec::completion_signatures<stdexec::set_value_t(), stdexec::set_error_t(std::exception_ptr)>{};
235 template <stdexec::receiver Rcvr>
236 stdexec::operation_state
auto connect(Rcvr rcvr) &&
noexcept(
241 constexpr auto get_env() const noexcept -> attrs {
248struct BECAUSE_THE_EXECUTION_SPACE_TYPE_IS_NOT_HOMOGENEOUS;
250template <
size_t Index,
typename Sndr>
253template <
size_t Index,
typename Sndr>
258 template <
typename Env,
typename... Sndrs>
261 template <
typename Env,
typename... Sndrs>
262 auto operator()(
const Env&, stdexec::when_all_t, stdexec::__ignore, Sndrs&&... sndrs)
const
263 noexcept(std::is_nothrow_constructible_v<
typename trnsfrmd_sndr_t<Env, Sndrs...>::sndrs_t, Sndrs&&...>) {
268 if constexpr ((std::same_as<Impl::exec_of_t<Sndrs, Env>,
execution_space> && ...)) {
269 return trnsfrmd_sndr_t<Env, Sndrs...>{.sndrs = {std::forward<Sndrs>(sndrs)...}};
272 STDEXEC_CONSTEXPR_LOCAL
bool map[] = {
273 !std::same_as<Impl::exec_of_t<stdexec::__m_at_c<0, Sndrs>, Env>,
execution_space>...};
274 STDEXEC_CONSTEXPR_LOCAL std::size_t index = stdexec::__pos_of(map, map +
sizeof...(Sndrs));
275 using invalid_sndr_t = stdexec::__m_at_c<index, Sndrs...>;
276 return stdexec::__not_a_sender<
277 stdexec::_WHAT_(CANNOT_DISPATCH_THIS_ALGORITHM_TO_THE_GRAPH_SCHEDULER),
278 stdexec::_WHY_(BECAUSE_THE_EXECUTION_SPACE_TYPE_IS_NOT_HOMOGENEOUS),
279 stdexec::_WHERE_(stdexec::_IN_ALGORITHM_, stdexec::when_all_t),
281 stdexec::_WITH_PRETTY_SENDERS_<Sndrs...>,
282 stdexec::_WITH_ENVIRONMENT_(Env)
288 STDEXEC_CONSTEXPR_LOCAL std::size_t index = stdexec::__pos_of(map, map +
sizeof...(Sndrs));
289 using invalid_sndr_t = stdexec::__m_at_c<index, Sndrs...>;
299template <
typename... Sndrs>
300extern __mtype<Kokkos::Execution::GraphImpl::WhenAllSender<__demangle_t<Sndrs>...>>
Concept for a sender whose completion scheduler is Kokkos::Execution::GraphImpl::Scheduler.
#define KOKKOS_EXECUTION_GET_ENV(_type_, _obj_)
Retrieve the environment of _obj_. // NOLINTNEXTLINE(cppcoreguidelines-macro-usage).
void graph_add_aggregate_node_event(const NodeType &aggregate, const Predecessors &... predecessors)
Record an event for an aggregate node added after predecessors.
WITH_SENDER_AT_INDEX< Index, stdexec::__demangle_t< Sndr > > WITH_PRETTY_SENDER_AT_INDEX
constexpr get_node_t get_node
auto * get_graph_impl_ptr(const NodeType &node) noexcept
Retrieve the raw graph pointer from a node.
void submit_graph(const Kokkos::Experimental::Graph< Exec > &graph, const Exec &exec)
Submit a graph and record the associated event with graph_submit_event.
auto no_graph_scheduler_in_env() noexcept
Show a better compile diagnostic when there is no Kokkos::Execution::GraphImpl::Scheduler found.
typename ExecOf< Args... >::type exec_of_t
void complete(stdexec::set_error_t, Error &&error) noexcept
completion_signal_t completion_signal
constexpr OpStateBase(Rcvr rcvr) noexcept(std::is_nothrow_constructible_v< completion_signal_t, Rcvr && >)
static constexpr bool value
RemainsOnGraphFor< ChildOpState, Exec > RemainsOnGraphForChild
Receiver for a child of stdexec::when_all.
constexpr auto query(get_node_t) const &noexcept -> const root_t &
Operation state for stdexec::when_all.
const auto & query(get_node_t) const &noexcept
WhenAllOpState(stdexec::__tuple< Sndrs... > &&sndrs_, Rcvr &&rcvr_)
State< GraphComposition::Create, execution_space > state_t
stdexec::operation_state_tag operation_state_concept
static constexpr bool as_one
Determine if all branches remain fully on the graph, if connected to WhenAllChildReceiver.
void start() &noexcept
If as_one is true, there is no need to start the branches.
typename state_t::graph_t::root_t root_t
void complete(Tag, Args &&... args) &noexcept
std::atomic< size_t > count
decltype(stdexec::__apply([](const auto &... ops) { return Kokkos::Experimental::when_all(ops.query(get_node)...);}, std::declval< const children_opstates_t & >())) node_t
const auto & query(get_graph_t) const &noexcept
stdexec::__tuple< stdexec::connect_result_t< Sndrs, WhenAllChildReceiver >... > children_opstates_t
OpStateBase< Exec, Rcvr > base_t
children_opstates_t children_opstates
void submit_graph() &noexcept
constexpr auto query(stdexec::get_completion_domain_t< stdexec::set_value_t >, const Env &...) const noexcept -> Domain
Sender for stdexec::when_all.
stdexec::sender_tag sender_concept
constexpr auto get_env() const noexcept -> attrs
static consteval auto get_completion_signatures()
stdexec::__tuple< Sndrs... > sndrs_t
stdexec::operation_state auto connect(Rcvr rcvr) &&noexcept(std::is_nothrow_constructible_v< WhenAllOpState< Exec, Rcvr, Sndrs... >, sndrs_t &&, Rcvr && >)
Receiver for an object parent_op that implements complete.
WhenAllOpState * parent_op
Kokkos::DefaultExecutionSpace execution_space