diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index e1d99ee8c3e2..0934e97d35cc 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -606,6 +606,8 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Union[Array, tuple[Array, Array]] def solve(a: ArrayLike, b: ArrayLike) -> Array: check_arraylike("jnp.linalg.solve", a, b) a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) + if a.ndim >= 2 and b.ndim > a.ndim: + a = lax.expand_dims(a, tuple(range(b.ndim - a.ndim))) return lax_linalg._solve(a, b) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index d239fffc9861..c4295d53535c 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -897,6 +897,7 @@ def tensor_maker(): ((4, 4), (4,)), ((8, 8), (8, 4)), ((1, 2, 2), (3, 2)), + ((2, 2), (3, 2, 2)), ((2, 1, 3, 3), (1, 4, 3, 4)), ((1, 0, 0), (1, 0, 2)), ]