Skip to content

Commit

Permalink
Stop catching TypeError in groupby methods (pandas-dev#29060)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and proost committed Dec 19, 2019
1 parent 76e312e commit 738b368
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 21 deletions.
33 changes: 24 additions & 9 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import numpy as np
cimport numpy as cnp
from numpy cimport (ndarray,
int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t,
uint32_t, uint64_t, float32_t, float64_t)
uint32_t, uint64_t, float32_t, float64_t, complex64_t, complex128_t)
cnp.import_array()


Expand Down Expand Up @@ -421,30 +421,38 @@ def group_any_all(uint8_t[:] out,
if values[i] == flag_val:
out[lab] = flag_val


# ----------------------------------------------------------------------
# group_add, group_prod, group_var, group_mean, group_ohlc
# ----------------------------------------------------------------------

ctypedef fused complexfloating_t:
float64_t
float32_t
complex64_t
complex128_t


@cython.wraparound(False)
@cython.boundscheck(False)
def _group_add(floating[:, :] out,
def _group_add(complexfloating_t[:, :] out,
int64_t[:] counts,
floating[:, :] values,
complexfloating_t[:, :] values,
const int64_t[:] labels,
Py_ssize_t min_count=0):
"""
Only aggregates on axis=0
"""
cdef:
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
floating val, count
floating[:, :] sumx, nobs
complexfloating_t val, count
complexfloating_t[:, :] sumx
int64_t[:, :] nobs

if len(values) != len(labels):
raise ValueError("len(index) != len(labels)")

nobs = np.zeros_like(out)
nobs = np.zeros((len(out), out.shape[1]), dtype=np.int64)
sumx = np.zeros_like(out)

N, K = (<object>values).shape
Expand All @@ -462,7 +470,12 @@ def _group_add(floating[:, :] out,
# not nan
if val == val:
nobs[lab, j] += 1
sumx[lab, j] += val
if (complexfloating_t is complex64_t or
complexfloating_t is complex128_t):
# clang errors if we use += with these dtypes
sumx[lab, j] = sumx[lab, j] + val
else:
sumx[lab, j] += val

for i in range(ncounts):
for j in range(K):
Expand All @@ -472,8 +485,10 @@ def _group_add(floating[:, :] out,
out[i, j] = sumx[i, j]


group_add_float32 = _group_add['float']
group_add_float64 = _group_add['double']
group_add_float32 = _group_add['float32_t']
group_add_float64 = _group_add['float64_t']
group_add_complex64 = _group_add['float complex']
group_add_complex128 = _group_add['double complex']


@cython.wraparound(False)
Expand Down
21 changes: 10 additions & 11 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,19 +1340,18 @@ def f(self, **kwargs):
# try a cython aggregation if we can
try:
return self._cython_agg_general(alias, alt=npfunc, **kwargs)
except AssertionError:
raise
except DataError:
pass
except (TypeError, NotImplementedError):
# TODO:
# - TypeError: this is reached via test_groupby_complex
# and can be fixed by implementing _group_add for
# complex dtypes
# - NotImplementedError: reached in test_max_nan_bug,
# raised in _get_cython_function and should probably
# be handled inside _cython_agg_blocks
pass
except NotImplementedError as err:
if "function is not implemented for this dtype" in str(err):
# raised in _get_cython_function, in some cases can
# be trimmed by implementing cython funcs for more dtypes
pass
elif "decimal does not support skipna=True" in str(err):
# FIXME: kludge for test_decimal:test_in_numeric_groupby
pass
else:
raise

# apply a non-cython aggregation
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
Expand Down
8 changes: 7 additions & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,13 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1, **kwargs):
func = self._get_cython_function(kind, how, values, is_numeric)
except NotImplementedError:
if is_numeric:
values = ensure_float64(values)
try:
values = ensure_float64(values)
except TypeError:
if lib.infer_dtype(values, skipna=False) == "complex":
values = values.astype(complex)
else:
raise
func = self._get_cython_function(kind, how, values, is_numeric)
else:
raise
Expand Down

0 comments on commit 738b368

Please sign in to comment.