Skip to content

Commit

Permalink
TST: add assert_array_equal_units to unyt.testing
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Jan 4, 2023
1 parent 508f138 commit 0c6a4cf
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
10 changes: 9 additions & 1 deletion unyt/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import warnings

from unyt.array import allclose_units
import numpy.testing as npt

from unyt.array import NULL_UNIT, allclose_units


def assert_allclose_units(actual, desired, rtol=1e-7, atol=0, **kwargs):
Expand Down Expand Up @@ -49,6 +51,12 @@ def assert_allclose_units(actual, desired, rtol=1e-7, atol=0, **kwargs):
raise AssertionError


def assert_array_equal_units(a, b):
# see https://github.com/yt-project/unyt/issues/281
npt.assert_array_equal(a, b)
assert getattr(a, "units", NULL_UNIT) == getattr(b, "units", NULL_UNIT)


def _process_warning(op, message, warning_class, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
Expand Down
10 changes: 2 additions & 8 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@

import numpy as np
import pytest
from numpy.testing import assert_array_equal
from packaging.version import Version

from unyt import K, cm, degC, degF, delta_degC, g, km, s
from unyt._array_functions import _HANDLED_FUNCTIONS as HANDLED_FUNCTIONS
from unyt.array import NULL_UNIT, unyt_array, unyt_quantity
from unyt.array import unyt_array, unyt_quantity
from unyt.exceptions import (
InvalidUnitOperation,
UnitConversionError,
UnitInconsistencyError,
UnytError,
)
from unyt.testing import assert_array_equal_units

NUMPY_VERSION = Version(version("numpy"))
# this is a subset of NOT_HANDLED_FUNCTIONS for which there's nothing to do
Expand Down Expand Up @@ -239,12 +239,6 @@ def get_wrapped_functions(*modules):
return dict(sorted(wrapped_functions.items()))


def assert_array_equal_units(a, b):
# see https://github.com/yt-project/unyt/issues/281
assert_array_equal(a, b)
assert getattr(a, "units", NULL_UNIT) == getattr(b, "units", NULL_UNIT)


def test_wrapping_completeness():
"""Ensure we wrap all numpy functions that support __array_function__"""
handled_numpy_functions = set(HANDLED_FUNCTIONS.keys())
Expand Down
11 changes: 5 additions & 6 deletions unyt/tests/test_unyt_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
UnitParseError,
UnitsNotReducible,
)
from unyt.testing import _process_warning, assert_allclose_units
from unyt.testing import (
_process_warning,
assert_allclose_units,
assert_array_equal_units,
)
from unyt.unit_registry import UnitRegistry
from unyt.unit_symbols import cm, degree, g, m

Expand All @@ -71,11 +75,6 @@ def assert_isinstance(a, type):
assert isinstance(a, type)


def assert_array_equal_units(a, b):
assert_array_equal(a, b)
assert_equal(a.units, b.units)


def test_addition():
"""
Test addition of two unyt_arrays
Expand Down

0 comments on commit 0c6a4cf

Please sign in to comment.