Skip to content

Commit

Permalink
Merge pull request #351 from neutrinoceros/nep18_masks
Browse files Browse the repository at this point in the history
ENH: (NEP 18) test and implement combinatory functions
  • Loading branch information
ngoldbaum authored Jan 9, 2023
2 parents d939417 + 7d8d96e commit 2d55916
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 27 deletions.
169 changes: 165 additions & 4 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from numbers import Number

import numpy as np
from packaging.version import Version
Expand Down Expand Up @@ -152,32 +153,46 @@ def histogram_bin_edges(a, *args, **kwargs):
)


def get_units(arrays):
def get_units(objs):
units = []
for sub in arrays:
for sub in objs:
if isinstance(sub, np.ndarray):
units.append(getattr(sub, "units", NULL_UNIT))
elif isinstance(sub, Number):
units.append(NULL_UNIT)
else:
units.extend(get_units(sub))
return units


def _validate_units_consistency(arrays):
def _validate_units_consistency(objs):
"""
Return unique units or raise UnitInconsistencyError if units are mixed.
"""
# NOTE: we cannot validate that all arrays are unyt_arrays
# by using this as a guard clause in unyt_array.__array_function__
# because it's already a necessary condition for numpy to use our
# custom implementations
units = get_units(arrays)
units = get_units(objs)
sunits = set(units)
if len(sunits) == 1:
return units[0]
else:
raise UnitInconsistencyError(*units)


def _validate_units_consistency_v2(ref_units, *args) -> None:
"""
raise UnitInconsistencyError if units are mixed
if all args are pure numbers, they are treated as having ref_units,
otherwise they are treated as dimensionless
"""
if all(isinstance(_, Number) for _ in args):
return
else:
_validate_units_consistency((1 * ref_units, *args))


@implements(np.concatenate)
def concatenate(arrs, /, axis=0, out=None, dtype=None, casting="same_kind"):
ret_units = _validate_units_consistency(arrs)
Expand Down Expand Up @@ -714,3 +729,149 @@ def cumprod(a, *args, **kwargs):
@implements(np.pad)
def pad(array, *args, **kwargs):
return np.pad._implementation(array.view(np.ndarray), *args, **kwargs) * array.units


@implements(np.choose)
def choose(a, choices, out=None, *args, **kwargs):
if (au := getattr(a, "units", NULL_UNIT)) != NULL_UNIT:
raise TypeError(
f"The first argument to numpy.choose must be dimensionless, got units={au}"
)
retu = _validate_units_consistency(choices)

if out is None:
return (
np.choose._implementation(
a, [np.asarray(c) for c in choices], *args, **kwargs
)
* retu
)

res = np.choose._implementation(
a,
[np.asarray(c) for c in choices],
out=out.view(np.ndarray),
*args,
**kwargs,
)
if getattr(out, "units", None) is not None:
out.units = retu
return unyt_array(res, retu, bypass_validation=True)


@implements(np.fill_diagonal)
def fill_diagonal(a, val, *args, **kwargs) -> None:
_validate_units_consistency_v2(a.units, val)
np.fill_diagonal._implementation(a.view(np.ndarray), val, *args, **kwargs)


@implements(np.insert)
def insert(arr, obj, values, *args, **kwargs):
_validate_units_consistency_v2(arr.units, values)
return (
np.insert._implementation(
arr.view(np.ndarray), obj, np.asarray(values), *args, **kwargs
)
* arr.units
)


@implements(np.isin)
def isin(element, test_elements, *args, **kwargs):
_validate_units_consistency((element, test_elements))
return np.isin._implementation(
np.asarray(element), np.asarray(test_elements), *args, **kwargs
)


@implements(np.place)
def place(arr, mask, vals, *args, **kwargs) -> None:
_validate_units_consistency_v2(arr.units, vals)
np.place._implementation(
arr.view(np.ndarray), mask, vals.view(np.ndarray), *args, **kwargs
)


@implements(np.put)
def put(a, ind, v, *args, **kwargs) -> None:
_validate_units_consistency_v2(a.units, v)
np.put._implementation(a.view(np.ndarray), ind, v.view(np.ndarray))


@implements(np.put_along_axis)
def put_along_axis(arr, indices, values, axis, *args, **kwargs) -> None:
_validate_units_consistency_v2(arr.units, values)
np.put_along_axis._implementation(
arr.view(np.ndarray), indices, np.asarray(values), axis, *args, **kwargs
)


@implements(np.putmask)
def putmask(a, mask, values, *args, **kwargs) -> None:
_validate_units_consistency_v2(a.units, values)
np.putmask._implementation(
a.view(np.ndarray), mask, np.asarray(values), *args, **kwargs
)


@implements(np.searchsorted)
def searchsorted(a, v, *args, **kwargs):
_validate_units_consistency_v2(a.units, v)
return np.searchsorted._implementation(
a.view(np.ndarray), np.asarray(v), *args, **kwargs
)


@implements(np.select)
def select(condlist, choicelist, default=0, *args, **kwargs):
ref_units = choicelist[0].units
_validate_units_consistency_v2(ref_units, choicelist, default)
return (
np.select._implementation(
condlist, [np.asarray(c) for c in choicelist], default
)
* ref_units
)


@implements(np.setdiff1d)
def setdiff1d(ar1, ar2, *args, **kwargs):
retu = _validate_units_consistency((ar1, ar2))
return (
np.setdiff1d._implementation(np.asarray(ar1), np.asarray(ar2), *args, **kwargs)
* retu
)


@implements(np.sinc)
def sinc(x, *args, **kwargs):
# this implementation becomes necessary after implementing where
# we *want* this one to ignore units
return np.sinc._implementation(x.view(np.ndarray), *args, **kwargs)


@implements(np.clip)
def clip(a, a_min, a_max, out=None, *args, **kwargs):
_validate_units_consistency_v2(a.units, a_min, a_max)
if out is None:
return (
np.clip._implementation(
np.asarray(a), np.asarray(a_min), np.asarray(a_max), *args, **kwargs
)
* a.units
)

res = (
np.clip._implementation(
np.asarray(a),
np.asarray(a_min),
np.asarray(a_max),
out=out.view(np.ndarray),
*args,
**kwargs,
)
* a.units
)
if getattr(out, "units", None) is not None:
out.units = a.units
return unyt_array(res, a.units, bypass_validation=True)
31 changes: 30 additions & 1 deletion unyt/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import warnings

from unyt.array import allclose_units
import numpy.testing as npt

from unyt.array import NULL_UNIT, allclose_units


def assert_allclose_units(actual, desired, rtol=1e-7, atol=0, **kwargs):
Expand Down Expand Up @@ -49,6 +51,33 @@ def assert_allclose_units(actual, desired, rtol=1e-7, atol=0, **kwargs):
raise AssertionError


def assert_array_equal_units(x, y, **kwargs):
"""A thin wrapper around :func:`numpy.testing.assert_array_equal` that also
verifies unit consistency
Arrays without units are considered dimensionless.
Parameters
----------
x : array_like
The actual object to check.
y : array_like
The desired, expected object.
See Also
--------
:func:`numpy.testing.assert_array_equal`
Notes
-----
Also accepts additional keyword arguments accepted by
:func:`numpy.testing.assert_array_equel`, see the documentation of that
function for details.
"""
# see https://github.com/yt-project/unyt/issues/281
npt.assert_array_equal(x, y, **kwargs)
assert getattr(x, "units", NULL_UNIT) == getattr(y, "units", NULL_UNIT)


def _process_warning(op, message, warning_class, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
Expand Down
Loading

0 comments on commit 2d55916

Please sign in to comment.