Skip to content

Commit

Permalink
Avoid explicit np.nan, np.inf (#383)
Browse files Browse the repository at this point in the history
* Handle dtypes.NA properly for datetime/timedelta

* Add Aggregation.preserves_dtype

* Fix ffill, bfill
  • Loading branch information
dcherian authored Aug 14, 2024
1 parent f0ce343 commit 4dbadae
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 80 deletions.
9 changes: 6 additions & 3 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from . import xrdtypes as dtypes
from .xrutils import is_scalar, isnull, notnull


Expand Down Expand Up @@ -98,7 +99,7 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
# partition the complex array in-place
labels_broadcast = np.broadcast_to(group_idx, array.shape)
with np.errstate(invalid="ignore"):
cmplx = labels_broadcast + 1j * array
cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
cmplx.partition(kth=kth, axis=-1)
if is_scalar_q:
a_ = cmplx.imag
Expand Down Expand Up @@ -158,6 +159,8 @@ def _np_grouped_op(


def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
if fillna in [dtypes.INF, dtypes.NINF]:
fillna = dtypes._get_fill_value(kwargs.get("dtype", array.dtype), fillna)
result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs)
# np.nanmax([np.nan, np.nan]) = np.nan
# To recover this behaviour, we need to search for the fillna value
Expand All @@ -175,9 +178,9 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
prod = partial(_np_grouped_op, op=np.multiply.reduceat)
nanprod = partial(_nan_grouped_op, func=prod, fillna=1)
max = partial(_np_grouped_op, op=np.maximum.reduceat)
nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf)
nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF)
min = partial(_np_grouped_op, op=np.minimum.reduceat)
nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf)
nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF)
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False))
Expand Down
109 changes: 37 additions & 72 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,60 +115,6 @@ def generic_aggregate(
return result


def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
if dtype is None:
dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
if array_dtype.kind in "fcmM":
dtype = array_dtype
else:
dtype = np.dtype("float64")
elif not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
dtype = np.result_type(dtype, fill_value)
return dtype


def _maybe_promote_int(dtype) -> np.dtype:
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
# The dtype of a is used by default unless a has an integer dtype of less precision
# than the default platform integer.
if not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if dtype.kind == "i":
dtype = np.result_type(dtype, np.intp)
elif dtype.kind == "u":
dtype = np.result_type(dtype, np.uintp)
return dtype


def _get_fill_value(dtype, fill_value):
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
if fill_value in [None, dtypes.NA] and dtype.kind in "US":
return ""
if fill_value == dtypes.INF or fill_value is None:
return dtypes.get_pos_infinity(dtype, max_for_int=True)
if fill_value == dtypes.NINF:
return dtypes.get_neg_infinity(dtype, min_for_int=True)
if fill_value == dtypes.NA:
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
return np.nan
# This is madness, but npg checks that fill_value is compatible
# with array dtype even if the fill_value is never used.
elif (
np.issubdtype(dtype, np.integer)
or np.issubdtype(dtype, np.timedelta64)
or np.issubdtype(dtype, np.datetime64)
):
return dtypes.get_neg_infinity(dtype, min_for_int=True)
else:
return None
return fill_value


def _atleast_1d(inp, min_length: int = 1):
if xrutils.is_scalar(inp):
inp = (inp,) * min_length
Expand Down Expand Up @@ -210,6 +156,7 @@ def __init__(
final_dtype: DTypeLike | None = None,
reduction_type: Literal["reduce", "argreduce"] = "reduce",
new_dims_func: Callable | None = None,
preserves_dtype: bool = False,
):
"""
Blueprint for computing grouped aggregations.
Expand Down Expand Up @@ -256,6 +203,8 @@ def __init__(
Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions
added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2,
so returns (2,)
preserves_dtype: bool,
Whether a function preserves the dtype on return E.g. min, max, first, last, mode
"""
self.name = name
# preprocess before blockwise
Expand Down Expand Up @@ -292,6 +241,7 @@ def __init__(
self.new_dims_func: Callable = (
returns_empty_tuple if new_dims_func is None else new_dims_func
)
self.preserves_dtype = preserves_dtype

@cached_property
def new_dims(self) -> tuple[Dim]:
Expand Down Expand Up @@ -434,10 +384,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
)


min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, preserves_dtype=True)
nanmin = Aggregation(
"nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA, preserves_dtype=True
)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True)
nanmax = Aggregation(
"nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA, preserves_dtype=True
)


def argreduce_preprocess(array, axis):
Expand Down Expand Up @@ -525,10 +479,14 @@ def _pick_second(*x):
final_dtype=np.intp,
)

first = Aggregation("first", chunk=None, combine=None, fill_value=None)
last = Aggregation("last", chunk=None, combine=None, fill_value=None)
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA)
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA)
first = Aggregation("first", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
last = Aggregation("last", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
nanfirst = Aggregation(
"nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA, preserves_dtype=True
)
nanlast = Aggregation(
"nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA, preserves_dtype=True
)

all_ = Aggregation(
"all",
Expand Down Expand Up @@ -579,8 +537,12 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
final_dtype=np.floating,
new_dims_func=quantile_new_dims_func,
)
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
mode = Aggregation(
name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
)
nanmode = Aggregation(
name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
)


@dataclass
Expand Down Expand Up @@ -634,7 +596,7 @@ def last(self) -> AlignedArrays:
# TODO: automate?
engine="flox",
dtype=self.array.dtype,
fill_value=_get_fill_value(self.array.dtype, dtypes.NA),
fill_value=dtypes._get_fill_value(self.array.dtype, dtypes.NA),
expected_groups=None,
)
return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"])
Expand Down Expand Up @@ -729,6 +691,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
binary_op=None,
reduction="nanlast",
scan="ffill",
# Important: this must be NaN otherwise, ffill does not work.
identity=np.nan,
mode="concat_then_scan",
)
Expand All @@ -737,6 +700,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
binary_op=None,
reduction="nanlast",
scan="ffill",
# Important: this must be NaN otherwise, bfill does not work.
identity=np.nan,
mode="concat_then_scan",
preprocess=reverse,
Expand Down Expand Up @@ -815,17 +779,18 @@ def _initialize_aggregation(
dtype_: np.dtype | None = (
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
)

final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax"]:
final_dtype = _maybe_promote_int(final_dtype)
final_dtype = dtypes._normalize_dtype(
dtype_ or agg.dtype_init["final"], array_dtype, fill_value
)
if not agg.preserves_dtype:
final_dtype = dtypes._maybe_promote_int(final_dtype)
agg.dtype = {
"user": dtype, # Save to automatically choose an engine
"final": final_dtype,
"numpy": (final_dtype,),
"intermediate": tuple(
(
_normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
dtypes._normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
if int_dtype is None
else np.dtype(int_dtype)
)
Expand All @@ -838,10 +803,10 @@ def _initialize_aggregation(
# Replace sentinel fill values according to dtype
agg.fill_value["user"] = fill_value
agg.fill_value["intermediate"] = tuple(
_get_fill_value(dt, fv)
dtypes._get_fill_value(dt, fv)
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
)
agg.fill_value[func] = _get_fill_value(agg.dtype["final"], agg.fill_value[func])
agg.fill_value[func] = dtypes._get_fill_value(agg.dtype["final"], agg.fill_value[func])

fv = fill_value if fill_value is not None else agg.fill_value[agg.name]
if _is_arg_reduction(agg):
Expand Down
55 changes: 55 additions & 0 deletions flox/xrdtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools

import numpy as np
from numpy.typing import DTypeLike

from . import xrutils as utils

Expand Down Expand Up @@ -147,3 +148,57 @@ def get_neg_infinity(dtype, min_for_int=False):
def is_datetime_like(dtype):
"""Check if a dtype is a subclass of the numpy datetime types"""
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)


def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
if dtype is None:
dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
if array_dtype.kind in "fcmM":
dtype = array_dtype
else:
dtype = np.dtype("float64")
elif not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if fill_value not in [None, INF, NINF, NA]:
dtype = np.result_type(dtype, fill_value)
return dtype


def _maybe_promote_int(dtype) -> np.dtype:
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
# The dtype of a is used by default unless a has an integer dtype of less precision
# than the default platform integer.
if not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if dtype.kind == "i":
dtype = np.result_type(dtype, np.intp)
elif dtype.kind == "u":
dtype = np.result_type(dtype, np.uintp)
return dtype


def _get_fill_value(dtype, fill_value):
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
if fill_value in [None, NA] and dtype.kind in "US":
return ""
if fill_value == INF or fill_value is None:
return get_pos_infinity(dtype, max_for_int=True)
if fill_value == NINF:
return get_neg_infinity(dtype, min_for_int=True)
if fill_value == NA:
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
return np.nan
# This is madness, but npg checks that fill_value is compatible
# with array dtype even if the fill_value is never used.
elif np.issubdtype(dtype, np.integer):
return get_neg_infinity(dtype, min_for_int=True)
elif np.issubdtype(dtype, np.timedelta64):
return np.timedelta64("NaT")
elif np.issubdtype(dtype, np.datetime64):
return np.datetime64("NaT")
else:
return None
return fill_value
11 changes: 6 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from numpy_groupies.aggregate_numpy import aggregate

import flox
from flox import xrdtypes as dtypes
from flox import xrutils
from flox.aggregations import Aggregation, _initialize_aggregation, _maybe_promote_int
from flox.aggregations import Aggregation, _initialize_aggregation
from flox.core import (
HAS_NUMBAGG,
_choose_engine,
Expand Down Expand Up @@ -161,7 +162,7 @@ def test_groupby_reduce(
if func == "mean" or func == "nanmean":
expected_result = np.array(expected, dtype=np.float64)
elif func == "sum":
expected_result = np.array(expected, dtype=_maybe_promote_int(array.dtype))
expected_result = np.array(expected, dtype=dtypes._maybe_promote_int(array.dtype))
elif func == "count":
expected_result = np.array(expected, dtype=np.intp)

Expand Down Expand Up @@ -389,7 +390,7 @@ def test_groupby_reduce_preserves_dtype(dtype, func):
array = np.ones((2, 12), dtype=dtype)
by = np.array([labels] * 2)
result, _ = groupby_reduce(from_array(array, chunks=(-1, 4)), by, func=func)
expect_dtype = _maybe_promote_int(array.dtype)
expect_dtype = dtypes._maybe_promote_int(array.dtype)
assert result.dtype == expect_dtype


Expand Down Expand Up @@ -1054,7 +1055,7 @@ def test_dtype_preservation(dtype, func, engine):
# https://github.com/numbagg/numbagg/issues/121
pytest.skip()
if func == "sum":
expected = _maybe_promote_int(dtype)
expected = dtypes._maybe_promote_int(dtype)
elif func == "mean" and "int" in dtype:
expected = np.float64
else:
Expand Down Expand Up @@ -1085,7 +1086,7 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
actual, actual_groups = groupby_reduce(array, labels, func="sum", method=method)
assert_equal(actual_groups, np.arange(6, dtype=labels.dtype))

expect_dtype = _maybe_promote_int(dtype)
expect_dtype = dtypes._maybe_promote_int(dtype)
assert_equal(actual, np.array([0, 4, 24, 6, 12, 20], dtype=expect_dtype))


Expand Down

0 comments on commit 4dbadae

Please sign in to comment.