diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 61fe20636f0..c1060d5f505 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -110,6 +110,7 @@ def cumcount(self): ) .groupby(self.grouping, sort=self._sort) .agg("cumcount") + .reset_index(drop=True) ) @cached_property @@ -225,9 +226,10 @@ def nth(self, n): """ Return the nth row from each group. """ - result = self.agg(lambda x: x.nth(n)) - sizes = self.size() - return result[n < sizes] + result = self.agg(lambda x: x.nth(n)).sort_index() + sizes = self.size().sort_index() + + return result[sizes > n] def serialize(self): header = {} diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index d1458c72770..2430b0da5ef 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -30,14 +30,28 @@ _index_type_aggs = {"count", "idxmin", "idxmax", "cumcount"} -def assert_groupby_results_equal(expect, got, sort=True, **kwargs): +def assert_groupby_results_equal( + expect, got, sort=True, as_index=True, by=None, **kwargs +): # Because we don't sort by index by default in groupby, # sort expect and got by index before comparing if sort: - expect = expect.sort_index() - got = got.sort_index() - else: - assert_eq(expect.sort_index(), got.sort_index(), **kwargs) + if as_index: + expect = expect.sort_index() + got = got.sort_index() + else: + assert by is not None + if isinstance(expect, (pd.DataFrame, cudf.DataFrame)): + expect = expect.sort_values(by=by).reset_index(drop=True) + else: + expect = expect.sort_values().reset_index(drop=True) + + if isinstance(got, cudf.DataFrame): + got = got.sort_values(by=by).reset_index(drop=True) + else: + got = got.sort_values().reset_index(drop=True) + + assert_eq(expect, got, **kwargs) def make_frame( @@ -201,10 +215,16 @@ def test_groupby_getitem_getattr(as_index): pdf = pd.DataFrame({"x": [1, 3, 1], "y": [1, 2, 3], "z": [1, 4, 5]}) gdf = cudf.from_pandas(pdf) assert_groupby_results_equal( - pdf.groupby("x")["y"].sum(), gdf.groupby("x")["y"].sum(), + pdf.groupby("x")["y"].sum(), + gdf.groupby("x")["y"].sum(), + as_index=as_index, + by="x", ) assert_groupby_results_equal( - pdf.groupby("x").y.sum(), gdf.groupby("x").y.sum(), + pdf.groupby("x").y.sum(), + gdf.groupby("x").y.sum(), + as_index=as_index, + by="x", ) assert_groupby_results_equal( pdf.groupby("x")[["y"]].sum(), gdf.groupby("x")[["y"]].sum(), @@ -212,6 +232,8 @@ def test_groupby_getitem_getattr(as_index): assert_groupby_results_equal( pdf.groupby(["x", "y"], as_index=as_index).sum(), gdf.groupby(["x", "y"], as_index=as_index).sum(), + as_index=as_index, + by=["x", "y"], ) @@ -1088,7 +1110,13 @@ def test_groupby_datetime(nelem, as_index, agg): else: pdres = pdg.agg({"datetime": agg}) gdres = gdg.agg({"datetime": agg}) - assert_groupby_results_equal(pdres, gdres, check_dtype=check_dtype) + assert_groupby_results_equal( + pdres, + gdres, + check_dtype=check_dtype, + as_index=as_index, + by=["datetime"], + ) def test_groupby_dropna(): @@ -1349,6 +1377,8 @@ def test_reset_index_after_empty_groupby(): assert_groupby_results_equal( pdf.groupby("a").sum().reset_index(), gdf.groupby("a").sum().reset_index(), + as_index=False, + by="a", )