Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

array api: add jnp.linalg.cross & jnp.linalg.outer #18928

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ jax.numpy.linalg

cholesky
cond
cross
det
eig
eigh
Expand All @@ -460,6 +461,7 @@ jax.numpy.linalg
matrix_rank
multi_dot
norm
outer
pinv
qr
slogdet
Expand Down
21 changes: 21 additions & 0 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,24 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *,
if numpy_resid:
return _lstsq(a, b, rcond, numpy_resid=True)
return _jit_lstsq(a, b, rcond)


@_wraps(getattr(np.linalg, "cross", None))
def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
check_arraylike("jnp.linalg.outer", x1, x2)
x1, x2 = jnp.asarray(x1), jnp.asarray(x2)
if x1.shape[axis] != 3 or x2.shape[axis] != 3:
raise ValueError(
"Both input arrays must be (arrays of) 3-dimensional vectors, "
f"but they have {x1.shape[axis]=} and {x2.shape[axis]=}"
)
return jnp.cross(x1, x2, axis=axis)


@_wraps(getattr(np.linalg, "outer", None))
def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
check_arraylike("jnp.linalg.outer", x1, x2)
x1, x2 = jnp.asarray(x1), jnp.asarray(x2)
if x1.ndim != 1 or x2.ndim != 1:
raise ValueError(f"Input arrays must be one-dimensional, but they are {x1.ndim=} {x2.ndim=}")
return x1[:, None] * x2[None, :]
4 changes: 2 additions & 2 deletions jax/experimental/array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def cross(x1, x2, /, *, axis=-1):
"""
Returns the cross product of 3-element vectors.
"""
return jax.numpy.cross(x1, x2, axis=axis)
return jax.numpy.linalg.cross(x1, x2, axis=axis)

def det(x, /):
"""
Expand Down Expand Up @@ -115,7 +115,7 @@ def outer(x1, x2, /):
"""
Returns the outer product of two vectors x1 and x2.
"""
return jax.numpy.outer(x1, x2)
return jax.numpy.linalg.outer(x1, x2)

def pinv(x, /, *, rtol=None):
"""
Expand Down
2 changes: 2 additions & 0 deletions jax/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from jax._src.numpy.linalg import (
cholesky as cholesky,
cross as cross,
det as det,
eig as eig,
eigh as eigh,
Expand All @@ -27,6 +28,7 @@
matrix_power as matrix_power,
matrix_rank as matrix_rank,
norm as norm,
outer as outer,
pinv as pinv,
qr as qr,
slogdet as slogdet,
Expand Down
39 changes: 39 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,45 @@ def f(inp):
cube_func = jax.jacfwd(hess_func)
self.assertFalse(np.any(np.isnan(cube_func(a))))

@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axis=axis)
for lhs_shape, rhs_shape, axis in [
[(3,), (3,), -1],
[(2, 3), (2, 3), -1],
[(3, 4), (3, 4), 0],
[(3, 5), (3, 4, 5), 0]
]],
lhs_dtype=jtu.dtypes.numeric,
rhs_dtype=jtu.dtypes.numeric,
)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testCross(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
lax_fun = partial(jnp.linalg.cross, axis=axis)
np_fun = jtu.promote_like_jnp(partial(
np.cross if jtu.numpy_version() < (2, 0, 0) else np.linalg.cross,
axis=axis))
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)

@jtu.sample_product(
lhs_shape=[(0,), (3,), (5,)],
rhs_shape=[(0,), (3,), (5,)],
lhs_dtype=jtu.dtypes.numeric,
rhs_dtype=jtu.dtypes.numeric,
)
def testOuter(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
lax_fun = jnp.linalg.outer
np_fun = jtu.promote_like_jnp(
np.outer if jtu.numpy_version() < (2, 0, 0) else np.linalg.outer)
with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]):
self._CheckAgainstNumpy(np_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)


class ScipyLinalgTest(jtu.JaxTestCase):

Expand Down