Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cudf::rolling_window SUM support for decimal32 and decimal64 #7147

Merged
merged 7 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,15 @@ struct target_type_impl<
using type = int64_t;
};

// Summing fixed_point numbers, always use the decimal64 accumulator
template <typename Source, aggregation::Kind k>
struct target_type_impl<
Source,
k,
std::enable_if_t<cudf::is_fixed_point<Source>() && (k == aggregation::SUM)>> {
using type = numeric::decimal64;
};

// Summing/Multiplying float/doubles, use same type accumulator
template <typename Source, aggregation::Kind k>
struct target_type_impl<
Expand Down
6 changes: 5 additions & 1 deletion cpp/src/rolling/rolling_detail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ static constexpr bool is_rolling_supported()

return is_valid_numeric_agg;

} else if (cudf::is_timestamp<ColumnType>() || cudf::is_fixed_point<ColumnType>()) {
} else if (cudf::is_timestamp<ColumnType>()) {
return (op == aggregation::MIN) or (op == aggregation::MAX) or
(op == aggregation::COUNT_VALID) or (op == aggregation::COUNT_ALL) or
(op == aggregation::ROW_NUMBER) or (op == aggregation::LEAD) or (op == aggregation::LAG);
} else if (cudf::is_fixed_point<ColumnType>()) {
return (op == aggregation::SUM) or (op == aggregation::MIN) or (op == aggregation::MAX) or
(op == aggregation::COUNT_VALID) or (op == aggregation::COUNT_ALL) or
(op == aggregation::ROW_NUMBER) or (op == aggregation::LEAD) or (op == aggregation::LAG);
} else if (std::is_same<ColumnType, cudf::string_view>()) {
return (op == aggregation::MIN) or (op == aggregation::MAX) or
(op == aggregation::COUNT_VALID) or (op == aggregation::COUNT_ALL) or
Expand Down
19 changes: 15 additions & 4 deletions cpp/tests/rolling/rolling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,13 +1005,15 @@ TYPED_TEST(FixedPointTests, MinMaxCountLagLeadNulls)
{
using namespace numeric;
using namespace cudf;
using decimalXX = TypeParam;
using RepType = cudf::device_storage_type_t<decimalXX>;
using fp_wrapper = cudf::test::fixed_point_column_wrapper<RepType>;
using fw_wrapper = cudf::test::fixed_width_column_wrapper<size_type>;
using decimalXX = TypeParam;
using RepType = cudf::device_storage_type_t<decimalXX>;
using fp_wrapper = cudf::test::fixed_point_column_wrapper<RepType>;
using fp64_wrapper = cudf::test::fixed_point_column_wrapper<int64_t>;
using fw_wrapper = cudf::test::fixed_width_column_wrapper<size_type>;

auto const scale = scale_type{-1};
auto const input = fp_wrapper{{42, 1729, 55, 343, 1, 2}, {1, 0, 1, 0, 1, 1}, scale};
auto const expected_sum = fp64_wrapper{{42, 97, 55, 56, 3, 3}, {1, 1, 1, 1, 1, 1}, scale};
auto const expected_min = fp_wrapper{{42, 42, 55, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, scale};
auto const expected_max = fp_wrapper{{42, 55, 55, 55, 2, 2}, {1, 1, 1, 1, 1, 1}, scale};
auto const expected_lag = fp_wrapper{{0, 42, 1729, 55, 343, 1}, {0, 1, 0, 1, 0, 1}, scale};
Expand All @@ -1020,6 +1022,7 @@ TYPED_TEST(FixedPointTests, MinMaxCountLagLeadNulls)
auto const expected_count_all = fw_wrapper{{2, 3, 3, 3, 3, 2}, {1, 1, 1, 1, 1, 1}};
auto const expected_rowno = fw_wrapper{{1, 2, 2, 2, 2, 2}, {1, 1, 1, 1, 1, 1}};

auto const sum = rolling_window(input, 2, 1, 1, make_sum_aggregation());
auto const min = rolling_window(input, 2, 1, 1, make_min_aggregation());
auto const max = rolling_window(input, 2, 1, 1, make_max_aggregation());
auto const lag = rolling_window(input, 2, 1, 1, make_lag_aggregation(1));
Expand All @@ -1028,13 +1031,21 @@ TYPED_TEST(FixedPointTests, MinMaxCountLagLeadNulls)
auto const all = rolling_window(input, 2, 1, 1, make_count_aggregation(null_policy::INCLUDE));
auto const rowno = rolling_window(input, 2, 1, 1, make_row_number_aggregation());

CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_sum, sum->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_min, min->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_max, max->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lag, lag->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lead, lead->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_count_val, valid->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_count_all, all->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_rowno, rowno->view());

EXPECT_THROW(rolling_window(input, 2, 1, 1, make_product_aggregation()), cudf::logic_error);
EXPECT_THROW(rolling_window(input, 2, 1, 1, make_mean_aggregation()), cudf::logic_error);
EXPECT_THROW(rolling_window(input, 2, 1, 1, make_variance_aggregation()), cudf::logic_error);
EXPECT_THROW(rolling_window(input, 2, 1, 1, make_std_aggregation()), cudf::logic_error);
EXPECT_THROW(rolling_window(input, 2, 1, 1, make_sum_of_squares_aggregation()),
cudf::logic_error);
}

CUDF_TEST_PROGRAM_MAIN()