Skip to content

Commit

Permalink
elementwise_min|max reduction op (#3341)
Browse files Browse the repository at this point in the history
`per_v_transform_reduce_incoming|outgoing_e` currently supports reduction operators that can be mapped to a raft::comms reduction operator (which is based on NCCL reduction). `per_v_transform_reduce_incoming|outgoing_e` currently takes min|max reduction op on thrust::tuple but performs elementwise min|max. This can be confusing to users. This PR updates `per_v_transform_reduce_incoming|outgoing_e` to take min/max reduction operators only when the value type is scalar (otherwise static_assertion will fail). If the value type is thrust::tuple, users need to pass an elementwise min|max operator instead (this will make it clearer to the users that the primitive will perform elementwise min|max reduction).

Authors:
  - Seunghwa Kang (https://github.com/seunghwak)

Approvers:
  - Chuck Hastings (https://github.com/ChuckHastings)
  - Joseph Nke (https://github.com/jnke2016)
  - Naim (https://github.com/naimnv)

URL: #3341
  • Loading branch information
seunghwak authored Mar 17, 2023
1 parent 759a24e commit ce00167
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle,
ReduceOp reduce_op,
VertexValueOutputIterator vertex_value_output_first)
{
static_assert(ReduceOp::pure_function || reduce_op::has_compatible_raft_comms_op_v<ReduceOp> ||
static_assert(ReduceOp::pure_function && reduce_op::has_compatible_raft_comms_op_v<ReduceOp> &&
reduce_op::has_identity_element_v<ReduceOp>); // current restriction, to support
// general reduction, we may need to
// take a less efficient code path
Expand Down
48 changes: 26 additions & 22 deletions cpp/src/prims/property_op_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -123,60 +123,62 @@ struct atomic_add_thrust_tuple_impl<Iterator, TupleType, I, I> {
};

template <typename T>
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> atomic_min_impl(
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> elementwise_atomic_min_impl(
thrust::detail::any_assign& /* dereferencing thrust::discard_iterator results in this type */ lhs,
T const& rhs)
{
// no-op
}

template <typename T>
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> atomic_min_impl(T& lhs,
T const& rhs)
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> elementwise_atomic_min_impl(
T& lhs, T const& rhs)
{
atomicMin(&lhs, rhs);
}

template <typename Iterator, typename TupleType, size_t I, size_t N>
struct atomic_min_thrust_tuple_impl {
struct elementwise_atomic_min_thrust_tuple_impl {
__device__ constexpr void compute(Iterator iter, TupleType const& value) const
{
atomic_min_impl(thrust::raw_reference_cast(thrust::get<I>(*iter)), thrust::get<I>(value));
atomic_min_thrust_tuple_impl<Iterator, TupleType, I + 1, N>().compute(iter, value);
elementwise_atomic_min_impl(thrust::raw_reference_cast(thrust::get<I>(*iter)),
thrust::get<I>(value));
elementwise_atomic_min_thrust_tuple_impl<Iterator, TupleType, I + 1, N>().compute(iter, value);
}
};

template <typename Iterator, typename TupleType, size_t I>
struct atomic_min_thrust_tuple_impl<Iterator, TupleType, I, I> {
struct elementwise_atomic_min_thrust_tuple_impl<Iterator, TupleType, I, I> {
__device__ constexpr void compute(Iterator iter, TupleType const& value) const {}
};

template <typename T>
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> atomic_max_impl(
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> elementwise_atomic_max_impl(
thrust::detail::any_assign& /* dereferencing thrust::discard_iterator results in this type */ lhs,
T const& rhs)
{
// no-op
}

template <typename T>
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> atomic_max_impl(T& lhs,
T const& rhs)
__device__ std::enable_if_t<std::is_arithmetic<T>::value, void> elementwise_atomic_max_impl(
T& lhs, T const& rhs)
{
atomicMax(&lhs, rhs);
}

template <typename Iterator, typename TupleType, size_t I, size_t N>
struct atomic_max_thrust_tuple_impl {
struct elementwise_atomic_max_thrust_tuple_impl {
__device__ constexpr void compute(Iterator iter, TupleType const& value) const
{
atomic_max_impl(thrust::raw_reference_cast(thrust::get<I>(*iter)), thrust::get<I>(value));
atomic_max_thrust_tuple_impl<Iterator, TupleType, I + 1, N>().compute(iter, value);
elementwise_atomic_max_impl(thrust::raw_reference_cast(thrust::get<I>(*iter)),
thrust::get<I>(value));
elementwise_atomic_max_thrust_tuple_impl<Iterator, TupleType, I + 1, N>().compute(iter, value);
}
};

template <typename Iterator, typename TupleType, size_t I>
struct atomic_max_thrust_tuple_impl<Iterator, TupleType, I, I> {
struct elementwise_atomic_max_thrust_tuple_impl<Iterator, TupleType, I, I> {
__device__ constexpr void compute(Iterator iter, TupleType const& value) const {}
};

Expand Down Expand Up @@ -292,7 +294,7 @@ __device__

template <typename Iterator, typename T>
__device__ std::enable_if_t<thrust::detail::is_discard_iterator<Iterator>::value, void>
atomic_min_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_min_edge_op_result(Iterator iter, T const& value)
{
// no-op
}
Expand All @@ -302,7 +304,7 @@ __device__
std::enable_if_t<std::is_same<typename thrust::iterator_traits<Iterator>::value_type, T>::value &&
std::is_arithmetic<T>::value,
void>
atomic_min_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_min_edge_op_result(Iterator iter, T const& value)
{
atomicMin(&(thrust::raw_reference_cast(*iter)), value);
}
Expand All @@ -312,17 +314,18 @@ __device__
std::enable_if_t<is_thrust_tuple<typename thrust::iterator_traits<Iterator>::value_type>::value &&
is_thrust_tuple<T>::value,
void>
atomic_min_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_min_edge_op_result(Iterator iter, T const& value)
{
static_assert(thrust::tuple_size<typename thrust::iterator_traits<Iterator>::value_type>::value ==
thrust::tuple_size<T>::value);
size_t constexpr tuple_size = thrust::tuple_size<T>::value;
detail::atomic_min_thrust_tuple_impl<Iterator, T, size_t{0}, tuple_size>().compute(iter, value);
detail::elementwise_atomic_min_thrust_tuple_impl<Iterator, T, size_t{0}, tuple_size>().compute(
iter, value);
}

template <typename Iterator, typename T>
__device__ std::enable_if_t<thrust::detail::is_discard_iterator<Iterator>::value, void>
atomic_max_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_max_edge_op_result(Iterator iter, T const& value)
{
// no-op
}
Expand All @@ -332,7 +335,7 @@ __device__
std::enable_if_t<std::is_same<typename thrust::iterator_traits<Iterator>::value_type, T>::value &&
std::is_arithmetic<T>::value,
void>
atomic_max_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_max_edge_op_result(Iterator iter, T const& value)
{
atomicMax(&(thrust::raw_reference_cast(*iter)), value);
}
Expand All @@ -342,12 +345,13 @@ __device__
std::enable_if_t<is_thrust_tuple<typename thrust::iterator_traits<Iterator>::value_type>::value &&
is_thrust_tuple<T>::value,
void>
atomic_max_edge_op_result(Iterator iter, T const& value)
elementwise_atomic_max_edge_op_result(Iterator iter, T const& value)
{
static_assert(thrust::tuple_size<typename thrust::iterator_traits<Iterator>::value_type>::value ==
thrust::tuple_size<T>::value);
size_t constexpr tuple_size = thrust::tuple_size<T>::value;
detail::atomic_max_thrust_tuple_impl<Iterator, T, size_t{0}, tuple_size>().compute(iter, value);
detail::elementwise_atomic_max_thrust_tuple_impl<Iterator, T, size_t{0}, tuple_size>().compute(
iter, value);
}

} // namespace cugraph
136 changes: 126 additions & 10 deletions cpp/src/prims/reduce_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,37 @@

#include <prims/property_op_utils.cuh>

#include <cugraph/utilities/thrust_tuple_utils.hpp>

#include <raft/core/comms.hpp>

#include <thrust/functional.h>

#include <utility>

namespace cugraph {
namespace reduce_op {

namespace detail {

template <typename T, std::size_t... Is>
__host__ __device__ std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<T>::value, T>
elementwise_thrust_min(T lhs, T rhs, std::index_sequence<Is...>)
{
return thrust::make_tuple(
(thrust::get<Is>(lhs) < thrust::get<Is>(rhs) ? thrust::get<Is>(lhs) : thrust::get<Is>(rhs))...);
}

template <typename T, std::size_t... Is>
__host__ __device__ std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<T>::value, T>
elementwise_thrust_max(T lhs, T rhs, std::index_sequence<Is...>)
{
return thrust::make_tuple(
(thrust::get<Is>(lhs) < thrust::get<Is>(rhs) ? thrust::get<Is>(rhs) : thrust::get<Is>(lhs))...);
}

} // namespace detail

// Guidance on writing a custom reduction operator.
// 1. It is required to add an "using value_type = type_of_the_reduced_values" statement.
// 2. A custom reduction operator MUST be side-effect free. We use thrust::reduce internally to
Expand Down Expand Up @@ -52,8 +76,8 @@ struct null {
using value_type = void;
};

// Binary reduction operator selecting any of the two input arguments, T should be arithmetic types
// or thrust tuple of arithmetic types.
// Binary reduction operator selecting any of the two input arguments, T should be an arithmetic
// type or a thrust tuple of arithmetic types.
template <typename T>
struct any {
using value_type = T;
Expand All @@ -62,10 +86,13 @@ struct any {
__host__ __device__ T operator()(T const& lhs, T const& rhs) const { return lhs; }
};

template <typename T, typename Enable = void>
struct minimum;

// Binary reduction operator selecting the minimum element of the two input arguments (using
// operator <), T should be arithmetic types or thrust tuple of arithmetic types.
// operator <), a compatible raft comms op exists if T is an arithmetic type.
template <typename T>
struct minimum {
struct minimum<T, std::enable_if_t<std::is_arithmetic_v<T>>> {
using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MIN;
Expand All @@ -77,10 +104,55 @@ struct minimum {
}
};

// Binary reduction operator selecting the minimum element of the two input arguments (using
// operator <), a compatible raft comms op does not exist when T is a thrust::tuple type.
template <typename T>
struct minimum<T, std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<T>::value>> {
using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
inline static T const identity_element = max_identity_element<T>();

__host__ __device__ T operator()(T const& lhs, T const& rhs) const
{
return lhs < rhs ? lhs : rhs;
}
};

// Binary reduction operator selecting the minimum element of the two input arguments elementwise
// (using operator < for each element), T should be an arithmetic type (this is identical to
// reduce_op::minimum if T is an arithmetic type) or a thrust tuple of arithmetic types.
template <typename T>
struct elementwise_minimum {
static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic<T>::value);

using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MIN;
inline static T const identity_element = max_identity_element<T>();

template <typename U = T>
__host__ __device__ std::enable_if_t<std::is_arithmetic_v<U>, T> operator()(T const& lhs,
T const& rhs) const
{
return lhs < rhs ? lhs : rhs;
}

template <typename U = T>
__host__ __device__ std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<U>::value, T>
operator()(T const& lhs, T const& rhs) const
{
return detail::elementwise_thrust_min(
lhs, rhs, std::make_index_sequence<thrust::tuple_size<T>::value>());
}
};

template <typename T, typename Enable = void>
struct maximum;

// Binary reduction operator selecting the maximum element of the two input arguments (using
// operator <), T should be arithmetic types or thrust tuple of arithmetic types.
// operator <), a compatible raft comms op exists if T is an arithmetic type.
template <typename T>
struct maximum {
struct maximum<T, std::enable_if_t<std::is_arithmetic_v<T>>> {
using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MAX;
Expand All @@ -92,10 +164,54 @@ struct maximum {
}
};

// Binary reduction operator summing the two input arguments, T should be arithmetic types or thrust
// tuple of arithmetic types.
// Binary reduction operator selecting the maximum element of the two input arguments (using
// operator <), a compatible raft comms op does not exist when T is a thrust::tuple type.
template <typename T>
struct maximum<T, std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<T>::value>> {
using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
inline static T const identity_element = min_identity_element<T>();

__host__ __device__ T operator()(T const& lhs, T const& rhs) const
{
return lhs < rhs ? rhs : lhs;
}
};

// Binary reduction operator selecting the maximum element of the two input arguments elementwise
// (using operator < for each element), T should be an arithmetic type (this is identical to
// reduce_op::maximum if T is an arithmetic type) or a thrust tuple of arithmetic types.
template <typename T>
struct elementwise_maximum {
static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic<T>::value);

using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MAX;
inline static T const identity_element = min_identity_element<T>();

template <typename U = T>
__host__ __device__ std::enable_if_t<std::is_arithmetic_v<U>, T> operator()(T const& lhs,
T const& rhs) const
{
return lhs < rhs ? rhs : lhs;
}

template <typename U = T>
__host__ __device__ std::enable_if_t<cugraph::is_thrust_tuple_of_arithmetic<U>::value, T>
operator()(T const& lhs, T const& rhs) const
{
return detail::elementwise_thrust_max(
lhs, rhs, std::make_index_sequence<thrust::tuple_size<T>::value>());
}
};

// Binary reduction operator summing the two input arguments, T should be an arithmetic type or a
// thrust tuple of arithmetic types.
template <typename T>
struct plus {
static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic<T>::value);

using value_type = T;
static constexpr bool pure_function = true; // this can be called in any process
static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::SUM;
Expand Down Expand Up @@ -146,9 +262,9 @@ __device__ std::enable_if_t<has_compatible_raft_comms_op_v<ReduceOp>, void> atom
if constexpr (ReduceOp::compatible_raft_comms_op == raft::comms::op_t::SUM) {
atomic_add_edge_op_result(iter, value);
} else if constexpr (ReduceOp::compatible_raft_comms_op == raft::comms::op_t::MIN) {
atomic_min_edge_op_result(iter, value);
elementwise_atomic_min_edge_op_result(iter, value);
} else {
atomic_max_edge_op_result(iter, value);
elementwise_atomic_max_edge_op_result(iter, value);
}
}

Expand Down
Loading

0 comments on commit ce00167

Please sign in to comment.