Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: (NEP 18) implement and test np.in1d and np.interp #398

Merged
merged 2 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,12 @@ def isin(element, test_elements, *args, **kwargs):
)


@implements(np.in1d)
def in1d(ar1, ar2, *args, **kwargs):
_validate_units_consistency((ar1, ar2))
return np.isin._implementation(np.asarray(ar1), np.asarray(ar2), *args, **kwargs)


@implements(np.place)
def place(arr, mask, vals, *args, **kwargs) -> None:
_validate_units_consistency_v2(arr.units, vals)
Expand Down Expand Up @@ -938,3 +944,18 @@ def tensordot(a, b, *args, **kwargs):
def unwrap(p, *args, **kwargs):
ret_units = p.units
return np.unwrap._implementation(p.view(np.ndarray), *args, **kwargs) * ret_units


@implements(np.interp)
def interp(x, xp, fp, *args, **kwargs):
_validate_units_consistency((x, xp))

# return array type should match fp's
# so, the fallback multiplier is 1 instead of NULL_UNITS
# This avoid leaking a dimensionless unyt_array if reference data
# is a pure np.ndarray
ret_units = getattr(fp, "units", 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ret_units = getattr(fp, "units", 1)
ret_units = getattr(fp, "units", NULL_UNIT)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't really make much difference, the 1 just made me double-take reading this, NULL_UNIT makes the logic clearer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally didn't use get_units (which is equivalent to your suggestion) because I wanted the output to match fp's type (so a pure ndarray wouldn't leak as a dimensionless unyt_array). I could add a comment to make that intention clearer, or I could just go with your suggestion and change the test too, no strong opinions actually. Your call !

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good call please add a comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return (
np.interp(np.asarray(x), np.asarray(xp), np.asarray(fp), *args, **kwargs)
* ret_units
)
41 changes: 39 additions & 2 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@
# this set represents all functions that need inspection, tests, or both
# it is always possible that some of its elements belong in NOOP_FUNCTIONS
TODO_FUNCTIONS = {
np.in1d,
np.interp,
np.ix_,
np.linalg.svd,
}
Expand Down Expand Up @@ -1412,6 +1410,18 @@ def test_isin():
assert np.isin(1 * cm, a)


def test_in1d_mixed_units():
a = [1, 2, 3] * cm
with pytest.raises(UnitInconsistencyError):
np.in1d([1, 2], a)


def test_in1d():
a = [1, 2, 3] * cm
b = [1, 2] * cm
assert np.all(np.in1d(b, a))


def test_place_mixed_units():
arr = np.arange(6).reshape(2, 3) * cm
with pytest.raises(UnitInconsistencyError):
Expand Down Expand Up @@ -1662,3 +1672,30 @@ def test_unwrap():
res = np.unwrap(phase)
assert type(res) is unyt_array
assert res.units == rad


def test_interp():
_x = np.array([1.1, 2.2, 3.3])
_xp = np.array([1, 2, 3])
_fp = np.array([4, 8, 12])

# any of the three input array-like might be unitful
# let's test all relevant combinations
# return type should match fp's

with pytest.raises(UnitInconsistencyError):
np.interp(_x * cm, _xp, _fp)

with pytest.raises(UnitInconsistencyError):
res = np.interp(_x, _xp * cm, _fp)

res = np.interp(_x * cm, _xp * cm, _fp)
assert type(res) is np.ndarray

res = np.interp(_x, _xp, _fp * K)
assert type(res) is unyt_array
assert res.units == K

res = np.interp(_x * cm, _xp * cm, _fp * K)
assert type(res) is unyt_array
assert res.units == K