Skip to content

Commit 36a92fd

Browse files
Fix: when_all cancelled from environment (#1718)
* Fix: when_all cancelled from environment In the case when when_all is being stopped from a stop source in the environment, the stop_source was triggered multiple times leading to a use after free bug. This patch fixes this by performing the following steps: - Inside the stop callback, increment the reference count. - Try setting the state to stopped, only if stopped, trigger the internal stop_source which cancels the children. - Arrive to decrement the count again and handle any concurrent completions if needed. * use atomics from `stdexec::__std::` instead of `std::` --------- Co-authored-by: Eric Niebler <[email protected]>
1 parent f146a8f commit 36a92fd

File tree

2 files changed

+54
-16
lines changed

2 files changed

+54
-16
lines changed

include/stdexec/__detail/__when_all.hpp

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -183,18 +183,42 @@ namespace stdexec {
183183

184184
struct _INVALID_ARGUMENTS_TO_WHEN_ALL_ { };
185185

186-
template <class _ErrorsVariant, class _ValuesTuple, class _StopToken, bool _SendsStopped>
186+
template <class _State, class _Receiver>
187+
struct __forward_stop_request {
188+
void operator()() const noexcept {
189+
// Temporarily increment the count to avoid concurrent/recursive arrivals to
190+
// pull the rug under our feet. Relaxed memory order is fine here.
191+
__state_->__count_.fetch_add(1, __std::memory_order_relaxed);
192+
193+
__state_t __expected = __started;
194+
// Transition to the "stopped" state if and only if we're in the
195+
// "started" state. (If this fails, it's because we're in an
196+
// error state, which trumps cancellation.)
197+
if (__state_->__state_.compare_exchange_strong(__expected, __stopped)) {
198+
__state_->__stop_source_.request_stop();
199+
}
200+
201+
// Arrive in order to decrement the count again and complete if needed.
202+
__state_->__arrive(*__rcvr_);
203+
}
204+
205+
_State* __state_;
206+
_Receiver* __rcvr_;
207+
};
208+
209+
template <class _ErrorsVariant, class _ValuesTuple, class _Receiver, bool _SendsStopped>
187210
struct __when_all_state {
188-
using __stop_callback_t = stop_callback_for_t<_StopToken, __forward_stop_request>;
211+
using __stop_callback_t = stop_callback_for_t<
212+
stop_token_of_t<env_of_t<_Receiver>>,
213+
__forward_stop_request<__when_all_state, _Receiver>
214+
>;
189215

190-
template <class _Receiver>
191216
void __arrive(_Receiver& __rcvr) noexcept {
192-
if (1 == __count_.fetch_sub(1)) {
217+
if (1 == __count_.fetch_sub(1, __std::memory_order_acq_rel)) {
193218
__complete(__rcvr);
194219
}
195220
}
196221

197-
template <class _Receiver>
198222
void __complete(_Receiver& __rcvr) noexcept {
199223
// Stop callback is no longer needed. Destroy it.
200224
__on_stop_.reset();
@@ -283,24 +307,25 @@ namespace stdexec {
283307
}
284308
};
285309

286-
template <class _Env>
287-
static auto __mk_state_fn(const _Env&) noexcept {
288-
return []<__max1_sender<__env_t<_Env>>... _Child>(__ignore, __ignore, _Child&&...) {
289-
using _Traits = __traits<_Env, _Child...>;
310+
template <class _Receiver>
311+
static auto __mk_state_fn(const _Receiver&) noexcept {
312+
using __env_of_t = env_of_t<_Receiver>;
313+
return []<__max1_sender<__env_t<__env_of_t>>... _Child>(__ignore, __ignore, _Child&&...) {
314+
using _Traits = __traits<__env_of_t, _Child...>;
290315
using _ErrorsVariant = _Traits::__errors_variant;
291316
using _ValuesTuple = _Traits::__values_tuple;
292317
using _State = __when_all_state<
293318
_ErrorsVariant,
294319
_ValuesTuple,
295-
stop_token_of_t<_Env>,
296-
(sends_stopped<_Child, _Env> || ...)
320+
_Receiver,
321+
(sends_stopped<_Child, __env_of_t> || ...)
297322
>;
298323
return _State{sizeof...(_Child)};
299324
};
300325
}
301326

302-
template <class _Env>
303-
using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Env>()));
327+
template <class _Receiver>
328+
using __mk_state_fn_t = decltype(__when_all::__mk_state_fn(__declval<_Receiver>()));
304329

305330
struct when_all_t {
306331
template <sender... _Senders>
@@ -340,9 +365,9 @@ namespace stdexec {
340365

341366
static constexpr auto get_state =
342367
[]<class _Self, class _Receiver>(_Self&& __self, _Receiver& __rcvr)
343-
-> __sexpr_apply_result_t<_Self, __mk_state_fn_t<env_of_t<_Receiver>>> {
368+
-> __sexpr_apply_result_t<_Self, __mk_state_fn_t<_Receiver>> {
344369
return __sexpr_apply(
345-
static_cast<_Self&&>(__self), __when_all::__mk_state_fn(stdexec::get_env(__rcvr)));
370+
static_cast<_Self&&>(__self), __when_all::__mk_state_fn(__rcvr));
346371
};
347372

348373
static constexpr auto start = []<class _State, class _Receiver, class... _Operations>(
@@ -351,7 +376,7 @@ namespace stdexec {
351376
_Operations&... __child_ops) noexcept -> void {
352377
// register stop callback:
353378
__state.__on_stop_.emplace(
354-
get_stop_token(stdexec::get_env(__rcvr)), __forward_stop_request{__state.__stop_source_});
379+
get_stop_token(stdexec::get_env(__rcvr)), __forward_stop_request<_State, _Receiver>{&__state, &__rcvr});
355380
(stdexec::start(__child_ops), ...);
356381
if constexpr (sizeof...(__child_ops) == 0) {
357382
__state.__complete(__rcvr);

test/stdexec/algos/adaptors/test_when_all.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
#include <catch2/catch.hpp>
1818
#include <stdexec/execution.hpp>
1919
#include <exec/env.hpp>
20+
#include <exec/async_scope.hpp>
2021
#include <test_common/schedulers.hpp>
2122
#include <test_common/receivers.hpp>
23+
#include <test_common/senders.hpp>
2224
#include <test_common/type_helpers.hpp>
2325

2426
namespace ex = stdexec;
@@ -367,4 +369,15 @@ namespace {
367369
auto op = ex::connect(snd, expect_void_receiver{});
368370
ex::start(op);
369371
}
372+
373+
374+
375+
TEST_CASE("when_all handles stop requests from the environment correctly", "[adaptors][when_all") {
376+
auto snd = ex::when_all(completes_if(false), completes_if(false));
377+
378+
exec::async_scope scope;
379+
scope.spawn(snd);
380+
scope.request_stop();
381+
ex::sync_wait(scope.on_empty());
382+
}
370383
} // namespace

0 commit comments

Comments
 (0)