diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 241dcadfa78..2a3ef6d1e14 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -198,6 +198,7 @@ add_library(cudf src/groupby/sort/group_argmin.cu src/groupby/sort/aggregate.cpp src/groupby/sort/group_collect.cu + src/groupby/sort/group_merge_lists.cu src/groupby/sort/group_count.cu src/groupby/sort/group_max.cu src/groupby/sort/group_min.cu diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index 2600926d363..5fab284d506 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -78,6 +78,8 @@ class aggregation { ROW_NUMBER, ///< get row-number of current index (relative to rolling window) COLLECT_LIST, ///< collect values into a list COLLECT_SET, ///< collect values into a list without duplicate entries + MERGE_LISTS, ///< merge multiple lists values into one list + MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries LEAD, ///< window function, accesses row at specified offset following current row LAG, ///< window function, accesses row at specified offset preceding current row PTX, ///< PTX UDF based reduction @@ -250,7 +252,7 @@ std::unique_ptr make_collect_list_aggregation( null_policy null_handling = null_policy::INCLUDE); /** - * @brief Factory to create a COLLECT_SET aggregation + * @brief Factory to create a COLLECT_SET aggregation. * * `COLLECT_SET` returns a lists column of all included elements in the group/series. Within each * list, the duplicated entries are dropped out such that each entry appears only once. @@ -259,16 +261,53 @@ std::unique_ptr make_collect_list_aggregation( * of the list rows. * * @param null_handling Indicates whether to include/exclude nulls during collection - * @param nulls_equal Flag to specify whether null entries within each list should be considered - * equal - * @param nans_equal Flag to specify whether NaN values in floating point column should be - * considered equal + * @param nulls_equal Flag to specify whether null entries within each list should be considered + * equal. + * @param nans_equal Flag to specify whether NaN values in floating point column should be + * considered equal. */ template std::unique_ptr make_collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE, null_equality nulls_equal = null_equality::EQUAL, nan_equality nans_equal = nan_equality::UNEQUAL); +/** + * @brief Factory to create a MERGE_LISTS aggregation. + * + * Given a lists column, this aggregation merges all the lists corresponding to the same key value + * into one list. It is designed specificly to merge the partial results of multiple (distributed) + * groupby `COLLECT_LIST` aggregations into a final `COLLECT_LIST` result. As such, it requires the + * input lists column to be non-nullable (the child column containing list entries is not subjected + * to this requirement). + */ +template +std::unique_ptr make_merge_lists_aggregation(); + +/** + * @brief Factory to create a MERGE_SETS aggregation. + * + * Given a lists column, this aggregation firstly merges all the lists corresponding to the same key + * value into one list, then it drops all the duplicate entries in each lists, producing a lists + * column containing non-repeated entries. + * + * This aggregation is designed specificly to merge the partial results of multiple (distributed) + * groupby `COLLECT_LIST` or `COLLECT_SET` aggregations into a final `COLLECT_SET` result. As such, + * it requires the input lists column to be non-nullable (the child column containing list entries + * is not subjected to this requirement). + * + * In practice, the input (partial results) to this aggregation should be generated by (distributed) + * `COLLECT_LIST` aggregations, not `COLLECT_SET`, to avoid unnecessarily removing duplicate entries + * for the partial results. + * + * @param nulls_equal Flag to specify whether nulls within each list should be considered equal + * during dropping duplicate list entries. + * @param nans_equal Flag to specify whether NaN values in floating point column should be + * considered equal during dropping duplicate list entries. + */ +template +std::unique_ptr make_merge_sets_aggregation(null_equality nulls_equal = null_equality::EQUAL, + nan_equality nans_equal = nan_equality::UNEQUAL); + /// Factory to create a LAG aggregation template std::unique_ptr make_lag_aggregation(size_type offset); diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index e230ce0b757..373d695a5b5 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -75,6 +75,10 @@ class simple_aggregations_collector { // Declares the interface for the simple data_type col_type, class collect_list_aggregation const& agg); virtual std::vector> visit(data_type col_type, class collect_set_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class merge_lists_aggregation const& agg); + virtual std::vector> visit(data_type col_type, + class merge_sets_aggregation const& agg); virtual std::vector> visit(data_type col_type, class lead_lag_aggregation const& agg); virtual std::vector> visit(data_type col_type, @@ -105,6 +109,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer virtual void visit(class row_number_aggregation const& agg); virtual void visit(class collect_list_aggregation const& agg); virtual void visit(class collect_set_aggregation const& agg); + virtual void visit(class merge_lists_aggregation const& agg); + virtual void visit(class merge_sets_aggregation const& agg); virtual void visit(class lead_lag_aggregation const& agg); virtual void visit(class udf_aggregation const& agg); }; @@ -627,6 +633,66 @@ class collect_set_aggregation final : public rolling_aggregation { } }; +/** + * @brief Derived aggregation class for specifying MERGE_LISTs aggregation + */ +class merge_lists_aggregation final : public aggregation { + public: + explicit merge_lists_aggregation() : aggregation{MERGE_LISTS} {} + + std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } +}; + +/** + * @brief Derived aggregation class for specifying MERGE_SETs aggregation + */ +class merge_sets_aggregation final : public aggregation { + public: + explicit merge_sets_aggregation(null_equality nulls_equal, nan_equality nans_equal) + : aggregation{MERGE_SETS}, _nulls_equal(nulls_equal), _nans_equal(nans_equal) + { + } + + null_equality _nulls_equal; ///< whether to consider nulls as equal value + nan_equality _nans_equal; ///< whether to consider NaNs as equal value (applicable only to + ///< floating point types) + + bool is_equal(aggregation const& _other) const override + { + if (!this->aggregation::is_equal(_other)) { return false; } + auto const& other = dynamic_cast(_other); + return (_nulls_equal == other._nulls_equal && _nans_equal == other._nans_equal); + } + + size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); } + + std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + std::vector> get_simple_aggregations( + data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override + { + return collector.visit(col_type, *this); + } + void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); } + + protected: + size_t hash_impl() const + { + return std::hash{}(static_cast(_nulls_equal) ^ static_cast(_nans_equal)); + } +}; + /** * @brief Derived aggregation class for specifying LEAD/LAG window aggregations */ @@ -904,6 +970,18 @@ struct target_type_impl { using type = cudf::list_view; }; +// Always use list for MERGE_LISTS +template +struct target_type_impl { + using type = cudf::list_view; +}; + +// Always use list for MERGE_SETS +template +struct target_type_impl { + using type = cudf::list_view; +}; + // Always use Source for LEAD template struct target_type_impl { @@ -1005,6 +1083,10 @@ CUDA_HOST_DEVICE_CALLABLE decltype(auto) aggregation_dispatcher(aggregation::Kin return f.template operator()(std::forward(args)...); case aggregation::COLLECT_SET: return f.template operator()(std::forward(args)...); + case aggregation::MERGE_LISTS: + return f.template operator()(std::forward(args)...); + case aggregation::MERGE_SETS: + return f.template operator()(std::forward(args)...); case aggregation::LEAD: return f.template operator()(std::forward(args)...); case aggregation::LAG: @@ -1107,4 +1189,4 @@ constexpr inline bool is_valid_aggregation() bool is_valid_aggregation(data_type source, aggregation::Kind k); } // namespace detail -} // namespace cudf \ No newline at end of file +} // namespace cudf diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index a878dbe1535..f0fd865f685 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -154,6 +154,18 @@ std::vector> simple_aggregations_collector::visit( return visit(col_type, static_cast(agg)); } +std::vector> simple_aggregations_collector::visit( + data_type col_type, merge_lists_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + +std::vector> simple_aggregations_collector::visit( + data_type col_type, merge_sets_aggregation const& agg) +{ + return visit(col_type, static_cast(agg)); +} + std::vector> simple_aggregations_collector::visit( data_type col_type, lead_lag_aggregation const& agg) { @@ -270,6 +282,16 @@ void aggregation_finalizer::visit(collect_set_aggregation const& agg) visit(static_cast(agg)); } +void aggregation_finalizer::visit(merge_lists_aggregation const& agg) +{ + visit(static_cast(agg)); +} + +void aggregation_finalizer::visit(merge_sets_aggregation const& agg) +{ + visit(static_cast(agg)); +} + void aggregation_finalizer::visit(lead_lag_aggregation const& agg) { visit(static_cast(agg)); @@ -471,6 +493,24 @@ template std::unique_ptr make_collect_set_aggregation( template std::unique_ptr make_collect_set_aggregation( null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal); +/// Factory to create a MERGE_LISTS aggregation +template +std::unique_ptr make_merge_lists_aggregation() +{ + return std::make_unique(); +} +template std::unique_ptr make_merge_lists_aggregation(); + +/// Factory to create a MERGE_SETS aggregation +template +std::unique_ptr make_merge_sets_aggregation(null_equality nulls_equal, + nan_equality nans_equal) +{ + return std::make_unique(nulls_equal, nans_equal); +} +template std::unique_ptr make_merge_sets_aggregation(null_equality, + nan_equality); + /// Factory to create a LAG aggregation template std::unique_ptr make_lag_aggregation(size_type offset) diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp index 9d8f145a7c9..5e202b9ef3f 100644 --- a/cpp/src/groupby/sort/aggregate.cpp +++ b/cpp/src/groupby/sort/aggregate.cpp @@ -366,36 +366,32 @@ void aggregate_result_functor::operator()(aggregation template <> void aggregate_result_functor::operator()(aggregation const& agg) { - auto null_handling = - dynamic_cast(agg)._null_handling; - agg.do_hash(); - - if (cache.has_result(col_idx, agg)) return; + if (cache.has_result(col_idx, agg)) { return; } + auto const null_handling = + dynamic_cast(agg)._null_handling; auto result = detail::group_collect(get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), null_handling, stream, mr); - cache.add_result(col_idx, agg, std::move(result)); }; template <> void aggregate_result_functor::operator()(aggregation const& agg) { - auto const null_handling = - dynamic_cast(agg)._null_handling; - if (cache.has_result(col_idx, agg)) { return; } + auto const null_handling = + dynamic_cast(agg)._null_handling; auto const collect_result = detail::group_collect(get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), null_handling, stream, - mr); + rmm::mr::get_current_device_resource()); auto const nulls_equal = dynamic_cast(agg)._nulls_equal; auto const nans_equal = @@ -406,6 +402,78 @@ void aggregate_result_functor::operator()(aggregation lists::detail::drop_list_duplicates( lists_column_view(collect_result->view()), nulls_equal, nans_equal, stream, mr)); }; + +/** + * @brief Perform merging for the lists that correspond to the same key value. + * + * This aggregation is similar to `COLLECT_LIST` with the following differences: + * - It requires the input values to be a non-nullable lists column, and + * - The values (lists) corresponding to the same key will not result in a list of lists as output + * from `COLLECT_LIST`. Instead, those lists will result in a list generated by merging them + * together. + * + * In practice, this aggregation is used to merge the partial results of multiple (distributed) + * groupby `COLLECT_LIST` aggregations into a final `COLLECT_LIST` result. Those distributed + * aggregations were executed on different values columns partitioned from the original values + * column, then their results were (vertically) concatenated before given as the values column for + * this aggregation. + */ +template <> +void aggregate_result_functor::operator()(aggregation const& agg) +{ + if (cache.has_result(col_idx, agg)) { return; } + + cache.add_result( + col_idx, + agg, + detail::group_merge_lists( + get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr)); +}; + +/** + * @brief Perform merging for the lists corresponding to the same key value, then dropping duplicate + * list entries. + * + * This aggregation is similar to `COLLECT_SET` with the following differences: + * - It requires the input values to be a non-nullable lists column, and + * - The values (lists) corresponding to the same key will result in a list generated by merging + * them together then dropping duplicate entries. + * + * In practice, this aggregation is used to merge the partial results of multiple (distributed) + * groupby `COLLECT_LIST` or `COLLECT_SET` aggregations into a final `COLLECT_SET` result. Those + * distributed aggregations were executed on different values columns partitioned from the original + * values column, then their results were (vertically) concatenated before given as the values + * column for this aggregation. + * + * Firstly, this aggregation performs `MERGE_LISTS` to concatenate the input lists (corresponding to + * the same key) into intermediate lists, then it calls `lists::drop_list_duplicates` on them to + * remove duplicate list entries. As such, the input (partial results) to this aggregation should be + * generated by (distributed) `COLLECT_LIST` aggregations, not `COLLECT_SET`, to avoid unnecessarily + * removing duplicate entries for the partial results. + * + * Since duplicate list entries will be removed, the parameters `null_equality` and `nan_equality` + * are needed for calling to `lists::drop_list_duplicates`. + */ +template <> +void aggregate_result_functor::operator()(aggregation const& agg) +{ + if (cache.has_result(col_idx, agg)) { return; } + + auto const merged_result = detail::group_merge_lists(get_grouped_values(), + helper.group_offsets(stream), + helper.num_groups(stream), + stream, + rmm::mr::get_current_device_resource()); + auto const merge_sets_agg = dynamic_cast(agg); + cache.add_result(col_idx, + agg, + lists::detail::drop_list_duplicates(lists_column_view(merged_result->view()), + merge_sets_agg._nulls_equal, + merge_sets_agg._nans_equal, + stream, + mr)); +}; + } // namespace detail // Sort-based groupby diff --git a/cpp/src/groupby/sort/group_merge_lists.cu b/cpp/src/groupby/sort/group_merge_lists.cu new file mode 100644 index 00000000000..3043d107635 --- /dev/null +++ b/cpp/src/groupby/sort/group_merge_lists.cu @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +#include + +namespace cudf { +namespace groupby { +namespace detail { +std::unique_ptr group_merge_lists(column_view const& values, + cudf::device_span group_offsets, + size_type num_groups, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_EXPECTS(values.type().id() == type_id::LIST, + "Input to `group_merge_lists` must be a lists column."); + CUDF_EXPECTS(!values.nullable(), + "Input to `group_merge_lists` must be a non-nullable lists column."); + + auto offsets_column = make_numeric_column( + data_type(type_to_id()), num_groups + 1, mask_state::UNALLOCATED, stream, mr); + + // Generate offsets of the output lists column by gathering from the provided group offsets and + // the input list offsets. + // + // For example: + // values = [[2, 1], [], [4, -1, -2], [], [, 4, ]] + // list_offsets = [0, 2, 2, 5, 5 8] + // group_offsets = [0, 3, 5] + // + // then, the output offsets_column is [0, 5, 8]. + // + thrust::gather(rmm::exec_policy(stream), + group_offsets.begin(), + group_offsets.end(), + lists_column_view(values).offsets_begin(), + offsets_column->mutable_view().template begin()); + + // The child column of the output lists column is just copied from the input column. + auto child_column = + std::make_unique(lists_column_view(values).get_sliced_child(stream), stream, mr); + + return make_lists_column(num_groups, + std::move(offsets_column), + std::move(child_column), + 0, + rmm::device_buffer{}, + stream, + mr); +} + +} // namespace detail +} // namespace groupby +} // namespace cudf diff --git a/cpp/src/groupby/sort/group_reductions.hpp b/cpp/src/groupby/sort/group_reductions.hpp index 7cc0aea8362..3390af29330 100644 --- a/cpp/src/groupby/sort/group_reductions.hpp +++ b/cpp/src/groupby/sort/group_reductions.hpp @@ -348,19 +348,19 @@ std::unique_ptr group_nth_element(column_view const& values, * * @code{.pseudo} * values = [2, 1, 4, -1, -2, , 4, ] - * group_offsets = [0, 3, 5, 7, 8] + * group_offsets = [0, 3, 5, 7, 8] * num_groups = 4 * - * group_collect = [[2, 1, 4], [-1, -2] [, 4], []] + * group_collect(...) = [[2, 1, 4], [-1, -2], [, 4], []] * @endcode * - * @param values Grouped values to collect - * @param group_offsets Offsets of groups' starting points within @p values - * @param num_groups Number of groups + * @param values Grouped values to collect. + * @param group_offsets Offsets of groups' starting points within @p values. + * @param num_groups Number of groups. * @param null_handling Exclude nulls while counting if null_policy::EXCLUDE, - * Include nulls if null_policy::INCLUDE. + * include nulls if null_policy::INCLUDE. * @param stream CUDA stream used for device memory operations and kernel launches. - * @param mr Device memory resource used to allocate the returned column's device memory + * @param mr Device memory resource used to allocate the returned column's device memory. */ std::unique_ptr group_collect(column_view const& values, cudf::device_span group_offsets, @@ -369,6 +369,29 @@ std::unique_ptr group_collect(column_view const& values, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr); +/** + * @brief Internal API to merge grouped lists into one list. + * + * @code{.pseudo} + * values = [[2, 1], [], [4, -1, -2], [], [, 4, ]] + * group_offsets = [0, 3, 5] + * num_groups = 2 + * + * group_merge_lists(...) = [[2, 1, 4, -1, -2], [, 4, ]] + * @endcode + * + * @param values Grouped values (lists column) to collect. + * @param group_offsets Offsets of groups' starting points within @p values. + * @param num_groups Number of groups. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned column's device memory. + */ +std::unique_ptr group_merge_lists(column_view const& values, + cudf::device_span group_offsets, + size_type num_groups, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr); + /** @endinternal * */ diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index b4aef085e03..7c061df6cc2 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -1,4 +1,4 @@ -#============================================================================= +#============================================================================= # Copyright (c) 2018-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -67,6 +67,8 @@ ConfigureTest(GROUPBY_TEST groupby/max_tests.cpp groupby/mean_tests.cpp groupby/median_tests.cpp + groupby/merge_lists_tests.cpp + groupby/merge_sets_tests.cpp groupby/min_scan_tests.cpp groupby/nth_element_tests.cpp groupby/nunique_tests.cpp diff --git a/cpp/tests/groupby/merge_lists_tests.cpp b/cpp/tests/groupby/merge_lists_tests.cpp new file mode 100644 index 00000000000..7851565d86a --- /dev/null +++ b/cpp/tests/groupby/merge_lists_tests.cpp @@ -0,0 +1,388 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace cudf::test::iterators; + +namespace { +constexpr bool print_all{false}; // For debugging +constexpr int32_t null{0}; // Mark for null elements + +using vcol_views = std::vector; + +auto merge_lists(vcol_views const& keys_cols, vcol_views const& values_cols) +{ + // Append all the keys and lists together. + auto const keys = cudf::concatenate(keys_cols); + auto const values = cudf::concatenate(values_cols); + + std::vector requests; + requests.emplace_back(cudf::groupby::aggregation_request()); + requests[0].values = *values; + requests[0].aggregations.emplace_back(cudf::make_merge_lists_aggregation()); + + auto gb_obj = cudf::groupby::groupby(cudf::table_view({*keys})); + auto result = gb_obj.aggregate(requests); + return std::make_pair(std::move(result.first->release()[0]), + std::move(result.second[0].results[0])); +} + +} // namespace + +template +struct GroupbyMergeListsTypedTest : public cudf::test::BaseFixture { +}; + +using FixedWidthTypesNotBool = cudf::test::Concat; +TYPED_TEST_CASE(GroupbyMergeListsTypedTest, FixedWidthTypesNotBool); + +TYPED_TEST(GroupbyMergeListsTypedTest, InvalidInput) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys = keys_col{1, 2, 3}; + + // The input lists column must NOT be nullable. + auto const lists = lists_col{{lists_col{1}, lists_col{} /*NULL*/, lists_col{2}}, null_at(1)}; + EXPECT_THROW(merge_lists({keys}, {lists}), cudf::logic_error); + + // The input column must be a lists column. + auto const non_lists = keys_col{1, 2, 3, 4, 5}; + EXPECT_THROW(merge_lists({keys}, {non_lists}), cudf::logic_error); +} + +TYPED_TEST(GroupbyMergeListsTypedTest, EmptyInput) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + // Keys and lists columns are all empty. + auto const keys = keys_col{}; + auto const lists0 = lists_col{{1, 2, 3}, {4, 5, 6}}; + auto const lists = cudf::empty_like(lists0); + + auto const [out_keys, out_lists] = merge_lists(vcol_views{keys}, vcol_views{*lists}); + auto const expected_keys = keys_col{}; + auto const expected_lists = cudf::empty_like(lists0); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeListsTypedTest, InputWithoutNull) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1 = keys_col{1, 2}; + auto const keys2 = keys_col{1, 3}; + auto const keys3 = keys_col{2, 3, 4}; + + auto const lists1 = lists_col{ + {1, 2, 3}, // key = 1 + {4, 5, 6} // key = 2 + }; + auto const lists2 = lists_col{ + {10, 11}, // key = 1 + {11, 12} // key = 3 + }; + auto const lists3 = lists_col{ + {20, 21, 22}, // key = 2 + {23, 24, 25}, // key = 3 + {24, 25, 26} // key = 4 + }; + + auto const [out_keys, out_lists] = + merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + {1, 2, 3, 10, 11}, // key = 1 + {4, 5, 6, 20, 21, 22}, // key = 2 + {11, 12, 23, 24, 25}, // key = 3 + {24, 25, 26} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeListsTypedTest, InputHasNulls) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1 = keys_col{1, 2}; + auto const keys2 = keys_col{1, 3}; + auto const keys3 = keys_col{2, 3, 4}; + + // Note that the null elements here are not sorted, while the results from current collect_list + // are sorted. + auto const lists1 = lists_col{ + lists_col{{1, null, 3}, null_at(1)}, // key = 1 + lists_col{4, 5, 6} // key = 2 + }; + auto const lists2 = lists_col{ + lists_col{10, 11}, // key = 1 + lists_col{{null, null, null}, all_nulls()} // key = 3 + }; + auto const lists3 = lists_col{ + lists_col{20, 21, 22}, // key = 2 + lists_col{{null, 24, null}, nulls_at({0, 2})}, // key = 3 + lists_col{{24, 25, 26}, no_nulls()} // key = 4 + }; + + auto const [out_keys, out_lists] = + merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + lists_col{{1, null, 3, 10, 11}, null_at(1)}, // key = 1 + lists_col{4, 5, 6, 20, 21, 22}, // key = 2 + lists_col{{null, null, null, null, 24, null}, nulls_at({0, 1, 2, 3, 5})}, // key = 3 + lists_col{24, 25, 26} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeListsTypedTest, InputHasEmptyLists) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1 = keys_col{1, 2}; + auto const keys2 = keys_col{1, 3}; + auto const keys3 = keys_col{2, 3, 4}; + + auto const lists1 = lists_col{ + {1, 2, 3}, // key = 1 + {} // key = 2 + }; + auto const lists2 = lists_col{ + {}, // key = 1 + {11, 12} // key = 3 + }; + auto const lists3 = lists_col{ + {}, // key = 2 + {}, // key = 3 + {24, 25, 26} // key = 4 + }; + + auto const [out_keys, out_lists] = + merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + {1, 2, 3}, // key = 1 + {}, // key = 2 + {11, 12}, // key = 3 + {24, 25, 26} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeListsTypedTest, InputHasNullsAndEmptyLists) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1 = keys_col{1, 2, 3}; + auto const keys2 = keys_col{1, 3, 4}; + auto const keys3 = keys_col{2, 3, 4}; + + // Note that the null elements here are not sorted, while the results from current collect_list + // are sorted. + auto const lists1 = lists_col{ + lists_col{{1, null, 3}, null_at(1)}, // key = 1 + lists_col{}, // key = 2 + lists_col{4, 5} // key = 3 + }; + auto const lists2 = lists_col{ + lists_col{10, 11}, // key = 1 + lists_col{{null, null, null}, all_nulls()}, // key = 3 + lists_col{} // key = 4 + }; + auto const lists3 = lists_col{ + lists_col{20, 21, 22}, // key = 2 + lists_col{{null, 24, null}, nulls_at({0, 2})}, // key = 3 + lists_col{{24, 25, 26}, no_nulls()} // key = 4 + }; + + auto const [out_keys, out_lists] = + merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + lists_col{{1, null, 3, 10, 11}, null_at(1)}, // key = 1 + lists_col{20, 21, 22}, // key = 2 + lists_col{{4, 5, null, null, null, null, 24, null}, nulls_at({2, 3, 4, 5, 7})}, // key = 3 + lists_col{24, 25, 26} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeListsTypedTest, InputHasListsOfLists) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1 = keys_col{1, 2}; + auto const keys2 = keys_col{1, 3}; + auto const keys3 = keys_col{2, 3, 4}; + + auto const lists1 = lists_col{ + lists_col{lists_col{1, 2, 3}, lists_col{4}, lists_col{5, 6}}, // key = 1 + lists_col{lists_col{}, lists_col{7}} // key = 2 + }; + auto const lists2 = lists_col{ + lists_col{lists_col{}, lists_col{8, 9}}, // key = 1 + lists_col{lists_col{11}, lists_col{12, 13}} // key = 3 + }; + auto const lists3 = lists_col{ + lists_col{lists_col{14}, lists_col{15, 16, 17, 18}}, // key = 2 + lists_col{lists_col{}}, // key = 3 + lists_col{lists_col{17, 18, 19, 20, 21}, lists_col{18, 19, 20}} // key = 4 + }; + + auto const [out_keys, out_lists] = + merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + lists_col{ + lists_col{1, 2, 3}, lists_col{4}, lists_col{5, 6}, lists_col{}, lists_col{8, 9}}, // key = 1 + lists_col{lists_col{}, lists_col{7}, lists_col{14}, lists_col{15, 16, 17, 18}}, // key = 2 + lists_col{lists_col{11}, lists_col{12, 13}, lists_col{}}, // key = 3 + lists_col{lists_col{17, 18, 19, 20, 21}, lists_col{18, 19, 20}} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeListsTypedTest, SlicedColumnsInput) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1_original = keys_col{1, 2, 4, 5, 6, 7, 8, 9, 10}; + auto const keys2_original = keys_col{0, 0, 1, 1, 1, 3, 4, 5, 6}; + auto const keys3_original = keys_col{0, 1, 2, 3, 4, 5, 6, 7, 8}; + + auto const keys1 = cudf::slice(keys1_original, {0, 2})[0]; // { 1, 2 } + auto const keys2 = cudf::slice(keys2_original, {4, 6})[0]; // { 1, 3 } + auto const keys3 = cudf::slice(keys3_original, {2, 5})[0]; // { 2, 3, 4 } + + auto const lists1_original = lists_col{ + {10, 11, 12}, + {12, 13, 14}, + {1, 2, 3}, // key = 1 + {4, 5, 6} // key = 2 + }; + auto const lists2_original = lists_col{{1, 2}, + {10, 11}, // key = 1 + {11, 12}, // key = 3 + {13}, + {14}, + {15, 16}}; + auto const lists3_original = lists_col{{20, 21, 22}, // key = 2 + {23, 24, 25}, // key = 3 + {24, 25, 26}, // key = 4 + {1, 2, 3, 4, 5}}; + + auto const lists1 = cudf::slice(lists1_original, {2, 4})[0]; + auto const lists2 = cudf::slice(lists2_original, {1, 3})[0]; + auto const lists3 = cudf::slice(lists3_original, {0, 3})[0]; + + auto const [out_keys, out_lists] = + merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + {1, 2, 3, 10, 11}, // key = 1 + {4, 5, 6, 20, 21, 22}, // key = 2 + {11, 12, 23, 24, 25}, // key = 3 + {24, 25, 26} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +struct GroupbyMergeListsTest : public cudf::test::BaseFixture { +}; + +TEST_F(GroupbyMergeListsTest, StringsColumnInput) +{ + using strings_col = cudf::test::strings_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1 = strings_col{"dog", "unknown"}; + auto const keys2 = strings_col{"banana", "unknown", "dog"}; + auto const keys3 = strings_col{"apple", "dog", "water melon"}; + + auto const lists1 = lists_col{ + lists_col{"Poodle", "Golden Retriever", "Corgi"}, // key = "dog" + lists_col{{"Whale", "" /*NULL*/, "Polar Bear"}, null_at(1)} // key = "unknown" + }; + auto const lists2 = lists_col{ + lists_col{"Green", "Yellow"}, // key = "banana" + lists_col{}, // key = "unknown" + lists_col{{"" /*NULL*/, "" /*NULL*/}, all_nulls()} // key = "dog" + }; + auto const lists3 = lists_col{ + lists_col{"Fuji", "Red Delicious"}, // key = "apple" + lists_col{{"" /*NULL*/, "German Shepherd", "" /*NULL*/}, nulls_at({0, 2})}, // key = "dog" + lists_col{{"Seeedless", "Mini"}, no_nulls()} // key = "water melon" + }; + + auto const [out_keys, out_lists] = + merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = strings_col{"apple", "banana", "dog", "unknown", "water melon"}; + auto const expected_lists = lists_col{ + lists_col{"Fuji", "Red Delicious"}, // key = "apple" + lists_col{"Green", "Yellow"}, // key = "banana" + lists_col{{ + "Poodle", + "Golden Retriever", + "Corgi", + "" /*NULL*/, + "" /*NULL*/, + "" /*NULL*/, + "German Shepherd", + "" /*NULL*/ + }, + nulls_at({3, 4, 5, 7})}, // key = "dog" + lists_col{{"Whale", "" /*NULL*/, "Polar Bear"}, null_at(1)}, // key = "unknown" + lists_col{{"Seeedless", "Mini"}, no_nulls()} // key = "water melon" + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} diff --git a/cpp/tests/groupby/merge_sets_tests.cpp b/cpp/tests/groupby/merge_sets_tests.cpp new file mode 100644 index 00000000000..1365245c8af --- /dev/null +++ b/cpp/tests/groupby/merge_sets_tests.cpp @@ -0,0 +1,345 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace cudf::test::iterators; + +namespace { +constexpr bool print_all{false}; // For debugging +constexpr int32_t null{0}; // Mark for null elements + +using vcol_views = std::vector; + +auto merge_sets(vcol_views const& keys_cols, vcol_views const& values_cols) +{ + // Append all the keys and lists together. + auto const keys = cudf::concatenate(keys_cols); + auto const values = cudf::concatenate(values_cols); + + std::vector requests; + requests.emplace_back(cudf::groupby::aggregation_request()); + requests[0].values = *values; + requests[0].aggregations.emplace_back(cudf::make_merge_sets_aggregation()); + + auto gb_obj = cudf::groupby::groupby(cudf::table_view({*keys})); + auto result = gb_obj.aggregate(requests); + return std::make_pair(std::move(result.first->release()[0]), + std::move(result.second[0].results[0])); +} + +} // namespace + +template +struct GroupbyMergeSetsTypedTest : public cudf::test::BaseFixture { +}; + +using FixedWidthTypesNotBool = cudf::test::Concat; +TYPED_TEST_CASE(GroupbyMergeSetsTypedTest, FixedWidthTypesNotBool); + +TYPED_TEST(GroupbyMergeSetsTypedTest, InvalidInput) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys = keys_col{1, 2, 3}; + + // The input lists column must NOT be nullable. + auto const lists = lists_col{{lists_col{1}, lists_col{} /*NULL*/, lists_col{2}}, null_at(1)}; + EXPECT_THROW(merge_sets({keys}, {lists}), cudf::logic_error); + + // The input column must be a lists column. + auto const non_lists = keys_col{1, 2, 3, 4, 5}; + EXPECT_THROW(merge_sets({keys}, {non_lists}), cudf::logic_error); +} + +TYPED_TEST(GroupbyMergeSetsTypedTest, EmptyInput) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + // Keys and lists columns are all empty. + auto const keys = keys_col{}; + auto const lists0 = lists_col{{1, 2, 3}, {4, 5, 6}}; + auto const lists = cudf::empty_like(lists0); + + auto const [out_keys, out_lists] = merge_sets(vcol_views{keys}, vcol_views{*lists}); + auto const expected_keys = keys_col{}; + auto const expected_lists = cudf::empty_like(lists0); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeSetsTypedTest, InputWithoutNull) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1 = keys_col{1, 2}; + auto const keys2 = keys_col{1, 3}; + auto const keys3 = keys_col{2, 3, 4}; + + auto const lists1 = lists_col{ + {1, 2, 3, 4, 5, 6}, // key = 1 + {10, 11, 12, 13, 14, 15} // key = 2 + }; + auto const lists2 = lists_col{ + {4, 5, 6, 7, 8, 9}, // key = 1 + {20, 21, 22, 23, 24, 25} // key = 3 + }; + auto const lists3 = lists_col{ + {11, 12}, // key = 2 + {23, 24, 25, 26, 27, 28}, // key = 3 + {30, 31, 32} // key = 4 + }; + + auto const [out_keys, out_lists] = + merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + {1, 2, 3, 4, 5, 6, 7, 8, 9}, // key = 1 + {10, 11, 12, 13, 14, 15}, // key = 2 + {20, 21, 22, 23, 24, 25, 26, 27, 28}, // key = 3 + {30, 31, 32} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeSetsTypedTest, InputHasNulls) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + // Note that the null elements here are not sorted, while the results from current collect_list + // and collect_set are sorted. + auto const keys1 = keys_col{1, 2}; + auto const keys2 = keys_col{1, 3}; + auto const keys3 = keys_col{2, 3, 4}; + + auto const lists1 = lists_col{ + lists_col{{1, null, null, null, 5, 6}, nulls_at({1, 2, 3})}, // key = 1 + lists_col{10, 11, 12, 13, 14, 15} // key = 2 + }; + auto const lists2 = lists_col{ + lists_col{{null, null, 6, 7, 8, 9}, nulls_at({0, 1})}, // key = 1 + lists_col{{null, 21, 22, 23, 24, 25}, null_at(0)} // key = 3 + }; + auto const lists3 = lists_col{ + lists_col{11, 12}, // key = 2 + lists_col{23, 24, 25, 26, 27, 28}, // key = 3 + lists_col{{30, null, 32}, null_at(1)} // key = 4 + }; + + auto const [out_keys, out_lists] = + merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + lists_col{{1, 5, 6, 7, 8, 9, null}, null_at(6)}, // key = 1 + lists_col{10, 11, 12, 13, 14, 15}, // key = 2 + lists_col{{21, 22, 23, 24, 25, 26, 27, 28, null}, null_at(8)}, // key = 3 + lists_col{{30, 32, null}, null_at(2)} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeSetsTypedTest, InputHasEmptyLists) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1 = keys_col{1, 2}; + auto const keys2 = keys_col{1, 3}; + auto const keys3 = keys_col{2, 3, 4}; + + auto const lists1 = lists_col{ + {1, 2, 3}, // key = 1 + {} // key = 2 + }; + auto const lists2 = lists_col{ + {0, 1, 2, 3, 4, 5}, // key = 1 + {11, 12, 12, 12, 12, 12} // key = 3 + }; + auto const lists3 = lists_col{ + {}, // key = 2 + {}, // key = 3 + {24, 25, 26} // key = 4 + }; + + auto const [out_keys, out_lists] = + merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + {0, 1, 2, 3, 4, 5}, // key = 1 + {}, // key = 2 + {11, 12}, // key = 3 + {24, 25, 26} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeSetsTypedTest, InputHasNullsAndEmptyLists) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + // Note that the null elements here are not sorted, while the results from current collect_list + // and collect_set are sorted. + auto const keys1 = keys_col{1, 2, 3}; + auto const keys2 = keys_col{1, 3, 4}; + auto const keys3 = keys_col{2, 3, 4}; + + auto const lists1 = lists_col{ + lists_col{{null, 1, 2, 3}, null_at(0)}, // key = 1 + lists_col{}, // key = 2 + lists_col{} // key = 3 + }; + auto const lists2 = lists_col{ + lists_col{0, 1, 2, 3, 4, 5}, // key = 1 + lists_col{{null, 11, null, 12, 12, 12, 12, 12}, nulls_at({0, 2})}, // key = 3 + lists_col{20} // key = 4 + }; + auto const lists3 = lists_col{ + lists_col{}, // key = 2 + lists_col{}, // key = 3 + lists_col{{24, 25, null, null, null, 26}, nulls_at({2, 3, 4})} // key = 4 + }; + + auto const [out_keys, out_lists] = + merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + lists_col{{0, 1, 2, 3, 4, 5, null}, null_at(6)}, // key = 1 + lists_col{}, // key = 2 + lists_col{{11, 12, null}, null_at(2)}, // key = 3 + lists_col{{20, 24, 25, 26, null}, null_at(4)} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +TYPED_TEST(GroupbyMergeSetsTypedTest, SlicedColumnsInput) +{ + using keys_col = cudf::test::fixed_width_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1_original = keys_col{1, 2, 4, 5, 6, 7, 8, 9, 10}; + auto const keys2_original = keys_col{0, 0, 1, 1, 1, 3, 4, 5, 6}; + auto const keys3_original = keys_col{0, 1, 2, 3, 4, 5, 6, 7, 8}; + + auto const keys1 = cudf::slice(keys1_original, {0, 2})[0]; // { 1, 2 } + auto const keys2 = cudf::slice(keys2_original, {4, 6})[0]; // { 1, 3 } + auto const keys3 = cudf::slice(keys3_original, {2, 5})[0]; // { 2, 3, 4 } + + auto const lists1_original = lists_col{ + {10, 11, 12, 10, 11, 12, 10, 11, 12}, + {12, 13, 12, 13, 12, 13, 12, 13, 14}, + {1, 2, 3, 1, 2, 3, 1, 2, 3}, // key = 1 + {4, 5, 6, 4, 5, 6, 4, 5, 6} // key = 2 + }; + auto const lists2_original = lists_col{{1, 1, 1, 1, 1, 1, 1, 2}, + {10, 11, 11, 11, 11, 11, 12}, // key = 1 + {11, 12, 13, 12, 13, 12, 13, 12, 13, 14, 15}, // key = 3 + {13, 14, 15}, + {14, 15, 16}, + {15, 16}}; + auto const lists3_original = lists_col{{20, 21, 20, 21, 20, 21, 20, 21, 22}, // key = 2 + {23, 24, 25, 23, 24, 25}, // key = 3 + {24, 25, 26}, // key = 4 + {1, 2, 3, 4, 5}}; + + auto const lists1 = cudf::slice(lists1_original, {2, 4})[0]; + auto const lists2 = cudf::slice(lists2_original, {1, 3})[0]; + auto const lists3 = cudf::slice(lists3_original, {0, 3})[0]; + + auto const [out_keys, out_lists] = + merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = keys_col{1, 2, 3, 4}; + auto const expected_lists = lists_col{ + {1, 2, 3, 10, 11, 12}, // key = 1 + {4, 5, 6, 20, 21, 22}, // key = 2 + {11, 12, 13, 14, 15, 23, 24, 25}, // key = 3 + {24, 25, 26} // key = 4 + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +} + +struct GroupbyMergeSetsTest : public cudf::test::BaseFixture { +}; + +TEST_F(GroupbyMergeSetsTest, StringsColumnInput) +{ + using strings_col = cudf::test::strings_column_wrapper; + using lists_col = cudf::test::lists_column_wrapper; + + auto const keys1 = strings_col{"apple", "dog", "unknown"}; + auto const keys2 = strings_col{"banana", "unknown", "dog"}; + auto const keys3 = strings_col{"apple", "dog", "water melon"}; + + auto const lists1 = lists_col{ + lists_col{"Fuji", "Honey Bee"}, // key = "apple" + lists_col{"Poodle", "Golden Retriever", "Corgi"}, // key = "dog" + lists_col{{"Whale", "" /*NULL*/, "Polar Bear"}, null_at(1)} // key = "unknown" + }; + auto const lists2 = lists_col{ + lists_col{"Green", "Yellow"}, // key = "banana" + lists_col{}, // key = "unknown" + lists_col{{"" /*NULL*/, "" /*NULL*/, "" /*NULL*/}, all_nulls()} // key = "dog" + }; + auto const lists3 = lists_col{ + lists_col{"Fuji", "Red Delicious"}, // key = "apple" + lists_col{{"" /*NULL*/, "Corgi", "German Shepherd", "" /*NULL*/, "Golden Retriever"}, + nulls_at({0, 3})}, // key = "dog" + lists_col{{"Seeedless", "Mini"}, no_nulls()} // key = "water melon" + }; + + auto const [out_keys, out_lists] = + merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3}); + auto const expected_keys = strings_col{"apple", "banana", "dog", "unknown", "water melon"}; + auto const expected_lists = lists_col{ + lists_col{"Fuji", "Honey Bee", "Red Delicious"}, // key = "apple" + lists_col{"Green", "Yellow"}, // key = "banana" + lists_col{{ + "Corgi", "German Shepherd", "Golden Retriever", "Poodle", "" /*NULL*/ + }, + null_at(4)}, // key = "dog" + lists_col{{"Polar Bear", "Whale", "" /*NULL*/}, null_at(2)}, // key = "unknown" + lists_col{{"Mini", "Seeedless"}, no_nulls()} // key = "water melon" + }; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all); +}