Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid explicit np.nan, np.inf #383

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from . import xrdtypes as dtypes
from .xrutils import is_scalar, isnull, notnull


Expand Down Expand Up @@ -98,7 +99,7 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non
# partition the complex array in-place
labels_broadcast = np.broadcast_to(group_idx, array.shape)
with np.errstate(invalid="ignore"):
cmplx = labels_broadcast + 1j * array
cmplx = labels_broadcast + 1j * (array.view(int) if array.dtype.kind in "Mm" else array)
cmplx.partition(kth=kth, axis=-1)
if is_scalar_q:
a_ = cmplx.imag
Expand Down Expand Up @@ -158,6 +159,8 @@ def _np_grouped_op(


def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
if fillna in [dtypes.INF, dtypes.NINF]:
fillna = dtypes._get_fill_value(kwargs.get("dtype", array.dtype), fillna)
result = func(group_idx, np.where(isnull(array), fillna, array), *args, **kwargs)
# np.nanmax([np.nan, np.nan]) = np.nan
# To recover this behaviour, we need to search for the fillna value
Expand All @@ -175,9 +178,9 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
prod = partial(_np_grouped_op, op=np.multiply.reduceat)
nanprod = partial(_nan_grouped_op, func=prod, fillna=1)
max = partial(_np_grouped_op, op=np.maximum.reduceat)
nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf)
nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF)
min = partial(_np_grouped_op, op=np.minimum.reduceat)
nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf)
nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF)
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False))
Expand Down
109 changes: 37 additions & 72 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,60 +115,6 @@ def generic_aggregate(
return result


def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
if dtype is None:
dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
if array_dtype.kind in "fcmM":
dtype = array_dtype
else:
dtype = np.dtype("float64")
elif not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
dtype = np.result_type(dtype, fill_value)
return dtype


def _maybe_promote_int(dtype) -> np.dtype:
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
# The dtype of a is used by default unless a has an integer dtype of less precision
# than the default platform integer.
if not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if dtype.kind == "i":
dtype = np.result_type(dtype, np.intp)
elif dtype.kind == "u":
dtype = np.result_type(dtype, np.uintp)
return dtype


def _get_fill_value(dtype, fill_value):
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
if fill_value in [None, dtypes.NA] and dtype.kind in "US":
return ""
if fill_value == dtypes.INF or fill_value is None:
return dtypes.get_pos_infinity(dtype, max_for_int=True)
if fill_value == dtypes.NINF:
return dtypes.get_neg_infinity(dtype, min_for_int=True)
if fill_value == dtypes.NA:
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
return np.nan
# This is madness, but npg checks that fill_value is compatible
# with array dtype even if the fill_value is never used.
elif (
np.issubdtype(dtype, np.integer)
or np.issubdtype(dtype, np.timedelta64)
or np.issubdtype(dtype, np.datetime64)
):
return dtypes.get_neg_infinity(dtype, min_for_int=True)
else:
return None
return fill_value


def _atleast_1d(inp, min_length: int = 1):
if xrutils.is_scalar(inp):
inp = (inp,) * min_length
Expand Down Expand Up @@ -210,6 +156,7 @@ def __init__(
final_dtype: DTypeLike | None = None,
reduction_type: Literal["reduce", "argreduce"] = "reduce",
new_dims_func: Callable | None = None,
preserves_dtype: bool = False,
):
"""
Blueprint for computing grouped aggregations.
Expand Down Expand Up @@ -256,6 +203,8 @@ def __init__(
Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions
added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2,
so returns (2,)
preserves_dtype: bool,
Whether a function preserves the dtype on return E.g. min, max, first, last, mode
"""
self.name = name
# preprocess before blockwise
Expand Down Expand Up @@ -292,6 +241,7 @@ def __init__(
self.new_dims_func: Callable = (
returns_empty_tuple if new_dims_func is None else new_dims_func
)
self.preserves_dtype = preserves_dtype

@cached_property
def new_dims(self) -> tuple[Dim]:
Expand Down Expand Up @@ -434,10 +384,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
)


min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, preserves_dtype=True)
nanmin = Aggregation(
"nanmin", chunk="nanmin", combine="nanmin", fill_value=dtypes.NA, preserves_dtype=True
)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, preserves_dtype=True)
nanmax = Aggregation(
"nanmax", chunk="nanmax", combine="nanmax", fill_value=dtypes.NA, preserves_dtype=True
)


def argreduce_preprocess(array, axis):
Expand Down Expand Up @@ -525,10 +479,14 @@ def _pick_second(*x):
final_dtype=np.intp,
)

first = Aggregation("first", chunk=None, combine=None, fill_value=None)
last = Aggregation("last", chunk=None, combine=None, fill_value=None)
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA)
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA)
first = Aggregation("first", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
last = Aggregation("last", chunk=None, combine=None, fill_value=None, preserves_dtype=True)
nanfirst = Aggregation(
"nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=dtypes.NA, preserves_dtype=True
)
nanlast = Aggregation(
"nanlast", chunk="nanlast", combine="nanlast", fill_value=dtypes.NA, preserves_dtype=True
)

all_ = Aggregation(
"all",
Expand Down Expand Up @@ -579,8 +537,12 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
final_dtype=np.floating,
new_dims_func=quantile_new_dims_func,
)
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
mode = Aggregation(
name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
)
nanmode = Aggregation(
name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True
)


@dataclass
Expand Down Expand Up @@ -634,7 +596,7 @@ def last(self) -> AlignedArrays:
# TODO: automate?
engine="flox",
dtype=self.array.dtype,
fill_value=_get_fill_value(self.array.dtype, dtypes.NA),
fill_value=dtypes._get_fill_value(self.array.dtype, dtypes.NA),
expected_groups=None,
)
return AlignedArrays(array=reduced["intermediates"][0], group_idx=reduced["groups"])
Expand Down Expand Up @@ -729,6 +691,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
binary_op=None,
reduction="nanlast",
scan="ffill",
# Important: this must be NaN otherwise, ffill does not work.
identity=np.nan,
mode="concat_then_scan",
)
Expand All @@ -737,6 +700,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan)
binary_op=None,
reduction="nanlast",
scan="ffill",
# Important: this must be NaN otherwise, bfill does not work.
identity=np.nan,
mode="concat_then_scan",
preprocess=reverse,
Expand Down Expand Up @@ -815,17 +779,18 @@ def _initialize_aggregation(
dtype_: np.dtype | None = (
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
)

final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
if agg.name not in ["first", "last", "nanfirst", "nanlast", "min", "max", "nanmin", "nanmax"]:
final_dtype = _maybe_promote_int(final_dtype)
final_dtype = dtypes._normalize_dtype(
dtype_ or agg.dtype_init["final"], array_dtype, fill_value
)
if not agg.preserves_dtype:
final_dtype = dtypes._maybe_promote_int(final_dtype)
agg.dtype = {
"user": dtype, # Save to automatically choose an engine
"final": final_dtype,
"numpy": (final_dtype,),
"intermediate": tuple(
(
_normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
dtypes._normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv)
if int_dtype is None
else np.dtype(int_dtype)
)
Expand All @@ -838,10 +803,10 @@ def _initialize_aggregation(
# Replace sentinel fill values according to dtype
agg.fill_value["user"] = fill_value
agg.fill_value["intermediate"] = tuple(
_get_fill_value(dt, fv)
dtypes._get_fill_value(dt, fv)
for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"])
)
agg.fill_value[func] = _get_fill_value(agg.dtype["final"], agg.fill_value[func])
agg.fill_value[func] = dtypes._get_fill_value(agg.dtype["final"], agg.fill_value[func])

fv = fill_value if fill_value is not None else agg.fill_value[agg.name]
if _is_arg_reduction(agg):
Expand Down
55 changes: 55 additions & 0 deletions flox/xrdtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools

import numpy as np
from numpy.typing import DTypeLike

from . import xrutils as utils

Expand Down Expand Up @@ -147,3 +148,57 @@ def get_neg_infinity(dtype, min_for_int=False):
def is_datetime_like(dtype):
"""Check if a dtype is a subclass of the numpy datetime types"""
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)


def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
if dtype is None:
dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
if array_dtype.kind in "fcmM":
dtype = array_dtype
else:
dtype = np.dtype("float64")
elif not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if fill_value not in [None, INF, NINF, NA]:
dtype = np.result_type(dtype, fill_value)
return dtype


def _maybe_promote_int(dtype) -> np.dtype:
# https://numpy.org/doc/stable/reference/generated/numpy.prod.html
# The dtype of a is used by default unless a has an integer dtype of less precision
# than the default platform integer.
if not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
if dtype.kind == "i":
dtype = np.result_type(dtype, np.intp)
elif dtype.kind == "u":
dtype = np.result_type(dtype, np.uintp)
return dtype


def _get_fill_value(dtype, fill_value):
"""Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
if fill_value in [None, NA] and dtype.kind in "US":
return ""
if fill_value == INF or fill_value is None:
return get_pos_infinity(dtype, max_for_int=True)
if fill_value == NINF:
return get_neg_infinity(dtype, min_for_int=True)
if fill_value == NA:
if np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating):
return np.nan
# This is madness, but npg checks that fill_value is compatible
# with array dtype even if the fill_value is never used.
elif np.issubdtype(dtype, np.integer):
return get_neg_infinity(dtype, min_for_int=True)
elif np.issubdtype(dtype, np.timedelta64):
return np.timedelta64("NaT")
elif np.issubdtype(dtype, np.datetime64):
return np.datetime64("NaT")
else:
return None
return fill_value
11 changes: 6 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from numpy_groupies.aggregate_numpy import aggregate

import flox
from flox import xrdtypes as dtypes
from flox import xrutils
from flox.aggregations import Aggregation, _initialize_aggregation, _maybe_promote_int
from flox.aggregations import Aggregation, _initialize_aggregation
from flox.core import (
HAS_NUMBAGG,
_choose_engine,
Expand Down Expand Up @@ -161,7 +162,7 @@ def test_groupby_reduce(
if func == "mean" or func == "nanmean":
expected_result = np.array(expected, dtype=np.float64)
elif func == "sum":
expected_result = np.array(expected, dtype=_maybe_promote_int(array.dtype))
expected_result = np.array(expected, dtype=dtypes._maybe_promote_int(array.dtype))
elif func == "count":
expected_result = np.array(expected, dtype=np.intp)

Expand Down Expand Up @@ -389,7 +390,7 @@ def test_groupby_reduce_preserves_dtype(dtype, func):
array = np.ones((2, 12), dtype=dtype)
by = np.array([labels] * 2)
result, _ = groupby_reduce(from_array(array, chunks=(-1, 4)), by, func=func)
expect_dtype = _maybe_promote_int(array.dtype)
expect_dtype = dtypes._maybe_promote_int(array.dtype)
assert result.dtype == expect_dtype


Expand Down Expand Up @@ -1054,7 +1055,7 @@ def test_dtype_preservation(dtype, func, engine):
# https://github.com/numbagg/numbagg/issues/121
pytest.skip()
if func == "sum":
expected = _maybe_promote_int(dtype)
expected = dtypes._maybe_promote_int(dtype)
elif func == "mean" and "int" in dtype:
expected = np.float64
else:
Expand Down Expand Up @@ -1085,7 +1086,7 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype):
actual, actual_groups = groupby_reduce(array, labels, func="sum", method=method)
assert_equal(actual_groups, np.arange(6, dtype=labels.dtype))

expect_dtype = _maybe_promote_int(dtype)
expect_dtype = dtypes._maybe_promote_int(dtype)
assert_equal(actual, np.array([0, 4, 24, 6, 12, 20], dtype=expect_dtype))


Expand Down
Loading