Skip to content

Commit

Permalink
array api: add jnp.linalg.cross & jnp.linalg.outer
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 12, 2023
1 parent 94d58b7 commit 5aec956
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ jax.numpy.linalg

cholesky
cond
cross
det
eig
eigh
Expand All @@ -460,6 +461,7 @@ jax.numpy.linalg
matrix_rank
multi_dot
norm
outer
pinv
qr
slogdet
Expand Down
21 changes: 21 additions & 0 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,24 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *,
if numpy_resid:
return _lstsq(a, b, rcond, numpy_resid=True)
return _jit_lstsq(a, b, rcond)


@_wraps(getattr(np.linalg, "cross", None))
def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
check_arraylike("jnp.linalg.outer", x1, x2)
x1, x2 = jnp.asarray(x1), jnp.asarray(x2)
if x1.shape[axis] != 3 or x2.shape[axis] != 3:
raise ValueError(
"Both input arrays must be (arrays of) 3-dimensional vectors, "
f"but they have {x1.shape[axis]=} and {x2.shape[axis]=}"
)
return jnp.cross(x1, x2, axis=axis)


@_wraps(getattr(np.linalg, "outer", None))
def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("jnp.linalg.outer", x1, x2)
x1, x2 = jnp.asarray(x1), jnp.asarray(x2)
if x1.ndim != 1 or x2.ndim != 1:
raise ValueError(f"Input arrays must be one-dimensional, but they are {x1.ndim=} {x2.ndim=}")
return x1[:, None] * x2[None, :]
4 changes: 2 additions & 2 deletions jax/experimental/array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def cross(x1, x2, /, *, axis=-1):
"""
Returns the cross product of 3-element vectors.
"""
return jax.numpy.cross(x1, x2, axis=axis)
return jax.numpy.linalg.cross(x1, x2, axis=axis)

def det(x, /):
"""
Expand Down Expand Up @@ -115,7 +115,7 @@ def outer(x1, x2, /):
"""
Returns the outer product of two vectors x1 and x2.
"""
return jax.numpy.outer(x1, x2)
return jax.numpy.linalg.outer(x1, x2)

def pinv(x, /, *, rtol=None):
"""
Expand Down
2 changes: 2 additions & 0 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from jax._src.numpy.linalg import (
cholesky as cholesky,
cross as cross,
det as det,
eig as eig,
eigh as eigh,
Expand All @@ -27,6 +28,7 @@
matrix_power as matrix_power,
matrix_rank as matrix_rank,
norm as norm,
outer as outer,
pinv as pinv,
qr as qr,
slogdet as slogdet,
Expand Down
42 changes: 42 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,48 @@ def f(inp):
cube_func = jax.jacfwd(hess_func)
self.assertFalse(np.any(np.isnan(cube_func(a))))

@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axis=axis)
for lhs_shape, rhs_shape, axis in [
[(3,), (3,), -1],
[(2, 3), (2, 3), -1],
[(3, 4), (3, 4), 0],
[(3, 5), (3, 4, 5), 0]
]],
lhs_dtype=jtu.dtypes.numeric,
rhs_dtype=jtu.dtypes.numeric,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testCross(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
lax_fun = partial(jnp.linalg.cross, axis=axis)
if jtu.numpy_version() < (2, 0, 0):
np_fun = partial(np.cross, axis=axis)
else:
np_fun = partial(np.linalg.cross, axis=axis)
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)

@jtu.sample_product(
lhs_shape=[(0,), (3,), (5,)],
rhs_shape=[(0,), (3,), (5,)],
lhs_dtype=jtu.dtypes.numeric,
rhs_dtype=jtu.dtypes.numeric,
)
def testOuter(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
lax_fun = jnp.linalg.outer
if jtu.numpy_version() < (2, 0, 0):
np_fun = np.outer
else:
np_fun = np.linalg.outer
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)


class ScipyLinalgTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 5aec956

Please sign in to comment.