Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop catching TypeError in groupby methods #29060

Merged
merged 12 commits into from
Oct 17, 2019
33 changes: 24 additions & 9 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import numpy as np
cimport numpy as cnp
from numpy cimport (ndarray,
int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t,
uint32_t, uint64_t, float32_t, float64_t)
uint32_t, uint64_t, float32_t, float64_t, complex64_t, complex128_t)
cnp.import_array()


Expand Down Expand Up @@ -421,30 +421,38 @@ def group_any_all(uint8_t[:] out,
if values[i] == flag_val:
out[lab] = flag_val


# ----------------------------------------------------------------------
# group_add, group_prod, group_var, group_mean, group_ohlc
# ----------------------------------------------------------------------

ctypedef fused complexfloating_t:
float64_t
float32_t
complex64_t
complex128_t


@cython.wraparound(False)
@cython.boundscheck(False)
def _group_add(floating[:, :] out,
def _group_add(complexfloating_t[:, :] out,
int64_t[:] counts,
floating[:, :] values,
complexfloating_t[:, :] values,
const int64_t[:] labels,
Py_ssize_t min_count=0):
"""
Only aggregates on axis=0
"""
cdef:
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
floating val, count
floating[:, :] sumx, nobs
complexfloating_t val, count
complexfloating_t[:, :] sumx
int64_t[:, :] nobs

if len(values) != len(labels):
raise ValueError("len(index) != len(labels)")

nobs = np.zeros_like(out)
nobs = np.zeros((len(out), out.shape[1]), dtype=np.int64)
sumx = np.zeros_like(out)

N, K = (<object>values).shape
Expand All @@ -462,7 +470,12 @@ def _group_add(floating[:, :] out,
# not nan
if val == val:
nobs[lab, j] += 1
sumx[lab, j] += val
if (complexfloating_t is complex64_t or
complexfloating_t is complex128_t):
# clang errors if we use += with these dtypes
sumx[lab, j] = sumx[lab, j] + val
else:
sumx[lab, j] += val

for i in range(ncounts):
for j in range(K):
Expand All @@ -472,8 +485,10 @@ def _group_add(floating[:, :] out,
out[i, j] = sumx[i, j]


group_add_float32 = _group_add['float']
group_add_float64 = _group_add['double']
group_add_float32 = _group_add['float32_t']
group_add_float64 = _group_add['float64_t']
group_add_complex64 = _group_add['float complex']
group_add_complex128 = _group_add['double complex']


@cython.wraparound(False)
Expand Down
21 changes: 10 additions & 11 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,19 +1340,18 @@ def f(self, **kwargs):
# try a cython aggregation if we can
try:
return self._cython_agg_general(alias, alt=npfunc, **kwargs)
except AssertionError:
raise
except DataError:
pass
except (TypeError, NotImplementedError):
# TODO:
# - TypeError: this is reached via test_groupby_complex
# and can be fixed by implementing _group_add for
# complex dtypes
# - NotImplementedError: reached in test_max_nan_bug,
# raised in _get_cython_function and should probably
# be handled inside _cython_agg_blocks
pass
except NotImplementedError as err:
if "function is not implemented for this dtype" in str(err):
# raised in _get_cython_function, in some cases can
# be trimmed by implementing cython funcs for more dtypes
pass
elif "decimal does not support skipna=True" in str(err):
# FIXME: kludge for test_decimal:test_in_numeric_groupby
pass
else:
raise

# apply a non-cython aggregation
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
Expand Down
8 changes: 7 additions & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,13 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1, **kwargs):
func = self._get_cython_function(kind, how, values, is_numeric)
except NotImplementedError:
if is_numeric:
values = ensure_float64(values)
try:
values = ensure_float64(values)
except TypeError:
if lib.infer_dtype(values, skipna=False) == "complex":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future PR should push this to pandas/core/dtypes/cast.py and just call here (maybe make ensure_float64_or_complex)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah. i think there was some discussion of a one-pass variant of lib.infer_dtype that would go well with that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great

values = values.astype(complex)
else:
raise
func = self._get_cython_function(kind, how, values, is_numeric)
else:
raise
Expand Down