Skip to content

Commit

Permalink
Handle dtypes.NA properly for datetime/timedelta
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 7, 2024
1 parent f0ce343 commit 59482d0
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 77 deletions.
17 changes: 10 additions & 7 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from . import xrdtypes as dtypes
from .xrutils import is_scalar, isnull, notnull


Expand Down Expand Up @@ -98,7 +99,7 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
# partition the complex array in-place
labels_broadcast = np.broadcast_to(group_idx, array.shape)
with np.errstate(invalid="ignore"):
cmplx = labels_broadcast + 1j * array
cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
cmplx.partition(kth=kth, axis=-1)
if is_scalar_q:
a_ = cmplx.imag
Expand Down Expand Up @@ -158,6 +159,8 @@ def _np_grouped_op(


def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
if fillna in [dtypes.INF, dtypes.NINF]:
fillna = dtypes._get_fill_value(kwargs.get("dtype", array.dtype), fillna)
result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs)
# np.nanmax([np.nan, np.nan]) = np.nan
# To recover this behaviour, we need to search for the fillna value
Expand All @@ -175,13 +178,13 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
prod = partial(_np_grouped_op, op=np.multiply.reduceat)
nanprod = partial(_nan_grouped_op, func=prod, fillna=1)
max = partial(_np_grouped_op, op=np.maximum.reduceat)
nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf)
nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF)
min = partial(_np_grouped_op, op=np.minimum.reduceat)
nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf)
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False))
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=True))
nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF)
quantile = partial(_np_grouped_op, op=partial(quantile, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=False))
nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=True))
# TODO: all, any


Expand Down
87 changes: 22 additions & 65 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,60 +115,6 @@ def generic_aggregate(
return result


def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
if dtype is None:
dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
if array_dtype.kind in "fcmM":
dtype = array_dtype
else:
dtype = np.dtype("float64")
elif not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
dtype = np.result_type(dtype, fill_value)
return dtype


def _maybe_promote_int(dtype) -> np.dtype:
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
# The dtype of a is used by default unless a has an integer dtype of less precision
# than the default platform integer.
if not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if dtype.kind == "i":
dtype = np.result_type(dtype, np.intp)
elif dtype.kind == "u":
dtype = np.result_type(dtype, np.uintp)
return dtype


def _get_fill_value(dtype, fill_value):
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
if fill_value in [None, dtypes.NA] and dtype.kind in "US":
return ""
if fill_value == dtypes.INF or fill_value is None:
return dtypes.get_pos_infinity(dtype, max_for_int=True)
if fill_value == dtypes.NINF:
return dtypes.get_neg_infinity(dtype, min_for_int=True)
if fill_value == dtypes.NA:
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
return np.nan
# This is madness, but npg checks that fill_value is compatible
# with array dtype even if the fill_value is never used.
elif (
np.issubdtype(dtype, np.integer)
or np.issubdtype(dtype, np.timedelta64)
or np.issubdtype(dtype, np.datetime64)
):
return dtypes.get_neg_infinity(dtype, min_for_int=True)
else:
return None
return fill_value


def _atleast_1d(inp, min_length: int = 1):
if xrutils.is_scalar(inp):
inp = (inp,) * min_length
Expand Down Expand Up @@ -435,9 +381,9 @@ def _std_finalize(sumsq, sum_, count, ddof=0):


min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA)


def argreduce_preprocess(array, axis):
Expand Down Expand Up @@ -634,7 +580,7 @@ def last(self) -> AlignedArrays:
# TODO: automate?
engine="flox",
dtype=self.array.dtype,
fill_value=_get_fill_value(self.array.dtype, dtypes.NA),
fill_value=dtypes._get_fill_value(self.array.dtype, dtypes.NA),
expected_groups=None,
)
return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"])
Expand Down Expand Up @@ -729,15 +675,15 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
binary_op=None,
reduction="nanlast",
scan="ffill",
identity=np.nan,
identity=dtypes.NA,
mode="concat_then_scan",
)
bfill = Scan(
"bfill",
binary_op=None,
reduction="nanlast",
scan="ffill",
identity=np.nan,
identity=dtypes.NA,
mode="concat_then_scan",
preprocess=reverse,
finalize=reverse,
Expand Down Expand Up @@ -816,16 +762,27 @@ def _initialize_aggregation(
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
)

final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax"]:
final_dtype = _maybe_promote_int(final_dtype)
final_dtype = dtypes._normalize_dtype(
dtype_ or agg.dtype_init["final"], array_dtype, fill_value
)
if agg.name not in [
"first",
"last",
"nanfirst",
"nanlast",
"min",
"max",
"nanmin",
"nanmax",
]:
final_dtype = dtypes._maybe_promote_int(final_dtype)
agg.dtype = {
"user": dtype, # Save to automatically choose an engine
"final": final_dtype,
"numpy": (final_dtype,),
"intermediate": tuple(
(
_normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
dtypes._normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
if int_dtype is None
else np.dtype(int_dtype)
)
Expand All @@ -838,10 +795,10 @@ def _initialize_aggregation(
# Replace sentinel fill values according to dtype
agg.fill_value["user"] = fill_value
agg.fill_value["intermediate"] = tuple(
_get_fill_value(dt, fv)
dtypes._get_fill_value(dt, fv)
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
)
agg.fill_value[func] = _get_fill_value(agg.dtype["final"], agg.fill_value[func])
agg.fill_value[func] = dtypes._get_fill_value(agg.dtype["final"], agg.fill_value[func])

fv = fill_value if fill_value is not None else agg.fill_value[agg.name]
if _is_arg_reduction(agg):
Expand Down
55 changes: 55 additions & 0 deletions flox/xrdtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools

import numpy as np
from numpy.typing import DTypeLike

from . import xrutils as utils

Expand Down Expand Up @@ -147,3 +148,57 @@ def get_neg_infinity(dtype, min_for_int=False):
def is_datetime_like(dtype):
"""Check if a dtype is a subclass of the numpy datetime types"""
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)


def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
if dtype is None:
dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
if array_dtype.kind in "fcmM":
dtype = array_dtype
else:
dtype = np.dtype("float64")
elif not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if fill_value not in [None, INF, NINF, NA]:
dtype = np.result_type(dtype, fill_value)
return dtype


def _maybe_promote_int(dtype) -> np.dtype:
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
# The dtype of a is used by default unless a has an integer dtype of less precision
# than the default platform integer.
if not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if dtype.kind == "i":
dtype = np.result_type(dtype, np.intp)
elif dtype.kind == "u":
dtype = np.result_type(dtype, np.uintp)
return dtype


def _get_fill_value(dtype, fill_value):
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
if fill_value in [None, NA] and dtype.kind in "US":
return ""
if fill_value == INF or fill_value is None:
return get_pos_infinity(dtype, max_for_int=True)
if fill_value == NINF:
return get_neg_infinity(dtype, min_for_int=True)
if fill_value == NA:
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
return np.nan
# This is madness, but npg checks that fill_value is compatible
# with array dtype even if the fill_value is never used.
elif np.issubdtype(dtype, np.integer):
return get_neg_infinity(dtype, min_for_int=True)
elif np.issubdtype(dtype, np.timedelta64):
return np.timedelta64("NaT")
elif np.issubdtype(dtype, np.datetime64):
return np.datetime64("NaT")
else:
return None
return fill_value
11 changes: 6 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from numpy_groupies.aggregate_numpy import aggregate

import flox
from flox import xrdtypes as dtypes
from flox import xrutils
from flox.aggregations import Aggregation, _initialize_aggregation, _maybe_promote_int
from flox.aggregations import Aggregation, _initialize_aggregation
from flox.core import (
HAS_NUMBAGG,
_choose_engine,
Expand Down Expand Up @@ -161,7 +162,7 @@ def test_groupby_reduce(
if func == "mean" or func == "nanmean":
expected_result = np.array(expected, dtype=np.float64)
elif func == "sum":
expected_result = np.array(expected, dtype=_maybe_promote_int(array.dtype))
expected_result = np.array(expected, dtype=dtypes._maybe_promote_int(array.dtype))
elif func == "count":
expected_result = np.array(expected, dtype=np.intp)

Expand Down Expand Up @@ -389,7 +390,7 @@ def test_groupby_reduce_preserves_dtype(dtype, func):
array = np.ones((2, 12), dtype=dtype)
by = np.array([labels] * 2)
result, _ = groupby_reduce(from_array(array, chunks=(-1, 4)), by, func=func)
expect_dtype = _maybe_promote_int(array.dtype)
expect_dtype = dtypes._maybe_promote_int(array.dtype)
assert result.dtype == expect_dtype


Expand Down Expand Up @@ -1054,7 +1055,7 @@ def test_dtype_preservation(dtype, func, engine):
# https://github.com/numbagg/numbagg/issues/121
pytest.skip()
if func == "sum":
expected = _maybe_promote_int(dtype)
expected = dtypes._maybe_promote_int(dtype)
elif func == "mean" and "int" in dtype:
expected = np.float64
else:
Expand Down Expand Up @@ -1085,7 +1086,7 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
actual, actual_groups = groupby_reduce(array, labels, func="sum", method=method)
assert_equal(actual_groups, np.arange(6, dtype=labels.dtype))

expect_dtype = _maybe_promote_int(dtype)
expect_dtype = dtypes._maybe_promote_int(dtype)
assert_equal(actual, np.array([0, 4, 24, 6, 12, 20], dtype=expect_dtype))


Expand Down

0 comments on commit 59482d0

Please sign in to comment.