kokkos-execution 0.0.1
Loading...
Searching...
No Matches
operation_state.hpp
Go to the documentation of this file.
1#ifndef KOKKOS_EXECUTION_GRAPH_OPERATION_STATE_HPP
2#define KOKKOS_EXECUTION_GRAPH_OPERATION_STATE_HPP
3
5
6#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
7# include "plog/Log.h"
8#endif
9
10#include "Kokkos_Core.hpp"
11#include "Kokkos_Graph.hpp"
12
23
25
26template <typename Clsr>
27concept Closure = requires {
28 typename Clsr::execution_space;
29 requires std::same_as<typename Clsr::device_handle_t, Kokkos::Impl::DeviceHandle<typename Clsr::execution_space>>;
30 typename Clsr::node_props_t;
31};
32
33template <typename GraphCompositionPolicy, Kokkos::ExecutionSpace Exec>
34struct State;
35
37template <Kokkos::ExecutionSpace Exec>
38struct State<GraphComposition::Attach, Exec> {
40
41 explicit State(const Kokkos::Impl::DeviceHandle<Exec>&) {
42 }
43};
44
46template <Kokkos::ExecutionSpace Exec>
47struct State<GraphComposition::Create, Exec> {
49
50 using graph_t = Kokkos::Experimental::Graph<Exec>;
51
53
54 explicit State(const Kokkos::Impl::DeviceHandle<Exec>& device_handle)
55 : graph(Kokkos::Execution::GraphImpl::create_graph(device_handle)) {
56 }
57
58 const auto& get_device_handle() const {
59 return graph.get_device_handle();
60 }
61};
62
68template <Kokkos::ExecutionSpace Exec, stdexec::receiver Rcvr>
72
74
75 constexpr explicit OpStateBase(Rcvr rcvr) noexcept(std::is_nothrow_constructible_v<completion_signal_t, Rcvr&&>)
76 : completion_signal(std::move(rcvr)) {
77 }
78
79 void complete(stdexec::set_value_t) noexcept {
80 completion_signal.propagate(stdexec::set_value);
81 }
82
83 template <typename Error>
84 void complete(stdexec::set_error_t, Error&& error) noexcept {
85 completion_signal.propagate(stdexec::set_error, std::forward<Error>(error));
86 }
87
88 void complete(stdexec::set_stopped_t) noexcept {
89 completion_signal.propagate(stdexec::set_stopped);
90 }
91};
92
94template <typename Predecessor, Closure FirstClosure, Closure... RestOfClosures>
95requires NodeRef<std::remove_cvref_t<Predecessor>>
96static auto add_nodes(Predecessor&& predecessor, FirstClosure&& clsr, RestOfClosures&&... clsrs) {
97 auto node = std::forward<FirstClosure>(clsr).add_node(std::forward<Predecessor>(predecessor));
98 if constexpr (sizeof...(RestOfClosures) == 0) {
99 return node;
100 } else {
101 return add_nodes(std::move(node), std::forward<RestOfClosures>(clsrs)...);
102 }
103}
104
106template <stdexec::sender Sndr, stdexec::receiver Rcvr, Closure FirstClosure, Closure... RestOfClosures>
108 : public Impl::Immovable
109 , public OpStateBase<typename FirstClosure::execution_space, Rcvr> {
110 using operation_state_concept = stdexec::operation_state_tag;
111
112 using execution_space = typename FirstClosure::execution_space;
113 using device_handle_t = typename FirstClosure::device_handle_t;
114
116 static_assert((std::same_as<typename RestOfClosures::execution_space, execution_space> && ...));
117
119 using inner_opstate_t = stdexec::connect_result_t<Sndr, rcvr_t>;
123
124 static constexpr bool after_root = std::same_as<graph_composition_policy_t, GraphComposition::Create>;
125
126 using node_t = decltype(add_nodes(
127 std::declval<predecessor_t>(),
128 std::declval<FirstClosure>(),
129 std::declval<RestOfClosures>()...));
130
134
136 constexpr OpState(
137 Sndr&& sndr, // NOLINT(cppcoreguidelines-rvalue-reference-param-not-moved)
138 Rcvr rcvr,
139 FirstClosure clsr,
140 RestOfClosures... clsrs) noexcept(false)
141 : OpStateBase<execution_space, Rcvr>(std::move(rcvr))
142 , inner_opstate(stdexec::connect(std::forward<Sndr>(sndr), rcvr_t{this}))
143 , state{Kokkos::Impl::get_property<device_handle_t>(clsr.node_props)}
144 , node{add_nodes(this->get_predecessor(), std::move(clsr), std::move(clsrs)...)} {
145#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
146 PLOG_INFO << "Operation state graph composition policy is "
147 << Kokkos::Impl::TypeInfo<graph_composition_policy_t>::name()
148 << " and the inner operation state is of type " << Kokkos::Impl::TypeInfo<inner_opstate_t>::name()
149 << '.';
150#endif
151 }
152
158 if constexpr (after_root) {
159#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
160 PLOG_INFO << "The predecessor is the root node of graph " << get_graph_impl_ptr(state.graph.root_node())
161 << '.';
162#endif
163 return state.graph.root_node();
164 } else {
165#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
166 PLOG_INFO << "The predecessor is the node " << get_node_ptr(inner_opstate.query(get_node)) << " of graph "
167 << get_graph_impl_ptr(inner_opstate.query(get_node)) << '.';
168#endif
169 return inner_opstate.query(get_node);
170 }
171 }
172
173 void complete(stdexec::set_value_t) noexcept {
174 if constexpr (after_root) {
175#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
176 PLOG_INFO << "Submitting graph " << get_graph_impl_ptr(state.graph.root_node()) << " on "
177 << Kokkos::Tools::Experimental::device_id(state.get_device_handle().m_exec) << '.';
178#endif
179 submit_graph(state.graph, state.get_device_handle().m_exec);
180
182 state.get_device_handle().m_exec.fence(std::string(Impl::dispatch_label<execution_space, ": sync_wait">()));
183 }
185 }
186
187 const auto& query(get_node_t) const & noexcept {
188 return node;
189 }
190
191 const auto& query(get_graph_t) const & noexcept {
192 if constexpr (after_root) {
193 return state.graph;
194 } else {
195 return inner_opstate.query(get_graph);
196 }
197 }
198
199 void start() & noexcept {
200 stdexec::start(inner_opstate);
201 }
202
203 [[nodiscard]]
204 constexpr auto get_env() const noexcept -> stdexec::env_of_t<Rcvr> {
205 return stdexec::get_env(this->completion_signal.rcvr);
206 }
207};
208
209template <typename Sndr, typename Rcvr, typename... Clsrs>
211
212template <typename Sndr, typename Rcvr, typename... Clsrs>
213using opstate_t = typename make_opstate_t<Sndr, Rcvr, Clsrs...>::type;
214
215#define KOKKOS_EXECUTION_GRAPH_OPERATION_STATE_CONNECT \
216 template <stdexec::receiver Rcvr> \
217 constexpr auto connect(Rcvr rcvr) && noexcept(noexcept(make_opstate_t<Sndr, Rcvr, closure_t>{}( \
218 std::declval<Sndr>(), std::declval<Rcvr>(), std::declval<closure_t>()))) -> opstate_t<Sndr, Rcvr, closure_t> { \
219 return make_opstate_t<Sndr, Rcvr, closure_t>{}(std::forward<Sndr>(sndr), std::move(rcvr), std::move(clsr)); \
220 }
221
222#if defined(KOKKOS_EXECUTION_ENABLE_DEBUG_LOGGING)
223# define KOKKOS_EXECUTION_IMPL_GRAPH_ADD_NODE_DEBUG_LOGGING(_type_, _node_, _predecessor_) \
224 PLOG_INFO << "Adding '" _type_ "' node " << get_node_ptr(_node_) << " to graph " << get_graph_impl_ptr(_node_) \
225 << " after " << get_node_ptr(_predecessor_) << " on device " \
226 << Kokkos::Tools::Experimental::device_id(get_node_ptr(_node_)->get_device_handle().m_exec) << '.';
227#else
228# define KOKKOS_EXECUTION_IMPL_GRAPH_ADD_NODE_DEBUG_LOGGING(_type_, _node_, _predecessor_)
229#endif
230} // namespace Kokkos::Execution::GraphImpl
231
232#endif // KOKKOS_EXECUTION_GRAPH_OPERATION_STATE_HPP
typename make_opstate_t< Sndr, Rcvr, Clsrs... >::type opstate_t
auto * get_node_ptr(const NodeType &node) noexcept
Retrieve the raw node pointer.
Definition events.hpp:82
constexpr get_node_t get_node
Definition get_node.hpp:15
auto create_graph(const Kokkos::Impl::DeviceHandle< Exec > &device_handle, Args &&... args)
Create a graph and record the associated event with graph_create_event.
Definition events.hpp:100
Impl::MakeOpState< Domain, OpState >::Huddle< Sndr, Rcvr, Clsrs... > make_opstate_t
constexpr get_graph_t get_graph
Definition get_graph.hpp:17
auto * get_graph_impl_ptr(const NodeType &node) noexcept
Retrieve the raw graph pointer from a node.
Definition events.hpp:76
static auto add_nodes(Predecessor &&predecessor, FirstClosure &&clsr, RestOfClosures &&... clsrs)
Add all nodes as a sequence. Hence, only the first node may be added after the root node.
void submit_graph(const Kokkos::Experimental::Graph< Exec > &graph, const Exec &exec)
Submit a graph and record the associated event with graph_submit_event.
Definition events.hpp:168
consteval std::string_view dispatch_label() noexcept
View the dispatch label as a std::string_view.
Attach to the existing graph of the predecessor.
Definition get_graph.hpp:22
Create a new graph and attach after the root node.
Definition get_graph.hpp:25
Inspired by https://github.com/NVIDIA/stdexec/blob/8c5eedd0fcf9a8ebcdb75d988f72f88efcf64a37/include/s...
Definition get_graph.hpp:20
typename node_helper_t< Exec, Queryable >::type node_t
Definition get_graph.hpp:51
std::conditional_t< stdexec::__queryable_with< Queryable, get_node_t >, Attach, Create > policy_t
Use the Attach policy if Queryable is queryable with get_node_t.
Definition get_graph.hpp:29
void complete(stdexec::set_error_t, Error &&error) noexcept
Impl::CompletionSignal< sync_policy_t, Exec, Rcvr > completion_signal_t
void complete(stdexec::set_stopped_t) noexcept
Impl::SyncPolicy::PassThrough sync_policy_t
void complete(stdexec::set_value_t) noexcept
constexpr OpStateBase(Rcvr rcvr) noexcept(std::is_nothrow_constructible_v< completion_signal_t, Rcvr && >)
stdexec::connect_result_t< Sndr, rcvr_t > inner_opstate_t
const auto & query(get_node_t) const &noexcept
stdexec::operation_state_tag operation_state_concept
constexpr OpState(Sndr &&sndr, Rcvr rcvr, FirstClosure clsr, RestOfClosures... clsrs) noexcept(false)
State< graph_composition_policy_t, execution_space > state_t
typename FirstClosure::device_handle_t device_handle_t
GraphComposition::policy_t< inner_opstate_t > graph_composition_policy_t
GraphComposition::node_t< execution_space, inner_opstate_t > predecessor_t
constexpr auto get_env() const noexcept -> stdexec::env_of_t< Rcvr >
Impl::Receiver< OpState, stdexec::env_of_t< Rcvr > > rcvr_t
Ensure that all closures are on the same execution space type.
typename FirstClosure::execution_space execution_space
const auto & query(get_graph_t) const &noexcept
decltype(add_nodes( std::declval< predecessor_t >(), std::declval< FirstClosure >(), std::declval< RestOfClosures >()...)) node_t
void complete(stdexec::set_value_t) noexcept
State(const Kokkos::Impl::DeviceHandle< Exec > &device_handle)
Receiver for an object parent_op that implements complete.
Definition receiver.hpp:12