diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index ecfc73d5..d9e79572 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -2,7 +2,7 @@ from packaging.version import Version from unyt.array import NULL_UNIT, unyt_array -from unyt.exceptions import UnitInconsistencyError +from unyt.exceptions import UnitConversionError, UnitInconsistencyError NUMPY_VERSION = Version(np.__version__) _HANDLED_FUNCTIONS = {} @@ -405,3 +405,78 @@ def trapz(y, x=None, dx=1.0, *args, **kwargs): @implements(np.sort_complex) def sort_complex(a): return np.sort_complex._implementation(a.view(np.ndarray)) * a.units + + +def _array_comp_helper(a, b): + au = getattr(a, "units", NULL_UNIT) + bu = getattr(b, "units", NULL_UNIT) + if bu != au and au != NULL_UNIT and bu != NULL_UNIT: + if (bu / au).is_dimensionless: + b = np.array(b) * (1 * bu).to(au) + else: + raise UnitConversionError(au, au.dimensions, bu, bu.dimensions) + elif bu == NULL_UNIT: + b = np.array(b) * au + elif au == NULL_UNIT: + a = np.array(a) * bu + + return a, b + + +@implements(np.isclose) +def isclose(a, b, *args, **kwargs): + a, b = _array_comp_helper(a, b) + return np.isclose._implementation( + a.view(np.ndarray), b.view(np.ndarray), *args, **kwargs + ) + + +@implements(np.allclose) +def allclose(a, b, *args, **kwargs): + a, b = _array_comp_helper(a, b) + return np.allclose._implementation( + a.view(np.ndarray), b.view(np.ndarray), *args, **kwargs + ) + + +@implements(np.linspace) +def linspace(start, stop, *args, **kwargs): + _validate_units_consistency((start, stop)) + return ( + np.linspace._implementation( + start.view(np.ndarray), stop.view(np.ndarray), *args, **kwargs + ) + * start.units + ) + + +@implements(np.logspace) +def logspace(start, stop, *args, **kwargs): + _validate_units_consistency((start, stop)) + return ( + np.logspace._implementation( + start.view(np.ndarray), stop.view(np.ndarray), *args, **kwargs + ) + * start.units + ) + + +@implements(np.geomspace) +def geomspace(start, stop, *args, **kwargs): + _validate_units_consistency((start, stop)) + return ( + np.geomspace._implementation( + start.view(np.ndarray), stop.view(np.ndarray), *args, **kwargs + ) + * start.units + ) + + +@implements(np.copyto) +def copyto(dst, src, *args, **kwargs): + # note that np.copyto is heavily used internally + # in numpy, and it may be used with fundamental datatypes, + # so we don't attempt to pass ndarray views to keep generality + np.copyto._implementation(dst, src, *args, **kwargs) + if getattr(dst, "units", None) is not None: + dst.units = getattr(src, "units", dst.units) diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index 44b57590..9d92a777 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -1,14 +1,17 @@ # tests for NumPy __array_function__ support import re +from importlib.metadata import version import numpy as np import pytest +from packaging.version import Version from unyt import cm, g, km, s from unyt._array_functions import _HANDLED_FUNCTIONS as HANDLED_FUNCTIONS from unyt.array import unyt_array, unyt_quantity -from unyt.exceptions import UnitInconsistencyError +from unyt.exceptions import UnitConversionError, UnitInconsistencyError +NUMPY_VERSION = Version(version("numpy")) # this is a subset of NOT_HANDLED_FUNCTIONS for which there's nothing to do # because they don't apply to (real) numeric types # or they work as expected out of the box @@ -56,12 +59,24 @@ np.sum, # works out of the box (tested) np.repeat, # works out of the box (tested) np.tile, # works out of the box (tested) + np.shares_memory, # works out of the box (tested) + np.sometrue, # works out of the box (tested) + np.nonzero, # works out of the box (tested) + np.count_nonzero, # returns pure numbers + np.flatnonzero, # works out of the box (tested) + np.isneginf, # works out of the box (tested) + np.isposinf, # works out of the box (tested) + np.empty_like, # works out of the box (tested) + np.full_like, # works out of the box (tested) + np.ones_like, # works out of the box (tested) + np.zeros_like, # works out of the box (tested) + np.copy, # works out of the box (tested) + np.meshgrid, # works out of the box (tested) } # this set represents all functions that need inspection, tests, or both # it is always possible that some of its elements belong in NOOP_FUNCTIONS TODO_FUNCTIONS = { - np.allclose, np.apply_along_axis, np.apply_over_axes, np.array_equal, @@ -78,11 +93,8 @@ np.common_type, np.compress, np.convolve, - np.copy, - np.copyto, np.corrcoef, np.correlate, - np.count_nonzero, np.cov, np.cumprod, np.cumproduct, @@ -100,17 +112,13 @@ np.ediff1d, # note: should return delta_K for temperatures ! np.einsum, np.einsum_path, - np.empty_like, np.expand_dims, np.extract, np.fill_diagonal, np.fix, - np.flatnonzero, np.flip, np.fliplr, np.flipud, - np.full_like, - np.geomspace, np.gradient, # note: should return delta_K for temperatures ! np.histogram_bin_edges, np.hsplit, @@ -120,10 +128,7 @@ np.insert, np.interp, np.is_busday, - np.isclose, np.isin, - np.isneginf, - np.isposinf, np.ix_, np.lexsort, np.linalg.cholesky, @@ -142,10 +147,7 @@ np.linalg.solve, np.linalg.svd, np.linalg.tensorsolve, - np.linspace, - np.logspace, np.may_share_memory, - np.meshgrid, np.min_scalar_type, np.moveaxis, np.msort, @@ -157,8 +159,6 @@ np.nanstd, np.nansum, np.nanvar, - np.nonzero, - np.ones_like, np.packbits, np.pad, np.partition, @@ -202,9 +202,7 @@ np.select, np.setdiff1d, np.setxor1d, - np.shares_memory, np.sinc, - np.sometrue, np.split, np.squeeze, np.std, @@ -226,7 +224,6 @@ np.var, np.vsplit, np.where, - np.zeros_like, } removed_functions = { @@ -271,7 +268,9 @@ def test_wrapping_completeness(): """Ensure we wrap all numpy functions that support __array_function__""" handled_numpy_functions = set(HANDLED_FUNCTIONS.keys()) # ensure no functions appear in both NOT_HANDLED_FUNCTIONS and HANDLED_FUNCTIONS - assert NOT_HANDLED_FUNCTIONS.isdisjoint(handled_numpy_functions) + assert NOT_HANDLED_FUNCTIONS.isdisjoint( + handled_numpy_functions + ), NOT_HANDLED_FUNCTIONS.intersection(handled_numpy_functions) # get list of functions that support wrapping by introspection on numpy module wrappable_functions = get_wrapped_functions(np, np.fft, np.linalg) for function in HANDLED_FUNCTIONS: @@ -635,9 +634,15 @@ def test_trim_zeros(): assert type(res) is unyt_array -def test_any(): - assert not np.any([0, 0, 0] * cm) - assert np.any([1, 0, 0] * cm) +@pytest.mark.parametrize("func", [np.any, np.sometrue]) +def test_any(func): + assert not func([0, 0, 0] * cm) + assert func([1, 0, 0] * cm) + + x = [1, 2, 3] * cm + assert func(x >= 3) + assert func(x >= 3 * cm) + assert not func(x >= 3 * km) def test_append(): @@ -822,3 +827,141 @@ def test_tile(): res = np.tile(x, (2, 3)) assert type(res) is unyt_array assert res.units == cm + + +def test_shares_memory(): + x = [1, 2, 3] * cm + assert np.shares_memory(x, x.view(np.ndarray)) + + +def test_nonzero(): + x = [1, 2, 0] * cm + res = np.nonzero(x) + assert len(res) == 1 + np.testing.assert_array_equal(res[0], [0, 1]) + + res2 = np.flatnonzero(x) + np.testing.assert_array_equal(res[0], res2) + + +def test_isinf(): + x = [1, float("inf"), float("-inf")] * cm + res = np.isneginf(x) + np.testing.assert_array_equal(res, [False, False, True]) + res = np.isposinf(x) + np.testing.assert_array_equal(res, [False, True, False]) + + +def test_allclose(): + x = [1, 2, 3] * cm + y = [1, 2, 3] * km + assert not np.allclose(x, y) + + +@pytest.mark.parametrize( + "a, b, expected", + [ + ([1, 2, 3] * cm, [1, 2, 3] * km, [False] * 3), + ([1, 2, 3] * cm, [1, 2, 3], [True] * 3), + ], +) +def test_isclose(a, b, expected): + res = np.isclose(a, b) + np.testing.assert_array_equal(res, expected) + + +def test_iclose_error(): + x = [1, 2, 3] * cm + y = [1, 2, 3] * g + with pytest.raises(UnitConversionError): + np.isclose(x, y) + + +@pytest.mark.parametrize( + "func", + [ + np.linspace, + np.logspace, + np.geomspace, + ], +) +def test_xspace(func): + res = func(1 * cm, 11 * cm, 10) + assert type(res) is unyt_array + assert res.units == cm + + +def test_full_like(): + x = [1, 2, 3] * cm + res = np.full_like(x, 6 * cm) + assert type(res) is unyt_array + assert res.units == cm + + +@pytest.mark.parametrize( + "func", + [ + np.empty_like, + np.zeros_like, + np.ones_like, + ], +) +def test_x_like(func): + x = unyt_array([1, 2, 3], cm, dtype="float32") + res = func(x) + assert type(res) is unyt_array + assert res.units == x.units + assert res.shape == x.shape + assert res.dtype == x.dtype + + +def test_copy(): + x = [1, 2, 3] * cm + y = np.copy(x) + # by default, subok=False, so we shouldn't + # expect a unyt_array without switching this arg + assert type(y) is np.ndarray + + +@pytest.mark.skipif( + NUMPY_VERSION < Version("1.19"), reason="np.copy's subok arg requires numpy 1.19+" +) +def test_copy_subok(): + x = [1, 2, 3] * cm + y = np.copy(x, subok=True) + assert type(y) is unyt_array + assert y.units == cm + + +def test_copyto(): + x = [1, 2, 3] * cm + y = np.empty_like(x) + np.copyto(y, x) + assert type(y) is unyt_array + assert y.units == cm + np.testing.assert_array_equal(x, y) + + +def test_copyto_edge_cases(): + x = [1, 2, 3] * cm + y = [1, 2, 3] * g + # copying to an array with a different unit is supported + # to be in line with how we treat the 'out' param in most + # numpy operations + np.copyto(y, x) + assert type(y) is unyt_array + assert y.units == cm + + y = np.empty_like(x.view(np.ndarray)) + np.copyto(y, x) + assert type(y) is np.ndarray + + +def test_meshgrid(): + x = [1, 2, 3] * cm + y = [1, 2, 3] * s + x2d, y2d = np.meshgrid(x, y) + assert type(x2d) is unyt_array + assert type(y2d) is unyt_array + assert x2d.units == cm + assert y2d.units == s