From 105b43659c2b8159b72ae938defd43eb6191aaba Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Tue, 18 May 2021 12:15:56 -0400 Subject: [PATCH 1/2] Actually test equality in assert_groupby_results_equal --- python/cudf/cudf/tests/test_groupby.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index d1458c72770..cbcc46dcf7b 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -36,8 +36,7 @@ def assert_groupby_results_equal(expect, got, sort=True, **kwargs): if sort: expect = expect.sort_index() got = got.sort_index() - else: - assert_eq(expect.sort_index(), got.sort_index(), **kwargs) + assert_eq(expect, got, **kwargs) def make_frame( From f1257b1cbd058e558d34f14d96f5c7e3dca19f21 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Wed, 19 May 2021 14:00:07 -0400 Subject: [PATCH 2/2] Fix revealed bugs in groupby --- python/cudf/cudf/core/groupby/groupby.py | 8 +++-- python/cudf/cudf/tests/test_groupby.py | 43 ++++++++++++++++++++---- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 61fe20636f0..c1060d5f505 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -110,6 +110,7 @@ def cumcount(self): ) .groupby(self.grouping, sort=self._sort) .agg("cumcount") + .reset_index(drop=True) ) @cached_property @@ -225,9 +226,10 @@ def nth(self, n): """ Return the nth row from each group. """ - result = self.agg(lambda x: x.nth(n)) - sizes = self.size() - return result[n < sizes] + result = self.agg(lambda x: x.nth(n)).sort_index() + sizes = self.size().sort_index() + + return result[sizes > n] def serialize(self): header = {} diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index cbcc46dcf7b..2430b0da5ef 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -30,12 +30,27 @@ _index_type_aggs = {"count", "idxmin", "idxmax", "cumcount"} -def assert_groupby_results_equal(expect, got, sort=True, **kwargs): +def assert_groupby_results_equal( + expect, got, sort=True, as_index=True, by=None, **kwargs +): # Because we don't sort by index by default in groupby, # sort expect and got by index before comparing if sort: - expect = expect.sort_index() - got = got.sort_index() + if as_index: + expect = expect.sort_index() + got = got.sort_index() + else: + assert by is not None + if isinstance(expect, (pd.DataFrame, cudf.DataFrame)): + expect = expect.sort_values(by=by).reset_index(drop=True) + else: + expect = expect.sort_values().reset_index(drop=True) + + if isinstance(got, cudf.DataFrame): + got = got.sort_values(by=by).reset_index(drop=True) + else: + got = got.sort_values().reset_index(drop=True) + assert_eq(expect, got, **kwargs) @@ -200,10 +215,16 @@ def test_groupby_getitem_getattr(as_index): pdf = pd.DataFrame({"x": [1, 3, 1], "y": [1, 2, 3], "z": [1, 4, 5]}) gdf = cudf.from_pandas(pdf) assert_groupby_results_equal( - pdf.groupby("x")["y"].sum(), gdf.groupby("x")["y"].sum(), + pdf.groupby("x")["y"].sum(), + gdf.groupby("x")["y"].sum(), + as_index=as_index, + by="x", ) assert_groupby_results_equal( - pdf.groupby("x").y.sum(), gdf.groupby("x").y.sum(), + pdf.groupby("x").y.sum(), + gdf.groupby("x").y.sum(), + as_index=as_index, + by="x", ) assert_groupby_results_equal( pdf.groupby("x")[["y"]].sum(), gdf.groupby("x")[["y"]].sum(), @@ -211,6 +232,8 @@ def test_groupby_getitem_getattr(as_index): assert_groupby_results_equal( pdf.groupby(["x", "y"], as_index=as_index).sum(), gdf.groupby(["x", "y"], as_index=as_index).sum(), + as_index=as_index, + by=["x", "y"], ) @@ -1087,7 +1110,13 @@ def test_groupby_datetime(nelem, as_index, agg): else: pdres = pdg.agg({"datetime": agg}) gdres = gdg.agg({"datetime": agg}) - assert_groupby_results_equal(pdres, gdres, check_dtype=check_dtype) + assert_groupby_results_equal( + pdres, + gdres, + check_dtype=check_dtype, + as_index=as_index, + by=["datetime"], + ) def test_groupby_dropna(): @@ -1348,6 +1377,8 @@ def test_reset_index_after_empty_groupby(): assert_groupby_results_equal( pdf.groupby("a").sum().reset_index(), gdf.groupby("a").sum().reset_index(), + as_index=False, + by="a", )