@@ -183,18 +183,42 @@ namespace stdexec {
183183
184184 struct _INVALID_ARGUMENTS_TO_WHEN_ALL_ { };
185185
186- template <class _ErrorsVariant , class _ValuesTuple , class _StopToken , bool _SendsStopped>
186+ template <class _State , class _Receiver >
187+ struct __forward_stop_request {
188+ void operator ()() const noexcept {
189+ // Temporarily increment the count to avoid concurrent/recursive arrivals to
190+ // pull the rug under our feet. Relaxed memory order is fine here.
191+ __state_->__count_ .fetch_add (1 , __std::memory_order_relaxed);
192+
193+ __state_t __expected = __started;
194+ // Transition to the "stopped" state if and only if we're in the
195+ // "started" state. (If this fails, it's because we're in an
196+ // error state, which trumps cancellation.)
197+ if (__state_->__state_ .compare_exchange_strong (__expected, __stopped)) {
198+ __state_->__stop_source_ .request_stop ();
199+ }
200+
201+ // Arrive in order to decrement the count again and complete if needed.
202+ __state_->__arrive (*__rcvr_);
203+ }
204+
205+ _State* __state_;
206+ _Receiver* __rcvr_;
207+ };
208+
209+ template <class _ErrorsVariant , class _ValuesTuple , class _Receiver , bool _SendsStopped>
187210 struct __when_all_state {
188- using __stop_callback_t = stop_callback_for_t <_StopToken, __forward_stop_request>;
211+ using __stop_callback_t = stop_callback_for_t <
212+ stop_token_of_t <env_of_t <_Receiver>>,
213+ __forward_stop_request<__when_all_state, _Receiver>
214+ >;
189215
190- template <class _Receiver >
191216 void __arrive (_Receiver& __rcvr) noexcept {
192- if (1 == __count_.fetch_sub (1 )) {
217+ if (1 == __count_.fetch_sub (1 , __std::memory_order_acq_rel )) {
193218 __complete (__rcvr);
194219 }
195220 }
196221
197- template <class _Receiver >
198222 void __complete (_Receiver& __rcvr) noexcept {
199223 // Stop callback is no longer needed. Destroy it.
200224 __on_stop_.reset ();
@@ -283,24 +307,25 @@ namespace stdexec {
283307 }
284308 };
285309
286- template <class _Env >
287- static auto __mk_state_fn (const _Env&) noexcept {
288- return []<__max1_sender<__env_t <_Env>>... _Child>(__ignore, __ignore, _Child&&...) {
289- using _Traits = __traits<_Env, _Child...>;
310+ template <class _Receiver >
311+ static auto __mk_state_fn (const _Receiver&) noexcept {
312+ using __env_of_t = env_of_t <_Receiver>;
313+ return []<__max1_sender<__env_t <__env_of_t >>... _Child>(__ignore, __ignore, _Child&&...) {
314+ using _Traits = __traits<__env_of_t , _Child...>;
290315 using _ErrorsVariant = _Traits::__errors_variant;
291316 using _ValuesTuple = _Traits::__values_tuple;
292317 using _State = __when_all_state<
293318 _ErrorsVariant,
294319 _ValuesTuple,
295- stop_token_of_t <_Env> ,
296- (sends_stopped<_Child, _Env > || ...)
320+ _Receiver ,
321+ (sends_stopped<_Child, __env_of_t > || ...)
297322 >;
298323 return _State{sizeof ...(_Child)};
299324 };
300325 }
301326
302- template <class _Env >
303- using __mk_state_fn_t = decltype (__when_all::__mk_state_fn(__declval<_Env >()));
327+ template <class _Receiver >
328+ using __mk_state_fn_t = decltype (__when_all::__mk_state_fn(__declval<_Receiver >()));
304329
305330 struct when_all_t {
306331 template <sender... _Senders>
@@ -340,9 +365,9 @@ namespace stdexec {
340365
341366 static constexpr auto get_state =
342367 []<class _Self , class _Receiver >(_Self&& __self, _Receiver& __rcvr)
343- -> __sexpr_apply_result_t <_Self, __mk_state_fn_t <env_of_t < _Receiver> >> {
368+ -> __sexpr_apply_result_t <_Self, __mk_state_fn_t <_Receiver>> {
344369 return __sexpr_apply (
345- static_cast <_Self&&>(__self), __when_all::__mk_state_fn (stdexec::get_env ( __rcvr) ));
370+ static_cast <_Self&&>(__self), __when_all::__mk_state_fn (__rcvr));
346371 };
347372
348373 static constexpr auto start = []<class _State , class _Receiver , class ... _Operations>(
@@ -351,7 +376,7 @@ namespace stdexec {
351376 _Operations&... __child_ops) noexcept -> void {
352377 // register stop callback:
353378 __state.__on_stop_ .emplace (
354- get_stop_token (stdexec::get_env (__rcvr)), __forward_stop_request{ __state. __stop_source_ });
379+ get_stop_token (stdexec::get_env (__rcvr)), __forward_stop_request<_State, _Receiver>{& __state, &__rcvr });
355380 (stdexec::start (__child_ops), ...);
356381 if constexpr (sizeof ...(__child_ops) == 0 ) {
357382 __state.__complete (__rcvr);
0 commit comments