OSSIA
Open Scenario System for Interactive Application
Loading...
Searching...
No Matches
graph_parallel_impl.hpp
1#pragma once
2#include <ossia/dataflow/exec_pool.hpp>
3#include <ossia/dataflow/graph_node.hpp>
4#include <ossia/detail/fmt.hpp>
5
6#include <smallfun.hpp>
7
8#include <atomic>
9#include <vector>
10
11namespace ossia
12{
13using task_function = smallfun::function<void(ossia::graph_node&), sizeof(void*) * 4>;
14
15class taskflow;
16class executor;
17
18// A graph node wrapped as a DAG task. Readiness is tracked by
19// m_remaining_dependencies: when it reaches zero the task is submitted to the
20// shared pool (or run inline). Memory ordering is carried entirely by the
21// acq_rel RMW chain on that counter, so no explicit fences are needed: the
22// thread observing the 1->0 transition happens-after every predecessor's
23// decrement, hence sees all predecessor buffer writes.
24class task : public ossia::pool_task
25{
26public:
27 task() = default;
28 task(const task&) = delete;
29 task(task&& other) noexcept
30 : pool_task{other}
31 , m_taskId{other.m_taskId}
32 , m_dependencies{other.m_dependencies}
33 , m_remaining_dependencies{other.m_remaining_dependencies.load()}
34 , m_node{other.m_node}
35 , m_exec{other.m_exec}
36 , m_precedes{std::move(other.m_precedes)}
37 {
38 other.m_precedes.clear();
39 }
40 task& operator=(const task&) = delete;
41 task& operator=(task&& other) noexcept
42 {
43 static_cast<pool_task&>(*this) = other;
44 m_taskId = other.m_taskId;
45 m_dependencies = other.m_dependencies;
46 m_remaining_dependencies = other.m_remaining_dependencies.load();
47 m_node = other.m_node;
48 m_exec = other.m_exec;
49 m_precedes = std::move(other.m_precedes);
50 other.m_precedes.clear();
51 return *this;
52 }
53
54 task(ossia::graph_node& node)
55 : m_node{&node}
56 {
57 }
58
59 void precede(task& other)
60 {
61 m_precedes.push_back(other.m_taskId);
62 other.m_dependencies++;
63 }
64
65private:
66 friend class taskflow;
67 friend class executor;
68
69 int m_taskId{};
70 int m_dependencies{0};
71 std::atomic_int m_remaining_dependencies{};
72
73 ossia::graph_node* m_node{};
74 executor* m_exec{};
75 ossia::small_pod_vector<int, 4> m_precedes;
76};
77
78class taskflow
79{
80public:
81 void clear() { m_tasks.clear(); }
82
83 void reserve(std::size_t sz) { m_tasks.reserve(sz); }
84
85 task* emplace(ossia::graph_node& node)
86 {
87 const int taskId = m_tasks.size();
88 auto& last = m_tasks.emplace_back(node);
89 last.m_taskId = taskId;
90 return &last;
91 }
92
93private:
94 friend class executor;
95
96 std::vector<task> m_tasks;
97};
98
99// Drives a taskflow over the process-wide ossia::task_pool. Holds no threads of
100// its own; rebuilding the graph no longer churns the worker set.
101class executor
102{
103public:
104 executor() = default;
105
106 void set_task_executor(task_function f) { m_func = std::move(f); }
107
108 void run(taskflow& tf)
109 {
110 m_tf = &tf;
111 if(tf.m_tasks.empty())
112 return;
113
114 auto& pool = ossia::task_pool::instance();
115
116 m_batch.remaining.store(
117 static_cast<int>(tf.m_tasks.size()), std::memory_order_relaxed);
118
119 for(auto& task : tf.m_tasks)
120 {
121 task.execute = &executor::run_task;
122 task.m_exec = this;
123 task.m_remaining_dependencies.store(
124 task.m_dependencies, std::memory_order_relaxed);
125 }
126
127 for(auto& task : tf.m_tasks)
128 {
129 if(task.m_dependencies == 0)
130 enqueue(task);
131 }
132
133 pool.corun_until(m_batch);
134 }
135
136private:
137 // pool_task entry point: recover the owning executor from the task itself.
138 static void run_task(ossia::pool_task* pt) noexcept
139 {
140 auto* t = static_cast<task*>(pt);
141 t->m_exec->execute(*t);
142 }
143
144 void enqueue(task& t)
145 {
146 auto& pool = ossia::task_pool::instance();
147 if(t.m_node->not_threadable())
148 pool.submit_owner_only(&t);
149 else
150 pool.submit(&t);
151 }
152
153 void execute(task& t) noexcept
154 {
155 try
156 {
157 m_func(*t.m_node);
158 }
159 catch(const std::exception& e)
160 {
161 fmt::print(
162 stderr, "ossia: error executing node '{}': {}\n", t.m_node->label(),
163 e.what());
164 }
165 catch(...)
166 {
167 fmt::print(
168 stderr, "ossia: error executing node '{}'\n", t.m_node->label());
169 }
170
171 process_done(t);
172 }
173
174 void process_done(task& t) noexcept
175 {
176 auto& pool = ossia::task_pool::instance();
177 for(int taskId : t.m_precedes)
178 {
179 auto& next = m_tf->m_tasks[taskId];
180 // acq_rel: the decrement that observes 1->0 happens-after all sibling
181 // decrements, so `next` sees every predecessor's writes when it runs.
182 if(next.m_remaining_dependencies.fetch_sub(1, std::memory_order_acq_rel)
183 == 1)
184 {
185 enqueue(next);
186 }
187 }
188 pool.finish(m_batch);
189 }
190
191 task_function m_func;
192 taskflow* m_tf{};
193 ossia::task_batch m_batch;
194};
195}
196
197#include <ossia/dataflow/graph/graph_static.hpp>
198#include <ossia/detail/hash_map.hpp>
199namespace ossia
200{
201struct custom_parallel_exec;
202template <typename Impl>
203struct custom_parallel_update
204{
205public:
206 std::shared_ptr<ossia::logger_type> logger;
207 std::shared_ptr<bench_map> perf_map;
208
209 template <typename Graph_T>
210 custom_parallel_update(Graph_T& g, const ossia::graph_setup_options& opt)
211 : impl{g, opt}
212 {
213 }
214
215 void update_graph(
216 ossia::node_map& nodes, const std::vector<graph_node*>& topo_order,
217 ossia::graph_t& graph)
218 {
219 flow_nodes.clear();
220 flow_graph.clear();
221
222 flow_graph.reserve(nodes.size());
223
224 if(logger)
225 {
226 if(perf_map)
227 {
228 executor.set_task_executor(
229 node_exec_logger_bench{cur_state, *perf_map, *logger});
230 for(auto node : topo_order)
231 {
232 (*perf_map)[node] = std::nullopt;
233 flow_nodes[node] = flow_graph.emplace(*node);
234 }
235 }
236 else
237 {
238 executor.set_task_executor(node_exec_logger{cur_state, *logger});
239 for(auto node : topo_order)
240 {
241 flow_nodes[node] = flow_graph.emplace(*node);
242 }
243 }
244 }
245 else
246 {
247 executor.set_task_executor(node_exec{cur_state});
248 for(auto node : topo_order)
249 {
250 flow_nodes[node] = flow_graph.emplace(*node);
251 }
252 }
253
254 for(auto [ei, ei_end] = boost::edges(graph); ei != ei_end; ++ei)
255 {
256 auto edge = *ei;
257 auto& n1 = graph[edge.m_source];
258 auto& n2 = graph[edge.m_target];
259
260 auto& sender = flow_nodes[n2.get()];
261 auto& receiver = flow_nodes[n1.get()];
262 sender->precede(*receiver);
263 }
264 }
265
266 template <typename Graph_T, typename DevicesT>
267 void operator()(Graph_T& g, const DevicesT& devices)
268 {
269 impl(g, devices);
270 update_graph(g.m_nodes, g.m_all_nodes, impl.m_sub_graph);
271 }
272
273private:
274 friend struct custom_parallel_exec;
275
276 Impl impl;
277 execution_state* cur_state{};
278
279 ossia::taskflow flow_graph;
280 ossia::executor executor;
281 ossia::hash_map<graph_node*, ossia::task*> flow_nodes;
282};
283
284struct custom_parallel_exec
285{
286 template <typename Graph_T>
287 custom_parallel_exec(Graph_T&)
288 {
289 }
290
291 template <typename Graph_T, typename Impl>
292 void operator()(
293 Graph_T& g, custom_parallel_update<Impl>& self, ossia::execution_state& e,
294 const std::vector<ossia::graph_node*>&)
295 {
296 self.cur_state = &e;
297 self.executor.run(self.flow_graph);
298 }
299};
300
301using custom_parallel_tc_graph
302 = graph_static<custom_parallel_update<tc_update<fast_tc>>, custom_parallel_exec>;
303}
Definition git_info.h:7
spdlog::logger & logger() noexcept
Where the errors will be logged. Default is stderr.
Definition context.cpp:118