From f6543f635a95c481046c5f3bec777139590705a2 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang <45857425+seunghwak@users.noreply.github.com> Date: Thu, 3 Aug 2023 14:56:56 -0700 Subject: [PATCH] Change the renumber_sampled_edgelist function behavior. (#3762) There was a misalignment between the `renumber_sampled_edgelist` function behavior and what PyG and DGL need. This PR fixes this. Authors: - Seunghwa Kang (https://github.com/seunghwak) Approvers: - Alex Barghi (https://github.com/alexbarghi-nv) - Chuck Hastings (https://github.com/ChuckHastings) URL: https://github.com/rapidsai/cugraph/pull/3762 --- cpp/include/cugraph/graph_functions.hpp | 25 +- cpp/src/c_api/uniform_neighbor_sampling.cpp | 2 +- .../renumber_sampled_edgelist_impl.cuh | 680 +++++++++--------- .../sampling/renumber_sampled_edgelist_sg.cu | 4 +- .../renumber_sampled_edgelist_test.cu | 245 ++++--- 5 files changed, 521 insertions(+), 435 deletions(-) diff --git a/cpp/include/cugraph/graph_functions.hpp b/cpp/include/cugraph/graph_functions.hpp index caffef60076..200ee725b7a 100644 --- a/cpp/include/cugraph/graph_functions.hpp +++ b/cpp/include/cugraph/graph_functions.hpp @@ -922,15 +922,16 @@ rmm::device_uvector select_random_vertices( * This function renumbers sampling function (e.g. uniform_neighbor_sample) outputs satisfying the * following requirements. * - * 1. Say @p edgelist_srcs has N unique vertices. These N unique vertices will be mapped to [0, N). - * 2. Among the N unique vertices, an original vertex with a smaller attached hop number will be - * renumbered to a smaller vertex ID than any other original vertices with a larger attached hop - * number (if @p edgelist_hops.has_value() is true). If a single vertex is attached to multiple hop - * numbers, the minimum hop number is used. - * 3. Say @p edgelist_dsts has M unique vertices that appear only in @p edgelist_dsts (the set of M - * unique vertices does not include any vertices that appear in @p edgelist_srcs). Then, these M - * unique vertices will be mapped to [N, N + M). - * 4. If label_offsets.has_value() is ture, edge lists for different labels will be renumbered + * 1. If @p edgelist_hops is valid, we can consider (vertex ID, flag=src, hop) triplets for each + * vertex ID in @p edgelist_srcs and (vertex ID, flag=dst, hop) triplets for each vertex ID in @p + * edgelist_dsts. From these triplets, we can find the minimum (hop, flag) pairs for every unique + * vertex ID (hop is the primary key and flag is the secondary key, flag=src is considered smaller + * than flag=dst if hop numbers are same). Vertex IDs with smaller (hop, flag) pairs precede vertex + * IDs with larger (hop, flag) pairs in renumbering. Ordering can be arbitrary among the vertices + * with the same (hop, flag) pairs. + * 2. If @p edgelist_hops is invalid, unique vertex IDs in @p edgelist_srcs precede vertex IDs that + * appear only in @p edgelist_dsts. + * 3. If label_offsets.has_value() is ture, edge lists for different labels will be renumbered * separately. * * This function is single-GPU only (we are not aware of any practical multi-GPU use cases). @@ -940,10 +941,10 @@ rmm::device_uvector select_random_vertices( * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and * handles to various CUDA libraries) to run graph algorithms. * @param edgelist_srcs A vector storing original edgelist source vertices. - * @param edgelist_hops An optional pointer to the array storing hops for each edge list source - * vertices (size = @p edgelist_srcs.size()). * @param edgelist_dsts A vector storing original edgelist destination vertices (size = @p * edgelist_srcs.size()). + * @param edgelist_hops An optional pointer to the array storing hops for each edge list (source, + * destination) pairs (size = @p edgelist_srcs.size() if valid). * @param label_offsets An optional tuple of unique labels and the input edge list (@p * edgelist_srcs, @p edgelist_hops, and @p edgelist_dsts) offsets for the labels (siez = # unique * labels + 1). @@ -962,8 +963,8 @@ std::tuple, renumber_sampled_edgelist( raft::handle_t const& handle, rmm::device_uvector&& edgelist_srcs, - std::optional> edgelist_hops, rmm::device_uvector&& edgelist_dsts, + std::optional> edgelist_hops, std::optional, raft::device_span>> label_offsets, bool do_expensive_check = false); diff --git a/cpp/src/c_api/uniform_neighbor_sampling.cpp b/cpp/src/c_api/uniform_neighbor_sampling.cpp index ff6a6c49437..caaba8e9c8d 100644 --- a/cpp/src/c_api/uniform_neighbor_sampling.cpp +++ b/cpp/src/c_api/uniform_neighbor_sampling.cpp @@ -236,9 +236,9 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct std::tie(src, dst, renumber_map, renumber_map_offsets) = cugraph::renumber_sampled_edgelist( handle_, std::move(src), + std::move(dst), hop ? std::make_optional(raft::device_span{hop->data(), hop->size()}) : std::nullopt, - std::move(dst), std::make_optional(std::make_tuple( raft::device_span{edge_label->data(), edge_label->size()}, raft::device_span{offsets->data(), offsets->size()})), diff --git a/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh b/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh index a4a6d64029a..6fdb1c887f2 100644 --- a/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh +++ b/cpp/src/sampling/renumber_sampled_edgelist_impl.cuh @@ -45,260 +45,70 @@ namespace cugraph { namespace { +// output sorted by (primary key:label_index, secondary key:vertex) template -std::tuple, std::optional>> -compute_renumber_map(raft::handle_t const& handle, - raft::device_span edgelist_srcs, - std::optional> edgelist_hops, - raft::device_span edgelist_dsts, - std::optional> label_offsets) +std::tuple> /* label indices */, + rmm::device_uvector /* vertices */, + std::optional> /* minimum hops for the vertices */, + std::optional> /* label offsets for the output */> +compute_min_hop_for_unique_label_vertex_pairs( + raft::handle_t const& handle, + raft::device_span vertices, + std::optional> hops, + std::optional> label_indices, + std::optional> label_offsets) { auto approx_edges_to_sort_per_iteration = static_cast(handle.get_device_properties().multiProcessorCount) * (1 << 20) /* tuning parameter */; // for segmented sort - std::optional> edgelist_label_indices{std::nullopt}; - if (label_offsets) { - edgelist_label_indices = - detail::expand_sparse_offsets(*label_offsets, label_index_t{0}, handle.get_stream()); - } + if (label_indices) { + auto num_labels = (*label_offsets).size() - 1; - std::optional> unique_label_src_pair_label_indices{ - std::nullopt}; - rmm::device_uvector unique_label_src_pair_vertices( - 0, handle.get_stream()); // sorted by (label, hop, src) - std::optional> sorted_srcs{ - std::nullopt}; // sorted by (label, src), relevant only when edgelist_hops is valid - { - if (label_offsets) { - rmm::device_uvector label_indices((*edgelist_label_indices).size(), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - (*edgelist_label_indices).begin(), - (*edgelist_label_indices).end(), - label_indices.begin()); + rmm::device_uvector tmp_label_indices((*label_indices).size(), + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + (*label_indices).begin(), + (*label_indices).end(), + tmp_label_indices.begin()); - if (edgelist_hops) { - rmm::device_uvector srcs(edgelist_srcs.size(), handle.get_stream()); - thrust::copy( - handle.get_thrust_policy(), edgelist_srcs.begin(), edgelist_srcs.end(), srcs.begin()); - - rmm::device_uvector hops((*edgelist_hops).size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - (*edgelist_hops).begin(), - (*edgelist_hops).end(), - hops.begin()); - auto triplet_first = - thrust::make_zip_iterator(label_indices.begin(), srcs.begin(), hops.begin()); - thrust::sort(handle.get_thrust_policy(), triplet_first, triplet_first + srcs.size()); - auto num_uniques = static_cast( - thrust::distance(triplet_first, - thrust::unique(handle.get_thrust_policy(), - triplet_first, - triplet_first + srcs.size(), - [] __device__(auto lhs, auto rhs) { - return (thrust::get<0>(lhs) == thrust::get<0>(rhs)) && - (thrust::get<1>(lhs) == thrust::get<1>(rhs)); - }))); - label_indices.resize(num_uniques, handle.get_stream()); - srcs.resize(num_uniques, handle.get_stream()); - hops.resize(num_uniques, handle.get_stream()); - label_indices.shrink_to_fit(handle.get_stream()); - srcs.shrink_to_fit(handle.get_stream()); - hops.shrink_to_fit(handle.get_stream()); - - auto num_labels = (*label_offsets).size() - 1; - rmm::device_uvector tmp_label_offsets(num_labels + 1, handle.get_stream()); - tmp_label_offsets.set_element_to_zero_async(0, handle.get_stream()); - thrust::upper_bound(handle.get_thrust_policy(), - label_indices.begin(), - label_indices.end(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(num_labels), - tmp_label_offsets.begin() + 1); - - unique_label_src_pair_label_indices = std::move(label_indices); - sorted_srcs = rmm::device_uvector(srcs.size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), srcs.begin(), srcs.end(), (*sorted_srcs).begin()); - - rmm::device_uvector segment_sorted_srcs(srcs.size(), handle.get_stream()); - - rmm::device_uvector d_tmp_storage(0, handle.get_stream()); - - auto [h_label_offsets, h_edge_offsets] = detail::compute_offset_aligned_edge_chunks( - handle, - tmp_label_offsets.data(), - static_cast(tmp_label_offsets.size() - 1), - hops.size(), - approx_edges_to_sort_per_iteration); - auto num_chunks = h_label_offsets.size() - 1; - size_t max_chunk_size{0}; - for (size_t i = 0; i < num_chunks; ++i) { - max_chunk_size = std::max(max_chunk_size, - static_cast(h_edge_offsets[i + 1] - h_edge_offsets[i])); - } - rmm::device_uvector segment_sorted_hops(max_chunk_size, handle.get_stream()); - - for (size_t i = 0; i < num_chunks; ++i) { - size_t tmp_storage_bytes{0}; - - auto offset_first = - thrust::make_transform_iterator(tmp_label_offsets.data() + h_label_offsets[i], - detail::shift_left_t{h_edge_offsets[i]}); - cub::DeviceSegmentedSort::SortPairs(static_cast(nullptr), - tmp_storage_bytes, - hops.begin() + h_edge_offsets[i], - segment_sorted_hops.begin(), - srcs.begin() + h_edge_offsets[i], - segment_sorted_srcs.begin() + h_edge_offsets[i], - h_edge_offsets[i + 1] - h_edge_offsets[i], - h_label_offsets[i + 1] - h_label_offsets[i], - offset_first, - offset_first + 1, - handle.get_stream()); - - if (tmp_storage_bytes > d_tmp_storage.size()) { - d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); - } - - cub::DeviceSegmentedSort::SortPairs(d_tmp_storage.data(), - tmp_storage_bytes, - hops.begin() + h_edge_offsets[i], - segment_sorted_hops.begin(), - srcs.begin() + h_edge_offsets[i], - segment_sorted_srcs.begin() + h_edge_offsets[i], - h_edge_offsets[i + 1] - h_edge_offsets[i], - h_label_offsets[i + 1] - h_label_offsets[i], - offset_first, - offset_first + 1, - handle.get_stream()); - } + rmm::device_uvector tmp_vertices(0, handle.get_stream()); + std::optional> tmp_hops{std::nullopt}; - unique_label_src_pair_vertices = std::move(segment_sorted_srcs); - } else { - rmm::device_uvector segment_sorted_srcs(edgelist_srcs.size(), - handle.get_stream()); - - rmm::device_uvector d_tmp_storage(0, handle.get_stream()); - - auto [h_label_offsets, h_edge_offsets] = detail::compute_offset_aligned_edge_chunks( - handle, - (*label_offsets).data(), - static_cast((*label_offsets).size() - 1), - edgelist_srcs.size(), - approx_edges_to_sort_per_iteration); - auto num_chunks = h_label_offsets.size() - 1; - - for (size_t i = 0; i < num_chunks; ++i) { - size_t tmp_storage_bytes{0}; - - auto offset_first = - thrust::make_transform_iterator((*label_offsets).data() + h_label_offsets[i], - detail::shift_left_t{h_edge_offsets[i]}); - cub::DeviceSegmentedSort::SortKeys(static_cast(nullptr), - tmp_storage_bytes, - edgelist_srcs.begin() + h_edge_offsets[i], - segment_sorted_srcs.begin() + h_edge_offsets[i], - h_edge_offsets[i + 1] - h_edge_offsets[i], - h_label_offsets[i + 1] - h_label_offsets[i], - offset_first, - offset_first + 1, - handle.get_stream()); - - if (tmp_storage_bytes > d_tmp_storage.size()) { - d_tmp_storage = rmm::device_uvector(tmp_storage_bytes, handle.get_stream()); - } - - cub::DeviceSegmentedSort::SortKeys(d_tmp_storage.data(), - tmp_storage_bytes, - edgelist_srcs.begin() + h_edge_offsets[i], - segment_sorted_srcs.begin() + h_edge_offsets[i], - h_edge_offsets[i + 1] - h_edge_offsets[i], - h_label_offsets[i + 1] - h_label_offsets[i], - offset_first, - offset_first + 1, - handle.get_stream()); - } - d_tmp_storage.resize(0, handle.get_stream()); - d_tmp_storage.shrink_to_fit(handle.get_stream()); - - auto pair_first = - thrust::make_zip_iterator(label_indices.begin(), segment_sorted_srcs.begin()); - auto num_uniques = static_cast(thrust::distance( - pair_first, - thrust::unique( - handle.get_thrust_policy(), pair_first, pair_first + label_indices.size()))); - label_indices.resize(num_uniques, handle.get_stream()); - segment_sorted_srcs.resize(num_uniques, handle.get_stream()); - label_indices.shrink_to_fit(handle.get_stream()); - segment_sorted_srcs.shrink_to_fit(handle.get_stream()); - - unique_label_src_pair_label_indices = std::move(label_indices); - unique_label_src_pair_vertices = std::move(segment_sorted_srcs); - } - } else { - rmm::device_uvector srcs(edgelist_srcs.size(), handle.get_stream()); + if (hops) { + tmp_vertices.resize(vertices.size(), handle.get_stream()); thrust::copy( - handle.get_thrust_policy(), edgelist_srcs.begin(), edgelist_srcs.end(), srcs.begin()); - - if (edgelist_hops) { - rmm::device_uvector hops((*edgelist_hops).size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - (*edgelist_hops).begin(), - (*edgelist_hops).end(), - hops.begin()); - - auto pair_first = thrust::make_zip_iterator( - srcs.begin(), hops.begin()); // src is a primary key, hop is a secondary key - thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + srcs.size()); - srcs.resize( - thrust::distance(srcs.begin(), - thrust::get<0>(thrust::unique_by_key( - handle.get_thrust_policy(), srcs.begin(), srcs.end(), hops.begin()))), - handle.get_stream()); - hops.resize(srcs.size(), handle.get_stream()); - - sorted_srcs = rmm::device_uvector(srcs.size(), handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), srcs.begin(), srcs.end(), (*sorted_srcs).begin()); - - thrust::sort_by_key(handle.get_thrust_policy(), hops.begin(), hops.end(), srcs.begin()); - } else { - thrust::sort(handle.get_thrust_policy(), srcs.begin(), srcs.end()); - srcs.resize( - thrust::distance(srcs.begin(), - thrust::unique(handle.get_thrust_policy(), srcs.begin(), srcs.end())), - handle.get_stream()); - srcs.shrink_to_fit(handle.get_stream()); - } - - unique_label_src_pair_vertices = std::move(srcs); - } - } - - std::optional> unique_label_dst_pair_label_indices{ - std::nullopt}; - rmm::device_uvector unique_label_dst_pair_vertices(0, handle.get_stream()); - { - rmm::device_uvector dsts(edgelist_dsts.size(), handle.get_stream()); - thrust::copy( - handle.get_thrust_policy(), edgelist_dsts.begin(), edgelist_dsts.end(), dsts.begin()); - if (label_offsets) { - rmm::device_uvector label_indices((*edgelist_label_indices).size(), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - (*edgelist_label_indices).begin(), - (*edgelist_label_indices).end(), - label_indices.begin()); - - rmm::device_uvector segment_sorted_dsts(dsts.size(), handle.get_stream()); + handle.get_thrust_policy(), vertices.begin(), vertices.end(), tmp_vertices.begin()); + tmp_hops = rmm::device_uvector((*hops).size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), (*hops).begin(), (*hops).end(), (*tmp_hops).begin()); + + auto triplet_first = thrust::make_zip_iterator( + tmp_label_indices.begin(), tmp_vertices.begin(), (*tmp_hops).begin()); + thrust::sort( + handle.get_thrust_policy(), triplet_first, triplet_first + tmp_label_indices.size()); + auto key_first = thrust::make_zip_iterator(tmp_label_indices.begin(), tmp_vertices.begin()); + auto num_uniques = static_cast( + thrust::distance(key_first, + thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), + key_first, + key_first + tmp_label_indices.size(), + (*tmp_hops).begin())))); + tmp_label_indices.resize(num_uniques, handle.get_stream()); + tmp_vertices.resize(num_uniques, handle.get_stream()); + (*tmp_hops).resize(num_uniques, handle.get_stream()); + tmp_label_indices.shrink_to_fit(handle.get_stream()); + tmp_vertices.shrink_to_fit(handle.get_stream()); + (*tmp_hops).shrink_to_fit(handle.get_stream()); + } else { + rmm::device_uvector segment_sorted_vertices(vertices.size(), handle.get_stream()); rmm::device_uvector d_tmp_storage(0, handle.get_stream()); auto [h_label_offsets, h_edge_offsets] = detail::compute_offset_aligned_edge_chunks(handle, (*label_offsets).data(), - static_cast((*label_offsets).size() - 1), - dsts.size(), + num_labels, + vertices.size(), approx_edges_to_sort_per_iteration); auto num_chunks = h_label_offsets.size() - 1; @@ -310,8 +120,8 @@ compute_renumber_map(raft::handle_t const& handle, detail::shift_left_t{h_edge_offsets[i]}); cub::DeviceSegmentedSort::SortKeys(static_cast(nullptr), tmp_storage_bytes, - dsts.begin() + h_edge_offsets[i], - segment_sorted_dsts.begin() + h_edge_offsets[i], + vertices.begin() + h_edge_offsets[i], + segment_sorted_vertices.begin() + h_edge_offsets[i], h_edge_offsets[i + 1] - h_edge_offsets[i], h_label_offsets[i + 1] - h_label_offsets[i], offset_first, @@ -324,121 +134,329 @@ compute_renumber_map(raft::handle_t const& handle, cub::DeviceSegmentedSort::SortKeys(d_tmp_storage.data(), tmp_storage_bytes, - dsts.begin() + h_edge_offsets[i], - segment_sorted_dsts.begin() + h_edge_offsets[i], + vertices.begin() + h_edge_offsets[i], + segment_sorted_vertices.begin() + h_edge_offsets[i], h_edge_offsets[i + 1] - h_edge_offsets[i], h_label_offsets[i + 1] - h_label_offsets[i], offset_first, offset_first + 1, handle.get_stream()); } - dsts.resize(0, handle.get_stream()); d_tmp_storage.resize(0, handle.get_stream()); - dsts.shrink_to_fit(handle.get_stream()); d_tmp_storage.shrink_to_fit(handle.get_stream()); auto pair_first = - thrust::make_zip_iterator(label_indices.begin(), segment_sorted_dsts.begin()); + thrust::make_zip_iterator(tmp_label_indices.begin(), segment_sorted_vertices.begin()); auto num_uniques = static_cast(thrust::distance( pair_first, - thrust::unique(handle.get_thrust_policy(), pair_first, pair_first + label_indices.size()))); - label_indices.resize(num_uniques, handle.get_stream()); - segment_sorted_dsts.resize(num_uniques, handle.get_stream()); - label_indices.shrink_to_fit(handle.get_stream()); - segment_sorted_dsts.shrink_to_fit(handle.get_stream()); - - unique_label_dst_pair_label_indices = std::move(label_indices); - unique_label_dst_pair_vertices = std::move(segment_sorted_dsts); + thrust::unique( + handle.get_thrust_policy(), pair_first, pair_first + tmp_label_indices.size()))); + tmp_label_indices.resize(num_uniques, handle.get_stream()); + segment_sorted_vertices.resize(num_uniques, handle.get_stream()); + tmp_label_indices.shrink_to_fit(handle.get_stream()); + segment_sorted_vertices.shrink_to_fit(handle.get_stream()); + + tmp_vertices = std::move(segment_sorted_vertices); + } + + rmm::device_uvector tmp_label_offsets(num_labels + 1, handle.get_stream()); + tmp_label_offsets.set_element_to_zero_async(0, handle.get_stream()); + thrust::upper_bound(handle.get_thrust_policy(), + tmp_label_indices.begin(), + tmp_label_indices.end(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(num_labels), + tmp_label_offsets.begin() + 1); + + return std::make_tuple(std::move(tmp_label_indices), + std::move(tmp_vertices), + std::move(tmp_hops), + std::move(tmp_label_offsets)); + } else { + rmm::device_uvector tmp_vertices(vertices.size(), handle.get_stream()); + thrust::copy( + handle.get_thrust_policy(), vertices.begin(), vertices.end(), tmp_vertices.begin()); + + if (hops) { + rmm::device_uvector tmp_hops((*hops).size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), (*hops).begin(), (*hops).end(), tmp_hops.begin()); + + auto pair_first = thrust::make_zip_iterator( + tmp_vertices.begin(), tmp_hops.begin()); // vertex is a primary key, hop is a secondary key + thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + tmp_vertices.size()); + tmp_vertices.resize( + thrust::distance(tmp_vertices.begin(), + thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), + tmp_vertices.begin(), + tmp_vertices.end(), + tmp_hops.begin()))), + handle.get_stream()); + tmp_hops.resize(tmp_vertices.size(), handle.get_stream()); + + return std::make_tuple( + std::nullopt, std::move(tmp_vertices), std::move(tmp_hops), std::nullopt); } else { - thrust::sort(handle.get_thrust_policy(), dsts.begin(), dsts.end()); - dsts.resize( - thrust::distance(dsts.begin(), - thrust::unique(handle.get_thrust_policy(), dsts.begin(), dsts.end())), + thrust::sort(handle.get_thrust_policy(), tmp_vertices.begin(), tmp_vertices.end()); + tmp_vertices.resize( + thrust::distance( + tmp_vertices.begin(), + thrust::unique(handle.get_thrust_policy(), tmp_vertices.begin(), tmp_vertices.end())), handle.get_stream()); - dsts.shrink_to_fit(handle.get_stream()); + tmp_vertices.shrink_to_fit(handle.get_stream()); - unique_label_dst_pair_vertices = std::move(dsts); + return std::make_tuple(std::nullopt, std::move(tmp_vertices), std::nullopt, std::nullopt); } } +} + +template +std::tuple, std::optional>> +compute_renumber_map(raft::handle_t const& handle, + raft::device_span edgelist_srcs, + raft::device_span edgelist_dsts, + std::optional> edgelist_hops, + std::optional> label_offsets) +{ + auto approx_edges_to_sort_per_iteration = + static_cast(handle.get_device_properties().multiProcessorCount) * + (1 << 20) /* tuning parameter */; // for segmented sort + + std::optional> edgelist_label_indices{std::nullopt}; + if (label_offsets) { + edgelist_label_indices = + detail::expand_sparse_offsets(*label_offsets, label_index_t{0}, handle.get_stream()); + } + + auto [unique_label_src_pair_label_indices, + unique_label_src_pair_vertices, + unique_label_src_pair_hops, + unique_label_src_pair_label_offsets] = + compute_min_hop_for_unique_label_vertex_pairs( + handle, + edgelist_srcs, + edgelist_hops, + edgelist_label_indices ? std::make_optional>( + (*edgelist_label_indices).data(), (*edgelist_label_indices).size()) + : std::nullopt, + label_offsets); + + auto [unique_label_dst_pair_label_indices, + unique_label_dst_pair_vertices, + unique_label_dst_pair_hops, + unique_label_dst_pair_label_offsets] = + compute_min_hop_for_unique_label_vertex_pairs( + handle, + edgelist_dsts, + edgelist_hops, + edgelist_label_indices ? std::make_optional>( + (*edgelist_label_indices).data(), (*edgelist_label_indices).size()) + : std::nullopt, + label_offsets); edgelist_label_indices = std::nullopt; if (label_offsets) { - auto label_src_pair_first = thrust::make_zip_iterator( - (*unique_label_src_pair_label_indices).begin(), - edgelist_hops ? (*sorted_srcs).begin() : unique_label_src_pair_vertices.begin()); - auto label_dst_pair_first = thrust::make_zip_iterator( - (*unique_label_dst_pair_label_indices).begin(), unique_label_dst_pair_vertices.begin()); - rmm::device_uvector output_label_indices( - (*unique_label_dst_pair_label_indices).size(), handle.get_stream()); - rmm::device_uvector output_vertices((*unique_label_dst_pair_label_indices).size(), - handle.get_stream()); - auto output_label_dst_pair_first = - thrust::make_zip_iterator(output_label_indices.begin(), output_vertices.begin()); - auto output_label_dst_pair_last = - thrust::set_difference(handle.get_thrust_policy(), - label_dst_pair_first, - label_dst_pair_first + (*unique_label_dst_pair_label_indices).size(), - label_src_pair_first, - label_src_pair_first + (*unique_label_src_pair_label_indices).size(), - output_label_dst_pair_first); - - sorted_srcs = std::nullopt; - output_label_indices.resize( - thrust::distance(output_label_dst_pair_first, output_label_dst_pair_last), - handle.get_stream()); - output_vertices.resize(output_label_indices.size(), handle.get_stream()); - output_label_indices.shrink_to_fit(handle.get_stream()); - output_vertices.shrink_to_fit(handle.get_stream()); - unique_label_dst_pair_label_indices = std::move(output_label_indices); - unique_label_dst_pair_vertices = std::move(output_vertices); + auto num_labels = (*label_offsets).size() - 1; - rmm::device_uvector merged_label_indices( + rmm::device_uvector renumber_map(0, handle.get_stream()); + rmm::device_uvector renumber_map_label_indices(0, handle.get_stream()); + + renumber_map.reserve( (*unique_label_src_pair_label_indices).size() + (*unique_label_dst_pair_label_indices).size(), handle.get_stream()); - rmm::device_uvector merged_vertices(merged_label_indices.size(), handle.get_stream()); - auto label_src_triplet_first = - thrust::make_zip_iterator((*unique_label_src_pair_label_indices).begin(), - thrust::make_constant_iterator(uint8_t{0}), - unique_label_src_pair_vertices.begin()); - auto label_dst_triplet_first = - thrust::make_zip_iterator((*unique_label_dst_pair_label_indices).begin(), - thrust::make_constant_iterator(uint8_t{1}), - unique_label_dst_pair_vertices.begin()); - thrust::merge( - handle.get_thrust_policy(), - label_src_triplet_first, - label_src_triplet_first + (*unique_label_src_pair_label_indices).size(), - label_dst_triplet_first, - label_dst_triplet_first + (*unique_label_dst_pair_label_indices).size(), - thrust::make_zip_iterator( - merged_label_indices.begin(), thrust::make_discard_iterator(), merged_vertices.begin())); - - return std::make_tuple(std::move(merged_vertices), std::move(merged_label_indices)); + renumber_map_label_indices.reserve(renumber_map.capacity(), handle.get_stream()); + + auto num_chunks = (edgelist_srcs.size() + (approx_edges_to_sort_per_iteration - 1)) / + approx_edges_to_sort_per_iteration; + auto chunk_size = (num_chunks > 0) ? ((num_labels + (num_chunks - 1)) / num_chunks) : 0; + + size_t copy_offset{0}; + for (size_t i = 0; i < num_chunks; ++i) { + auto src_start_offset = + (*unique_label_src_pair_label_offsets).element(chunk_size * i, handle.get_stream()); + auto src_end_offset = + (*unique_label_src_pair_label_offsets) + .element(std::min(chunk_size * (i + 1), num_labels), handle.get_stream()); + auto dst_start_offset = + (*unique_label_dst_pair_label_offsets).element(chunk_size * i, handle.get_stream()); + auto dst_end_offset = + (*unique_label_dst_pair_label_offsets) + .element(std::min(chunk_size * (i + 1), num_labels), handle.get_stream()); + + rmm::device_uvector merged_label_indices( + (src_end_offset - src_start_offset) + (dst_end_offset - dst_start_offset), + handle.get_stream()); + rmm::device_uvector merged_vertices(merged_label_indices.size(), + handle.get_stream()); + rmm::device_uvector merged_flags(merged_label_indices.size(), handle.get_stream()); + + if (edgelist_hops) { + rmm::device_uvector merged_hops(merged_label_indices.size(), handle.get_stream()); + auto src_quad_first = + thrust::make_zip_iterator((*unique_label_src_pair_label_indices).begin(), + unique_label_src_pair_vertices.begin(), + (*unique_label_src_pair_hops).begin(), + thrust::make_constant_iterator(int8_t{0})); + auto dst_quad_first = + thrust::make_zip_iterator((*unique_label_dst_pair_label_indices).begin(), + unique_label_dst_pair_vertices.begin(), + (*unique_label_dst_pair_hops).begin(), + thrust::make_constant_iterator(int8_t{1})); + thrust::merge(handle.get_thrust_policy(), + src_quad_first + src_start_offset, + src_quad_first + src_end_offset, + dst_quad_first + dst_start_offset, + dst_quad_first + dst_end_offset, + thrust::make_zip_iterator(merged_label_indices.begin(), + merged_vertices.begin(), + merged_hops.begin(), + merged_flags.begin())); + + auto unique_key_first = + thrust::make_zip_iterator(merged_label_indices.begin(), merged_vertices.begin()); + merged_label_indices.resize( + thrust::distance( + unique_key_first, + thrust::get<0>(thrust::unique_by_key( + handle.get_thrust_policy(), + unique_key_first, + unique_key_first + merged_label_indices.size(), + thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin())))), + handle.get_stream()); + merged_vertices.resize(merged_label_indices.size(), handle.get_stream()); + merged_hops.resize(merged_label_indices.size(), handle.get_stream()); + merged_flags.resize(merged_label_indices.size(), handle.get_stream()); + auto sort_key_first = thrust::make_zip_iterator( + merged_label_indices.begin(), merged_hops.begin(), merged_flags.begin()); + thrust::sort_by_key(handle.get_thrust_policy(), + sort_key_first, + sort_key_first + merged_label_indices.size(), + merged_vertices.begin()); + } else { + auto src_triplet_first = + thrust::make_zip_iterator((*unique_label_src_pair_label_indices).begin(), + unique_label_src_pair_vertices.begin(), + thrust::make_constant_iterator(int8_t{0})); + auto dst_triplet_first = + thrust::make_zip_iterator((*unique_label_dst_pair_label_indices).begin(), + unique_label_dst_pair_vertices.begin(), + thrust::make_constant_iterator(int8_t{1})); + thrust::merge( + handle.get_thrust_policy(), + src_triplet_first + src_start_offset, + src_triplet_first + src_end_offset, + dst_triplet_first + dst_start_offset, + dst_triplet_first + dst_end_offset, + thrust::make_zip_iterator( + merged_label_indices.begin(), merged_vertices.begin(), merged_flags.begin())); + + auto unique_key_first = + thrust::make_zip_iterator(merged_label_indices.begin(), merged_vertices.begin()); + merged_label_indices.resize( + thrust::distance( + unique_key_first, + thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), + unique_key_first, + unique_key_first + merged_label_indices.size(), + merged_flags.begin()))), + handle.get_stream()); + merged_vertices.resize(merged_label_indices.size(), handle.get_stream()); + merged_flags.resize(merged_label_indices.size(), handle.get_stream()); + auto sort_key_first = + thrust::make_zip_iterator(merged_label_indices.begin(), merged_flags.begin()); + thrust::sort_by_key(handle.get_thrust_policy(), + sort_key_first, + sort_key_first + merged_label_indices.size(), + merged_vertices.begin()); + } + + renumber_map.resize(copy_offset + merged_vertices.size(), handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + merged_vertices.begin(), + merged_vertices.end(), + renumber_map.begin() + copy_offset); + renumber_map_label_indices.resize(copy_offset + merged_label_indices.size(), + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + merged_label_indices.begin(), + merged_label_indices.end(), + renumber_map_label_indices.begin() + copy_offset); + + copy_offset += merged_vertices.size(); + } + + renumber_map.shrink_to_fit(handle.get_stream()); + renumber_map_label_indices.shrink_to_fit(handle.get_stream()); + + return std::make_tuple(std::move(renumber_map), std::move(renumber_map_label_indices)); } else { - rmm::device_uvector output_vertices(unique_label_dst_pair_vertices.size(), - handle.get_stream()); - auto output_last = thrust::set_difference( - handle.get_thrust_policy(), - unique_label_dst_pair_vertices.begin(), - unique_label_dst_pair_vertices.end(), - edgelist_hops ? (*sorted_srcs).begin() : unique_label_src_pair_vertices.begin(), - edgelist_hops ? (*sorted_srcs).end() : unique_label_src_pair_vertices.end(), - output_vertices.begin()); - - sorted_srcs = std::nullopt; - - auto num_unique_srcs = unique_label_src_pair_vertices.size(); - auto renumber_map = std::move(unique_label_src_pair_vertices); - renumber_map.resize( - renumber_map.size() + thrust::distance(output_vertices.begin(), output_last), - handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - output_vertices.begin(), - output_last, - renumber_map.begin() + num_unique_srcs); + if (edgelist_hops) { + rmm::device_uvector merged_vertices( + unique_label_src_pair_vertices.size() + unique_label_dst_pair_vertices.size(), + handle.get_stream()); + rmm::device_uvector merged_hops(merged_vertices.size(), handle.get_stream()); + rmm::device_uvector merged_flags(merged_vertices.size(), handle.get_stream()); + auto src_triplet_first = thrust::make_zip_iterator(unique_label_src_pair_vertices.begin(), + (*unique_label_src_pair_hops).begin(), + thrust::make_constant_iterator(int8_t{0})); + auto dst_triplet_first = thrust::make_zip_iterator(unique_label_dst_pair_vertices.begin(), + (*unique_label_dst_pair_hops).begin(), + thrust::make_constant_iterator(int8_t{1})); + thrust::merge(handle.get_thrust_policy(), + src_triplet_first, + src_triplet_first + unique_label_src_pair_vertices.size(), + dst_triplet_first, + dst_triplet_first + unique_label_dst_pair_vertices.size(), + thrust::make_zip_iterator( + merged_vertices.begin(), merged_hops.begin(), merged_flags.begin())); + + unique_label_src_pair_vertices.resize(0, handle.get_stream()); + unique_label_src_pair_vertices.shrink_to_fit(handle.get_stream()); + unique_label_src_pair_hops = std::nullopt; + unique_label_dst_pair_vertices.resize(0, handle.get_stream()); + unique_label_dst_pair_vertices.shrink_to_fit(handle.get_stream()); + unique_label_dst_pair_hops = std::nullopt; + + merged_vertices.resize( + thrust::distance(merged_vertices.begin(), + thrust::get<0>(thrust::unique_by_key( + handle.get_thrust_policy(), + merged_vertices.begin(), + merged_vertices.end(), + thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin())))), + handle.get_stream()); + merged_hops.resize(merged_vertices.size(), handle.get_stream()); + merged_flags.resize(merged_vertices.size(), handle.get_stream()); + + auto sort_key_first = thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin()); + thrust::sort_by_key(handle.get_thrust_policy(), + sort_key_first, + sort_key_first + merged_hops.size(), + merged_vertices.begin()); + + return std::make_tuple(std::move(merged_vertices), std::nullopt); + } else { + rmm::device_uvector output_vertices(unique_label_dst_pair_vertices.size(), + handle.get_stream()); + auto output_last = thrust::set_difference(handle.get_thrust_policy(), + unique_label_dst_pair_vertices.begin(), + unique_label_dst_pair_vertices.end(), + unique_label_src_pair_vertices.begin(), + unique_label_src_pair_vertices.end(), + output_vertices.begin()); + + auto num_unique_srcs = unique_label_src_pair_vertices.size(); + auto renumber_map = std::move(unique_label_src_pair_vertices); + renumber_map.resize( + renumber_map.size() + thrust::distance(output_vertices.begin(), output_last), + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + output_vertices.begin(), + output_last, + renumber_map.begin() + num_unique_srcs); - return std::make_tuple(std::move(renumber_map), std::nullopt); + return std::make_tuple(std::move(renumber_map), std::nullopt); + } } } @@ -452,8 +470,8 @@ std::tuple, renumber_sampled_edgelist( raft::handle_t const& handle, rmm::device_uvector&& edgelist_srcs, - std::optional> edgelist_hops, rmm::device_uvector&& edgelist_dsts, + std::optional> edgelist_hops, std::optional, raft::device_span>> label_offsets, bool do_expensive_check) @@ -504,8 +522,8 @@ renumber_sampled_edgelist( auto [renumber_map, renumber_map_label_indices] = compute_renumber_map( handle, raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), - edgelist_hops, raft::device_span(edgelist_dsts.data(), edgelist_dsts.size()), + edgelist_hops, label_offsets ? std::make_optional>(std::get<1>(*label_offsets)) : std::nullopt); diff --git a/cpp/src/sampling/renumber_sampled_edgelist_sg.cu b/cpp/src/sampling/renumber_sampled_edgelist_sg.cu index 629fa45e1f9..46e2264a0c1 100644 --- a/cpp/src/sampling/renumber_sampled_edgelist_sg.cu +++ b/cpp/src/sampling/renumber_sampled_edgelist_sg.cu @@ -27,8 +27,8 @@ template std::tuple, renumber_sampled_edgelist( raft::handle_t const& handle, rmm::device_uvector&& edgelist_srcs, - std::optional> edgelist_hops, rmm::device_uvector&& edgelist_dsts, + std::optional> edgelist_hops, std::optional, raft::device_span>> label_offsets, bool do_expensive_check); @@ -40,8 +40,8 @@ template std::tuple, renumber_sampled_edgelist( raft::handle_t const& handle, rmm::device_uvector&& edgelist_srcs, - std::optional> edgelist_hops, rmm::device_uvector&& edgelist_dsts, + std::optional> edgelist_hops, std::optional, raft::device_span>> label_offsets, bool do_expensive_check); diff --git a/cpp/tests/sampling/renumber_sampled_edgelist_test.cu b/cpp/tests/sampling/renumber_sampled_edgelist_test.cu index 6d944314605..96c8d6173e7 100644 --- a/cpp/tests/sampling/renumber_sampled_edgelist_test.cu +++ b/cpp/tests/sampling/renumber_sampled_edgelist_test.cu @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -25,9 +26,12 @@ #include +#include #include #include +#include #include +#include #include #include @@ -147,10 +151,10 @@ class Tests_RenumberSampledEdgelist cugraph::renumber_sampled_edgelist( handle, std::move(renumbered_edgelist_srcs), + std::move(renumbered_edgelist_dsts), edgelist_hops ? std::make_optional>( (*edgelist_hops).data(), (*edgelist_hops).size()) : std::nullopt, - std::move(renumbered_edgelist_dsts), label_offsets ? std::make_optional< std::tuple, raft::device_span>>( @@ -173,6 +177,8 @@ class Tests_RenumberSampledEdgelist size_t edgelist_end_offset = label_offsets ? std::get<1>(*label_offsets).element(i + 1, handle.get_stream()) : usecase.num_sampled_edges; + if (edgelist_start_offset == edgelist_end_offset) continue; + auto this_label_org_edgelist_srcs = raft::device_span(org_edgelist_srcs.data() + edgelist_start_offset, edgelist_end_offset - edgelist_start_offset); @@ -229,11 +235,11 @@ class Tests_RenumberSampledEdgelist }); ASSERT_TRUE(num_renumber_errors == 0) << "Renumber error in edge list destinations."; - // check the invariants in renumber_map (1. vertices appeared in edge list sources should - // have a smaller renumbered vertex ID than the vertices appear only in edge list - // destinations, 2. edge list source vertices with a smaller minimum hop number should have - // a smaller renumbered vertex ID than the edge list source vertices with a larger hop - // number) + // Check the invariants in renumber_map + // Say we found the minimum (primary key:hop, secondary key:flag) pairs for every unique + // vertices, where flag is 0 for sources and 1 for destinations. Then, vertices with smaller + // (hop, flag) pairs should be renumbered to smaller numbers than vertices with larger (hop, + // flag) pairs. rmm::device_uvector unique_srcs(this_label_org_edgelist_srcs.size(), handle.get_stream()); @@ -277,27 +283,35 @@ class Tests_RenumberSampledEdgelist this_label_org_edgelist_dsts.begin(), this_label_org_edgelist_dsts.end(), unique_dsts.begin()); - thrust::sort(handle.get_thrust_policy(), unique_dsts.begin(), unique_dsts.end()); - unique_dsts.resize( - thrust::distance( - unique_dsts.begin(), - thrust::unique(handle.get_thrust_policy(), unique_dsts.begin(), unique_dsts.end())), - handle.get_stream()); + std::optional> unique_dst_hops = + this_label_edgelist_hops ? std::make_optional>( + (*this_label_edgelist_hops).size(), handle.get_stream()) + : std::nullopt; + if (this_label_edgelist_hops) { + thrust::copy(handle.get_thrust_policy(), + (*this_label_edgelist_hops).begin(), + (*this_label_edgelist_hops).end(), + (*unique_dst_hops).begin()); - unique_dsts.resize( - thrust::distance( - unique_dsts.begin(), - thrust::remove_if(handle.get_thrust_policy(), - unique_dsts.begin(), - unique_dsts.end(), - [sorted_unique_srcs = raft::device_span( - unique_srcs.data(), unique_srcs.size())] __device__(auto dst) { - return thrust::binary_search(thrust::seq, - sorted_unique_srcs.begin(), - sorted_unique_srcs.end(), - dst); - })), - handle.get_stream()); + auto pair_first = + thrust::make_zip_iterator(unique_dsts.begin(), (*unique_dst_hops).begin()); + thrust::sort(handle.get_thrust_policy(), pair_first, pair_first + unique_dsts.size()); + unique_dsts.resize( + thrust::distance(unique_dsts.begin(), + thrust::get<0>(thrust::unique_by_key(handle.get_thrust_policy(), + unique_dsts.begin(), + unique_dsts.end(), + (*unique_dst_hops).begin()))), + handle.get_stream()); + (*unique_dst_hops).resize(unique_dsts.size(), handle.get_stream()); + } else { + thrust::sort(handle.get_thrust_policy(), unique_dsts.begin(), unique_dsts.end()); + unique_dsts.resize( + thrust::distance( + unique_dsts.begin(), + thrust::unique(handle.get_thrust_policy(), unique_dsts.begin(), unique_dsts.end())), + handle.get_stream()); + } rmm::device_uvector sorted_org_vertices(this_label_renumber_map.size(), handle.get_stream()); @@ -316,51 +330,56 @@ class Tests_RenumberSampledEdgelist sorted_org_vertices.end(), matching_renumbered_vertices.begin()); - auto max_src_renumbered_vertex = thrust::transform_reduce( - handle.get_thrust_policy(), - unique_srcs.begin(), - unique_srcs.end(), - [sorted_org_vertices = raft::device_span(sorted_org_vertices.data(), - sorted_org_vertices.size()), - matching_renumbered_vertices = raft::device_span( - matching_renumbered_vertices.data(), - matching_renumbered_vertices.size())] __device__(vertex_t src) { - auto it = thrust::lower_bound( - thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), src); - return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), it)]; - }, - std::numeric_limits::lowest(), - thrust::maximum{}); - - auto min_dst_renumbered_vertex = thrust::transform_reduce( - handle.get_thrust_policy(), - unique_dsts.begin(), - unique_dsts.end(), - [sorted_org_vertices = raft::device_span(sorted_org_vertices.data(), - sorted_org_vertices.size()), - matching_renumbered_vertices = raft::device_span( - matching_renumbered_vertices.data(), - matching_renumbered_vertices.size())] __device__(vertex_t dst) { - auto it = thrust::lower_bound( - thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), dst); - return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), it)]; - }, - std::numeric_limits::max(), - thrust::minimum{}); - - ASSERT_TRUE(max_src_renumbered_vertex < min_dst_renumbered_vertex) - << "Invariants violated, a source vertex is renumbered to a non-smaller value than a " - "vertex that appear only in the edge list destinations."; - if (this_label_edgelist_hops) { + rmm::device_uvector merged_vertices(unique_srcs.size() + unique_dsts.size(), + handle.get_stream()); + rmm::device_uvector merged_hops(merged_vertices.size(), handle.get_stream()); + rmm::device_uvector merged_flags(merged_vertices.size(), handle.get_stream()); + + auto src_triplet_first = + thrust::make_zip_iterator(unique_srcs.begin(), + (*unique_src_hops).begin(), + thrust::make_constant_iterator(int8_t{0})); + auto dst_triplet_first = + thrust::make_zip_iterator(unique_dsts.begin(), + (*unique_dst_hops).begin(), + thrust::make_constant_iterator(int8_t{1})); + thrust::merge(handle.get_thrust_policy(), + src_triplet_first, + src_triplet_first + unique_srcs.size(), + dst_triplet_first, + dst_triplet_first + unique_dsts.size(), + thrust::make_zip_iterator( + merged_vertices.begin(), merged_hops.begin(), merged_flags.begin())); + merged_vertices.resize( + thrust::distance( + merged_vertices.begin(), + thrust::get<0>(thrust::unique_by_key( + handle.get_thrust_policy(), + merged_vertices.begin(), + merged_vertices.end(), + thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin())))), + handle.get_stream()); + merged_hops.resize(merged_vertices.size(), handle.get_stream()); + merged_flags.resize(merged_vertices.size(), handle.get_stream()); + + auto sort_key_first = + thrust::make_zip_iterator(merged_hops.begin(), merged_flags.begin()); thrust::sort_by_key(handle.get_thrust_policy(), - (*unique_src_hops).begin(), - (*unique_src_hops).end(), - unique_srcs.begin()); - rmm::device_uvector min_vertices(usecase.num_hops, handle.get_stream()); - rmm::device_uvector max_vertices(usecase.num_hops, handle.get_stream()); - auto unique_renumbered_src_first = thrust::make_transform_iterator( - unique_srcs.begin(), + sort_key_first, + sort_key_first + merged_hops.size(), + merged_vertices.begin()); + + auto num_unique_keys = thrust::count_if( + handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(merged_hops.size()), + cugraph::detail::is_first_in_run_t{sort_key_first}); + rmm::device_uvector min_vertices(num_unique_keys, handle.get_stream()); + rmm::device_uvector max_vertices(num_unique_keys, handle.get_stream()); + + auto renumbered_merged_vertex_first = thrust::make_transform_iterator( + merged_vertices.begin(), [sorted_org_vertices = raft::device_span(sorted_org_vertices.data(), sorted_org_vertices.size()), matching_renumbered_vertices = raft::device_span( @@ -372,32 +391,27 @@ class Tests_RenumberSampledEdgelist it)]; }); - auto this_label_num_unique_hops = static_cast( - thrust::distance(min_vertices.begin(), - thrust::get<1>(thrust::reduce_by_key(handle.get_thrust_policy(), - (*unique_src_hops).begin(), - (*unique_src_hops).end(), - unique_renumbered_src_first, - thrust::make_discard_iterator(), - min_vertices.begin(), - thrust::equal_to{}, - thrust::minimum{})))); - min_vertices.resize(this_label_num_unique_hops, handle.get_stream()); - thrust::reduce_by_key(handle.get_thrust_policy(), - (*unique_src_hops).begin(), - (*unique_src_hops).end(), - unique_renumbered_src_first, + sort_key_first, + sort_key_first + merged_hops.size(), + renumbered_merged_vertex_first, + thrust::make_discard_iterator(), + min_vertices.begin(), + thrust::equal_to>{}, + thrust::minimum{}); + thrust::reduce_by_key(handle.get_thrust_policy(), + sort_key_first, + sort_key_first + merged_hops.size(), + renumbered_merged_vertex_first, thrust::make_discard_iterator(), max_vertices.begin(), - thrust::equal_to{}, + thrust::equal_to>{}, thrust::maximum{}); - max_vertices.resize(this_label_num_unique_hops, handle.get_stream()); auto num_violations = thrust::count_if(handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{1}), - thrust::make_counting_iterator(this_label_num_unique_hops), + thrust::make_counting_iterator(min_vertices.size()), [min_vertices = raft::device_span(min_vertices.data(), min_vertices.size()), max_vertices = raft::device_span( @@ -406,8 +420,61 @@ class Tests_RenumberSampledEdgelist }); ASSERT_TRUE(num_violations == 0) - << "Invariant violated, a vertex with a smaller hop is renumbered to a non-smaller " - "value than a vertex with a larger hop."; + << "Invariant violated, a vertex with a smaller (hop,flag) pair is renumbered to a " + "larger value than a vertex with a larger (hop, flag) pair."; + } else { + unique_dsts.resize( + thrust::distance( + unique_dsts.begin(), + thrust::remove_if(handle.get_thrust_policy(), + unique_dsts.begin(), + unique_dsts.end(), + [sorted_unique_srcs = raft::device_span( + unique_srcs.data(), unique_srcs.size())] __device__(auto dst) { + return thrust::binary_search(thrust::seq, + sorted_unique_srcs.begin(), + sorted_unique_srcs.end(), + dst); + })), + handle.get_stream()); + + auto max_src_renumbered_vertex = thrust::transform_reduce( + handle.get_thrust_policy(), + unique_srcs.begin(), + unique_srcs.end(), + [sorted_org_vertices = raft::device_span(sorted_org_vertices.data(), + sorted_org_vertices.size()), + matching_renumbered_vertices = raft::device_span( + matching_renumbered_vertices.data(), + matching_renumbered_vertices.size())] __device__(vertex_t src) { + auto it = thrust::lower_bound( + thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), src); + return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), + it)]; + }, + std::numeric_limits::lowest(), + thrust::maximum{}); + + auto min_dst_renumbered_vertex = thrust::transform_reduce( + handle.get_thrust_policy(), + unique_dsts.begin(), + unique_dsts.end(), + [sorted_org_vertices = raft::device_span(sorted_org_vertices.data(), + sorted_org_vertices.size()), + matching_renumbered_vertices = raft::device_span( + matching_renumbered_vertices.data(), + matching_renumbered_vertices.size())] __device__(vertex_t dst) { + auto it = thrust::lower_bound( + thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), dst); + return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), + it)]; + }, + std::numeric_limits::max(), + thrust::minimum{}); + + ASSERT_TRUE(max_src_renumbered_vertex < min_dst_renumbered_vertex) + << "Invariants violated, a source vertex is renumbered to a non-smaller value than a " + "vertex that appear only in the edge list destinations."; } } }