diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 3e4b8192888..bf24864c29d 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -40,6 +40,15 @@ from cudf.utils.utils import GetAttrGetItemMixin +def _deprecate_collect(): + warnings.warn( + "Groupby.collect is deprecated and " + "will be removed in a future version. " + "Use `.agg(list)` instead.", + FutureWarning, + ) + + # The three functions below return the quantiles [25%, 50%, 75%] # respectively, which are called in the describe() method to output # the summary stats of a GroupBy object @@ -2180,7 +2189,8 @@ def func(x): @_cudf_nvtx_annotate def collect(self): """Get a list of all the values for each column in each group.""" - return self.agg("collect") + _deprecate_collect() + return self.agg(list) @_cudf_nvtx_annotate def unique(self): diff --git a/python/dask_cudf/dask_cudf/expr/_groupby.py b/python/dask_cudf/dask_cudf/expr/_groupby.py index 116893891e3..65688115b59 100644 --- a/python/dask_cudf/dask_cudf/expr/_groupby.py +++ b/python/dask_cudf/dask_cudf/expr/_groupby.py @@ -9,19 +9,21 @@ from dask.dataframe.groupby import Aggregation +from cudf.core.groupby.groupby import _deprecate_collect + ## ## Custom groupby classes ## -class Collect(SingleAggregation): +class ListAgg(SingleAggregation): @staticmethod def groupby_chunk(arg): - return arg.agg("collect") + return arg.agg(list) @staticmethod def groupby_aggregate(arg): - gb = arg.agg("collect") + gb = arg.agg(list) if gb.ndim > 1: for col in gb.columns: gb[col] = gb[col].list.concat() @@ -30,10 +32,10 @@ def groupby_aggregate(arg): return gb.list.concat() -collect_aggregation = Aggregation( - name="collect", - chunk=Collect.groupby_chunk, - agg=Collect.groupby_aggregate, +list_aggregation = Aggregation( + name="list", + chunk=ListAgg.groupby_chunk, + agg=ListAgg.groupby_aggregate, ) @@ -41,13 +43,13 @@ def _translate_arg(arg): # Helper function to translate args so that # they can be processed correctly by upstream # dask & dask-expr. Right now, the only necessary - # translation is "collect" aggregations. + # translation is list aggregations. if isinstance(arg, dict): return {k: _translate_arg(v) for k, v in arg.items()} elif isinstance(arg, list): return [_translate_arg(x) for x in arg] elif arg in ("collect", "list", list): - return collect_aggregation + return list_aggregation else: return arg @@ -84,7 +86,8 @@ def __getitem__(self, key): return g def collect(self, **kwargs): - return self._single_agg(Collect, **kwargs) + _deprecate_collect() + return self._single_agg(ListAgg, **kwargs) def aggregate(self, arg, **kwargs): return super().aggregate(_translate_arg(arg), **kwargs) @@ -96,7 +99,8 @@ def __init__(self, *args, observed=None, **kwargs): super().__init__(*args, observed=observed, **kwargs) def collect(self, **kwargs): - return self._single_agg(Collect, **kwargs) + _deprecate_collect() + return self._single_agg(ListAgg, **kwargs) def aggregate(self, arg, **kwargs): return super().aggregate(_translate_arg(arg), **kwargs) diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 43ad4f0fee3..ef47ea436c7 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -15,6 +15,7 @@ from dask.utils import funcname import cudf +from cudf.core.groupby.groupby import _deprecate_collect from cudf.utils.nvtx_annotation import _dask_cudf_nvtx_annotate from dask_cudf.sorting import _deprecate_shuffle_kwarg @@ -28,7 +29,7 @@ "sum", "min", "max", - "collect", + list, "first", "last", ) @@ -164,9 +165,10 @@ def max(self, split_every=None, split_out=1): @_dask_cudf_nvtx_annotate @_check_groupby_optimized def collect(self, split_every=None, split_out=1): + _deprecate_collect() return _make_groupby_agg_call( self, - self._make_groupby_method_aggs("collect"), + self._make_groupby_method_aggs(list), split_every, split_out, ) @@ -308,9 +310,10 @@ def max(self, split_every=None, split_out=1): @_dask_cudf_nvtx_annotate @_check_groupby_optimized def collect(self, split_every=None, split_out=1): + _deprecate_collect() return _make_groupby_agg_call( self, - {self._slice: "collect"}, + {self._slice: list}, split_every, split_out, )[self._slice] @@ -472,7 +475,7 @@ def groupby_agg( This aggregation algorithm only supports the following options - * "collect" + * "list" * "count" * "first" * "last" @@ -667,8 +670,8 @@ def _redirect_aggs(arg): sum: "sum", max: "max", min: "min", - list: "collect", - "list": "collect", + "collect": list, + "list": list, } if isinstance(arg, dict): new_arg = dict() @@ -704,7 +707,7 @@ def _aggs_optimized(arg, supported: set): _global_set = set(arg) return bool(_global_set.issubset(supported)) - elif isinstance(arg, str): + elif isinstance(arg, (str, type)): return arg in supported return False @@ -783,6 +786,8 @@ def _tree_node_agg(df, gb_cols, dropna, sort, sep): agg = col.split(sep)[-1] if agg in ("count", "sum"): agg_dict[col] = ["sum"] + elif agg == "list": + agg_dict[col] = [list] elif agg in OPTIMIZED_AGGS: agg_dict[col] = [agg] else: @@ -873,8 +878,8 @@ def _finalize_gb_agg( gb.drop(columns=[sum_name], inplace=True) if "count" not in agg_list: gb.drop(columns=[count_name], inplace=True) - if "collect" in agg_list: - collect_name = _make_name((col, "collect"), sep=sep) + if list in agg_list: + collect_name = _make_name((col, "list"), sep=sep) gb[collect_name] = gb[collect_name].list.concat() # Ensure sorted keys if `sort=True` diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index dc279bfa690..cf916b713b2 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -9,6 +9,7 @@ from dask.utils_test import hlg_layer import cudf +from cudf.testing._utils import expect_warning_if import dask_cudf from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized @@ -47,7 +48,13 @@ def pdf(request): return pdf -@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS) +# NOTE: We only want to test aggregation "methods" here, +# so we need to leave out `list`. We also include a +# deprecation check for "collect". +@pytest.mark.parametrize( + "aggregation", + sorted(tuple(set(OPTIMIZED_AGGS) - {list}) + ("collect",)), +) @pytest.mark.parametrize("series", [False, True]) def test_groupby_basic(series, aggregation, pdf): gdf = cudf.DataFrame.from_pandas(pdf) @@ -62,8 +69,9 @@ def test_groupby_basic(series, aggregation, pdf): check_dtype = aggregation != "count" - expect = getattr(gdf_grouped, aggregation)() - actual = getattr(ddf_grouped, aggregation)() + with expect_warning_if(aggregation == "collect"): + expect = getattr(gdf_grouped, aggregation)() + actual = getattr(ddf_grouped, aggregation)() if not QUERY_PLANNING_ON: assert_cudf_groupby_layers(actual)