diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index d4cfd63c0f0a..12e795ecf582 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -449,6 +449,7 @@ jax.numpy.linalg cholesky cond + cross det eig eigh @@ -460,6 +461,7 @@ jax.numpy.linalg matrix_rank multi_dot norm + outer pinv qr slogdet diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 0934e97d35cc..3a8410acf384 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -677,3 +677,24 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *, if numpy_resid: return _lstsq(a, b, rcond, numpy_resid=True) return _jit_lstsq(a, b, rcond) + + +@_wraps(getattr(np.linalg, "cross", None)) +def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1): + check_arraylike("jnp.linalg.outer", x1, x2) + x1, x2 = jnp.asarray(x1), jnp.asarray(x2) + if x1.shape[axis] != 3 or x2.shape[axis] != 3: + raise ValueError( + "Both input arrays must be (arrays of) 3-dimensional vectors, " + f"but they have {x1.shape[axis]=} and {x2.shape[axis]=}" + ) + return jnp.cross(x1, x2, axis=axis) + + +@_wraps(getattr(np.linalg, "outer", None)) +def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array: + check_arraylike("jnp.linalg.outer", x1, x2) + x1, x2 = jnp.asarray(x1), jnp.asarray(x2) + if x1.ndim != 1 or x2.ndim != 1: + raise ValueError(f"Input arrays must be one-dimensional, but they are {x1.ndim=} {x2.ndim=}") + return x1[:, None] * x2[None, :] diff --git a/jax/experimental/array_api/_linear_algebra_functions.py b/jax/experimental/array_api/_linear_algebra_functions.py index 2ce616afd3b7..154d963a765d 100644 --- a/jax/experimental/array_api/_linear_algebra_functions.py +++ b/jax/experimental/array_api/_linear_algebra_functions.py @@ -47,7 +47,7 @@ def cross(x1, x2, /, *, axis=-1): """ Returns the cross product of 3-element vectors. """ - return jax.numpy.cross(x1, x2, axis=axis) + return jax.numpy.linalg.cross(x1, x2, axis=axis) def det(x, /): """ @@ -115,7 +115,7 @@ def outer(x1, x2, /): """ Returns the outer product of two vectors x1 and x2. """ - return jax.numpy.outer(x1, x2) + return jax.numpy.linalg.outer(x1, x2) def pinv(x, /, *, rtol=None): """ diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index 88f9b3781603..42536822ce5a 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -17,6 +17,7 @@ from jax._src.numpy.linalg import ( cholesky as cholesky, + cross as cross, det as det, eig as eig, eigh as eigh, @@ -27,6 +28,7 @@ matrix_power as matrix_power, matrix_rank as matrix_rank, norm as norm, + outer as outer, pinv as pinv, qr as qr, slogdet as slogdet, diff --git a/tests/linalg_test.py b/tests/linalg_test.py index c4295d53535c..1f73471a307a 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1092,6 +1092,45 @@ def f(inp): cube_func = jax.jacfwd(hess_func) self.assertFalse(np.any(np.isnan(cube_func(a)))) + @jtu.sample_product( + [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axis=axis) + for lhs_shape, rhs_shape, axis in [ + [(3,), (3,), -1], + [(2, 3), (2, 3), -1], + [(3, 4), (3, 4), 0], + [(3, 5), (3, 4, 5), 0] + ]], + lhs_dtype=jtu.dtypes.numeric, + rhs_dtype=jtu.dtypes.numeric, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testCross(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + lax_fun = partial(jnp.linalg.cross, axis=axis) + np_fun = jtu.promote_like_jnp(partial( + np.cross if jtu.numpy_version() < (2, 0, 0) else np.linalg.cross, + axis=axis)) + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker) + + @jtu.sample_product( + lhs_shape=[(0,), (3,), (5,)], + rhs_shape=[(0,), (3,), (5,)], + lhs_dtype=jtu.dtypes.numeric, + rhs_dtype=jtu.dtypes.numeric, + ) + def testOuter(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + lax_fun = jnp.linalg.outer + np_fun = jtu.promote_like_jnp( + np.outer if jtu.numpy_version() < (2, 0, 0) else np.linalg.outer) + with jtu.strict_promotion_if_dtypes_match([lhs_dtype, rhs_dtype]): + self._CheckAgainstNumpy(np_fun, lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker) + class ScipyLinalgTest(jtu.JaxTestCase):