diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 697e637ae57..de602fa01bd 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -461,6 +461,10 @@ if(BUILD_CUGRAPH_MG_TESTS) ########################################################################################### # - MG PRIMS TRANSFORM_REDUCE_E tests ----------------------------------------------------- ConfigureTestMG(MG_TRANSFORM_REDUCE_E_TEST prims/mg_transform_reduce_e.cu) + + ########################################################################################### + # - MG PRIMS COUNT_IF_E tests ------------------------------------------------------------- + ConfigureTestMG(MG_COUNT_IF_E_TEST prims/mg_count_if_e.cu) else() message(FATAL_ERROR "OpenMPI NOT found, cannot build MG tests.") endif() diff --git a/cpp/tests/prims/mg_count_if_e.cu b/cpp/tests/prims/mg_count_if_e.cu new file mode 100644 index 00000000000..d44d0ae95da --- /dev/null +++ b/cpp/tests/prims/mg_count_if_e.cu @@ -0,0 +1,373 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +template +struct property_type { + using type = std::conditional_t<(sizeof...(Args) > 1), + thrust::tuple, + typename thrust::tuple_element<0, thrust::tuple>::type>; +}; + +template +struct property_transform + : public thrust::unary_function::type> { + int mod{}; + property_transform(int mod_count) : mod(mod_count) {} + + template ::type> + constexpr __device__ + typename std::enable_if_t::value, type> + operator()(const vertex_t& val) + { + cuco::detail::MurmurHash3_32 hash_func{}; + auto value = hash_func(val) % mod; + return thrust::make_tuple(static_cast(value)...); + } + + template ::type> + constexpr __device__ typename std::enable_if_t::value, type> operator()( + const vertex_t& val) + { + cuco::detail::MurmurHash3_32 hash_func{}; + auto value = hash_func(val) % mod; + return static_cast(value); + } +}; + +template typename Tuple, typename... Args> +struct property_transform> : public property_transform { +}; + +template +struct generate_impl { + private: + using type = typename property_type::type; + using property_buffer_type = std::conditional_t< + (sizeof...(Args) > 1), + std::tuple...>, + rmm::device_uvector>::type>>; + + public: + static thrust::tuple initial_value(int init) + { + return thrust::make_tuple(static_cast(init)...); + } + template + static auto vertex_property(rmm::device_uvector& labels, + int hash_bin_count, + raft::handle_t const& handle) + { + auto data = cugraph::allocate_dataframe_buffer(labels.size(), handle.get_stream()); + auto zip = cugraph::get_dataframe_buffer_begin(data); + thrust::transform(handle.get_thrust_policy(), + labels.begin(), + labels.end(), + zip, + property_transform(hash_bin_count)); + return data; + } + template + static auto vertex_property(thrust::counting_iterator begin, + thrust::counting_iterator end, + int hash_bin_count, + raft::handle_t const& handle) + { + auto length = thrust::distance(begin, end); + auto data = cugraph::allocate_dataframe_buffer(length, handle.get_stream()); + auto zip = cugraph::get_dataframe_buffer_begin(data); + thrust::transform(handle.get_thrust_policy(), + begin, + end, + zip, + property_transform(hash_bin_count)); + return data; + } + + template + static auto column_property(raft::handle_t const& handle, + graph_view_type const& graph_view, + property_buffer_type& property) + { + auto output_property = cugraph::col_properties_t(handle, graph_view); + copy_to_adj_matrix_col( + handle, graph_view, cugraph::get_dataframe_buffer_begin(property), output_property); + return output_property; + } + + template + static auto row_property(raft::handle_t const& handle, + graph_view_type const& graph_view, + property_buffer_type& property) + { + auto output_property = cugraph::row_properties_t(handle, graph_view); + copy_to_adj_matrix_row( + handle, graph_view, cugraph::get_dataframe_buffer_begin(property), output_property); + return output_property; + } +}; + +template +struct generate : public generate_impl { + static T initial_value(int init) { return static_cast(init); } +}; +template +struct generate> : public generate_impl { +}; + +struct Prims_Usecase { + bool check_correctness{true}; + bool test_weighted{false}; +}; + +template +class Tests_MG_TransformCountIfE + : public ::testing::TestWithParam> { + public: + Tests_MG_TransformCountIfE() {} + static void SetupTestCase() {} + static void TearDownTestCase() {} + + virtual void SetUp() {} + virtual void TearDown() {} + + // Verify the results of count_if_e primitive + template + void run_current_test(Prims_Usecase const& prims_usecase, input_usecase_t const& input_usecase) + { + // 1. initialize handle + + raft::handle_t handle{}; + HighResClock hr_clock{}; + + raft::comms::initialize_mpi_comms(&handle, MPI_COMM_WORLD); + auto& comm = handle.get_comms(); + auto const comm_size = comm.get_size(); + auto const comm_rank = comm.get_rank(); + + auto row_comm_size = static_cast(sqrt(static_cast(comm_size))); + while (comm_size % row_comm_size != 0) { + --row_comm_size; + } + cugraph::partition_2d::subcomm_factory_t + subcomm_factory(handle, row_comm_size); + + // 2. create MG graph + + if (cugraph::test::g_perf) { + CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle.get_comms().barrier(); + hr_clock.start(); + } + auto [mg_graph, d_mg_renumber_map_labels] = + cugraph::test::construct_graph( + handle, input_usecase, prims_usecase.test_weighted, true); + + if (cugraph::test::g_perf) { + CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle.get_comms().barrier(); + double elapsed_time{0.0}; + hr_clock.stop(&elapsed_time); + std::cout << "MG construct_graph took " << elapsed_time * 1e-6 << " s.\n"; + } + + auto mg_graph_view = mg_graph.view(); + + // 3. run MG count_if_e + + const int hash_bin_count = 5; + const int initial_value = 0; + + auto property_initial_value = generate::initial_value(initial_value); + using property_t = decltype(property_initial_value); + auto vertex_property_data = + generate::vertex_property((*d_mg_renumber_map_labels), hash_bin_count, handle); + auto col_prop = + generate::column_property(handle, mg_graph_view, vertex_property_data); + auto row_prop = generate::row_property(handle, mg_graph_view, vertex_property_data); + + if (cugraph::test::g_perf) { + CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle.get_comms().barrier(); + hr_clock.start(); + } + + auto result = count_if_e( + handle, + mg_graph_view, + row_prop.device_view(), + col_prop.device_view(), + [] __device__(auto row, auto col, weight_t wt, auto row_property, auto col_property) { + return row_property < col_property; + }); + if (cugraph::test::g_perf) { + CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle.get_comms().barrier(); + double elapsed_time{0.0}; + hr_clock.stop(&elapsed_time); + std::cout << "MG count if e took " << elapsed_time * 1e-6 << " s.\n"; + } + + //// 4. compare SG & MG results + + if (prims_usecase.check_correctness) { + cugraph::graph_t sg_graph(handle); + std::tie(sg_graph, std::ignore) = + cugraph::test::construct_graph( + handle, input_usecase, prims_usecase.test_weighted, false); + auto sg_graph_view = sg_graph.view(); + + auto sg_vertex_property_data = generate::vertex_property( + thrust::make_counting_iterator(sg_graph_view.get_local_vertex_first()), + thrust::make_counting_iterator(sg_graph_view.get_local_vertex_last()), + hash_bin_count, + handle); + auto sg_col_prop = + generate::column_property(handle, sg_graph_view, sg_vertex_property_data); + auto sg_row_prop = + generate::row_property(handle, sg_graph_view, sg_vertex_property_data); + + auto expected_result = count_if_e( + handle, + sg_graph_view, + sg_row_prop.device_view(), + sg_col_prop.device_view(), + [] __device__(auto row, auto col, weight_t wt, auto row_property, auto col_property) { + return row_property < col_property; + }); + ASSERT_TRUE(expected_result == result); + } + } +}; + +using Tests_MG_TransformCountIfE_File = Tests_MG_TransformCountIfE; +using Tests_MG_TransformCountIfE_Rmat = Tests_MG_TransformCountIfE; + +TEST_P(Tests_MG_TransformCountIfE_File, CheckInt32Int32FloatTupleIntFloatTransposeFalse) +{ + auto param = GetParam(); + run_current_test, false>(std::get<0>(param), + std::get<1>(param)); +} + +TEST_P(Tests_MG_TransformCountIfE_Rmat, CheckInt32Int32FloatTupleIntFloatTransposeFalse) +{ + auto param = GetParam(); + run_current_test, false>( + std::get<0>(param), + cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +TEST_P(Tests_MG_TransformCountIfE_File, CheckInt32Int32FloatTupleIntFloatTransposeTrue) +{ + auto param = GetParam(); + run_current_test, true>(std::get<0>(param), + std::get<1>(param)); +} + +TEST_P(Tests_MG_TransformCountIfE_Rmat, CheckInt32Int32FloatTupleIntFloatTransposeTrue) +{ + auto param = GetParam(); + run_current_test, true>( + std::get<0>(param), + cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +TEST_P(Tests_MG_TransformCountIfE_File, CheckInt32Int32FloatTransposeFalse) +{ + auto param = GetParam(); + run_current_test(std::get<0>(param), std::get<1>(param)); +} + +TEST_P(Tests_MG_TransformCountIfE_Rmat, CheckInt32Int32FloatTransposeFalse) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), + cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +TEST_P(Tests_MG_TransformCountIfE_File, CheckInt32Int32FloatTransposeTrue) +{ + auto param = GetParam(); + run_current_test(std::get<0>(param), std::get<1>(param)); +} + +TEST_P(Tests_MG_TransformCountIfE_Rmat, CheckInt32Int32FloatTransposeTrue) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), + cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_MG_TransformCountIfE_File, + ::testing::Combine( + ::testing::Values(Prims_Usecase{true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), + cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_MG_TransformCountIfE_Rmat, + ::testing::Combine(::testing::Values(Prims_Usecase{true}), + ::testing::Values(cugraph::test::Rmat_Usecase( + 10, 16, 0.57, 0.19, 0.19, 0, false, false, 0, true)))); + +INSTANTIATE_TEST_SUITE_P( + rmat_large_test, + Tests_MG_TransformCountIfE_Rmat, + ::testing::Combine(::testing::Values(Prims_Usecase{false}), + ::testing::Values(cugraph::test::Rmat_Usecase( + 20, 32, 0.57, 0.19, 0.19, 0, false, false, 0, true)))); + +CUGRAPH_MG_TEST_PROGRAM_MAIN()