-
Notifications
You must be signed in to change notification settings - Fork 908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add collect list to dask-cudf groupby aggregations #8045
Changes from 9 commits
df93b7f
63e4d62
b0218b4
b3fab63
6dee893
f16ec9d
a9dc883
b768d3f
1c12d9c
405b2c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,7 +62,16 @@ def aggregate(self, arg, split_every=None, split_out=1): | |
return self.size() | ||
arg = _redirect_aggs(arg) | ||
|
||
_supported = {"count", "mean", "std", "var", "sum", "min", "max"} | ||
_supported = { | ||
"count", | ||
"mean", | ||
"std", | ||
"var", | ||
"sum", | ||
"min", | ||
"max", | ||
"collect", | ||
} | ||
if ( | ||
isinstance(self.obj, DaskDataFrame) | ||
and isinstance(self.index, (str, list)) | ||
|
@@ -109,7 +118,16 @@ def aggregate(self, arg, split_every=None, split_out=1): | |
return self.size() | ||
arg = _redirect_aggs(arg) | ||
|
||
_supported = {"count", "mean", "std", "var", "sum", "min", "max"} | ||
_supported = { | ||
"count", | ||
"mean", | ||
"std", | ||
"var", | ||
"sum", | ||
"min", | ||
"max", | ||
"collect", | ||
} | ||
if ( | ||
isinstance(self.obj, DaskDataFrame) | ||
and isinstance(self.index, (str, list)) | ||
|
@@ -147,7 +165,7 @@ def groupby_agg( | |
|
||
This aggregation algorithm only supports the following options: | ||
|
||
{"count", "mean", "std", "var", "sum", "min", "max"} | ||
{"count", "mean", "std", "var", "sum", "min", "max", "collect"} | ||
|
||
This "optimized" approach is more performant than the algorithm | ||
in `dask.dataframe`, because it allows the cudf backend to | ||
|
@@ -173,15 +191,24 @@ def groupby_agg( | |
# strings (no lists) | ||
str_cols_out = True | ||
for col in aggs: | ||
if isinstance(aggs[col], str): | ||
if isinstance(aggs[col], str) or callable(aggs[col]): | ||
aggs[col] = [aggs[col]] | ||
else: | ||
str_cols_out = False | ||
if col in gb_cols: | ||
columns.append(col) | ||
|
||
# Assert that aggregations are supported | ||
_supported = {"count", "mean", "std", "var", "sum", "min", "max"} | ||
_supported = { | ||
"count", | ||
"mean", | ||
"std", | ||
"var", | ||
"sum", | ||
"min", | ||
"max", | ||
"collect", | ||
} | ||
if not _is_supported(aggs, _supported): | ||
raise ValueError( | ||
f"Supported aggs include {_supported} for groupby_agg API. " | ||
|
@@ -282,7 +309,13 @@ def groupby_agg( | |
def _redirect_aggs(arg): | ||
""" Redirect aggregations to their corresponding name in cuDF | ||
""" | ||
redirects = {sum: "sum", max: "max", min: "min"} | ||
redirects = { | ||
sum: "sum", | ||
max: "max", | ||
min: "min", | ||
list: "collect", | ||
"list": "collect", | ||
} | ||
if isinstance(arg, dict): | ||
new_arg = dict() | ||
for col in arg: | ||
|
@@ -400,6 +433,8 @@ def _tree_node_agg(dfs, gb_cols, split_out, dropna, sort, sep): | |
agg_dict[col] = ["sum"] | ||
elif agg in ("min", "max"): | ||
agg_dict[col] = [agg] | ||
elif agg == "collect": | ||
agg_dict[col] = ["collect"] | ||
else: | ||
raise ValueError(f"Unexpected aggregation: {agg}") | ||
|
||
|
@@ -478,6 +513,9 @@ def _finalize_gb_agg( | |
gb.drop(columns=[sum_name], inplace=True) | ||
if "count" not in agg_list: | ||
gb.drop(columns=[count_name], inplace=True) | ||
if "collect" in agg_list: | ||
collect_name = _make_name(col, "collect", sep=sep) | ||
gb[collect_name] = gb[collect_name].list.concat() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Worth the wait - Thanks for this |
||
|
||
# Ensure sorted keys if `sort=True` | ||
if sort: | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -125,6 +125,38 @@ def test_groupby_std(func): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dd.assert_eq(a, b) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@pytest.mark.parametrize( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"func", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
lambda df: df.groupby("x").agg({"y": "collect"}), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pytest.param( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
lambda df: df.groupby("x").y.agg("collect"), marks=pytest.mark.skip | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+128
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any reason to define There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This param skip, and the index nulling, I lifted from another dask-cudf groupby test: cudf/python/dask_cudf/dask_cudf/tests/test_groupby.py Lines 47 to 77 in 5b8895d
I can see what happens when we don't set the index to __________________________________________ test_groupby_collect[<lambda>1] ___________________________________________
func = <function <lambda> at 0x7f3c804a0d30>
@pytest.mark.parametrize(
"func",
[
lambda df: df.groupby("x").agg({"y": "collect"}),
lambda df: df.groupby("x").y.agg("collect"),
],
)
def test_groupby_collect(func):
pdf = pd.DataFrame(
{
"x": np.random.randint(0, 5, size=10000),
"y": np.random.normal(size=10000),
}
)
gdf = cudf.DataFrame.from_pandas(pdf)
ddf = dask_cudf.from_cudf(gdf, npartitions=5)
a = func(gdf).to_pandas()
b = func(ddf).compute().to_pandas()
a.index.name = None
a.name = None
b.index.name = None
b.name = None
> dd.assert_eq(a, b)
dask_cudf/tests/test_groupby.py:155:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../compose/etc/conda/cuda_11.2.72/envs/rapids/lib/python3.8/site-packages/dask/dataframe/utils.py:559: in assert_eq
tm.assert_series_equal(a, b, check_names=check_names, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
left = 0 [-0.02279966962796973, -0.2268040371246616, 0....
1 [1.0547561143327269, 0.07632651478542447, -0.0...
2 [1.... [0.35305396010499146, 2.022936601816015, -0.02...
4 [0.7639835327097312, 0.9458744987601149, 0.370...
dtype: object
right = y
0 [-0.02279966962796973, -0.2268040371246616, 0....
1 [1.054756...796, 1.498...
3 [0.35305396010499146, 2.022936601816015, -0.02...
4 [0.7639835327097312, 0.9458744987601149, 0.370...
cls = <class 'pandas.core.series.Series'>
def _check_isinstance(left, right, cls):
"""
Helper method for our assert_* methods that ensures that
the two objects being compared have the right type before
proceeding with the comparison.
Parameters
----------
left : The first object being compared.
right : The second object being compared.
cls : The class type to check against.
Raises
------
AssertionError : Either `left` or `right` is not an instance of `cls`.
"""
cls_name = cls.__name__
if not isinstance(left, cls):
raise AssertionError(
f"{cls_name} Expected type {cls}, found {type(left)} instead"
)
if not isinstance(right, cls):
> raise AssertionError(
f"{cls_name} Expected type {cls}, found {type(right)} instead"
)
E AssertionError: Series Expected type <class 'pandas.core.series.Series'>, found <class 'pandas.core.frame.DataFrame'> instead It looks like we are creating a dataframe here when we should be making a series. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ah - It seems like this was already a problem before this PR. In that case, it is probably okay to fix that in a follow-up PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, do you know if there is an open issue for this problem? If not, I can open one so we can keep track of the follow up fix. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def test_groupby_collect(func): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pdf = pd.DataFrame( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"x": np.random.randint(0, 5, size=10000), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"y": np.random.normal(size=10000), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
gdf = cudf.DataFrame.from_pandas(pdf) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ddf = dask_cudf.from_cudf(gdf, npartitions=5) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
a = func(gdf).to_pandas() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
b = func(ddf).compute().to_pandas() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
a.index.name = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
a.name = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
b.index.name = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
b.name = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
charlesbluca marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
dd.assert_eq(a, b) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# reason gotattr in cudf | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@pytest.mark.parametrize( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"func", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this approach.