Skip to content

Commit

Permalink
Add collect list to dask-cudf groupby aggregations (#8045)
Browse files Browse the repository at this point in the history
Closes #7812

Adds support for cuDF's `collect` aggregation in dask-cuDF.

Authors:
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

Approvers:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

URL: #8045
  • Loading branch information
charlesbluca authored Jul 6, 2021
1 parent 3ee264c commit c54346e
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
50 changes: 44 additions & 6 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,16 @@ def aggregate(self, arg, split_every=None, split_out=1):
return self.size()
arg = _redirect_aggs(arg)

_supported = {"count", "mean", "std", "var", "sum", "min", "max"}
_supported = {
"count",
"mean",
"std",
"var",
"sum",
"min",
"max",
"collect",
}
if (
isinstance(self.obj, DaskDataFrame)
and isinstance(self.index, (str, list))
Expand Down Expand Up @@ -109,7 +118,16 @@ def aggregate(self, arg, split_every=None, split_out=1):
return self.size()
arg = _redirect_aggs(arg)

_supported = {"count", "mean", "std", "var", "sum", "min", "max"}
_supported = {
"count",
"mean",
"std",
"var",
"sum",
"min",
"max",
"collect",
}
if (
isinstance(self.obj, DaskDataFrame)
and isinstance(self.index, (str, list))
Expand Down Expand Up @@ -147,7 +165,7 @@ def groupby_agg(
This aggregation algorithm only supports the following options:
{"count", "mean", "std", "var", "sum", "min", "max"}
{"count", "mean", "std", "var", "sum", "min", "max", "collect"}
This "optimized" approach is more performant than the algorithm
in `dask.dataframe`, because it allows the cudf backend to
Expand All @@ -173,15 +191,24 @@ def groupby_agg(
# strings (no lists)
str_cols_out = True
for col in aggs:
if isinstance(aggs[col], str):
if isinstance(aggs[col], str) or callable(aggs[col]):
aggs[col] = [aggs[col]]
else:
str_cols_out = False
if col in gb_cols:
columns.append(col)

# Assert that aggregations are supported
_supported = {"count", "mean", "std", "var", "sum", "min", "max"}
_supported = {
"count",
"mean",
"std",
"var",
"sum",
"min",
"max",
"collect",
}
if not _is_supported(aggs, _supported):
raise ValueError(
f"Supported aggs include {_supported} for groupby_agg API. "
Expand Down Expand Up @@ -282,7 +309,13 @@ def groupby_agg(
def _redirect_aggs(arg):
""" Redirect aggregations to their corresponding name in cuDF
"""
redirects = {sum: "sum", max: "max", min: "min"}
redirects = {
sum: "sum",
max: "max",
min: "min",
list: "collect",
"list": "collect",
}
if isinstance(arg, dict):
new_arg = dict()
for col in arg:
Expand Down Expand Up @@ -400,6 +433,8 @@ def _tree_node_agg(dfs, gb_cols, split_out, dropna, sort, sep):
agg_dict[col] = ["sum"]
elif agg in ("min", "max"):
agg_dict[col] = [agg]
elif agg == "collect":
agg_dict[col] = ["collect"]
else:
raise ValueError(f"Unexpected aggregation: {agg}")

Expand Down Expand Up @@ -478,6 +513,9 @@ def _finalize_gb_agg(
gb.drop(columns=[sum_name], inplace=True)
if "count" not in agg_list:
gb.drop(columns=[count_name], inplace=True)
if "collect" in agg_list:
collect_name = _make_name(col, "collect", sep=sep)
gb[collect_name] = gb[collect_name].list.concat()

# Ensure sorted keys if `sort=True`
if sort:
Expand Down
27 changes: 27 additions & 0 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,33 @@ def test_groupby_std(func):
dd.assert_eq(a, b)


@pytest.mark.parametrize(
"func",
[
lambda df: df.groupby("x").agg({"y": "collect"}),
pytest.param(
lambda df: df.groupby("x").y.agg("collect"), marks=pytest.mark.skip
),
],
)
def test_groupby_collect(func):
pdf = pd.DataFrame(
{
"x": np.random.randint(0, 5, size=10000),
"y": np.random.normal(size=10000),
}
)

gdf = cudf.DataFrame.from_pandas(pdf)

ddf = dask_cudf.from_cudf(gdf, npartitions=5)

a = func(gdf).to_pandas()
b = func(ddf).compute().to_pandas()

dd.assert_eq(a, b)


# reason gotattr in cudf
@pytest.mark.parametrize(
"func",
Expand Down

0 comments on commit c54346e

Please sign in to comment.