Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate Groupby.collect #15808

Merged
merged 9 commits into from
May 22, 2024
12 changes: 11 additions & 1 deletion python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@
from cudf.utils.utils import GetAttrGetItemMixin


def _deprecate_collect():
warnings.warn(
"Groupby.collect is deprecated and "
"will be removed in a future version. "
"Use `.agg(list)` instead.",
FutureWarning,
)


# The three functions below return the quantiles [25%, 50%, 75%]
# respectively, which are called in the describe() method to output
# the summary stats of a GroupBy object
Expand Down Expand Up @@ -2180,7 +2189,8 @@ def func(x):
@_cudf_nvtx_annotate
def collect(self):
"""Get a list of all the values for each column in each group."""
return self.agg("collect")
_deprecate_collect()
return self.agg(list)

@_cudf_nvtx_annotate
def unique(self):
Expand Down
26 changes: 15 additions & 11 deletions python/dask_cudf/dask_cudf/expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@

from dask.dataframe.groupby import Aggregation

from cudf.core.groupby.groupby import _deprecate_collect

##
## Custom groupby classes
##


class Collect(SingleAggregation):
class ListAgg(SingleAggregation):
@staticmethod
def groupby_chunk(arg):
return arg.agg("collect")
return arg.agg(list)

@staticmethod
def groupby_aggregate(arg):
gb = arg.agg("collect")
gb = arg.agg(list)
if gb.ndim > 1:
for col in gb.columns:
gb[col] = gb[col].list.concat()
Expand All @@ -30,24 +32,24 @@ def groupby_aggregate(arg):
return gb.list.concat()


collect_aggregation = Aggregation(
name="collect",
chunk=Collect.groupby_chunk,
agg=Collect.groupby_aggregate,
list_aggregation = Aggregation(
name="list",
chunk=ListAgg.groupby_chunk,
agg=ListAgg.groupby_aggregate,
)


def _translate_arg(arg):
# Helper function to translate args so that
# they can be processed correctly by upstream
# dask & dask-expr. Right now, the only necessary
# translation is "collect" aggregations.
# translation is list aggregations.
if isinstance(arg, dict):
return {k: _translate_arg(v) for k, v in arg.items()}
elif isinstance(arg, list):
return [_translate_arg(x) for x in arg]
elif arg in ("collect", "list", list):
return collect_aggregation
return list_aggregation
else:
return arg

Expand Down Expand Up @@ -84,7 +86,8 @@ def __getitem__(self, key):
return g

def collect(self, **kwargs):
return self._single_agg(Collect, **kwargs)
_deprecate_collect()
return self._single_agg(ListAgg, **kwargs)

def aggregate(self, arg, **kwargs):
return super().aggregate(_translate_arg(arg), **kwargs)
Expand All @@ -96,7 +99,8 @@ def __init__(self, *args, observed=None, **kwargs):
super().__init__(*args, observed=observed, **kwargs)

def collect(self, **kwargs):
return self._single_agg(Collect, **kwargs)
_deprecate_collect()
return self._single_agg(ListAgg, **kwargs)

def aggregate(self, arg, **kwargs):
return super().aggregate(_translate_arg(arg), **kwargs)
23 changes: 14 additions & 9 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dask.utils import funcname

import cudf
from cudf.core.groupby.groupby import _deprecate_collect
from cudf.utils.nvtx_annotation import _dask_cudf_nvtx_annotate

from dask_cudf.sorting import _deprecate_shuffle_kwarg
Expand All @@ -28,7 +29,7 @@
"sum",
"min",
"max",
"collect",
list,
"first",
"last",
)
Expand Down Expand Up @@ -164,9 +165,10 @@ def max(self, split_every=None, split_out=1):
@_dask_cudf_nvtx_annotate
@_check_groupby_optimized
def collect(self, split_every=None, split_out=1):
rjzamora marked this conversation as resolved.
Show resolved Hide resolved
_deprecate_collect()
return _make_groupby_agg_call(
self,
self._make_groupby_method_aggs("collect"),
self._make_groupby_method_aggs(list),
split_every,
split_out,
)
Expand Down Expand Up @@ -308,9 +310,10 @@ def max(self, split_every=None, split_out=1):
@_dask_cudf_nvtx_annotate
@_check_groupby_optimized
def collect(self, split_every=None, split_out=1):
_deprecate_collect()
return _make_groupby_agg_call(
self,
{self._slice: "collect"},
{self._slice: list},
split_every,
split_out,
)[self._slice]
Expand Down Expand Up @@ -472,7 +475,7 @@ def groupby_agg(

This aggregation algorithm only supports the following options

* "collect"
* "list"
* "count"
* "first"
* "last"
Expand Down Expand Up @@ -667,8 +670,8 @@ def _redirect_aggs(arg):
sum: "sum",
max: "max",
min: "min",
list: "collect",
"list": "collect",
"collect": list,
"list": list,
}
if isinstance(arg, dict):
new_arg = dict()
Expand Down Expand Up @@ -704,7 +707,7 @@ def _aggs_optimized(arg, supported: set):
_global_set = set(arg)

return bool(_global_set.issubset(supported))
elif isinstance(arg, str):
elif isinstance(arg, (str, type)):
return arg in supported
return False

Expand Down Expand Up @@ -783,6 +786,8 @@ def _tree_node_agg(df, gb_cols, dropna, sort, sep):
agg = col.split(sep)[-1]
if agg in ("count", "sum"):
agg_dict[col] = ["sum"]
elif agg == "list":
agg_dict[col] = [list]
elif agg in OPTIMIZED_AGGS:
agg_dict[col] = [agg]
else:
Expand Down Expand Up @@ -873,8 +878,8 @@ def _finalize_gb_agg(
gb.drop(columns=[sum_name], inplace=True)
if "count" not in agg_list:
gb.drop(columns=[count_name], inplace=True)
if "collect" in agg_list:
collect_name = _make_name((col, "collect"), sep=sep)
if list in agg_list:
collect_name = _make_name((col, "list"), sep=sep)
gb[collect_name] = gb[collect_name].list.concat()

# Ensure sorted keys if `sort=True`
Expand Down
14 changes: 11 additions & 3 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dask.utils_test import hlg_layer

import cudf
from cudf.testing._utils import expect_warning_if

import dask_cudf
from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -47,7 +48,13 @@ def pdf(request):
return pdf


@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS)
# NOTE: We only want to test aggregation "methods" here,
# so we need to leave out `list`. We also include a
# deprecation check for "collect".
@pytest.mark.parametrize(
"aggregation",
sorted(tuple(set(OPTIMIZED_AGGS) - {list}) + ("collect",)),
)
@pytest.mark.parametrize("series", [False, True])
def test_groupby_basic(series, aggregation, pdf):
gdf = cudf.DataFrame.from_pandas(pdf)
Expand All @@ -62,8 +69,9 @@ def test_groupby_basic(series, aggregation, pdf):

check_dtype = aggregation != "count"

expect = getattr(gdf_grouped, aggregation)()
actual = getattr(ddf_grouped, aggregation)()
with expect_warning_if(aggregation == "collect"):
expect = getattr(gdf_grouped, aggregation)()
actual = getattr(ddf_grouped, aggregation)()

if not QUERY_PLANNING_ON:
assert_cudf_groupby_layers(actual)
Expand Down
Loading