libs/capy/include/boost/capy/when_all.hpp

96.9% Lines (95/98) 91.1% Functions (431/473) 100.0% Branches (22/22)
libs/capy/include/boost/capy/when_all.hpp
Line Branch Hits Source Code
1 //
2 // Copyright (c) 2026 Steve Gerbino
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 // Official repository: https://github.com/cppalliance/capy
8 //
9
10 #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 #define BOOST_CAPY_WHEN_ALL_HPP
12
13 #include <boost/capy/detail/config.hpp>
14 #include <boost/capy/concept/executor.hpp>
15 #include <boost/capy/concept/io_awaitable.hpp>
16 #include <coroutine>
17 #include <boost/capy/ex/io_env.hpp>
18 #include <boost/capy/ex/frame_allocator.hpp>
19 #include <boost/capy/task.hpp>
20
21 #include <array>
22 #include <atomic>
23 #include <exception>
24 #include <optional>
25 #include <stop_token>
26 #include <tuple>
27 #include <type_traits>
28 #include <utility>
29
30 namespace boost {
31 namespace capy {
32
33 namespace detail {
34
35 /** Type trait to filter void types from a tuple.
36
37 Void-returning tasks do not contribute a value to the result tuple.
38 This trait computes the filtered result type.
39
40 Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 */
42 template<typename T>
43 using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44
45 template<typename... Ts>
46 using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47
48 /** Holds the result of a single task within when_all.
49 */
50 template<typename T>
51 struct result_holder
52 {
53 std::optional<T> value_;
54
55 59 void set(T v)
56 {
57 59 value_ = std::move(v);
58 59 }
59
60 52 T get() &&
61 {
62 52 return std::move(*value_);
63 }
64 };
65
66 /** Specialization for void tasks - no value storage needed.
67 */
68 template<>
69 struct result_holder<void>
70 {
71 };
72
73 /** Shared state for when_all operation.
74
75 @tparam Ts The result types of the tasks.
76 */
77 template<typename... Ts>
78 struct when_all_state
79 {
80 static constexpr std::size_t task_count = sizeof...(Ts);
81
82 // Completion tracking - when_all waits for all children
83 std::atomic<std::size_t> remaining_count_;
84
85 // Result storage in input order
86 std::tuple<result_holder<Ts>...> results_;
87
88 // Runner handles - destroyed in await_resume while allocator is valid
89 std::array<std::coroutine_handle<>, task_count> runner_handles_{};
90
91 // Exception storage - first error wins, others discarded
92 std::atomic<bool> has_exception_{false};
93 std::exception_ptr first_exception_;
94
95 // Stop propagation - on error, request stop for siblings
96 std::stop_source stop_source_;
97
98 // Connects parent's stop_token to our stop_source
99 struct stop_callback_fn
100 {
101 std::stop_source* source_;
102 4 void operator()() const { source_->request_stop(); }
103 };
104 using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 std::optional<stop_callback_t> parent_stop_callback_;
106
107 // Parent resumption
108 std::coroutine_handle<> continuation_;
109 io_env const* caller_env_ = nullptr;
110
111 33 when_all_state()
112
1/1
✓ Branch 5 taken 33 times.
33 : remaining_count_(task_count)
113 {
114 33 }
115
116 // Runners self-destruct in final_suspend. No destruction needed here.
117
118 /** Capture an exception (first one wins).
119 */
120 11 void capture_exception(std::exception_ptr ep)
121 {
122 11 bool expected = false;
123
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 3 times.
11 if(has_exception_.compare_exchange_strong(
124 expected, true, std::memory_order_relaxed))
125 8 first_exception_ = ep;
126 11 }
127
128 };
129
130 /** Wrapper coroutine that intercepts task completion.
131
132 This runner awaits its assigned task and stores the result in
133 the shared state, or captures the exception and requests stop.
134 */
135 template<typename T, typename... Ts>
136 struct when_all_runner
137 {
138 struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
139 {
140 when_all_state<Ts...>* state_ = nullptr;
141 io_env env_;
142
143 78 when_all_runner get_return_object()
144 {
145 78 return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
146 }
147
148 78 std::suspend_always initial_suspend() noexcept
149 {
150 78 return {};
151 }
152
153 78 auto final_suspend() noexcept
154 {
155 struct awaiter
156 {
157 promise_type* p_;
158
159 8 bool await_ready() const noexcept
160 {
161 8 return false;
162 }
163
164 8 std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept
165 {
166 // Extract everything needed before self-destruction.
167 8 auto* state = p_->state_;
168 8 auto* counter = &state->remaining_count_;
169 8 auto* caller_env = state->caller_env_;
170 8 auto cont = state->continuation_;
171
172 8 h.destroy();
173
174 // If last runner, dispatch parent for symmetric transfer.
175 8 auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
176
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
8 if(remaining == 1)
177 4 return caller_env->executor.dispatch(cont);
178 4 return std::noop_coroutine();
179 }
180
181 void await_resume() const noexcept
182 {
183 }
184 };
185 78 return awaiter{this};
186 }
187
188 67 void return_void()
189 {
190 67 }
191
192 11 void unhandled_exception()
193 {
194 11 state_->capture_exception(std::current_exception());
195 // Request stop for sibling tasks
196 11 state_->stop_source_.request_stop();
197 11 }
198
199 template<class Awaitable>
200 struct transform_awaiter
201 {
202 std::decay_t<Awaitable> a_;
203 promise_type* p_;
204
205 78 bool await_ready()
206 {
207 78 return a_.await_ready();
208 }
209
210 78 decltype(auto) await_resume()
211 {
212 78 return a_.await_resume();
213 }
214
215 template<class Promise>
216 77 auto await_suspend(std::coroutine_handle<Promise> h)
217 {
218 77 return a_.await_suspend(h, &p_->env_);
219 }
220 };
221
222 template<class Awaitable>
223 78 auto await_transform(Awaitable&& a)
224 {
225 using A = std::decay_t<Awaitable>;
226 if constexpr (IoAwaitable<A>)
227 {
228 return transform_awaiter<Awaitable>{
229 156 std::forward<Awaitable>(a), this};
230 }
231 else
232 {
233 static_assert(sizeof(A) == 0, "requires IoAwaitable");
234 }
235 78 }
236 };
237
238 std::coroutine_handle<promise_type> h_;
239
240 78 explicit when_all_runner(std::coroutine_handle<promise_type> h)
241 78 : h_(h)
242 {
243 78 }
244
245 // Enable move for all clang versions - some versions need it
246 when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
247
248 // Non-copyable
249 when_all_runner(when_all_runner const&) = delete;
250 when_all_runner& operator=(when_all_runner const&) = delete;
251 when_all_runner& operator=(when_all_runner&&) = delete;
252
253 78 auto release() noexcept
254 {
255 78 return std::exchange(h_, nullptr);
256 }
257 };
258
259 /** Create a runner coroutine for a single awaitable.
260
261 Awaitable is passed directly to ensure proper coroutine frame storage.
262 */
263 template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
264 when_all_runner<awaitable_result_t<Awaitable>, Ts...>
265
1/1
✓ Branch 1 taken 78 times.
78 make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
266 {
267 using T = awaitable_result_t<Awaitable>;
268 if constexpr (std::is_void_v<T>)
269 {
270 co_await std::move(inner);
271 }
272 else
273 {
274 std::get<Index>(state->results_).set(co_await std::move(inner));
275 }
276 156 }
277
278 /** Internal awaitable that launches all runner coroutines and waits.
279
280 This awaitable is used inside the when_all coroutine to handle
281 the concurrent execution of child awaitables.
282 */
283 template<IoAwaitable... Awaitables>
284 class when_all_launcher
285 {
286 using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
287
288 std::tuple<Awaitables...>* awaitables_;
289 state_type* state_;
290
291 public:
292 33 when_all_launcher(
293 std::tuple<Awaitables...>* awaitables,
294 state_type* state)
295 33 : awaitables_(awaitables)
296 33 , state_(state)
297 {
298 33 }
299
300 33 bool await_ready() const noexcept
301 {
302 33 return sizeof...(Awaitables) == 0;
303 }
304
305 33 std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation, io_env const* caller_env)
306 {
307 33 state_->continuation_ = continuation;
308 33 state_->caller_env_ = caller_env;
309
310 // Forward parent's stop requests to children
311
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 25 times.
33 if(caller_env->stop_token.stop_possible())
312 {
313 16 state_->parent_stop_callback_.emplace(
314 8 caller_env->stop_token,
315 8 typename state_type::stop_callback_fn{&state_->stop_source_});
316
317
2/2
✓ Branch 1 taken 4 times.
✓ Branch 2 taken 4 times.
8 if(caller_env->stop_token.stop_requested())
318 4 state_->stop_source_.request_stop();
319 }
320
321 // CRITICAL: If the last task finishes synchronously then the parent
322 // coroutine resumes, destroying its frame, and destroying this object
323 // prior to the completion of await_suspend. Therefore, await_suspend
324 // must ensure `this` cannot be referenced after calling `launch_one`
325 // for the last time.
326 33 auto token = state_->stop_source_.get_token();
327 [&]<std::size_t... Is>(std::index_sequence<Is...>) {
328
2/2
✓ Branch 2 taken 4 times.
✓ Branch 6 taken 4 times.
4 (..., launch_one<Is>(caller_env->executor, token));
329
2/2
✓ Branch 1 taken 29 times.
✓ Branch 1 taken 4 times.
33 }(std::index_sequence_for<Awaitables...>{});
330
331 // Let signal_completion() handle resumption
332 66 return std::noop_coroutine();
333 33 }
334
335 33 void await_resume() const noexcept
336 {
337 // Results are extracted by the when_all coroutine from state
338 33 }
339
340 private:
341 template<std::size_t I>
342 78 void launch_one(executor_ref caller_ex, std::stop_token token)
343 {
344
1/1
✓ Branch 2 taken 78 times.
78 auto runner = make_when_all_runner<I>(
345 78 std::move(std::get<I>(*awaitables_)), state_);
346
347 78 auto h = runner.release();
348 78 h.promise().state_ = state_;
349 78 h.promise().env_ = io_env{caller_ex, token, state_->caller_env_->allocator};
350
351 78 std::coroutine_handle<> ch{h};
352 78 state_->runner_handles_[I] = ch;
353
1/1
✓ Branch 1 taken 78 times.
78 state_->caller_env_->executor.post(ch);
354 156 }
355 };
356
357 /** Compute the result type for when_all.
358
359 Returns void when all tasks are void (P2300 aligned),
360 otherwise returns a tuple with void types filtered out.
361 */
362 template<typename... Ts>
363 using when_all_result_t = std::conditional_t<
364 std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
365 void,
366 filter_void_tuple_t<Ts...>>;
367
368 /** Helper to extract a single result, returning empty tuple for void.
369 This is a separate function to work around a GCC-11 ICE that occurs
370 when using nested immediately-invoked lambdas with pack expansion.
371 */
372 template<std::size_t I, typename... Ts>
373 55 auto extract_single_result(when_all_state<Ts...>& state)
374 {
375 using T = std::tuple_element_t<I, std::tuple<Ts...>>;
376 if constexpr (std::is_void_v<T>)
377 3 return std::tuple<>();
378 else
379
1/1
✓ Branch 4 taken 52 times.
52 return std::make_tuple(std::move(std::get<I>(state.results_)).get());
380 }
381
382 /** Extract results from state, filtering void types.
383 */
384 template<typename... Ts>
385 23 auto extract_results(when_all_state<Ts...>& state)
386 {
387 23 return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
388
3/3
✓ Branch 1 taken 4 times.
✓ Branch 4 taken 4 times.
✓ Branch 7 taken 4 times.
4 return std::tuple_cat(extract_single_result<Is>(state)...);
389
1/1
✓ Branch 1 taken 23 times.
46 }(std::index_sequence_for<Ts...>{});
390 }
391
392 } // namespace detail
393
394 /** Execute multiple awaitables concurrently and collect their results.
395
396 Launches all awaitables simultaneously and waits for all to complete
397 before returning. Results are collected in input order. If any
398 awaitable throws, cancellation is requested for siblings and the first
399 exception is rethrown after all awaitables complete.
400
401 @li All child awaitables run concurrently on the caller's executor
402 @li Results are returned as a tuple in input order
403 @li Void-returning awaitables do not contribute to the result tuple
404 @li If all awaitables return void, `when_all` returns `task<void>`
405 @li First exception wins; subsequent exceptions are discarded
406 @li Stop is requested for siblings on first error
407 @li Completes only after all children have finished
408
409 @par Thread Safety
410 The returned task must be awaited from a single execution context.
411 Child awaitables execute concurrently but complete through the caller's
412 executor.
413
414 @param awaitables The awaitables to execute concurrently. Each must
415 satisfy @ref IoAwaitable and is consumed (moved-from) when
416 `when_all` is awaited.
417
418 @return A task yielding a tuple of non-void results. Returns
419 `task<void>` when all input awaitables return void.
420
421 @par Example
422
423 @code
424 task<> example()
425 {
426 // Concurrent fetch, results collected in order
427 auto [user, posts] = co_await when_all(
428 fetch_user( id ), // task<User>
429 fetch_posts( id ) // task<std::vector<Post>>
430 );
431
432 // Void awaitables don't contribute to result
433 co_await when_all(
434 log_event( "start" ), // task<void>
435 notify_user( id ) // task<void>
436 );
437 // Returns task<void>, no result tuple
438 }
439 @endcode
440
441 @see IoAwaitable, task
442 */
443 template<IoAwaitable... As>
444
1/1
✓ Branch 1 taken 33 times.
33 [[nodiscard]] auto when_all(As... awaitables)
445 -> task<detail::when_all_result_t<detail::awaitable_result_t<As>...>>
446 {
447 using result_type = detail::when_all_result_t<detail::awaitable_result_t<As>...>;
448
449 // State is stored in the coroutine frame, using the frame allocator
450 detail::when_all_state<detail::awaitable_result_t<As>...> state;
451
452 // Store awaitables in the frame
453 std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
454
455 // Launch all awaitables and wait for completion
456 co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
457
458 // Propagate first exception if any.
459 // Safe without explicit acquire: capture_exception() is sequenced-before
460 // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
461 // last task's decrement that resumes this coroutine.
462 if(state.first_exception_)
463 std::rethrow_exception(state.first_exception_);
464
465 // Extract and return results
466 if constexpr (std::is_void_v<result_type>)
467 co_return;
468 else
469 co_return detail::extract_results(state);
470 66 }
471
472 /// Compute the result type of `when_all` for the given task types.
473 template<typename... Ts>
474 using when_all_result_type = detail::when_all_result_t<Ts...>;
475
476 } // namespace capy
477 } // namespace boost
478
479 #endif
480