Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure all dask-cudf supported aggs are handled in _tree_node_agg #9487

Merged
merged 6 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 31 additions & 42 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@

import cudf

SUPPORTED_AGGS = (
"count",
"mean",
"std",
"var",
"sum",
"min",
"max",
"collect",
"first",
"last",
)
rjzamora marked this conversation as resolved.
Show resolved Hide resolved


class CudfDataFrameGroupBy(DataFrameGroupBy):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -47,6 +60,19 @@ def __getitem__(self, key):
g._meta = g._meta[key]
return g

def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.index,
{c: "collect" for c in self.obj.columns if c not in self.index},
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved
split_every=split_every,
split_out=split_out,
dropna=self.dropna,
sep=self.sep,
sort=self.sort,
as_index=self.as_index,
)

def mean(self, split_every=None, split_out=1):
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved
return groupby_agg(
self.obj,
Expand All @@ -65,18 +91,6 @@ def aggregate(self, arg, split_every=None, split_out=1):
return self.size()
arg = _redirect_aggs(arg)

_supported = {
"count",
"mean",
"std",
"var",
"sum",
"min",
"max",
"collect",
"first",
"last",
}
if (
isinstance(self.obj, DaskDataFrame)
and (
Expand All @@ -86,7 +100,7 @@ def aggregate(self, arg, split_every=None, split_out=1):
and all(isinstance(x, str) for x in self.index)
)
)
and _is_supported(arg, _supported)
and _is_supported(arg, SUPPORTED_AGGS)
):
if isinstance(self._meta.grouping.keys, cudf.MultiIndex):
keys = self._meta.grouping.keys.names
Expand Down Expand Up @@ -134,23 +148,10 @@ def aggregate(self, arg, split_every=None, split_out=1):
return self.size()
arg = _redirect_aggs(arg)

_supported = {
"count",
"mean",
"std",
"var",
"sum",
"min",
"max",
"collect",
"first",
"last",
}

if (
isinstance(self.obj, DaskDataFrame)
and isinstance(self.index, (str, list))
and _is_supported({self._slice: arg}, _supported)
and _is_supported({self._slice: arg}, SUPPORTED_AGGS)
):
return groupby_agg(
self.obj,
Expand Down Expand Up @@ -201,21 +202,9 @@ def groupby_agg(
"""
# Assert that aggregations are supported
aggs = _redirect_aggs(aggs_in)
_supported = {
"count",
"mean",
"std",
"var",
"sum",
"min",
"max",
"collect",
"first",
"last",
}
if not _is_supported(aggs, _supported):
if not _is_supported(aggs, SUPPORTED_AGGS):
raise ValueError(
f"Supported aggs include {_supported} for groupby_agg API. "
f"Supported aggs include {SUPPORTED_AGGS} for groupby_agg API. "
f"Aggregations must be specified with dict or list syntax."
)

Expand Down Expand Up @@ -478,7 +467,7 @@ def _tree_node_agg(dfs, gb_cols, split_out, dropna, sort, sep):
agg = col.split(sep)[-1]
if agg in ("count", "sum"):
agg_dict[col] = ["sum"]
elif agg in ("min", "max", "collect"):
elif agg in SUPPORTED_AGGS:
agg_dict[col] = [agg]
else:
raise ValueError(f"Unexpected aggregation: {agg}")
Expand Down
29 changes: 2 additions & 27 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from cudf.core._compat import PANDAS_GE_120

import dask_cudf
from dask_cudf.groupby import _is_supported
from dask_cudf.groupby import SUPPORTED_AGGS, _is_supported


@pytest.mark.parametrize("aggregation", ["sum", "mean", "count", "min", "max"])
@pytest.mark.parametrize("aggregation", SUPPORTED_AGGS)
def test_groupby_basic_aggs(aggregation):
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved
pdf = pd.DataFrame(
{
Expand Down Expand Up @@ -117,31 +117,6 @@ def test_groupby_std(func):
dd.assert_eq(a, b)


@pytest.mark.parametrize(
"func",
[
lambda df: df.groupby("x").agg({"y": "collect"}),
lambda df: df.groupby("x").y.agg("collect"),
],
)
def test_groupby_collect(func):
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved
pdf = pd.DataFrame(
{
"x": np.random.randint(0, 5, size=10000),
"y": np.random.normal(size=10000),
}
)

gdf = cudf.DataFrame.from_pandas(pdf)

ddf = dask_cudf.from_cudf(gdf, npartitions=5)

a = func(gdf).to_pandas()
b = func(ddf).compute().to_pandas()

dd.assert_eq(a, b)


# reason gotattr in cudf
@pytest.mark.parametrize(
"func",
Expand Down