Skip to content

Commit

Permalink
Fix null_equality config of rolling_collect_set (#8415)
Browse files Browse the repository at this point in the history
Fix #8405, and add some tests for various `null_equality` and `nan_equality`.

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - MithunR (https://github.com/mythrocks)

URL: #8415
  • Loading branch information
sperlingxx authored Jun 3, 2021
1 parent f24c6b4 commit d8fbb19
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 16 deletions.
7 changes: 2 additions & 5 deletions cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -746,11 +746,8 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation
stream,
rmm::mr::get_current_device_resource());

result = lists::detail::drop_list_duplicates(lists_column_view(collected_list->view()),
null_equality::EQUAL,
nan_equality::UNEQUAL,
stream,
mr);
result = lists::detail::drop_list_duplicates(
lists_column_view(collected_list->view()), agg._nulls_equal, agg._nans_equal, stream, mr);
}

std::unique_ptr<column> get_result()
Expand Down
41 changes: 41 additions & 0 deletions cpp/tests/groupby/collect_set_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,47 @@ TEST_F(CollectSetTest, StringInput)
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
}

TEST_F(CollectSetTest, FloatsWithNaN)
{
COL_K keys{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
cudf::test::fixed_width_column_wrapper<float> vals{
{1.0f, 1.0f, -2.3e-5f, -2.3e-5f, 2.3e5f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f, 0.0f},
{true, true, true, true, true, true, true, true, true, true, false, false}};
COL_K keys_expected{1};
// null equal with nan unequal
cudf::test::lists_column_wrapper<float> vals_expected{
{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f},
VALIDITY{true, true, true, true, true, true, true, false}},
};
test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set());
// null unequal with nan unequal
vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f, 0.0f},
VALIDITY{true, true, true, true, true, true, true, false, false}}};
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal());
// null exclude with nan unequal
vals_expected = {{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN}};
test_single_agg(
keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude());
// null equal with nan equal
vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, NAN, 0.0f}, VALIDITY{true, true, true, true, false}}};
test_single_agg(keys,
vals,
keys_expected,
vals_expected,
cudf::make_collect_set_aggregation(
null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL));
// null unequal with nan equal
vals_expected = {
{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, 0.0f, 0.0f}, VALIDITY{true, true, true, true, false, false}}};
test_single_agg(keys,
vals,
keys_expected,
vals_expected,
cudf::make_collect_set_aggregation(
null_policy::INCLUDE, null_equality::UNEQUAL, nan_equality::ALL_EQUAL));
}

TYPED_TEST(CollectSetTypedTest, CollectWithNulls)
{
// Just use an arbitrary value to store null entries
Expand Down
185 changes: 174 additions & 11 deletions cpp/tests/rolling/collect_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1661,16 +1661,16 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls)

using T = TypeParam;

auto const group_column = fixed_width_column_wrapper<int32_t>{1, 1, 1, 1, 1, 2, 2, 2, 2};
auto const group_column = fixed_width_column_wrapper<int32_t>{1, 1, 1, 1, 1, 2, 2, 2, 2, 2};
auto const input_column = fixed_width_column_wrapper<T, int32_t>{
{10, 11, 12, 13, 13, 20, 21, 21, 23}, {1, 0, 0, 1, 1, 1, 0, 1, 1}};
{10, 0, 0, 13, 13, 20, 21, 0, 0, 23}, {1, 0, 0, 1, 1, 1, 1, 0, 0, 1}};

auto const preceding = 2;
auto const following = 1;
auto const min_periods = 1;

{
// Nulls included.
// Nulls included and nulls are equal.
auto const result =
grouped_rolling_window(table_view{std::vector<column_view>{group_column}},
input_column,
Expand All @@ -1679,10 +1679,78 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls)
min_periods,
*make_collect_set_aggregation<rolling_aggregation>());
// Null values are sorted to the tails of lists (sets)
auto expected_child = fixed_width_column_wrapper<T, int32_t>{
{10, 11, 10, 11, 13, 11, 13, 12, 13, 20, 21, 20, 21, 21, 21, 23, 21, 21, 23},
{1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1}};
auto expected_offsets = fixed_width_column_wrapper<int32_t>{0, 2, 4, 6, 8, 9, 11, 14, 17, 19};
auto expected_child = fixed_width_column_wrapper<T, int32_t>{{
10, 0, // row 0
10, 0, // row 1
13, 0, // row 2
13, 0, // row 3
13, // row 4
20, 21, // row 5
20, 21, 0, // row 6
21, 0, // row 7
23, 0, // row 8
23, 0, // row 9
},
{
1, 0, // row 0
1, 0, // row 1
1, 0, // row 2
1, 0, // row 3
1, // row 4
1, 1, // row 5
1, 1, 0, // row 6
1, 0, // row 7
1, 0, // row 8
1, 0 // row 9
}};
auto expected_offsets =
fixed_width_column_wrapper<int32_t>{0, 2, 4, 6, 8, 9, 11, 14, 16, 18, 20};

auto expected_result = make_lists_column(static_cast<column_view>(group_column).size(),
expected_offsets.release(),
expected_child.release(),
0,
{});

CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());
}

{
// Nulls included and nulls are NOT equal.
auto const result = grouped_rolling_window(table_view{std::vector<column_view>{group_column}},
input_column,
preceding,
following,
min_periods,
*make_collect_set_aggregation<rolling_aggregation>(
null_policy::INCLUDE, null_equality::UNEQUAL));
// Null values are sorted to the tails of lists (sets)
auto expected_child = fixed_width_column_wrapper<T, int32_t>{{
10, 0, // row 0
10, 0, 0, // row 1
13, 0, 0, // row 2
13, 0, // row 3
13, // row 4
20, 21, // row 5
20, 21, 0, // row 6
21, 0, 0, // row 7
23, 0, 0, // row 8
23, 0 // row 9
},
{
1, 0, // row 0
1, 0, 0, // row 1
1, 0, 0, // row 2
1, 0, // row 3
1, // row 4
1, 1, // row 5
1, 1, 0, // row 6
1, 0, 0, // row 7
1, 0, 0, // row 8
1, 0 // row 9
}};
auto expected_offsets =
fixed_width_column_wrapper<int32_t>{0, 2, 5, 8, 10, 11, 13, 16, 19, 22, 24};

auto expected_result = make_lists_column(static_cast<column_view>(group_column).size(),
expected_offsets.release(),
Expand All @@ -1703,10 +1771,22 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls)
min_periods,
*make_collect_set_aggregation<rolling_aggregation>(null_policy::EXCLUDE));

auto expected_child =
fixed_width_column_wrapper<T, int32_t>{10, 10, 13, 13, 13, 20, 20, 21, 21, 23, 21, 23};

auto expected_offsets = fixed_width_column_wrapper<int32_t>{0, 1, 2, 3, 4, 5, 6, 8, 10, 12};
auto expected_child = fixed_width_column_wrapper<T, int32_t>{
10, // row 0
10, // row 1
13, // row 2
13, // row 3
13, // row 4
20,
21, // row 5
20,
21, // row 6
21, // row 7
23, // row 8
23 // row 9
};

auto expected_offsets = fixed_width_column_wrapper<int32_t>{0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12};

auto expected_result = make_lists_column(static_cast<column_view>(group_column).size(),
expected_offsets.release(),
Expand Down Expand Up @@ -1957,6 +2037,68 @@ TEST_F(CollectSetTest, BoolGroupedRollingWindow)
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());
}

TEST_F(CollectSetTest, FloatGroupedRollingWindowWithNaNs)
{
using namespace cudf;
using namespace cudf::test;

auto const group_column = fixed_width_column_wrapper<int32_t>{1, 1, 1, 1, 1, 2, 2, 2, 2};
auto const input_column = fixed_width_column_wrapper<double>{
{1.23, 0.2341, 0.2341, -5.23e9, std::nan("1"), 1.1, std::nan("1"), std::nan("1"), 0.0},
{true, true, true, true, true, true, true, true, false}};

auto const preceding = 2;
auto const following = 1;
auto const min_periods = 1;
// test on nan_equality::UNEQUAL
auto const result = grouped_rolling_window(table_view{std::vector<column_view>{group_column}},
input_column,
preceding,
following,
min_periods,
*make_collect_set_aggregation<rolling_aggregation>());

auto const expected_result = lists_column_wrapper<double>{
{{0.2341, 1.23}, std::initializer_list<bool>{true, true}},
{{0.2341, 1.23}, std::initializer_list<bool>{true, true}},
{{-5.23e9, 0.2341}, std::initializer_list<bool>{true, true}},
{{-5.23e9, 0.2341, std::nan("1")}, std::initializer_list<bool>{true, true, true}},
{{-5.23e9, std::nan("1")}, std::initializer_list<bool>{true, true}},
{{1.1, std::nan("1")}, std::initializer_list<bool>{true, true}},
{{1.1, std::nan("1"), std::nan("1")}, std::initializer_list<bool>{true, true, true}},
{{std::nan("1"), std::nan("1"), 0.0}, std::initializer_list<bool>{true, true, false}},
{{std::nan("1"), 0.0},
std::initializer_list<bool>{
true, false}}}.release();

CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view());

// test on nan_equality::ALL_EQUAL
auto const result_nan_equal =
grouped_rolling_window(table_view{std::vector<column_view>{group_column}},
input_column,
preceding,
following,
min_periods,
*make_collect_set_aggregation<rolling_aggregation>(
null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL));

auto const expected_result_nan_equal = lists_column_wrapper<double>{
{{0.2341, 1.23}, std::initializer_list<bool>{true, true}},
{{0.2341, 1.23}, std::initializer_list<bool>{true, true}},
{{-5.23e9, 0.2341}, std::initializer_list<bool>{true, true}},
{{-5.23e9, 0.2341, std::nan("1")}, std::initializer_list<bool>{true, true, true}},
{{-5.23e9, std::nan("1")}, std::initializer_list<bool>{true, true}},
{{1.1, std::nan("1")}, std::initializer_list<bool>{true, true}},
{{1.1, std::nan("1")}, std::initializer_list<bool>{true, true}},
{{std::nan("1"), 0.0}, std::initializer_list<bool>{true, false}},
{{std::nan("1"), 0.0},
std::initializer_list<bool>{true,
false}}}.release();

CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_nan_equal->view(), result_nan_equal->view());
}

TEST_F(CollectSetTest, BasicRollingWindowWithNaNs)
{
using namespace cudf;
Expand Down Expand Up @@ -2002,6 +2144,27 @@ TEST_F(CollectSetTest, BasicRollingWindowWithNaNs)
*make_collect_set_aggregation<rolling_aggregation>(null_policy::EXCLUDE));

CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());

auto const expected_result_for_nan_equal =
lists_column_wrapper<double>{
{0.2341, 1.23},
{0.2341, 1.23, std::nan("1")},
{0.2341, std::nan("1")},
{-5.23e9, std::nan("1")},
{-5.23e9, std::nan("1")},
}
.release();

auto const result_with_nan_equal =
rolling_window(input_column,
2,
1,
1,
*make_collect_set_aggregation<rolling_aggregation>(
null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL));

CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_for_nan_equal->view(),
result_with_nan_equal->view());
}

TEST_F(CollectSetTest, ListTypeRollingWindow)
Expand Down

0 comments on commit d8fbb19

Please sign in to comment.