diff --git a/cpp/src/structure/detail/structure_utils.cuh b/cpp/src/structure/detail/structure_utils.cuh index 7630d5855a0..f0f729bce18 100644 --- a/cpp/src/structure/detail/structure_utils.cuh +++ b/cpp/src/structure/detail/structure_utils.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -524,35 +525,21 @@ std::tuple> mark_entries(raft::handle_t co return word; }); - // FIXME: use detail::count_set_bits - size_t bit_count = thrust::transform_reduce( - handle.get_thrust_policy(), - marked_entries.begin(), - marked_entries.end(), - [] __device__(auto word) { return __popc(word); }, - size_t{0}, - thrust::plus()); + size_t bit_count = detail::count_set_bits(handle, marked_entries.begin(), num_entries); return std::make_tuple(bit_count, std::move(marked_entries)); } template -rmm::device_uvector remove_flagged_elements(raft::handle_t const& handle, - rmm::device_uvector&& vector, - raft::device_span remove_flags, - size_t remove_count) +rmm::device_uvector keep_flagged_elements(raft::handle_t const& handle, + rmm::device_uvector&& vector, + raft::device_span keep_flags, + size_t keep_count) { - rmm::device_uvector result(vector.size() - remove_count, handle.get_stream()); - - thrust::copy_if( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(vector.size()), - thrust::make_transform_output_iterator(result.begin(), - indirection_t{vector.data()}), - [remove_flags] __device__(size_t i) { - return !(remove_flags[cugraph::packed_bool_offset(i)] & cugraph::packed_bool_mask(i)); - }); + rmm::device_uvector result(keep_count, handle.get_stream()); + + detail::copy_if_mask_set( + handle, vector.begin(), vector.end(), keep_flags.begin(), result.begin()); return result; } diff --git a/cpp/src/structure/remove_multi_edges_impl.cuh b/cpp/src/structure/remove_multi_edges_impl.cuh index ab6b1fba8eb..fdd3059f874 100644 --- a/cpp/src/structure/remove_multi_edges_impl.cuh +++ b/cpp/src/structure/remove_multi_edges_impl.cuh @@ -254,50 +254,47 @@ remove_multi_edges(raft::handle_t const& handle, } } - auto [multi_edge_count, multi_edges_to_delete] = - detail::mark_entries(handle, - edgelist_srcs.size(), - [d_edgelist_srcs = edgelist_srcs.data(), - d_edgelist_dsts = edgelist_dsts.data()] __device__(auto idx) { - return (idx > 0) && (d_edgelist_srcs[idx - 1] == d_edgelist_srcs[idx]) && - (d_edgelist_dsts[idx - 1] == d_edgelist_dsts[idx]); - }); - - if (multi_edge_count > 0) { - edgelist_srcs = detail::remove_flagged_elements( + auto [keep_count, keep_flags] = detail::mark_entries( + handle, + edgelist_srcs.size(), + [d_edgelist_srcs = edgelist_srcs.data(), + d_edgelist_dsts = edgelist_dsts.data()] __device__(auto idx) { + return !((idx > 0) && (d_edgelist_srcs[idx - 1] == d_edgelist_srcs[idx]) && + (d_edgelist_dsts[idx - 1] == d_edgelist_dsts[idx])); + }); + + if (keep_count < edgelist_srcs.size()) { + edgelist_srcs = detail::keep_flagged_elements( handle, std::move(edgelist_srcs), - raft::device_span{multi_edges_to_delete.data(), multi_edges_to_delete.size()}, - multi_edge_count); - edgelist_dsts = detail::remove_flagged_elements( + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); + edgelist_dsts = detail::keep_flagged_elements( handle, std::move(edgelist_dsts), - raft::device_span{multi_edges_to_delete.data(), multi_edges_to_delete.size()}, - multi_edge_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_weights) - edgelist_weights = detail::remove_flagged_elements( + edgelist_weights = detail::keep_flagged_elements( handle, std::move(*edgelist_weights), - raft::device_span{multi_edges_to_delete.data(), - multi_edges_to_delete.size()}, - multi_edge_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_edge_ids) - edgelist_edge_ids = detail::remove_flagged_elements( + edgelist_edge_ids = detail::keep_flagged_elements( handle, std::move(*edgelist_edge_ids), - raft::device_span{multi_edges_to_delete.data(), - multi_edges_to_delete.size()}, - multi_edge_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_edge_types) - edgelist_edge_types = detail::remove_flagged_elements( + edgelist_edge_types = detail::keep_flagged_elements( handle, std::move(*edgelist_edge_types), - raft::device_span{multi_edges_to_delete.data(), - multi_edges_to_delete.size()}, - multi_edge_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); } return std::make_tuple(std::move(edgelist_srcs), diff --git a/cpp/src/structure/remove_self_loops_impl.cuh b/cpp/src/structure/remove_self_loops_impl.cuh index 161ffeae28e..dafe26cd1c5 100644 --- a/cpp/src/structure/remove_self_loops_impl.cuh +++ b/cpp/src/structure/remove_self_loops_impl.cuh @@ -44,44 +44,44 @@ remove_self_loops(raft::handle_t const& handle, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types) { - auto [self_loop_count, self_loops_to_delete] = + auto [keep_count, keep_flags] = detail::mark_entries(handle, edgelist_srcs.size(), [d_srcs = edgelist_srcs.data(), d_dsts = edgelist_dsts.data()] __device__( - size_t i) { return d_srcs[i] == d_dsts[i]; }); + size_t i) { return d_srcs[i] != d_dsts[i]; }); - if (self_loop_count > 0) { - edgelist_srcs = detail::remove_flagged_elements( + if (keep_count < edgelist_srcs.size()) { + edgelist_srcs = detail::keep_flagged_elements( handle, std::move(edgelist_srcs), - raft::device_span{self_loops_to_delete.data(), self_loops_to_delete.size()}, - self_loop_count); - edgelist_dsts = detail::remove_flagged_elements( + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); + edgelist_dsts = detail::keep_flagged_elements( handle, std::move(edgelist_dsts), - raft::device_span{self_loops_to_delete.data(), self_loops_to_delete.size()}, - self_loop_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_weights) - edgelist_weights = detail::remove_flagged_elements( + edgelist_weights = detail::keep_flagged_elements( handle, std::move(*edgelist_weights), - raft::device_span{self_loops_to_delete.data(), self_loops_to_delete.size()}, - self_loop_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_edge_ids) - edgelist_edge_ids = detail::remove_flagged_elements( + edgelist_edge_ids = detail::keep_flagged_elements( handle, std::move(*edgelist_edge_ids), - raft::device_span{self_loops_to_delete.data(), self_loops_to_delete.size()}, - self_loop_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_edge_types) - edgelist_edge_types = detail::remove_flagged_elements( + edgelist_edge_types = detail::keep_flagged_elements( handle, std::move(*edgelist_edge_types), - raft::device_span{self_loops_to_delete.data(), self_loops_to_delete.size()}, - self_loop_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); } return std::make_tuple(std::move(edgelist_srcs), diff --git a/cpp/tests/community/triangle_count_test.cpp b/cpp/tests/community/triangle_count_test.cpp index 836bab59457..592924c3c47 100644 --- a/cpp/tests/community/triangle_count_test.cpp +++ b/cpp/tests/community/triangle_count_test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -232,7 +232,7 @@ class Tests_TriangleCount for (size_t i = 0; i < h_cugraph_vertices.size(); ++i) { auto v = h_cugraph_vertices[i]; auto count = h_cugraph_triangle_counts[i]; - ASSERT_TRUE(count == h_reference_triangle_counts[v]) + ASSERT_EQ(count, h_reference_triangle_counts[v]) << "Triangle count values do not match with the reference values."; } } diff --git a/cpp/tests/utilities/test_graphs.hpp b/cpp/tests/utilities/test_graphs.hpp index 16c9d3ed145..8cc87b26f1d 100644 --- a/cpp/tests/utilities/test_graphs.hpp +++ b/cpp/tests/utilities/test_graphs.hpp @@ -621,9 +621,25 @@ construct_graph(raft::handle_t const& handle, CUGRAPH_EXPECTS(d_src_v.size() <= static_cast(std::numeric_limits::max()), "Invalid template parameter: edge_t overflow."); - if (drop_self_loops) { remove_self_loops(handle, d_src_v, d_dst_v, d_weights_v); } + if (drop_self_loops) { + std::tie(d_src_v, d_dst_v, d_weights_v, std::ignore, std::ignore) = + cugraph::remove_self_loops(handle, + std::move(d_src_v), + std::move(d_dst_v), + std::move(d_weights_v), + std::nullopt, + std::nullopt); + } - if (drop_multi_edges) { sort_and_remove_multi_edges(handle, d_src_v, d_dst_v, d_weights_v); } + if (drop_multi_edges) { + std::tie(d_src_v, d_dst_v, d_weights_v, std::ignore, std::ignore) = + cugraph::remove_multi_edges(handle, + std::move(d_src_v), + std::move(d_dst_v), + std::move(d_weights_v), + std::nullopt, + std::nullopt); + } graph_t graph(handle); std::optional< diff --git a/cpp/tests/utilities/thrust_wrapper.cu b/cpp/tests/utilities/thrust_wrapper.cu index cb7e6f1bd66..2daf250b4a2 100644 --- a/cpp/tests/utilities/thrust_wrapper.cu +++ b/cpp/tests/utilities/thrust_wrapper.cu @@ -206,131 +206,5 @@ template void populate_vertex_ids(raft::handle_t const& handle, rmm::device_uvector& d_vertices_v, int64_t vertex_id_offset); -template -void remove_self_loops(raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */) -{ - if (d_weight_v) { - auto edge_first = thrust::make_zip_iterator( - thrust::make_tuple(d_src_v.begin(), d_dst_v.begin(), (*d_weight_v).begin())); - d_src_v.resize( - thrust::distance(edge_first, - thrust::remove_if( - handle.get_thrust_policy(), - edge_first, - edge_first + d_src_v.size(), - [] __device__(auto e) { return thrust::get<0>(e) == thrust::get<1>(e); })), - handle.get_stream()); - d_dst_v.resize(d_src_v.size(), handle.get_stream()); - (*d_weight_v).resize(d_src_v.size(), handle.get_stream()); - } else { - auto edge_first = - thrust::make_zip_iterator(thrust::make_tuple(d_src_v.begin(), d_dst_v.begin())); - d_src_v.resize( - thrust::distance(edge_first, - thrust::remove_if( - handle.get_thrust_policy(), - edge_first, - edge_first + d_src_v.size(), - [] __device__(auto e) { return thrust::get<0>(e) == thrust::get<1>(e); })), - handle.get_stream()); - d_dst_v.resize(d_src_v.size(), handle.get_stream()); - } - - d_src_v.shrink_to_fit(handle.get_stream()); - d_dst_v.shrink_to_fit(handle.get_stream()); - if (d_weight_v) { (*d_weight_v).shrink_to_fit(handle.get_stream()); } -} - -template void remove_self_loops( - raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */); - -template void remove_self_loops( - raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */); - -template void remove_self_loops( - raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */); - -template void remove_self_loops( - raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */); - -template -void sort_and_remove_multi_edges( - raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */) -{ - if (d_weight_v) { - auto edge_first = thrust::make_zip_iterator( - thrust::make_tuple(d_src_v.begin(), d_dst_v.begin(), (*d_weight_v).begin())); - thrust::sort(handle.get_thrust_policy(), edge_first, edge_first + d_src_v.size()); - d_src_v.resize( - thrust::distance(edge_first, - thrust::unique(handle.get_thrust_policy(), - edge_first, - edge_first + d_src_v.size(), - [] __device__(auto lhs, auto rhs) { - return (thrust::get<0>(lhs) == thrust::get<0>(rhs)) && - (thrust::get<1>(lhs) == thrust::get<1>(rhs)); - })), - handle.get_stream()); - d_dst_v.resize(d_src_v.size(), handle.get_stream()); - (*d_weight_v).resize(d_src_v.size(), handle.get_stream()); - } else { - auto edge_first = - thrust::make_zip_iterator(thrust::make_tuple(d_src_v.begin(), d_dst_v.begin())); - thrust::sort(handle.get_thrust_policy(), edge_first, edge_first + d_src_v.size()); - d_src_v.resize( - thrust::distance( - edge_first, - thrust::unique(handle.get_thrust_policy(), edge_first, edge_first + d_src_v.size())), - handle.get_stream()); - d_dst_v.resize(d_src_v.size(), handle.get_stream()); - } - - d_src_v.shrink_to_fit(handle.get_stream()); - d_dst_v.shrink_to_fit(handle.get_stream()); - if (d_weight_v) { (*d_weight_v).shrink_to_fit(handle.get_stream()); } -} - -template void sort_and_remove_multi_edges( - raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */); - -template void sort_and_remove_multi_edges( - raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */); - -template void sort_and_remove_multi_edges( - raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */); - -template void sort_and_remove_multi_edges( - raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */); - } // namespace test } // namespace cugraph diff --git a/cpp/tests/utilities/thrust_wrapper.hpp b/cpp/tests/utilities/thrust_wrapper.hpp index eead4dc268f..fb82d781198 100644 --- a/cpp/tests/utilities/thrust_wrapper.hpp +++ b/cpp/tests/utilities/thrust_wrapper.hpp @@ -46,18 +46,5 @@ void populate_vertex_ids(raft::handle_t const& handle, rmm::device_uvector& d_vertices_v /* [INOUT] */, vertex_t vertex_id_offset); -template -void remove_self_loops(raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */); - -template -void sort_and_remove_multi_edges( - raft::handle_t const& handle, - rmm::device_uvector& d_src_v /* [INOUT] */, - rmm::device_uvector& d_dst_v /* [INOUT] */, - std::optional>& d_weight_v /* [INOUT] */); - } // namespace test } // namespace cugraph