Skip to content

Commit

Permalink
ENH: (NEP 18) implement and test scalar reducer functions
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Nov 12, 2022
1 parent 8eb9e17 commit cfb4667
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 11 deletions.
39 changes: 39 additions & 0 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,42 @@ def copyto(dst, src, *args, **kwargs):
np.copyto._implementation(dst, src, *args, **kwargs)
if getattr(dst, "units", None) is not None:
dst.units = getattr(src, "units", dst.units)


@implements(np.prod)
def prod(a, *args, **kwargs):
return (
np.prod._implementation(a.view(np.ndarray), *args, **kwargs) * a.units**a.size
)


@implements(np.var)
def var(a, *args, **kwargs):
return np.var._implementation(a.view(np.ndarray), *args, **kwargs) * a.units**2


@implements(np.trace)
def trace(a, *args, **kwargs):
return np.trace._implementation(a.view(np.ndarray), *args, **kwargs) * a.units


@implements(np.percentile)
def percentile(a, *args, **kwargs):
return np.percentile._implementation(a.view(np.ndarray), *args, **kwargs) * a.units


@implements(np.quantile)
def quantile(a, *args, **kwargs):
return np.quantile._implementation(a.view(np.ndarray), *args, **kwargs) * a.units


@implements(np.nanpercentile)
def nanpercentile(a, *args, **kwargs):
return (
np.nanpercentile._implementation(a.view(np.ndarray), *args, **kwargs) * a.units
)


@implements(np.nanquantile)
def nanquantile(a, *args, **kwargs):
return np.nanquantile._implementation(a.view(np.ndarray), *args, **kwargs) * a.units
60 changes: 49 additions & 11 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@
np.vsplit, # works out of the box (tested)
np.swapaxes, # works out of the box (tested)
np.moveaxis, # works out of the box (tested)
np.nansum, # works out of the box (tested)
np.product, # is implemented via np.prod
np.std, # works out of the box (tested)
np.nanstd, # works out of the box (tested)
}

# this set represents all functions that need inspection, tests, or both
Expand Down Expand Up @@ -164,15 +168,10 @@
np.msort,
np.nancumprod,
np.nancumsum,
np.nanpercentile,
np.nanprod,
np.nanquantile,
np.nanstd,
np.nansum,
np.nanvar,
np.packbits,
np.pad,
np.percentile,
np.piecewise,
np.place,
np.poly,
Expand All @@ -184,13 +183,10 @@
np.polymul,
np.polysub,
np.polyval,
np.prod,
np.product,
np.ptp, # note: should return delta_K for temperatures !
np.put,
np.put_along_axis,
np.putmask,
np.quantile,
np.ravel,
np.ravel_multi_index,
np.real,
Expand All @@ -208,11 +204,9 @@
np.setdiff1d,
np.setxor1d,
np.sinc,
np.std,
np.take,
np.take_along_axis,
np.tensordot,
np.trace,
np.tril,
np.tril_indices_from,
np.triu,
Expand All @@ -222,7 +216,6 @@
np.unravel_index,
np.unwrap,
np.vander,
np.var,
np.where,
}

Expand Down Expand Up @@ -1045,3 +1038,48 @@ def test_xsplit(func, args):
y = func(x, *args)
assert all(type(_) is unyt_array for _ in y)
assert all(_.units == cm for _ in y)


@pytest.mark.parametrize(
"func, expected_units",
[
(np.prod, cm**9),
(np.product, cm**9),
(np.var, cm**2),
(np.std, cm),
(np.nanprod, cm**9),
(np.nansum, cm),
(np.nanvar, cm**2),
(np.nanstd, cm),
(np.trace, cm),
],
)
def test_scalar_reducer(func, expected_units):
x = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
] * cm
y = func(x)
assert type(y) is unyt_quantity
assert y.units == expected_units


@pytest.mark.parametrize(
"func",
[
np.percentile,
np.quantile,
np.nanpercentile,
np.nanquantile,
],
)
def test_percentile(func):
x = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
] * cm
y = func(x, 1)
assert type(y) is unyt_quantity
assert y.units == cm

0 comments on commit cfb4667

Please sign in to comment.