From 738b36876c64cce3d90aa9681f8ac9fafad32538 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Thu, 17 Oct 2019 15:17:13 -0700 Subject: [PATCH] Stop catching TypeError in groupby methods (#29060) --- pandas/_libs/groupby.pyx | 33 ++++++++++++++++++++++++--------- pandas/core/groupby/groupby.py | 21 ++++++++++----------- pandas/core/groupby/ops.py | 8 +++++++- 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index 4f7488c88630b8..68c21139e73845 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -8,7 +8,7 @@ import numpy as np cimport numpy as cnp from numpy cimport (ndarray, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, - uint32_t, uint64_t, float32_t, float64_t) + uint32_t, uint64_t, float32_t, float64_t, complex64_t, complex128_t) cnp.import_array() @@ -421,16 +421,23 @@ def group_any_all(uint8_t[:] out, if values[i] == flag_val: out[lab] = flag_val + # ---------------------------------------------------------------------- # group_add, group_prod, group_var, group_mean, group_ohlc # ---------------------------------------------------------------------- +ctypedef fused complexfloating_t: + float64_t + float32_t + complex64_t + complex128_t + @cython.wraparound(False) @cython.boundscheck(False) -def _group_add(floating[:, :] out, +def _group_add(complexfloating_t[:, :] out, int64_t[:] counts, - floating[:, :] values, + complexfloating_t[:, :] values, const int64_t[:] labels, Py_ssize_t min_count=0): """ @@ -438,13 +445,14 @@ def _group_add(floating[:, :] out, """ cdef: Py_ssize_t i, j, N, K, lab, ncounts = len(counts) - floating val, count - floating[:, :] sumx, nobs + complexfloating_t val, count + complexfloating_t[:, :] sumx + int64_t[:, :] nobs if len(values) != len(labels): raise ValueError("len(index) != len(labels)") - nobs = np.zeros_like(out) + nobs = np.zeros((len(out), out.shape[1]), dtype=np.int64) sumx = np.zeros_like(out) N, K = (values).shape @@ -462,7 +470,12 @@ def _group_add(floating[:, :] out, # not nan if val == val: nobs[lab, j] += 1 - sumx[lab, j] += val + if (complexfloating_t is complex64_t or + complexfloating_t is complex128_t): + # clang errors if we use += with these dtypes + sumx[lab, j] = sumx[lab, j] + val + else: + sumx[lab, j] += val for i in range(ncounts): for j in range(K): @@ -472,8 +485,10 @@ def _group_add(floating[:, :] out, out[i, j] = sumx[i, j] -group_add_float32 = _group_add['float'] -group_add_float64 = _group_add['double'] +group_add_float32 = _group_add['float32_t'] +group_add_float64 = _group_add['float64_t'] +group_add_complex64 = _group_add['float complex'] +group_add_complex128 = _group_add['double complex'] @cython.wraparound(False) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index fa651794698403..b27d5bb05ee8fa 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1340,19 +1340,18 @@ def f(self, **kwargs): # try a cython aggregation if we can try: return self._cython_agg_general(alias, alt=npfunc, **kwargs) - except AssertionError: - raise except DataError: pass - except (TypeError, NotImplementedError): - # TODO: - # - TypeError: this is reached via test_groupby_complex - # and can be fixed by implementing _group_add for - # complex dtypes - # - NotImplementedError: reached in test_max_nan_bug, - # raised in _get_cython_function and should probably - # be handled inside _cython_agg_blocks - pass + except NotImplementedError as err: + if "function is not implemented for this dtype" in str(err): + # raised in _get_cython_function, in some cases can + # be trimmed by implementing cython funcs for more dtypes + pass + elif "decimal does not support skipna=True" in str(err): + # FIXME: kludge for test_decimal:test_in_numeric_groupby + pass + else: + raise # apply a non-cython aggregation result = self.aggregate(lambda x: npfunc(x, axis=self.axis)) diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 27415a1bacdbd9..e380cf5930f97d 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -526,7 +526,13 @@ def _cython_operation(self, kind, values, how, axis, min_count=-1, **kwargs): func = self._get_cython_function(kind, how, values, is_numeric) except NotImplementedError: if is_numeric: - values = ensure_float64(values) + try: + values = ensure_float64(values) + except TypeError: + if lib.infer_dtype(values, skipna=False) == "complex": + values = values.astype(complex) + else: + raise func = self._get_cython_function(kind, how, values, is_numeric) else: raise