Skip to content

Commit

Permalink
Add compatibility for sender/receiver customization points for HPX 1.…
Browse files Browse the repository at this point in the history
…7.X and HPX master (#428)

This allows building DLA-Future with HPX 1.7.X and master (that was not the case with #371). Two reasons:

- In 1.7.X sender and receiver customizations could be done as either member functions (these have priority) or tag_dispatch. On master we've made tag_dispatch the only customization mechanism (simplifies internals).
- In 1.7.X we made the bad decision of renaming tag_invoke to tag_dispatch (because of concerns of us using it in non-intended ways). We realized this was a mistake and renamed it back to tag_invoke on master.

So this PR changes to use tag_invoke/dispatch as the customization mechanism since that works on 1.7.X and master. Additionally, it puts a little local compatibility definition for tag_invoke/dispatch so that the correct function name is used depending on the HPX version.

Finally, because of the move to tag_invoke/dispatch the customization for set_value requires SFINAE to not even be considered as an overload for set_error and set_done.
  • Loading branch information
msimberg authored Nov 22, 2021
1 parent 1e237a4 commit 528d645
Showing 1 changed file with 64 additions and 39 deletions.
103 changes: 64 additions & 39 deletions include/dlaf/sender/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

#include <hpx/local/execution.hpp>
#include <hpx/local/unwrap.hpp>
#include <hpx/version.hpp>

#include "dlaf/init.h"
#include "dlaf/sender/policy.h"
#include "dlaf/sender/traits.h"
#include "dlaf/sender/typelist.h"
#include "dlaf/sender/when_all_lift.h"
#include "dlaf/types.h"
Expand All @@ -28,6 +30,15 @@
#include "dlaf/cusolver/handle_pool.h"
#endif

// This is a compatibility layer for HPX's tag_invoke. In 1.7.X it was
// mistakenly renamed to tag_dispatch. In 1.8.0 onwards it is called tag_invoke
// again.
#if HPX_VERSION_FULL >= 0x10800
#define DLAF_TAG_INVOKE tag_invoke
#else
#define DLAF_TAG_INVOKE tag_dispatch
#endif

namespace dlaf {
namespace internal {
/// DLAF-specific transform, templated on a backend. This, together with
Expand Down Expand Up @@ -63,32 +74,39 @@ struct Transform<Backend::GPU> {
std::decay_t<S> s;
std::decay_t<F> f;

using unwrapping_function_type = decltype(hpx::unwrapping(std::declval<std::decay_t<F>>()));
template <typename... Ts>
static constexpr bool is_cuda_stream_invocable =
std::is_invocable_v<unwrapping_function_type, std::decay_t<Ts>&..., cudaStream_t>;
template <typename... Ts>
static constexpr bool is_cublas_handle_invocable =
std::is_invocable_v<unwrapping_function_type, cublasHandle_t, std::decay_t<Ts>&...>;
template <typename... Ts>
static constexpr bool is_cusolver_handle_invocable =
std::is_invocable_v<unwrapping_function_type, cusolverDnHandle_t, std::decay_t<Ts>&...>;
template <typename... Ts>
static constexpr bool is_gpu_invocable =
is_cuda_stream_invocable<Ts...> || is_cublas_handle_invocable<Ts...> ||
is_cusolver_handle_invocable<Ts...>;

template <typename G, typename... Us>
static auto call_helper(cudaStream_t stream, cublasHandle_t cublas_handle,
cusolverDnHandle_t cusolver_handle, G&& g, Us&... us) {
using unwrapping_function_type = decltype(hpx::unwrapping(std::forward<G>(g)));
constexpr bool is_cuda_stream_invocable =
std::is_invocable_v<unwrapping_function_type, Us&..., cudaStream_t>;
constexpr bool is_cublas_handle_invocable =
std::is_invocable_v<unwrapping_function_type, cublasHandle_t, Us&...>;
constexpr bool is_cusolver_handle_invocable =
std::is_invocable_v<unwrapping_function_type, cusolverDnHandle_t, Us&...>;
static_assert(is_cuda_stream_invocable || is_cublas_handle_invocable ||
is_cusolver_handle_invocable,
"function passed to transform<GPU> must be invocable with a cublasStream_t as the "
static_assert(is_gpu_invocable<Us...>,
"function passed to transform<GPU> must be invocable with a cudaStream_t as the"
"last argument or a cublasHandle_t/cusolverDnHandle_t as the first argument");

if constexpr (is_cuda_stream_invocable) {
if constexpr (is_cuda_stream_invocable<Us...>) {
(void) cublas_handle;
(void) cusolver_handle;
return std::invoke(hpx::unwrapping(std::forward<G>(g)), us..., stream);
}
else if constexpr (is_cublas_handle_invocable) {
else if constexpr (is_cublas_handle_invocable<Us...>) {
(void) cusolver_handle;
(void) stream;
return std::invoke(hpx::unwrapping(std::forward<G>(g)), cublas_handle, us...);
}
else if constexpr (is_cusolver_handle_invocable) {
else if constexpr (is_cusolver_handle_invocable<Us...>) {
(void) cublas_handle;
(void) stream;
return std::invoke(hpx::unwrapping(std::forward<G>(g)), cusolver_handle, us...);
Expand Down Expand Up @@ -128,45 +146,48 @@ struct Transform<Backend::GPU> {
std::decay_t<F> f;

template <typename E>
void set_error(E&& e) && noexcept {
hpx::execution::experimental::set_error(std::move(r), std::forward<E>(e));
friend void DLAF_TAG_INVOKE(hpx::execution::experimental::set_error_t, GPUTransformReceiver&& r,
E&& e) noexcept {
hpx::execution::experimental::set_error(std::move(r.r), std::forward<E>(e));
}

void set_done() && noexcept {
friend void DLAF_TAG_INVOKE(hpx::execution::experimental::set_done_t,
GPUTransformReceiver&& r) noexcept {
hpx::execution::experimental::set_done(std::move(r));
}

template <typename... Ts>
void set_value(Ts&&... ts) noexcept {
template <typename... Ts, typename Enable = std::enable_if_t<is_gpu_invocable<Ts...>>>
friend auto DLAF_TAG_INVOKE(hpx::execution::experimental::set_value_t, GPUTransformReceiver&& r,
Ts&&... ts) {
try {
cudaStream_t stream = stream_pool.getNextStream();
cublasHandle_t cublas_handle = cublas_handle_pool.getNextHandle(stream);
cusolverDnHandle_t cusolver_handle = cusolver_handle_pool.getNextHandle(stream);
cudaStream_t stream = r.stream_pool.getNextStream();
cublasHandle_t cublas_handle = r.cublas_handle_pool.getNextHandle(stream);
cusolverDnHandle_t cusolver_handle = r.cusolver_handle_pool.getNextHandle(stream);

// NOTE: We do not forward ts because we keep the pack alive longer in
// the continuation.
if constexpr (std::is_void_v<decltype(
call_helper(stream, cublas_handle, cusolver_handle, std::move(f), ts...))>) {
call_helper(stream, cublas_handle, cusolver_handle, std::move(f), ts...);
if constexpr (std::is_void_v<decltype(call_helper(stream, cublas_handle, cusolver_handle,
std::move(r.f), ts...))>) {
call_helper(stream, cublas_handle, cusolver_handle, std::move(r.f), ts...);
hpx::cuda::experimental::detail::add_event_callback(
[r = std::move(r),
[r = std::move(r.r),
keep_alive =
std::make_tuple(std::forward<Ts>(ts)..., std::move(stream_pool),
std::move(cublas_handle_pool),
std::move(cusolver_handle_pool))](cudaError_t status) mutable {
std::make_tuple(std::forward<Ts>(ts)..., std::move(r.stream_pool),
std::move(r.cublas_handle_pool),
std::move(r.cusolver_handle_pool))](cudaError_t status) mutable {
DLAF_CUDA_CALL(status);
hpx::execution::experimental::set_value(std::move(r));
},
stream);
}
else {
auto res = call_helper(stream, cublas_handle, cusolver_handle, std::move(f), ts...);
auto res = call_helper(stream, cublas_handle, cusolver_handle, std::move(r.f), ts...);
hpx::cuda::experimental::detail::add_event_callback(
[r = std::move(r), res = std::move(res),
[r = std::move(r.r), res = std::move(res),
keep_alive =
std::make_tuple(std::forward<Ts>(ts)..., std::move(stream_pool),
std::move(cublas_handle_pool),
std::move(cusolver_handle_pool))](cudaError_t status) mutable {
std::make_tuple(std::forward<Ts>(ts)..., std::move(r.stream_pool),
std::move(r.cublas_handle_pool),
std::move(r.cusolver_handle_pool))](cudaError_t status) mutable {
DLAF_CUDA_CALL(status);
hpx::execution::experimental::set_value(std::move(r), std::move(res));
},
Expand All @@ -180,13 +201,15 @@ struct Transform<Backend::GPU> {
};

template <typename R>
auto connect(R&& r) && {
return hpx::execution::experimental::connect(std::move(s),
GPUTransformReceiver<R>{stream_pool,
cublas_handle_pool,
cusolver_handle_pool,
friend auto DLAF_TAG_INVOKE(hpx::execution::experimental::connect_t, GPUTransformSender&& s, R&& r) {
return hpx::execution::experimental::connect(std::move(s.s),
GPUTransformReceiver<R>{std::move(s.stream_pool),
std::move(
s.cublas_handle_pool),
std::move(
s.cusolver_handle_pool),
std::forward<R>(r),
std::move(f)});
std::move(s.f)});
}
};

Expand Down Expand Up @@ -227,3 +250,5 @@ void transformLiftDetach(const Policy<B> policy, F&& f, Ts&&... ts) {
}
}
}

#undef DLAF_TAG_INVOKE

0 comments on commit 528d645

Please sign in to comment.