Skip to content

Commit

Permalink
Merge pull request #19252 from jakevdp:fix-vecdot
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 596762791
  • Loading branch information
jax authors committed Jan 9, 2024
2 parents 9556b09 + f901bea commit 6a99e38
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
8 changes: 2 additions & 6 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,15 +744,11 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1) -> Array:
"""Computes the (vector) dot product of two arrays."""
check_arraylike("jnp.linalg.vecdot", x1, x2)
x1_arr, x2_arr = jnp.asarray(x1), jnp.asarray(x2)
rank = max(x1_arr.ndim, x2_arr.ndim)
x1_arr = jax.lax.broadcast_to_rank(x1_arr, rank)
x2_arr = jax.lax.broadcast_to_rank(x2_arr, rank)
if x1_arr.shape[axis] != x2_arr.shape[axis]:
raise ValueError("x1 and x2 must have the same size along specified axis.")
raise ValueError(f"axes must match; got shapes {x1_arr.shape} and {x2_arr.shape} with {axis=}")
x1_arr = jax.numpy.moveaxis(x1_arr, axis, -1)
x2_arr = jax.numpy.moveaxis(x2_arr, axis, -1)
# TODO(jakevdp): call lax.dot_general directly
return jax.numpy.matmul(x1_arr[..., None, :], x2_arr[..., None])[..., 0, 0]
return jax.numpy.vectorize(jnp.vdot, signature="(n),(n)->()")(x1_arr, x2_arr)


@_wraps(getattr(np.linalg, "matmul", None))
Expand Down
34 changes: 25 additions & 9 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

T = lambda x: np.swapaxes(x, -1, -2)


broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
int_types = jtu.dtypes.all_integer
Expand Down Expand Up @@ -654,26 +654,42 @@ def np_fn(x, *, ord, keepdims, axis):
self._CompileAndCheck(jnp_fn, args_maker)

@jtu.sample_product(
[dict(lhs_shape=(2, 3, 4), rhs_shape=(1, 4), axis=-1),
dict(lhs_shape=(2, 3, 4), rhs_shape=(2, 1, 1), axis=0),
dict(lhs_shape=(2, 3, 4), rhs_shape=(3, 4), axis=1)],
lhs_batch=broadcast_compatible_shapes,
rhs_batch=broadcast_compatible_shapes,
axis_size=[2, 4],
axis=range(-2, 2),
dtype=float_types + complex_types,
)
def testVecDot(self, lhs_shape, rhs_shape, axis, dtype):
@jax.default_matmul_precision("float32")
def testVecDot(self, lhs_batch, rhs_batch, axis_size, axis, dtype):
# Construct vecdot-compatible shapes.
size = min(len(lhs_batch), len(rhs_batch))
axis = int(np.clip(axis, -size - 1, size))
if axis >= 0:
lhs_shape = (*lhs_batch[:axis], axis_size, *lhs_batch[axis:])
rhs_shape = (*rhs_batch[:axis], axis_size, *rhs_batch[axis:])
else:
laxis = axis + len(lhs_batch) + 1
lhs_shape = (*lhs_batch[:laxis], axis_size, *lhs_batch[laxis:])
raxis = axis + len(rhs_batch) + 1
rhs_shape = (*rhs_batch[:raxis], axis_size, *rhs_batch[raxis:])

rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fn(x, y, axis=axis):
x, y = np.broadcast_arrays(x, y)
x = np.moveaxis(x, axis, -1)
y = np.moveaxis(y, axis, -1)
return np.matmul(x[..., None, :], y[..., None])[..., 0, 0]
x, y = np.broadcast_arrays(x, y)
return np.matmul(np.conj(x[..., None, :]), y[..., None])[..., 0, 0]
else:
np_fn = partial(np.linalg.vecdot, axis=axis)
np_fn = jtu.promote_like_jnp(np_fn, inexact=True)
jnp_fn = partial(jnp.linalg.vecdot, axis=axis)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)
tol = {np.float16: 1e-2, np.float32: 1E-3, np.float64: 1e-12,
np.complex64: 1E-3, np.complex128: 1e-12}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)

# jnp.linalg.matmul is an alias of jnp.matmul; do a minimal test here.
@jtu.sample_product(
Expand Down

0 comments on commit 6a99e38

Please sign in to comment.