Skip to content

Commit

Permalink
Create consume_rvalues separate from unwrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
aurianer committed Aug 15, 2023
1 parent ccd357b commit 90bda4a
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 19 deletions.
43 changes: 43 additions & 0 deletions include/dlaf/common/consume_rvalues.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//
// Distributed Linear Algebra with Future (DLAF)
//
// Copyright (c) 2018-2023, ETH Zurich
// All rights reserved.
//
// Please, refer to the LICENSE file in the root directory.
// SPDX-License-Identifier: BSD-3-Clause
//
#pragma once

/// @file

#include <tuple>
#include <type_traits>
#include <utility>

namespace dlaf::common::internal {
/// ConsumeRvalues is a callable object wrapper that consumes rvalues passed as arguments
/// after calling the wrapped callable.
template <typename F>
struct ConsumeRvalues {
std::decay_t<F> f;

template <typename... Ts>
auto operator()(Ts&&... ts) -> decltype(std::move(f)(std::forward<Ts>(ts)...)) {
using result_type = decltype(std::move(f)(std::forward<Ts>(ts)...));
if constexpr (std::is_void_v<result_type>) {
std::move(f)(std::forward<Ts>(ts)...);
std::tuple<Ts...>(std::forward<Ts>(ts)...);
}
else {
auto r = std::move(f)(std::forward<Ts>(ts)...);
std::tuple<Ts...>(std::forward<Ts>(ts)...);
return r;
}
}
};

template <typename F>
ConsumeRvalues(F&&) -> ConsumeRvalues<std::decay_t<F>>;

}
13 changes: 1 addition & 12 deletions include/dlaf/common/unwrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,7 @@ struct Unwrapping {
template <typename... Ts>
auto operator()(Ts&&... ts)
-> decltype(std::move(f)(Unwrapper<std::decay_t<Ts>>::unwrap(std::forward<Ts>(ts))...)) {
using result_type = decltype(std::move(f)(Unwrapper<std::decay_t<Ts>>::unwrap(std::forward<Ts>(ts))...));
if constexpr(std::is_void_v<result_type>)
{
std::move(f)(Unwrapper<std::decay_t<Ts>>::unwrap(std::forward<Ts>(ts))...);
std::tuple<Ts...>(std::forward<Ts>(ts)...);
}
else
{
auto r = std::move(f)(Unwrapper<std::decay_t<Ts>>::unwrap(std::forward<Ts>(ts))...);
std::tuple<Ts...>(std::forward<Ts>(ts)...);
return r;
}
return std::move(f)(Unwrapper<std::decay_t<Ts>>::unwrap(std::forward<Ts>(ts))...);
}
};

Expand Down
15 changes: 9 additions & 6 deletions include/dlaf/sender/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <pika/execution.hpp>

#include <dlaf/common/consume_rvalues.h>
#include <dlaf/common/unwrap.h>
#include <dlaf/init.h>
#include <dlaf/schedulers.h>
Expand Down Expand Up @@ -43,7 +44,7 @@ enum class TransformDispatchType { Plain, Blas, Lapack };
// allows choosing the priority.
//
// At its core, transform is a convenience wrapper around
// sender | transfer(with_priority(scheduler, priority)) | then(unwrapping(f)).
// sender | transfer(with_priority(scheduler, priority)) | then(ConsumeRvalues(unwrapping(f))).

/// Lazy transform. This does not submit the work and returns a sender.
template <TransformDispatchType Tag = TransformDispatchType::Plain, Backend B = Backend::MC,
Expand All @@ -56,8 +57,11 @@ template <TransformDispatchType Tag = TransformDispatchType::Plain, Backend B =
auto scheduler = getBackendScheduler<B>(policy.priority());
auto transfer_sender = transfer(std::forward<Sender>(sender), std::move(scheduler));

using dlaf::common::internal::ConsumeRvalues;
using dlaf::common::internal::Unwrapping;

if constexpr (B == Backend::MC) {
return then(std::move(transfer_sender), dlaf::common::internal::Unwrapping{std::forward<F>(f)});
return then(std::move(transfer_sender), ConsumeRvalues{Unwrapping{std::forward<F>(f)}});
}
else if constexpr (B == Backend::GPU) {
#if defined(DLAF_WITH_GPU)
Expand All @@ -67,16 +71,15 @@ template <TransformDispatchType Tag = TransformDispatchType::Plain, Backend B =

if constexpr (Tag == TransformDispatchType::Plain) {
return then_with_stream(std::move(transfer_sender),
dlaf::common::internal::Unwrapping{std::forward<F>(f)});
ConsumeRvalues{Unwrapping{std::forward<F>(f)}});
}
else if constexpr (Tag == TransformDispatchType::Blas) {
return then_with_cublas(std::move(transfer_sender),
dlaf::common::internal::Unwrapping{std::forward<F>(f)},
return then_with_cublas(std::move(transfer_sender), ConsumeRvalues{Unwrapping{std::forward<F>(f)}},
CUBLAS_POINTER_MODE_HOST);
}
else if constexpr (Tag == TransformDispatchType::Lapack) {
return then_with_cusolver(std::move(transfer_sender),
dlaf::common::internal::Unwrapping{std::forward<F>(f)});
ConsumeRvalues{Unwrapping{std::forward<F>(f)}});
}
else {
DLAF_STATIC_FAIL(
Expand Down
3 changes: 2 additions & 1 deletion include/dlaf/sender/transform_mpi.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <type_traits>

#include <dlaf/common/consume_rvalues.h>
#include <dlaf/common/pipeline.h>
#include <dlaf/common/unwrap.h>
#include <dlaf/communication/communicator.h>
Expand Down Expand Up @@ -88,7 +89,7 @@ template <typename F, typename Sender,
namespace ex = pika::execution::experimental;

return ex::transfer(std::forward<Sender>(sender), dlaf::internal::getMPIScheduler()) |
ex::then(MPICallHelper{std::forward<F>(f)});
ex::then(dlaf::common::internal::ConsumeRvalues{MPICallHelper{std::forward<F>(f)}});
}

/// Fire-and-forget transformMPI. This submits the work and returns void.
Expand Down

0 comments on commit 90bda4a

Please sign in to comment.