Skip to content

Commit

Permalink
Merge pull request #19274 from jakevdp:vecdot
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597347328
  • Loading branch information
jax authors committed Jan 10, 2024
2 parents 6174145 + 9890b23 commit 0c4b680
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 40 deletions.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ namespace; they are listed below.
vander
var
vdot
vecdot
vectorize
vsplit
vstack
Expand Down
15 changes: 15 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3389,6 +3389,21 @@ def vdot(
preferred_element_type=preferred_element_type)


@util._wraps(getattr(np, "vecdot", None), lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array:
util.check_arraylike("jnp.vecdot", x1, x2)
x1_arr, x2_arr = asarray(x1), asarray(x2)
if x1_arr.shape[axis] != x2_arr.shape[axis]:
raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}")
x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1)
x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1)
return vectorize(partial(vdot, precision=precision, preferred_element_type=preferred_element_type),
signature="(n),(n)->()")(x1_arr, x2_arr)


@util._wraps(np.tensordot, lax_description=_PRECISION_DOC,
extra_params=_DOT_PREFERRED_ELEMENT_TYPE_DESCRIPTION)
def tensordot(a: ArrayLike, b: ArrayLike,
Expand Down
9 changes: 1 addition & 8 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,14 +741,7 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa

@_wraps(getattr(np.linalg, "vecdot", None))
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
"""Computes the (vector) dot product of two arrays."""
check_arraylike("jnp.linalg.vecdot", x1, x2)
x1_arr, x2_arr = jnp.asarray(x1), jnp.asarray(x2)
if x1_arr.shape[axis] != x2_arr.shape[axis]:
raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}")
x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1)
x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1)
return jax.numpy.vectorize(jnp.vdot, signature="(n),(n)->()")(x1_arr, x2_arr)
return jnp.vecdot(x1, x2, axis=axis)


@_wraps(getattr(np.linalg, "matmul", None))
Expand Down
10 changes: 10 additions & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,3 +1370,13 @@ def fwd_bwd_jaxprs(f, *example_args):
lambda *args: jax.vjp(f, *args), return_shape=True)(*example_args)
bwd_jaxpr = jax.make_jaxpr(lambda res, outs: res(outs))(res_shape, y_shape)
return fwd_jaxpr, bwd_jaxpr


def numpy_vecdot(x, y, axis):
"""Implementation of numpy.vecdot for testing on numpy < 2.0.0"""
if numpy_version() >= (2, 0, 0):
raise ValueError("should be calling vecdot directly on numpy 2.0.0")
x = np.moveaxis(x, axis, -1)
y = np.moveaxis(y, axis, -1)
x, y = np.broadcast_arrays(x, y)
return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0]
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@
unwrap as unwrap,
vander as vander,
vdot as vdot,
vecdot as vecdot,
vsplit as vsplit,
vstack as vstack,
where as where,
Expand Down
3 changes: 3 additions & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,9 @@ def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
def vdot(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ...,
preferred_element_type: Optional[DTypeLike] = ...) -> Array: ...
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = ...,
precision: PrecisionLike = ...,
preferred_element_type: Optional[DTypeLike] = ...) -> Array: ...
def vsplit(
ary: ArrayLike, indices_or_sections: Union[int, ArrayLike]
) -> list[Array]: ...
Expand Down
42 changes: 40 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes
one_dim_array_shapes = [(1,), (6,), (12,)]
empty_array_shapes = [(0,), (0, 4), (3, 0),]
broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]

scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE]
array_shapes = nonempty_array_shapes + empty_array_shapes
Expand Down Expand Up @@ -554,6 +555,39 @@ def np_fun(x, y):
self._CheckAgainstNumpy(np_fun, jnp.matmul, args_maker, tol=tol)
self._CompileAndCheck(jnp.matmul, args_maker, atol=tol, rtol=tol)

@jtu.sample_product(
lhs_batch=broadcast_compatible_shapes,
rhs_batch=broadcast_compatible_shapes,
axis_size=[2, 4],
axis=range(-2, 2),
dtype=number_dtypes,
)
@jax.default_matmul_precision("float32")
def testVecdot(self, lhs_batch, rhs_batch, axis_size, axis, dtype):
# Construct vecdot-compatible shapes.
size = min(len(lhs_batch), len(rhs_batch))
axis = int(np.clip(axis, -size - 1, size))
if axis >= 0:
lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:])
rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:])
else:
laxis = axis + len(lhs_batch) + 1
lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:])
raxis = axis + len(rhs_batch) + 1
rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:])

rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
@jtu.promote_like_jnp
def np_fn(x, y, axis=axis):
f = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.vecdot
return f(x, y, axis=axis).astype(x.dtype)
jnp_fn = partial(jnp.vecdot, axis=axis)
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
np.complex64: 1E-3, np.complex128: 1e-12}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

@jtu.sample_product(
[dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, axes=axes)
for lhs_shape, rhs_shape, axes in [
Expand Down Expand Up @@ -5050,6 +5084,10 @@ def testPrecision(self):
HIGHEST,
partial(jnp.vdot, precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.vecdot, precision=HIGHEST),
ones_1d, ones_1d)
jtu.assert_dot_precision(
HIGHEST,
partial(jnp.tensordot, axes=2, precision=HIGHEST),
Expand All @@ -5076,7 +5114,7 @@ def testPrecision(self):
ones_1d, ones_1d)

@jtu.sample_product(
funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot']
funcname=['inner', 'matmul', 'dot', 'vdot', 'tensordot', 'vecdot']
)
def testPreferredElementType(self, funcname):
func = getattr(jnp, funcname)
Expand Down Expand Up @@ -5615,7 +5653,7 @@ def testWrappedSignaturesMatch(self):

# TODO(jakevdp): implement missing ufuncs
UNIMPLEMENTED_UFUNCS = {'spacing', 'bitwise_invert', 'bitwise_left_shift',
'bitwise_right_shift', 'pow', 'vecdot'}
'bitwise_right_shift', 'pow'}


def _all_numpy_ufuncs() -> Iterator[str]:
Expand Down
42 changes: 12 additions & 30 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

T = lambda x: np.swapaxes(x, -1, -2)

broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
int_types = jtu.dtypes.all_integer
Expand Down Expand Up @@ -653,41 +652,24 @@ def np_fn(x, *, ord, keepdims, axis):
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)

# jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here.
@jtu.sample_product(
lhs_batch=broadcast_compatible_shapes,
rhs_batch=broadcast_compatible_shapes,
axis_size=[2, 4],
axis=range(-2, 2),
dtype=float_types + complex_types,
[
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axis=0),
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axis=1),
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axis=-1),
],
dtype=int_types + float_types + complex_types
)
@jax.default_matmul_precision("float32")
def testVecDot(self, lhs_batch, rhs_batch, axis_size, axis, dtype):
# Construct vecdot-compatible shapes.
size = min(len(lhs_batch), len(rhs_batch))
axis = int(np.clip(axis, -size - 1, size))
if axis >= 0:
lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:])
rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:])
else:
laxis = axis + len(lhs_batch) + 1
lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:])
raxis = axis + len(rhs_batch) + 1
rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:])

def testVecdot(self, lhs_shape, rhs_shape, axis, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fn(x, y, axis=axis):
x = np.moveaxis(x, axis, -1)
y = np.moveaxis(y, axis, -1)
x, y = np.broadcast_arrays(x, y)
return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0]
else:
np_fn = partial(np.linalg.vecdot, axis=axis)
np_fn = jtu.promote_like_jnp(np_fn, inexact=True)
np_fn = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.linalg.vecdot
np_fn = jtu.promote_like_jnp(partial(np_fn, axis=axis))
jnp_fn = partial(jnp.linalg.vecdot, axis=axis)
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
np.complex64: 1E-3, np.complex128: 1e-12}
tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12,
np.complex128: 1e-12}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

Expand Down

0 comments on commit 0c4b680

Please sign in to comment.