From 9890b23b0ac25b79e33adb9139eae817b06b2bdd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 9 Jan 2024 13:23:57 -0800 Subject: [PATCH] Add jnp.vecdot --- docs/jax.numpy.rst | 1 + jax/_src/numpy/lax_numpy.py | 15 +++++++++++++ jax/_src/numpy/linalg.py | 9 +------- jax/_src/test_util.py | 10 +++++++++ jax/numpy/__init__.py | 1 + jax/numpy/__init__.pyi | 3 +++ tests/lax_numpy_test.py | 42 +++++++++++++++++++++++++++++++++++-- tests/linalg_test.py | 42 +++++++++++-------------------------- 8 files changed, 83 insertions(+), 40 deletions(-) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index 4f94b0672b41..36e6ceae6638 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -415,6 +415,7 @@ namespace; they are listed below. vander var vdot + vecdot vectorize vsplit vstack diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 49c06e1581b1..c9dcb674fadd 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3389,6 +3389,21 @@ def vdot( preferred_element_type=preferred_element_type) +@util._wraps(getattr(np, "vecdot", None), lax_description=_PRECISION_DOC, + extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) +def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1, + precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: + util.check_arraylike("jnp.vecdot", x1, x2) + x1_arr, x2_arr = asarray(x1), asarray(x2) + if x1_arr.shape[axis] != x2_arr.shape[axis]: + raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}") + x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1) + x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1) + return vectorize(partial(vdot, precision=precision, preferred_element_type=preferred_element_type), + signature="(n),(n)->()")(x1_arr, x2_arr) + + @util._wraps(np.tensordot, lax_description=_PRECISION_DOC, extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION) def tensordot(a: ArrayLike, b: ArrayLike, diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index f6d92f9bdfc6..5e50b7ef13b5 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -741,14 +741,7 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa @_wraps(getattr(np.linalg, "vecdot", None)) def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array: - """Computes the (vector) dot product of two arrays.""" - check_arraylike("jnp.linalg.vecdot", x1, x2) - x1_arr, x2_arr = jnp.asarray(x1), jnp.asarray(x2) - if x1_arr.shape[axis] != x2_arr.shape[axis]: - raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}") - x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1) - x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1) - return jax.numpy.vectorize(jnp.vdot, signature="(n),(n)->()")(x1_arr, x2_arr) + return jnp.vecdot(x1, x2, axis=axis) @_wraps(getattr(np.linalg, "matmul", None)) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 9efdbf92dcc5..48c7de087c00 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1370,3 +1370,13 @@ def fwd_bwd_jaxprs(f, *example_args): lambda *args: jax.vjp(f, *args), return_shape=True)(*example_args) bwd_jaxpr = jax.make_jaxpr(lambda res, outs: res(outs))(res_shape, y_shape) return fwd_jaxpr, bwd_jaxpr + + +def numpy_vecdot(x, y, axis): + """Implementation of numpy.vecdot for testing on numpy < 2.0.0""" + if numpy_version() >= (2, 0, 0): + raise ValueError("should be calling vecdot directly on numpy 2.0.0") + x = np.moveaxis(x, axis, -1) + y = np.moveaxis(y, axis, -1) + x, y = np.broadcast_arrays(x, y) + return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0] diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index d623d7cdaa36..6b51f45fa2d1 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -249,6 +249,7 @@ unwrap as unwrap, vander as vander, vdot as vdot, + vecdot as vecdot, vsplit as vsplit, vstack as vstack, where as where, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 9b574fafb39f..e60fa5a4e098 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -839,6 +839,9 @@ def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., def vdot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... +def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = ..., + precision: PrecisionLike = ..., + preferred_element_type: Optional[DTypeLike] = ...) -> Array: ... def vsplit( ary: ArrayLike, indices_or_sections: Union[int, ArrayLike] ) -> list[Array]: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7d0537b21530..e0d34eac6f9f 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -62,6 +62,7 @@ nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes one_dim_array_shapes = [(1,), (6,), (12,)] empty_array_shapes = [(0,), (0, 4), (3, 0),] +broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] array_shapes = nonempty_array_shapes + empty_array_shapes @@ -554,6 +555,39 @@ def np_fun(x, y): self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol) self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol) + @jtu.sample_product( + lhs_batch=broadcast_compatible_shapes, + rhs_batch=broadcast_compatible_shapes, + axis_size=[2, 4], + axis=range(-2, 2), + dtype=number_dtypes, + ) + @jax.default_matmul_precision("float32") + def testVecdot(self, lhs_batch, rhs_batch, axis_size, axis, dtype): + # Construct vecdot-compatible shapes. + size = min(len(lhs_batch), len(rhs_batch)) + axis = int(np.clip(axis, -size - 1, size)) + if axis >= 0: + lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:]) + rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:]) + else: + laxis = axis + len(lhs_batch) + 1 + lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:]) + raxis = axis + len(rhs_batch) + 1 + rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:]) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + @jtu.promote_like_jnp + def np_fn(x, y, axis=axis): + f = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.vecdot + return f(x, y, axis=axis).astype(x.dtype) + jnp_fn = partial(jnp.vecdot, axis=axis) + tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12, + np.complex64: 1E-3, np.complex128: 1e-12} + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) + self._CompileAndCheck(jnp_fn, args_maker, tol=tol) + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes) for lhs_shape, rhs_shape, axes in [ @@ -5050,6 +5084,10 @@ def testPrecision(self): HIGHEST, partial(jnp.vdot, precision=HIGHEST), ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(jnp.vecdot, precision=HIGHEST), + ones_1d, ones_1d) jtu.assert_dot_precision( HIGHEST, partial(jnp.tensordot, axes=2, precision=HIGHEST), @@ -5076,7 +5114,7 @@ def testPrecision(self): ones_1d, ones_1d) @jtu.sample_product( - funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot'] + funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot', 'vecdot'] ) def testPreferredElementType(self, funcname): func = getattr(jnp, funcname) @@ -5615,7 +5653,7 @@ def testWrappedSignaturesMatch(self): # TODO(jakevdp): implement missing ufuncs UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_invert', 'bitwise_left_shift', - 'bitwise_right_shift', 'pow', 'vecdot'} + 'bitwise_right_shift', 'pow'} def _all_numpy_ufuncs() -> Iterator[str]: diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 037849f38084..92c1d6b0288c 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -43,7 +43,6 @@ T = lambda x: np.swapaxes(x, -1, -2) -broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] float_types = jtu.dtypes.floating complex_types = jtu.dtypes.complex int_types = jtu.dtypes.all_integer @@ -653,41 +652,24 @@ def np_fn(x, *, ord, keepdims, axis): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3) self._CompileAndCheck(jnp_fn, args_maker) + # jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here. @jtu.sample_product( - lhs_batch=broadcast_compatible_shapes, - rhs_batch=broadcast_compatible_shapes, - axis_size=[2, 4], - axis=range(-2, 2), - dtype=float_types + complex_types, + [ + dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axis=0), + dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axis=1), + dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axis=-1), + ], + dtype=int_types + float_types + complex_types ) @jax.default_matmul_precision("float32") - def testVecDot(self, lhs_batch, rhs_batch, axis_size, axis, dtype): - # Construct vecdot-compatible shapes. - size = min(len(lhs_batch), len(rhs_batch)) - axis = int(np.clip(axis, -size - 1, size)) - if axis >= 0: - lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:]) - rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:]) - else: - laxis = axis + len(lhs_batch) + 1 - lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:]) - raxis = axis + len(rhs_batch) + 1 - rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:]) - + def testVecdot(self, lhs_shape, rhs_shape, axis, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] - if jtu.numpy_version() < (2, 0, 0): - def np_fn(x, y, axis=axis): - x = np.moveaxis(x, axis, -1) - y = np.moveaxis(y, axis, -1) - x, y = np.broadcast_arrays(x, y) - return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0] - else: - np_fn = partial(np.linalg.vecdot, axis=axis) - np_fn = jtu.promote_like_jnp(np_fn, inexact=True) + np_fn = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.linalg.vecdot + np_fn = jtu.promote_like_jnp(partial(np_fn, axis=axis)) jnp_fn = partial(jnp.linalg.vecdot, axis=axis) - tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12, - np.complex64: 1E-3, np.complex128: 1e-12} + tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, + np.complex128: 1e-12} self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol) self._CompileAndCheck(jnp_fn, args_maker, tol=tol)