diff --git a/CHANGELOG.md b/CHANGELOG.md index 849f9891699..0be9a237926 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,7 @@ - PR #6644 Cover different CSV reader/writer options in benchmarks - PR #6651 Add cudf::dictionary::make_dictionary_pair_iterator - PR #6635 Add cudf::test::dictionary_column_wrapper class +- PR #6676 Add dictionary support to `cudf::quantile` - PR #6673 Parameterize avro and json benchmark - PR #6609 Support fixed-point decimal for HostColumnVector - PR #6705 Add nested type support to Java table serialization diff --git a/cpp/src/quantiles/quantile.cu b/cpp/src/quantiles/quantile.cu index 09a8b714819..0b1ae2560b8 100644 --- a/cpp/src/quantiles/quantile.cu +++ b/cpp/src/quantiles/quantile.cu @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include @@ -62,19 +64,29 @@ struct quantile_functor { return output; } - auto d_input = column_device_view::create(input); + auto d_input = column_device_view::create(input, stream); auto d_output = mutable_column_device_view::create(output->mutable_view()); rmm::device_vector q_device{q}; - auto sorted_data = thrust::make_permutation_iterator(input.data(), ordered_indices); - - thrust::transform(q_device.begin(), - q_device.end(), - d_output->template begin(), - [sorted_data, interp = interp, size = size] __device__(double q) { - return select_quantile_data(sorted_data, size, q, interp); - }); + if (!cudf::is_dictionary(input.type())) { + auto sorted_data = thrust::make_permutation_iterator(input.data(), ordered_indices); + thrust::transform(q_device.begin(), + q_device.end(), + d_output->template begin(), + [sorted_data, interp = interp, size = size] __device__(double q) { + return select_quantile_data(sorted_data, size, q, interp); + }); + } else { + auto sorted_data = thrust::make_permutation_iterator( + dictionary::detail::make_dictionary_iterator(*d_input), ordered_indices); + thrust::transform(q_device.begin(), + q_device.end(), + d_output->template begin(), + [sorted_data, interp = interp, size = size] __device__(double q) { + return select_quantile_data(sorted_data, size, q, interp); + }); + } if (input.nullable()) { auto sorted_validity = thrust::make_transform_iterator( @@ -113,7 +125,11 @@ std::unique_ptr quantile(column_view const& input, auto functor = quantile_functor{ ordered_indices, size, q, interp, retain_types, mr, stream}; - return type_dispatcher(input.type(), functor, input); + auto input_type = cudf::is_dictionary(input.type()) && !input.is_empty() + ? dictionary_column_view(input).keys().type() + : input.type(); + + return type_dispatcher(input_type, functor, input); } } // namespace detail diff --git a/cpp/tests/quantiles/quantile_test.cpp b/cpp/tests/quantiles/quantile_test.cpp index 9014ab27d18..63b55636b59 100644 --- a/cpp/tests/quantiles/quantile_test.cpp +++ b/cpp/tests/quantiles/quantile_test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2020, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -448,6 +448,28 @@ TYPED_TEST(QuantileUnsupportedTypesTest, TestMultipleElements) EXPECT_THROW(cudf::quantile(input, {0}), cudf::logic_error); } +struct QuantileDictionaryTest : public BaseFixture { +}; + +TEST_F(QuantileDictionaryTest, TestValid) +{ + dictionary_column_wrapper col{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + fixed_width_column_wrapper indices{0, 2, 4, 6, 8, 1, 3, 5, 7, 9}; + + auto result = cudf::quantile(col, {0.5}, cudf::interpolation::LINEAR); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), fixed_width_column_wrapper{5.5}); + + result = cudf::quantile(col, {0.5}, cudf::interpolation::LINEAR, indices); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), fixed_width_column_wrapper{5.5}); + + result = cudf::quantile(col, {0.1, 0.2}, cudf::interpolation::HIGHER); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), fixed_width_column_wrapper{2.0, 3.0}); + + result = cudf::quantile(col, {0.25, 0.5, 0.75}, cudf::interpolation::MIDPOINT); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), + fixed_width_column_wrapper{3.5, 5.5, 7.5}); +}; + } // anonymous namespace CUDF_TEST_PROGRAM_MAIN()