Skip to content

Commit

Permalink
Fix groupby of categoricals with just "max"
Browse files Browse the repository at this point in the history
I discovered pandas-dev/pandas#28641
while testing ... and I fixed it.
  • Loading branch information
adamhooper committed Sep 26, 2019
1 parent 6421b0b commit a73828f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
15 changes: 8 additions & 7 deletions groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,16 @@ def groupby(
and hasattr(table[colname], "cat")
}
for colname in category_colnames:
table[colname] = table[colname].cat.as_ordered()
table[colname].cat.as_ordered(inplace=True)
# Add dummy "size" to work around
# https://github.com/pandas-dev/pandas/issues/28641
agg_sets[colname].add("size")

if group_specs:
# aggs: DataFrame indexed by group
# out: just the group colnames, no values yet (we'll add them later)
grouped = table.groupby(group_specs)
if agg_sets:
aggs = grouped.agg(agg_sets)
grouped = table.groupby(group_specs, as_index=True)
aggs = grouped.agg(agg_sets)
out = aggs.index.to_frame(index=False)
# Remove unused categories (because `np.nan` deletes categories)
for column in out:
Expand All @@ -320,8 +322,7 @@ def groupby(
# aggs: DataFrame with just one row
# out: one empty row, no columns yet
grouped = table
if agg_sets:
aggs = table.agg(agg_sets)
aggs = table.agg(agg_sets)
out = pd.DataFrame(columns=[], index=[0])

# Now copy values from `aggs` into `out`. (They have the same index.)
Expand All @@ -348,7 +349,7 @@ def groupby(

# Remember those category colnames we converted to ordered? Now we need to
# undo that (and remove newly-unused categories).
for colname in out.columns:
for colname in list(out.columns):
column = out[colname]
if hasattr(column, "cat") and column.cat.ordered:
column.cat.remove_unused_categories(inplace=True)
Expand Down
16 changes: 16 additions & 0 deletions test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,22 @@ def test_aggregate_text_category_values(self):
),
)

def test_aggregate_text_category_values_max(self):
# https://github.com/pandas-dev/pandas/issues/28641
result = groupby(
pd.DataFrame(
{"A": [1997], "B": pd.Series(["30-SEP-97"], dtype="category")}
),
[Group("A", None)],
[Aggregation(Operation.MAX, "B", "X")],
)
assert_frame_equal(
result,
pd.DataFrame(
{"A": [1997], "X": pd.Series(["30-SEP-97"], dtype="category")}
),
)

def test_aggregate_text_category_values_empty_still_has_object_dtype(self):
result = groupby(
pd.DataFrame({"A": [None]}, dtype=str).astype("category"),
Expand Down

0 comments on commit a73828f

Please sign in to comment.