Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions include/nvexec/stream/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -532,28 +532,21 @@ namespace nv::execution
{}

template <class... Args>
STDEXEC_ATTRIBUTE(host, device)
STDEXEC_ATTRIBUTE(device)
void set_value(Args&&... args) noexcept
{
using tuple_t = decayed_tuple_t<set_value_t, Args...>;
variant_->template emplace<tuple_t>(set_value_t(), static_cast<Args&&>(args)...);
producer_(task_);
}

STDEXEC_ATTRIBUTE(host, device) void set_stopped() noexcept
{
using tuple_t = decayed_tuple_t<set_stopped_t>;
variant_->template emplace<tuple_t>(set_stopped_t());
producer_(task_);
}

template <class Error>
STDEXEC_ATTRIBUTE(host, device)
STDEXEC_ATTRIBUTE(device)
void set_error(Error&& err) noexcept
{
if constexpr (__decays_to<Error, std::exception_ptr>)
{
// What is `exception_ptr` but death pending
// What is `exception_ptr` but death pending?
using tuple_t = decayed_tuple_t<set_error_t, cudaError_t>;
variant_->template emplace<tuple_t>(STDEXEC::set_error, cudaErrorUnknown);
}
Expand All @@ -565,6 +558,15 @@ namespace nv::execution
producer_(task_);
}

STDEXEC_ATTRIBUTE(device)
void set_stopped() noexcept
{
using tuple_t = decayed_tuple_t<set_stopped_t>;
variant_->template emplace<tuple_t>(set_stopped_t());
producer_(task_);
}

[[nodiscard]]
auto get_env() const noexcept -> Env const &
{
return *env_;
Expand Down
15 changes: 5 additions & 10 deletions include/nvexec/stream/repeat_n.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,8 @@ namespace nv::execution::_strm
{
using operation_state_concept = STDEXEC::operation_state_t;

using scheduler_t =
STDEXEC::__result_of<STDEXEC::get_completion_scheduler<STDEXEC::set_value_t>,
STDEXEC::env_of_t<Sender>,
STDEXEC::env_of_t<Receiver>>;

using inner_sender_t =
STDEXEC::__result_of<exec::sequence, STDEXEC::schedule_result_t<scheduler_t&>, Sender&>;
using scheduler_t = __completion_scheduler_of_t<set_value_t, Sender, env_of_t<Receiver>>;
using inner_sender_t = STDEXEC::__result_of<STDEXEC::starts_on, scheduler_t, Sender&>;
using inner_opstate_t = STDEXEC::connect_result_t<inner_sender_t, receiver<opstate>>;

explicit opstate(Sender&& sndr, Receiver rcvr, std::size_t count, scheduler_t sched)
Expand All @@ -91,9 +86,9 @@ namespace nv::execution::_strm

auto& _connect()
{
inner_opstate_.__emplace_from(STDEXEC::connect,
exec::sequence(STDEXEC::schedule(sched_), sndr_),
receiver{*this});
return inner_opstate_.__emplace_from(STDEXEC::connect,
STDEXEC::starts_on(sched_, sndr_),
receiver{*this});
}

template <class Tag, class... Args>
Expand Down
39 changes: 39 additions & 0 deletions include/nvexec/stream/starts_on.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2026 NVIDIA Corporation
*
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://llvm.org/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// clang-format Language: Cpp

#pragma once

#include "../../stdexec/execution.hpp"
#include "../../stdexec/functional.hpp"

#include "let_xxx.cuh"

#include "common.cuh"

namespace nv::execution::_strm
{
template <>
struct transform_sender_for<STDEXEC::starts_on_t>
{
template <class Env, STDEXEC::scheduler Scheduler, STDEXEC::sender Sender>
auto operator()(Env const &, STDEXEC::starts_on_t, Scheduler&& sched, Sender&& sndr) const
{
return STDEXEC::let_value(STDEXEC::schedule(sched),
STDEXEC::__always(static_cast<Sender&&>(sndr)));
}
};
} // namespace nv::execution::_strm
25 changes: 12 additions & 13 deletions include/nvexec/stream/then.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@ namespace nv::execution::_strm
requires std::invocable<Fun, __decay_t<Args>...>
void set_value(Args&&... args) noexcept
{
using result_t = std::invoke_result_t<Fun, __decay_t<Args>...>;
constexpr bool does_not_return_a_value = std::is_same_v<void, result_t>;
_strm::opstate_base<Receiver>& opstate = opstate_;
cudaStream_t stream = opstate.get_stream();
using result_t = std::invoke_result_t<Fun, __decay_t<Args>...>;
constexpr bool does_not_return_a_value = std::is_same_v<void, result_t>;
cudaStream_t stream = opstate_.get_stream();

if constexpr (does_not_return_a_value)
{
Expand All @@ -83,29 +82,29 @@ namespace nv::execution::_strm
if (cudaError_t status = STDEXEC_LOG_CUDA_API(cudaPeekAtLastError());
status == cudaSuccess)
{
opstate.propagate_completion_signal(STDEXEC::set_value);
opstate_.propagate_completion_signal(STDEXEC::set_value);
}
else
{
opstate.propagate_completion_signal(STDEXEC::set_error, std::move(status));
opstate_.propagate_completion_signal(STDEXEC::set_error, std::move(status));
}
}
else
{
using decayed_result_t = __decay_t<result_t>;
auto* d_result = static_cast<decayed_result_t*>(opstate.temp_storage_);
auto* d_result = static_cast<decayed_result_t*>(opstate_.temp_storage_);
_then_kernel_with_result<Args&&...>
<<<1, 1, 0, stream>>>(std::move(f_), d_result, static_cast<Args&&>(args)...);
opstate.defer_temp_storage_destruction(d_result);
opstate_.defer_temp_storage_destruction(d_result);

if (cudaError_t status = STDEXEC_LOG_CUDA_API(cudaPeekAtLastError());
status == cudaSuccess)
{
opstate.propagate_completion_signal(STDEXEC::set_value, std::move(*d_result));
opstate_.propagate_completion_signal(STDEXEC::set_value, std::move(*d_result));
}
else
{
opstate.propagate_completion_signal(STDEXEC::set_error, std::move(status));
opstate_.propagate_completion_signal(STDEXEC::set_error, std::move(status));
}
}
}
Expand Down Expand Up @@ -185,7 +184,7 @@ namespace nv::execution::_strm
static_cast<Self&&>(self).sndr_,
static_cast<Receiver&&>(rcvr),
[&](_strm::opstate_base<Receiver>& stream_provider) -> receiver_t<Receiver>
{ return receiver_t<Receiver>(self.fun_, stream_provider); });
{ return receiver_t<Receiver>(static_cast<Self&&>(self).fun_, stream_provider); });
}
STDEXEC_EXPLICIT_THIS_END(connect)

Expand All @@ -209,11 +208,11 @@ namespace nv::execution::_strm
struct transform_sender_for<STDEXEC::then_t>
{
template <class Env, class Fn, class CvSender>
auto operator()(Env const &, __ignore, Fn fun, CvSender&& sndr) const
auto operator()(Env const &, __ignore, Fn&& fun, CvSender&& sndr) const
{
if constexpr (stream_completing_sender<CvSender, Env>)
{
using _sender_t = then_sender<__decay_t<CvSender>, Fn>;
using _sender_t = then_sender<__decay_t<CvSender>, __decay_t<Fn>>;
return _sender_t{static_cast<CvSender&&>(sndr), static_cast<Fn&&>(fun)};
}
else
Expand Down
1 change: 1 addition & 0 deletions include/nvexec/stream_context.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "stream/schedule_from.cuh" // IWYU pragma: export
#include "stream/split.cuh" // IWYU pragma: export
#include "stream/start_detached.cuh" // IWYU pragma: export
#include "stream/starts_on.cuh" // IWYU pragma: export
#include "stream/sync_wait.cuh" // IWYU pragma: export
#include "stream/then.cuh" // IWYU pragma: export
#include "stream/upon_error.cuh" // IWYU pragma: export
Expand Down
23 changes: 21 additions & 2 deletions include/stdexec/__detail/__bulk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,24 @@ namespace STDEXEC
template <class _Fun>
STDEXEC_HOST_DEVICE_DEDUCTION_GUIDE __as_bulk_chunked_fn(_Fun) -> __as_bulk_chunked_fn<_Fun>;

template <class _Child>
struct __attrs : env<__fwd_env_t<env_of_t<_Child>>>
{
using __base_t = env<__fwd_env_t<env_of_t<_Child>>>;
using __base_t::query;

constexpr explicit __attrs(_Child const & __child) noexcept
: __base_t{__fwd_env(STDEXEC::get_env(__child))}
{}

template <class... _Env>
STDEXEC_ATTRIBUTE(nodiscard, always_inline, host, device)
constexpr auto query(__get_completion_behavior_t<set_value_t>, _Env&&...) const noexcept
{
return STDEXEC::__get_completion_behavior<set_value_t, _Child, _Env...>();
}
};

template <class _AlgoTag>
struct __impl_base : __sexpr_defaults
{
Expand All @@ -252,9 +270,10 @@ namespace STDEXEC
using __shape_t = decltype(__decay_t<__data_of<_Sender>>::__shape_);

// Forward the child sender's environment (which contains completion scheduler)
static constexpr auto __get_attrs = [](__ignore, __ignore, auto const & __child) noexcept
static constexpr auto __get_attrs = //
[]<class _Child>(__ignore, __ignore, _Child const & __child) noexcept
{
return __fwd_env(STDEXEC::get_env(__child));
return __attrs{__child};
};

template <class _Sender, class... _Env>
Expand Down
33 changes: 18 additions & 15 deletions include/stdexec/__detail/__continues_on.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ namespace STDEXEC
return false;
}

_Scheduler const & __sch_;
_Sender const & __sndr_;
_Scheduler __sch_;
env_of_t<_Sender> __attrs_;

public:
constexpr explicit __attrs(_Scheduler const & __sch, _Sender const & __sndr) noexcept
: __sch_(__sch)
, __sndr_(__sndr)
constexpr explicit __attrs(_Scheduler __sch, env_of_t<_Sender> __attrs) noexcept
: __sch_(static_cast<_Scheduler&&>(__sch))
, __attrs_(static_cast<env_of_t<_Sender>&&>(__attrs))
{}

//! @brief Queries the completion scheduler for a given @c _SetTag.
Expand Down Expand Up @@ -221,7 +221,7 @@ namespace STDEXEC
env_of_t<_Sender>,
__fwd_env_t<_Env>...>
{
return get_completion_scheduler<_SetTag>(get_env(__sndr_), __fwd_env(__env)...);
return get_completion_scheduler<_SetTag>(__attrs_, __fwd_env(__env)...);
}

//! @brief Queries the completion domain for a given @c _SetTag.
Expand Down Expand Up @@ -295,15 +295,16 @@ namespace STDEXEC
}

//! @brief Forwards other queries to the underlying sender's environment.
//! @pre @c _Tag is a forwarding query but not a completion query.
template <__forwarding_query _Tag, class... _Args>
requires(!__completion_query<_Tag>) && __queryable_with<env_of_t<_Sender>, _Tag, _Args...>
//! @pre @c _Query is a forwarding query but not a completion query.
template <__forwarding_query _Query, class... _Args>
requires(!__completion_query<_Query>)
&& __queryable_with<env_of_t<_Sender>, _Query, _Args...>
[[nodiscard]]
constexpr auto query(_Tag, _Args&&... __args) const
noexcept(__nothrow_queryable_with<env_of_t<_Sender>, _Tag, _Args...>)
-> __query_result_t<env_of_t<_Sender>, _Tag, _Args...>
constexpr auto query(_Query, _Args&&... __args) const
noexcept(__nothrow_queryable_with<env_of_t<_Sender>, _Query, _Args...>)
-> __query_result_t<env_of_t<_Sender>, _Query, _Args...>
{
return get_env(__sndr_).query(_Tag{}, static_cast<_Args&&>(__args)...);
return __attrs_.query(_Query(), static_cast<_Args&&>(__args)...);
}
};

Expand Down Expand Up @@ -346,9 +347,11 @@ namespace STDEXEC

public:
static constexpr auto __get_attrs =
[](__ignore, auto const & __data, auto const & __child) noexcept
[]<class _Scheduler, class _Child>(__ignore,
_Scheduler const & __data,
_Child const & __child) noexcept
{
return __attrs{__data, __child};
return __attrs<_Scheduler, _Child>{__data, STDEXEC::get_env(__child)};
};

template <class _Sender, class... _Env>
Expand Down
Loading
Loading