Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement groupby MERGE_LISTS and MERGE_SETS aggregates #8436

Merged
merged 25 commits into from
Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d540bae
Implement prototype for groupby merge_lists
ttnghia Jun 3, 2021
10c2ca4
Add a unit test for merge_lists
ttnghia Jun 3, 2021
8fb0a77
Rewrite doxygen for factory functions
ttnghia Jun 7, 2021
0c6b2cf
Add factory function for creating `MERGE_SETS`
ttnghia Jun 7, 2021
d74932c
Implement `MERGE_SETS`
ttnghia Jun 7, 2021
b9b3e40
Rename `detail::group_collect_merge` to `detail::group_merge_lists`
ttnghia Jun 7, 2021
a2c43a3
Add `merge_sets_tests.cpp`
ttnghia Jun 7, 2021
c62c320
Implement unit tests for `merge_lists`
ttnghia Jun 8, 2021
5fffe78
Add unit tests for `merge_lists`
ttnghia Jun 8, 2021
2557d3e
Finish unit tests for `merge_sets`
ttnghia Jun 8, 2021
b64298e
Rewrite examples in doxygen
ttnghia Jun 8, 2021
7677350
Fix `merge_sets_aggregation` initialization
ttnghia Jun 8, 2021
75b5f79
Merge branch 'branch-21.08' into groupby_merge_lists
ttnghia Jun 8, 2021
5eacf35
Update doxygen
ttnghia Jun 9, 2021
91b5641
Use `thrust::gather` instead of `thrust::transform`
ttnghia Jun 9, 2021
a4b04b3
Merge branch 'branch-21.08' into groupby_merge_lists
ttnghia Jun 15, 2021
5dc696b
Fix typo
ttnghia Jun 15, 2021
ecd6d71
Move doxygen to detail implementation
ttnghia Jun 15, 2021
81b0b6d
Revert "Merge branch 'branch-21.08' into groupby_merge_lists"
ttnghia Jun 15, 2021
6d8636a
Revert "Revert "Merge branch 'branch-21.08' into groupby_merge_lists""
ttnghia Jun 15, 2021
8fab02b
Merge branch 'branch-21.08' into groupby_merge_lists
ttnghia Jun 16, 2021
817e8bf
WIP
ttnghia Jun 17, 2021
ee26a04
Merge branch 'branch-21.08' into groupby_merge_lists
ttnghia Jun 17, 2021
2cf6afd
Fix unit tests
ttnghia Jun 17, 2021
078e90e
Fix headers
ttnghia Jun 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ add_library(cudf
src/groupby/sort/group_argmin.cu
src/groupby/sort/aggregate.cpp
src/groupby/sort/group_collect.cu
src/groupby/sort/group_merge_lists.cu
src/groupby/sort/group_count.cu
src/groupby/sort/group_max.cu
src/groupby/sort/group_min.cu
Expand Down
71 changes: 66 additions & 5 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class aggregation {
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
Expand Down Expand Up @@ -250,7 +252,7 @@ std::unique_ptr<Base> make_collect_list_aggregation(
null_policy null_handling = null_policy::INCLUDE);

/**
* @brief Factory to create a COLLECT_SET aggregation
* @brief Factory to create a COLLECT_SET aggregation.
*
* `COLLECT_SET` returns a lists column of all included elements in the group/series. Within each
* list, the duplicated entries are dropped out such that each entry appears only once.
Expand All @@ -259,16 +261,75 @@ std::unique_ptr<Base> make_collect_list_aggregation(
* of the list rows.
*
* @param null_handling Indicates whether to include/exclude nulls during collection
* @param nulls_equal Flag to specify whether null entries within each list should be considered
* equal
* @param nans_equal Flag to specify whether NaN values in floating point column should be
* considered equal
* @param nulls_equal Flag to specify whether null entries within each list should be considered
* equal.
* @param nans_equal Flag to specify whether NaN values in floating point column should be
* considered equal.
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);

/**
* @brief Factory to create a MERGE_LISTS aggregation.
*
* This aggregation is similar to `COLLECT_LIST` with the following differences:
* - It requires the input values to be a non-nullable lists column, and
* - The values (lists) corresponding to the same key will not result in a list of lists as output
* from `COLLECT_LIST`. Instead, those lists will result in a list generated by merging them
* together.
*
* In practice, this aggregation is used to merge the partial results of multiple (distributed)
* groupby `COLLECT_LIST` aggregations into a final `COLLECT_LIST` result. Those distributed
* aggregations were executed on different values columns partitioned from the original values
* column, then their results were (vertically) concatenated before given as the values column for
* this aggregation.
*
* Note that this aggregation does not accept any null handling parameter as it is designed for
* simply merging lists of the input lists column. Because the output from `COLLECT_LIST` are
* non-nullable, the input lists column to this aggregation is also required to be non-nullable (but
* its child column containing list entries is not subject to this requirement).
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_lists_aggregation();

/**
* @brief Factory to create a MERGE_SETS aggregation.
*
* This aggregation is similar to `COLLECT_SET` with the following differences:
* - It requires the input values to be a non-nullable lists column, and
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
* - The values (lists) corresponding to the same key will result in a list generated by merging
* them together then dropping duplicate entries.
*
* In practice, this aggregation is used to merge the partial results of multiple (distributed)
* groupby `COLLECT_LIST/COLLECT_SET` aggregations into a final `COLLECT_SET` result. Those
* distributed aggregations were executed on different values columns partitioned from the original
* values column, then their results were (vertically) concatenated before given as the values
* column for this aggregation.
*
* This aggregation firstly calls `MERGE_LISTS` to merge the input lists into intermediate lists,
* then it calls `lists::drop_list_duplicates` on them to remove duplicate list entries. As such,
* the input (partial results) to this aggregation should be generated by (distributed)
* `COLLECT_LIST` aggregations, not `COLLECT_SET`, to avoid unnecessarily removing duplicate entries
* for the partial results.
*
* Similar to `MERGE_LISTS`, this aggregation does not need the `null_policy` parameter and requires
* the input lists column to be non-nullable (its child column containing list entries is not
* subject to this requirement).
*
* Since duplicate list entries will be removed, the parameters `null_equality` and `nan_equality`
* are needed for calling to `lists::drop_list_duplicates`.
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
*
* @param nulls_equal Flag to specify whether null entries within each list should be considered
* equal.
* @param nans_equal Flag to specify whether NaN values in floating point column should be
* considered equal.
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_sets_aggregation(null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);

/// Factory to create a LAG aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_lag_aggregation(size_type offset);
Expand Down
84 changes: 83 additions & 1 deletion cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ class simple_aggregations_collector { // Declares the interface for the simple
data_type col_type, class collect_list_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class collect_set_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class merge_lists_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class merge_sets_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class lead_lag_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
Expand Down Expand Up @@ -105,6 +109,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer
virtual void visit(class row_number_aggregation const& agg);
virtual void visit(class collect_list_aggregation const& agg);
virtual void visit(class collect_set_aggregation const& agg);
virtual void visit(class merge_lists_aggregation const& agg);
virtual void visit(class merge_sets_aggregation const& agg);
virtual void visit(class lead_lag_aggregation const& agg);
virtual void visit(class udf_aggregation const& agg);
};
Expand Down Expand Up @@ -627,6 +633,66 @@ class collect_set_aggregation final : public rolling_aggregation {
}
};

/**
* @brief Derived aggregation class for specifying MERGE_LISTs aggregation
*/
class merge_lists_aggregation final : public aggregation {
public:
explicit merge_lists_aggregation() : aggregation{MERGE_LISTS} {}

std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<merge_lists_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived aggregation class for specifying MERGE_SETs aggregation
*/
class merge_sets_aggregation final : public aggregation {
public:
explicit merge_sets_aggregation(null_equality nulls_equal, nan_equality nans_equal)
: aggregation{MERGE_SETS}, _nulls_equal(nulls_equal), _nans_equal(nans_equal)
{
}

null_equality _nulls_equal; ///< whether to consider nulls as equal value
nan_equality _nans_equal; ///< whether to consider NaNs as equal value (applicable only to
///< floating point types)

bool is_equal(aggregation const& _other) const override
{
if (!this->aggregation::is_equal(_other)) { return false; }
auto const& other = dynamic_cast<merge_sets_aggregation const&>(_other);
return (_nulls_equal == other._nulls_equal && _nans_equal == other._nans_equal);
}

size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); }

std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<merge_sets_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }

protected:
size_t hash_impl() const
{
return std::hash<int>{}(static_cast<int>(_nulls_equal) ^ static_cast<int>(_nans_equal));
}
};

/**
* @brief Derived aggregation class for specifying LEAD/LAG window aggregations
*/
Expand Down Expand Up @@ -904,6 +970,18 @@ struct target_type_impl<Source, aggregation::COLLECT_SET> {
using type = cudf::list_view;
};

// Always use list for MERGE_LISTS
template <typename Source>
struct target_type_impl<Source, aggregation::MERGE_LISTS> {
using type = cudf::list_view;
};

// Always use list for MERGE_SETS
template <typename Source>
struct target_type_impl<Source, aggregation::MERGE_SETS> {
using type = cudf::list_view;
};

// Always use Source for LEAD
template <typename Source>
struct target_type_impl<Source, aggregation::LEAD> {
Expand Down Expand Up @@ -1005,6 +1083,10 @@ CUDA_HOST_DEVICE_CALLABLE decltype(auto) aggregation_dispatcher(aggregation::Kin
return f.template operator()<aggregation::COLLECT_LIST>(std::forward<Ts>(args)...);
case aggregation::COLLECT_SET:
return f.template operator()<aggregation::COLLECT_SET>(std::forward<Ts>(args)...);
case aggregation::MERGE_LISTS:
return f.template operator()<aggregation::MERGE_LISTS>(std::forward<Ts>(args)...);
case aggregation::MERGE_SETS:
return f.template operator()<aggregation::MERGE_SETS>(std::forward<Ts>(args)...);
case aggregation::LEAD:
return f.template operator()<aggregation::LEAD>(std::forward<Ts>(args)...);
case aggregation::LAG:
Expand Down Expand Up @@ -1107,4 +1189,4 @@ constexpr inline bool is_valid_aggregation()
bool is_valid_aggregation(data_type source, aggregation::Kind k);

} // namespace detail
} // namespace cudf
} // namespace cudf
40 changes: 40 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, merge_lists_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, merge_sets_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, lead_lag_aggregation const& agg)
{
Expand Down Expand Up @@ -270,6 +282,16 @@ void aggregation_finalizer::visit(collect_set_aggregation const& agg)
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(merge_lists_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(merge_sets_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(lead_lag_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
Expand Down Expand Up @@ -471,6 +493,24 @@ template std::unique_ptr<aggregation> make_collect_set_aggregation<aggregation>(
template std::unique_ptr<rolling_aggregation> make_collect_set_aggregation<rolling_aggregation>(
null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal);

/// Factory to create a MERGE_LISTS aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_lists_aggregation()
{
return std::make_unique<detail::merge_lists_aggregation>();
}
template std::unique_ptr<aggregation> make_merge_lists_aggregation<aggregation>();

/// Factory to create a MERGE_SETS aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_sets_aggregation(null_equality nulls_equal,
nan_equality nans_equal)
{
return std::make_unique<detail::merge_sets_aggregation>(nulls_equal, nans_equal);
}
template std::unique_ptr<aggregation> make_merge_sets_aggregation<aggregation>(null_equality,
nan_equality);

/// Factory to create a LAG aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_lag_aggregation(size_type offset)
Expand Down
50 changes: 40 additions & 10 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,36 +366,32 @@ void aggregate_result_functor::operator()<aggregation::NTH_ELEMENT>(aggregation
template <>
void aggregate_result_functor::operator()<aggregation::COLLECT_LIST>(aggregation const& agg)
{
auto null_handling =
dynamic_cast<cudf::detail::collect_list_aggregation const&>(agg)._null_handling;
agg.do_hash();

if (cache.has_result(col_idx, agg)) return;
if (cache.has_result(col_idx, agg)) { return; }

auto const null_handling =
dynamic_cast<cudf::detail::collect_list_aggregation const&>(agg)._null_handling;
auto result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);

cache.add_result(col_idx, agg, std::move(result));
};

template <>
void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation const& agg)
{
auto const null_handling =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_handling;

if (cache.has_result(col_idx, agg)) { return; }

auto const null_handling =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_handling;
auto const collect_result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);
rmm::mr::get_current_device_resource());
auto const nulls_equal =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._nulls_equal;
auto const nans_equal =
Expand All @@ -406,6 +402,40 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, nans_equal, stream, mr));
};

template <>
void aggregate_result_functor::operator()<aggregation::MERGE_LISTS>(aggregation const& agg)
{
if (cache.has_result(col_idx, agg)) { return; }

cache.add_result(
col_idx,
agg,
detail::group_merge_lists(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
};

template <>
void aggregate_result_functor::operator()<aggregation::MERGE_SETS>(aggregation const& agg)
{
if (cache.has_result(col_idx, agg)) { return; }

auto const merged_result = detail::group_merge_lists(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
stream,
rmm::mr::get_current_device_resource());
auto const nulls_equal =
dynamic_cast<cudf::detail::merge_sets_aggregation const&>(agg)._nulls_equal;
auto const nans_equal =
dynamic_cast<cudf::detail::merge_sets_aggregation const&>(agg)._nans_equal;
cache.add_result(
col_idx,
agg,
lists::detail::drop_list_duplicates(
lists_column_view(merged_result->view()), nulls_equal, nans_equal, stream, mr));
};

} // namespace detail

// Sort-based groupby
Expand Down
Loading