From ce00167822e971021665b4b2f85207c1d5ee8d61 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang <45857425+seunghwak@users.noreply.github.com> Date: Fri, 17 Mar 2023 13:08:40 -0700 Subject: [PATCH] elementwise_min|max reduction op (#3341) `per_v_transform_reduce_incoming|outgoing_e` currently supports reduction operators that can be mapped to a raft::comms reduction operator (which is based on NCCL reduction). `per_v_transform_reduce_incoming|outgoing_e` currently takes min|max reduction op on thrust::tuple but performs elementwise min|max. This can be confusing to users. This PR updates `per_v_transform_reduce_incoming|outgoing_e` to take min/max reduction operators only when the value type is scalar (otherwise static_assertion will fail). If the value type is thrust::tuple, users need to pass an elementwise min|max operator instead (this will make it clearer to the users that the primitive will perform elementwise min|max reduction). Authors: - Seunghwa Kang (https://github.com/seunghwak) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Joseph Nke (https://github.com/jnke2016) - Naim (https://github.com/naimnv) URL: https://github.com/rapidsai/cugraph/pull/3341 --- ...v_transform_reduce_incoming_outgoing_e.cuh | 2 +- cpp/src/prims/property_op_utils.cuh | 48 ++++--- cpp/src/prims/reduce_op.cuh | 136 ++++++++++++++++-- ..._v_transform_reduce_incoming_outgoing_e.cu | 36 ++--- 4 files changed, 171 insertions(+), 51 deletions(-) diff --git a/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh b/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh index 2181d7831f6..bf8baf03c80 100644 --- a/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh +++ b/cpp/src/prims/per_v_transform_reduce_incoming_outgoing_e.cuh @@ -482,7 +482,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, ReduceOp reduce_op, VertexValueOutputIterator vertex_value_output_first) { - static_assert(ReduceOp::pure_function || reduce_op::has_compatible_raft_comms_op_v || + static_assert(ReduceOp::pure_function && reduce_op::has_compatible_raft_comms_op_v && reduce_op::has_identity_element_v); // current restriction, to support // general reduction, we may need to // take a less efficient code path diff --git a/cpp/src/prims/property_op_utils.cuh b/cpp/src/prims/property_op_utils.cuh index 5b350d2f083..a55dbfbe5ba 100644 --- a/cpp/src/prims/property_op_utils.cuh +++ b/cpp/src/prims/property_op_utils.cuh @@ -123,7 +123,7 @@ struct atomic_add_thrust_tuple_impl { }; template -__device__ std::enable_if_t::value, void> atomic_min_impl( +__device__ std::enable_if_t::value, void> elementwise_atomic_min_impl( thrust::detail::any_assign& /* dereferencing thrust::discard_iterator results in this type */ lhs, T const& rhs) { @@ -131,28 +131,29 @@ __device__ std::enable_if_t::value, void> atomic_min_impl( } template -__device__ std::enable_if_t::value, void> atomic_min_impl(T& lhs, - T const& rhs) +__device__ std::enable_if_t::value, void> elementwise_atomic_min_impl( + T& lhs, T const& rhs) { atomicMin(&lhs, rhs); } template -struct atomic_min_thrust_tuple_impl { +struct elementwise_atomic_min_thrust_tuple_impl { __device__ constexpr void compute(Iterator iter, TupleType const& value) const { - atomic_min_impl(thrust::raw_reference_cast(thrust::get(*iter)), thrust::get(value)); - atomic_min_thrust_tuple_impl().compute(iter, value); + elementwise_atomic_min_impl(thrust::raw_reference_cast(thrust::get(*iter)), + thrust::get(value)); + elementwise_atomic_min_thrust_tuple_impl().compute(iter, value); } }; template -struct atomic_min_thrust_tuple_impl { +struct elementwise_atomic_min_thrust_tuple_impl { __device__ constexpr void compute(Iterator iter, TupleType const& value) const {} }; template -__device__ std::enable_if_t::value, void> atomic_max_impl( +__device__ std::enable_if_t::value, void> elementwise_atomic_max_impl( thrust::detail::any_assign& /* dereferencing thrust::discard_iterator results in this type */ lhs, T const& rhs) { @@ -160,23 +161,24 @@ __device__ std::enable_if_t::value, void> atomic_max_impl( } template -__device__ std::enable_if_t::value, void> atomic_max_impl(T& lhs, - T const& rhs) +__device__ std::enable_if_t::value, void> elementwise_atomic_max_impl( + T& lhs, T const& rhs) { atomicMax(&lhs, rhs); } template -struct atomic_max_thrust_tuple_impl { +struct elementwise_atomic_max_thrust_tuple_impl { __device__ constexpr void compute(Iterator iter, TupleType const& value) const { - atomic_max_impl(thrust::raw_reference_cast(thrust::get(*iter)), thrust::get(value)); - atomic_max_thrust_tuple_impl().compute(iter, value); + elementwise_atomic_max_impl(thrust::raw_reference_cast(thrust::get(*iter)), + thrust::get(value)); + elementwise_atomic_max_thrust_tuple_impl().compute(iter, value); } }; template -struct atomic_max_thrust_tuple_impl { +struct elementwise_atomic_max_thrust_tuple_impl { __device__ constexpr void compute(Iterator iter, TupleType const& value) const {} }; @@ -292,7 +294,7 @@ __device__ template __device__ std::enable_if_t::value, void> -atomic_min_edge_op_result(Iterator iter, T const& value) +elementwise_atomic_min_edge_op_result(Iterator iter, T const& value) { // no-op } @@ -302,7 +304,7 @@ __device__ std::enable_if_t::value_type, T>::value && std::is_arithmetic::value, void> - atomic_min_edge_op_result(Iterator iter, T const& value) + elementwise_atomic_min_edge_op_result(Iterator iter, T const& value) { atomicMin(&(thrust::raw_reference_cast(*iter)), value); } @@ -312,17 +314,18 @@ __device__ std::enable_if_t::value_type>::value && is_thrust_tuple::value, void> - atomic_min_edge_op_result(Iterator iter, T const& value) + elementwise_atomic_min_edge_op_result(Iterator iter, T const& value) { static_assert(thrust::tuple_size::value_type>::value == thrust::tuple_size::value); size_t constexpr tuple_size = thrust::tuple_size::value; - detail::atomic_min_thrust_tuple_impl().compute(iter, value); + detail::elementwise_atomic_min_thrust_tuple_impl().compute( + iter, value); } template __device__ std::enable_if_t::value, void> -atomic_max_edge_op_result(Iterator iter, T const& value) +elementwise_atomic_max_edge_op_result(Iterator iter, T const& value) { // no-op } @@ -332,7 +335,7 @@ __device__ std::enable_if_t::value_type, T>::value && std::is_arithmetic::value, void> - atomic_max_edge_op_result(Iterator iter, T const& value) + elementwise_atomic_max_edge_op_result(Iterator iter, T const& value) { atomicMax(&(thrust::raw_reference_cast(*iter)), value); } @@ -342,12 +345,13 @@ __device__ std::enable_if_t::value_type>::value && is_thrust_tuple::value, void> - atomic_max_edge_op_result(Iterator iter, T const& value) + elementwise_atomic_max_edge_op_result(Iterator iter, T const& value) { static_assert(thrust::tuple_size::value_type>::value == thrust::tuple_size::value); size_t constexpr tuple_size = thrust::tuple_size::value; - detail::atomic_max_thrust_tuple_impl().compute(iter, value); + detail::elementwise_atomic_max_thrust_tuple_impl().compute( + iter, value); } } // namespace cugraph diff --git a/cpp/src/prims/reduce_op.cuh b/cpp/src/prims/reduce_op.cuh index 4de0b4a698e..df3bfdf0ee2 100644 --- a/cpp/src/prims/reduce_op.cuh +++ b/cpp/src/prims/reduce_op.cuh @@ -18,13 +18,37 @@ #include +#include + #include #include +#include + namespace cugraph { namespace reduce_op { +namespace detail { + +template +__host__ __device__ std::enable_if_t::value, T> +elementwise_thrust_min(T lhs, T rhs, std::index_sequence) +{ + return thrust::make_tuple( + (thrust::get(lhs) < thrust::get(rhs) ? thrust::get(lhs) : thrust::get(rhs))...); +} + +template +__host__ __device__ std::enable_if_t::value, T> +elementwise_thrust_max(T lhs, T rhs, std::index_sequence) +{ + return thrust::make_tuple( + (thrust::get(lhs) < thrust::get(rhs) ? thrust::get(rhs) : thrust::get(lhs))...); +} + +} // namespace detail + // Guidance on writing a custom reduction operator. // 1. It is required to add an "using value_type = type_of_the_reduced_values" statement. // 2. A custom reduction operator MUST be side-effect free. We use thrust::reduce internally to @@ -52,8 +76,8 @@ struct null { using value_type = void; }; -// Binary reduction operator selecting any of the two input arguments, T should be arithmetic types -// or thrust tuple of arithmetic types. +// Binary reduction operator selecting any of the two input arguments, T should be an arithmetic +// type or a thrust tuple of arithmetic types. template struct any { using value_type = T; @@ -62,10 +86,13 @@ struct any { __host__ __device__ T operator()(T const& lhs, T const& rhs) const { return lhs; } }; +template +struct minimum; + // Binary reduction operator selecting the minimum element of the two input arguments (using -// operator <), T should be arithmetic types or thrust tuple of arithmetic types. +// operator <), a compatible raft comms op exists if T is an arithmetic type. template -struct minimum { +struct minimum>> { using value_type = T; static constexpr bool pure_function = true; // this can be called in any process static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MIN; @@ -77,10 +104,55 @@ struct minimum { } }; +// Binary reduction operator selecting the minimum element of the two input arguments (using +// operator <), a compatible raft comms op does not exist when T is a thrust::tuple type. +template +struct minimum::value>> { + using value_type = T; + static constexpr bool pure_function = true; // this can be called in any process + inline static T const identity_element = max_identity_element(); + + __host__ __device__ T operator()(T const& lhs, T const& rhs) const + { + return lhs < rhs ? lhs : rhs; + } +}; + +// Binary reduction operator selecting the minimum element of the two input arguments elementwise +// (using operator < for each element), T should be an arithmetic type (this is identical to +// reduce_op::minimum if T is an arithmetic type) or a thrust tuple of arithmetic types. +template +struct elementwise_minimum { + static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic::value); + + using value_type = T; + static constexpr bool pure_function = true; // this can be called in any process + static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MIN; + inline static T const identity_element = max_identity_element(); + + template + __host__ __device__ std::enable_if_t, T> operator()(T const& lhs, + T const& rhs) const + { + return lhs < rhs ? lhs : rhs; + } + + template + __host__ __device__ std::enable_if_t::value, T> + operator()(T const& lhs, T const& rhs) const + { + return detail::elementwise_thrust_min( + lhs, rhs, std::make_index_sequence::value>()); + } +}; + +template +struct maximum; + // Binary reduction operator selecting the maximum element of the two input arguments (using -// operator <), T should be arithmetic types or thrust tuple of arithmetic types. +// operator <), a compatible raft comms op exists if T is an arithmetic type. template -struct maximum { +struct maximum>> { using value_type = T; static constexpr bool pure_function = true; // this can be called in any process static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MAX; @@ -92,10 +164,54 @@ struct maximum { } }; -// Binary reduction operator summing the two input arguments, T should be arithmetic types or thrust -// tuple of arithmetic types. +// Binary reduction operator selecting the maximum element of the two input arguments (using +// operator <), a compatible raft comms op does not exist when T is a thrust::tuple type. +template +struct maximum::value>> { + using value_type = T; + static constexpr bool pure_function = true; // this can be called in any process + inline static T const identity_element = min_identity_element(); + + __host__ __device__ T operator()(T const& lhs, T const& rhs) const + { + return lhs < rhs ? rhs : lhs; + } +}; + +// Binary reduction operator selecting the maximum element of the two input arguments elementwise +// (using operator < for each element), T should be an arithmetic type (this is identical to +// reduce_op::maximum if T is an arithmetic type) or a thrust tuple of arithmetic types. +template +struct elementwise_maximum { + static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic::value); + + using value_type = T; + static constexpr bool pure_function = true; // this can be called in any process + static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::MAX; + inline static T const identity_element = min_identity_element(); + + template + __host__ __device__ std::enable_if_t, T> operator()(T const& lhs, + T const& rhs) const + { + return lhs < rhs ? rhs : lhs; + } + + template + __host__ __device__ std::enable_if_t::value, T> + operator()(T const& lhs, T const& rhs) const + { + return detail::elementwise_thrust_max( + lhs, rhs, std::make_index_sequence::value>()); + } +}; + +// Binary reduction operator summing the two input arguments, T should be an arithmetic type or a +// thrust tuple of arithmetic types. template struct plus { + static_assert(cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic::value); + using value_type = T; static constexpr bool pure_function = true; // this can be called in any process static constexpr raft::comms::op_t compatible_raft_comms_op = raft::comms::op_t::SUM; @@ -146,9 +262,9 @@ __device__ std::enable_if_t, void> atom if constexpr (ReduceOp::compatible_raft_comms_op == raft::comms::op_t::SUM) { atomic_add_edge_op_result(iter, value); } else if constexpr (ReduceOp::compatible_raft_comms_op == raft::comms::op_t::MIN) { - atomic_min_edge_op_result(iter, value); + elementwise_atomic_min_edge_op_result(iter, value); } else { - atomic_max_edge_op_result(iter, value); + elementwise_atomic_max_edge_op_result(iter, value); } } diff --git a/cpp/tests/prims/mg_per_v_transform_reduce_incoming_outgoing_e.cu b/cpp/tests/prims/mg_per_v_transform_reduce_incoming_outgoing_e.cu index a0d4baa5aea..b6ef8c701ef 100644 --- a/cpp/tests/prims/mg_per_v_transform_reduce_incoming_outgoing_e.cu +++ b/cpp/tests/prims/mg_per_v_transform_reduce_incoming_outgoing_e.cu @@ -214,9 +214,9 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE auto mg_dst_prop = cugraph::test::generate::dst_property( *handle_, mg_graph_view, mg_vertex_prop); - enum class reduction_type_t { PLUS, MINIMUM, MAXIMUM }; + enum class reduction_type_t { PLUS, ELEMWISE_MIN, ELEMWISE_MAX }; std::array reduction_types = { - reduction_type_t::PLUS, reduction_type_t::MINIMUM, reduction_type_t::MAXIMUM}; + reduction_type_t::PLUS, reduction_type_t::ELEMWISE_MIN, reduction_type_t::ELEMWISE_MAX}; std::vector(0, rmm::cuda_stream_view{}))> in_results{}; @@ -247,7 +247,7 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::reduce_op::plus{}, cugraph::get_dataframe_buffer_begin(in_results[i])); break; - case reduction_type_t::MINIMUM: + case reduction_type_t::ELEMWISE_MIN: per_v_transform_reduce_incoming_e(*handle_, mg_graph_view, mg_src_prop.view(), @@ -255,10 +255,10 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::edge_dummy_property_t{}.view(), e_op_t{}, property_initial_value, - cugraph::reduce_op::minimum{}, + cugraph::reduce_op::elementwise_minimum{}, cugraph::get_dataframe_buffer_begin(in_results[i])); break; - case reduction_type_t::MAXIMUM: + case reduction_type_t::ELEMWISE_MAX: per_v_transform_reduce_incoming_e(*handle_, mg_graph_view, mg_src_prop.view(), @@ -266,7 +266,7 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::edge_dummy_property_t{}.view(), e_op_t{}, property_initial_value, - cugraph::reduce_op::maximum{}, + cugraph::reduce_op::elementwise_maximum{}, cugraph::get_dataframe_buffer_begin(in_results[i])); break; default: FAIL() << "should not be reached."; @@ -300,7 +300,7 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::reduce_op::plus{}, cugraph::get_dataframe_buffer_begin(out_results[i])); break; - case reduction_type_t::MINIMUM: + case reduction_type_t::ELEMWISE_MIN: per_v_transform_reduce_outgoing_e(*handle_, mg_graph_view, mg_src_prop.view(), @@ -308,10 +308,10 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::edge_dummy_property_t{}.view(), e_op_t{}, property_initial_value, - cugraph::reduce_op::minimum{}, + cugraph::reduce_op::elementwise_minimum{}, cugraph::get_dataframe_buffer_begin(out_results[i])); break; - case reduction_type_t::MAXIMUM: + case reduction_type_t::ELEMWISE_MAX: per_v_transform_reduce_outgoing_e(*handle_, mg_graph_view, mg_src_prop.view(), @@ -319,7 +319,7 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::edge_dummy_property_t{}.view(), e_op_t{}, property_initial_value, - cugraph::reduce_op::maximum{}, + cugraph::reduce_op::elementwise_maximum{}, cugraph::get_dataframe_buffer_begin(out_results[i])); break; default: FAIL() << "should not be reached."; @@ -371,7 +371,7 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::reduce_op::plus{}, cugraph::get_dataframe_buffer_begin(global_in_result)); break; - case reduction_type_t::MINIMUM: + case reduction_type_t::ELEMWISE_MIN: per_v_transform_reduce_incoming_e( *handle_, sg_graph_view, @@ -380,10 +380,10 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::edge_dummy_property_t{}.view(), e_op_t{}, property_initial_value, - cugraph::reduce_op::minimum{}, + cugraph::reduce_op::elementwise_minimum{}, cugraph::get_dataframe_buffer_begin(global_in_result)); break; - case reduction_type_t::MAXIMUM: + case reduction_type_t::ELEMWISE_MAX: per_v_transform_reduce_incoming_e( *handle_, sg_graph_view, @@ -392,7 +392,7 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::edge_dummy_property_t{}.view(), e_op_t{}, property_initial_value, - cugraph::reduce_op::maximum{}, + cugraph::reduce_op::elementwise_maximum{}, cugraph::get_dataframe_buffer_begin(global_in_result)); break; default: FAIL() << "should not be reached."; @@ -414,7 +414,7 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::reduce_op::plus{}, cugraph::get_dataframe_buffer_begin(global_out_result)); break; - case reduction_type_t::MINIMUM: + case reduction_type_t::ELEMWISE_MIN: per_v_transform_reduce_outgoing_e( *handle_, sg_graph_view, @@ -423,10 +423,10 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::edge_dummy_property_t{}.view(), e_op_t{}, property_initial_value, - cugraph::reduce_op::minimum{}, + cugraph::reduce_op::elementwise_minimum{}, cugraph::get_dataframe_buffer_begin(global_out_result)); break; - case reduction_type_t::MAXIMUM: + case reduction_type_t::ELEMWISE_MAX: per_v_transform_reduce_outgoing_e( *handle_, sg_graph_view, @@ -435,7 +435,7 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE cugraph::edge_dummy_property_t{}.view(), e_op_t{}, property_initial_value, - cugraph::reduce_op::maximum{}, + cugraph::reduce_op::elementwise_maximum{}, cugraph::get_dataframe_buffer_begin(global_out_result)); break; default: FAIL() << "should not be reached.";