Skip to content

Commit

Permalink
FIX: add support for numpy.dot
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Aug 20, 2021
1 parent 9b1d18c commit 3a51a41
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 1 deletion.
21 changes: 21 additions & 0 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from unyt.exceptions import UnitConversionError

_HANDLED_FUNCTIONS = {}

Expand All @@ -18,6 +19,26 @@ def array2string(a, *args, **kwargs):
return np.array2string._implementation(a, *args, **kwargs) + f" {a.units}"


@implements(np.dot)
def dot(a, b, out=None):
prod_units = a.units * b.units
if out is None:
return np.dot._implementation(a.ndview, b.ndview) * prod_units

try:
conv_factor = (1 * out.units).to(prod_units).d
except (AttributeError, UnitConversionError) as exc:
raise TypeError(
"output array is not acceptable "
f"(units '{out.units}' cannot be converted to '{prod_units}')"
) from exc

np.dot._implementation(a.ndview, b.ndview, out=out.ndview)
if not out.units == prod_units:
out[:] *= conv_factor
return out


@implements(np.linalg.inv)
def linalg_inv(a, *args, **kwargs):
return np.linalg.inv._implementation(a.ndview, *args, **kwargs) / a.units
Expand Down
64 changes: 63 additions & 1 deletion unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,76 @@
# tests for NEP 18
import re

import numpy as np
import pytest

from unyt import cm, s, g
from unyt import cm, s, g, km
from unyt.array import unyt_array


def test_array_repr():
arr = [1, 2, 3] * cm
assert np.array_repr(arr) == "unyt_array([1, 2, 3] cm)"


def test_dot_vectors():
a = [1, 2, 3] * cm
b = [1, 2, 3] * s
res = np.dot(a, b)
assert res.units == cm * s
assert res.d == 14


@pytest.mark.parametrize(
"out",
[
None,
np.empty((3, 3), dtype="int64"),
np.empty((3, 3), dtype="int64") * cm * s,
np.empty((3, 3), dtype="int64") * km * s,
],
ids=[
"None",
"pure ndarray",
"same units",
"convertible units",
],
)
def test_dot_matrices(out):
a = np.arange(9) * cm
a.shape = (3, 3)
b = np.arange(9) * s
b.shape = (3, 3)

res = np.dot(a, b, out=out)

if out is not None:
np.testing.assert_array_equal(res, out)
assert res is out

if isinstance(out, unyt_array):
# check that the result can be converted to predictible units
res.in_units("cm * s")
assert out.units == res.units


def test_invalid_dot_matrices():
a = np.arange(9) * cm
a.shape = (3, 3)
b = np.arange(9) * s
b.shape = (3, 3)

out = np.empty((3, 3), dtype="int64") * s ** 2
with pytest.raises(
TypeError,
match=re.escape(
"output array is not acceptable "
"(units 's**2' cannot be converted to 'cm*s')"
),
):
np.dot(a, b, out=out)


def test_linalg_inv():
arr = np.random.random_sample((3, 3)) * cm
iarr = np.linalg.inv(arr)
Expand Down

0 comments on commit 3a51a41

Please sign in to comment.