Skip to content

Commit

Permalink
Preserve dtype better when specified. (#389)
Browse files Browse the repository at this point in the history
* Preserve dtype better when specified.

* Add one more test

* tweak test

* more test

* [revert] test with Xarray PR branch

* tweak

* show versions

* Drop python 3.9, use ruff

* switch to Ruff

* fix mypy

* remove toctrees

* fix

* one more
  • Loading branch information
dcherian authored Sep 8, 2024
1 parent 0438a7e commit 7421cb1
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 15 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
- name: Run Tests
id: status
run: |
python -c "import xarray; xarray.show_versions()"
pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci
- name: Upload code coverage to Codecov
uses: codecov/[email protected]
Expand Down Expand Up @@ -98,7 +99,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
repository: "pydata/xarray"
repository: "dcherian/xarray"
fetch-depth: 0 # Fetch all history for all branches and tags.
- name: Set up conda environment
uses: mamba-org/setup-micromamba@v1
Expand All @@ -112,6 +113,7 @@ jobs:
pint>=0.22
- name: Install xarray
run: |
git checkout flox-preserve-dtype
python -m pip install --no-deps .
- name: Install upstream flox
run: |
Expand Down
3 changes: 2 additions & 1 deletion ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ dependencies:
- pytest-pretty
- pytest-xdist
- syrupy
- xarray
- pre-commit
- numpy_groupies>=0.9.19
- pooch
- toolz
- numba
- numbagg>=0.3
- hypothesis
- pip:
- git+https://github.com/dcherian/xarray.git@flox-preserve-dtype
3 changes: 2 additions & 1 deletion ci/no-dask.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ dependencies:
- pytest-pretty
- pytest-xdist
- syrupy
- xarray
- numpydoc
- pre-commit
- numpy_groupies>=0.9.19
- pooch
- toolz
- numba
- numbagg>=0.3
- pip:
- git+https://github.com/dcherian/xarray.git@flox-preserve-dtype
13 changes: 8 additions & 5 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,20 +549,23 @@ def quantile_new_dims_func(q) -> tuple[Dim]:
return (Dim(name="quantile", values=q),)


# if the input contains integers or floats smaller than float64,
# the output data-type is float64. Otherwise, the output data-type is the same as that
# of the input.
quantile = Aggregation(
name="quantile",
fill_value=dtypes.NA,
chunk=None,
combine=None,
final_dtype=np.floating,
final_dtype=np.float64,
new_dims_func=quantile_new_dims_func,
)
nanquantile = Aggregation(
name="nanquantile",
fill_value=dtypes.NA,
chunk=None,
combine=None,
final_dtype=np.floating,
final_dtype=np.float64,
new_dims_func=quantile_new_dims_func,
)
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True)
Expand Down Expand Up @@ -801,9 +804,9 @@ def _initialize_aggregation(
dtype_: np.dtype | None = (
np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype
)
final_dtype = dtypes._normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
if not agg.preserves_dtype:
final_dtype = dtypes._maybe_promote_int(final_dtype)
final_dtype = dtypes._normalize_dtype(
dtype_ or agg.dtype_init["final"], array_dtype, agg.preserves_dtype, fill_value
)
agg.dtype = {
"user": dtype, # Save to automatically choose an engine
"final": final_dtype,
Expand Down
9 changes: 7 additions & 2 deletions flox/xrdtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,14 @@ def is_datetime_like(dtype):
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)


def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
def _normalize_dtype(
dtype: DTypeLike, array_dtype: np.dtype, preserves_dtype: bool, fill_value=None
) -> np.dtype:
if dtype is None:
dtype = array_dtype
if not preserves_dtype:
dtype = _maybe_promote_int(array_dtype)
else:
dtype = array_dtype
if dtype is np.floating:
# mean, std, var always result in floating
# but we preserve the array's dtype if it is floating
Expand Down
4 changes: 2 additions & 2 deletions tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:


# TODO: stop excluding everything but U
array_dtype_st = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
array_dtypes = supported_dtypes().filter(lambda x: x.kind not in "cmMU")
by_dtype_st = supported_dtypes()

NON_NUMPY_FUNCS = [
Expand All @@ -43,7 +43,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:

func_st = st.sampled_from([f for f in ALL_FUNCS if f not in NON_NUMPY_FUNCS and f not in SKIPPED_FUNCS])
numeric_arrays = npst.arrays(
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtype_st
elements={"allow_subnormal": False}, shape=npst.array_shapes(), dtype=array_dtypes
)
all_arrays = npst.arrays(
elements={"allow_subnormal": False},
Expand Down
15 changes: 14 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _get_array_func(func: str) -> Callable:

def npfunc(x, **kwargs):
x = np.asarray(x)
return (~np.isnan(x)).sum()
return (~xrutils.isnull(x)).sum(**kwargs)

elif func in ["nanfirst", "nanlast"]:
npfunc = getattr(xrutils, func)
Expand Down Expand Up @@ -1984,3 +1984,16 @@ def test_blockwise_nans():
)
assert_equal(expected_groups, actual_groups)
assert_equal(expected, actual)


@pytest.mark.parametrize("func", ["sum", "prod", "count", "nansum"])
@pytest.mark.parametrize("engine", ["flox", "numpy"])
def test_agg_dtypes(func, engine):
# regression test for GH388
counts = np.array([0, 2, 1, 0, 1])
group = np.array([1, 1, 1, 2, 2])
actual, _ = groupby_reduce(
counts, group, expected_groups=(np.array([1, 2]),), func=func, dtype="uint8", engine=engine
)
expected = _get_array_func(func)(counts, dtype="uint8")
assert actual.dtype == np.uint8 == expected.dtype
24 changes: 23 additions & 1 deletion tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flox.xrutils import notnull

from . import assert_equal
from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
from .strategies import array_dtypes, by_arrays, chunked_arrays, func_st, numeric_arrays
from .strategies import chunks as chunks_strategy

dask.config.set(scheduler="sync")
Expand Down Expand Up @@ -244,3 +244,25 @@ def test_first_last_useless(data, func):
actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
assert_equal(actual, expected)


@given(
func=st.sampled_from(["sum", "prod", "nansum", "nanprod"]),
engine=st.sampled_from(["numpy", "flox"]),
array_dtype=st.none() | array_dtypes,
dtype=st.none() | array_dtypes,
)
def test_agg_dtype_specified(func, array_dtype, dtype, engine):
# regression test for GH388
counts = np.array([0, 2, 1, 0, 1], dtype=array_dtype)
group = np.array([1, 1, 1, 2, 2])
actual, _ = groupby_reduce(
counts,
group,
expected_groups=(np.array([1, 2]),),
func=func,
dtype=dtype,
engine=engine,
)
expected = getattr(np, func)(counts, keepdims=True, dtype=dtype)
assert actual.dtype == expected.dtype
25 changes: 24 additions & 1 deletion tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# test against legacy xarray implementation
# avoid some compilation overhead
xr.set_options(use_flox=False, use_numbagg=False)
xr.set_options(use_flox=False, use_numbagg=False, use_bottleneck=False)
tolerance64 = {"rtol": 1e-15, "atol": 1e-18}
np.random.seed(123)

Expand Down Expand Up @@ -760,3 +760,26 @@ def test_direct_reduction(func):
with xr.set_options(use_flox=False):
expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs)
xr.testing.assert_identical(expected, actual)


@pytest.mark.parametrize("reduction", ["max", "min", "nanmax", "nanmin", "sum", "nansum", "prod", "nanprod"])
def test_groupby_preserve_dtype(reduction):
# all groups are present, we should follow numpy exactly
ds = xr.Dataset(
{
"test": (
["x", "y"],
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int16"),
)
},
coords={"idx": ("x", [1, 2, 1])},
)

kwargs = {"engine": "numpy"}
if "nan" in reduction:
kwargs["skipna"] = True
with xr.set_options(use_flox=True):
actual = getattr(ds.groupby("idx"), reduction.removeprefix("nan"))(**kwargs).test.dtype
expected = getattr(np, reduction)(ds.test.data, axis=0).dtype

assert actual == expected

0 comments on commit 7421cb1

Please sign in to comment.