diff --git a/CHANGELOG.md b/CHANGELOG.md index a4af11e7f72f..486998b6321b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -110,6 +110,9 @@ Remember to align the itemized text with the first line of an item within a list such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching. * {func}`jax.lax.tie_in` is deprecated: it has been a no-op since JAX v0.2.0. + * {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D + solves with `b.ndim > 1`. In the future these will be treated as batched 2D + solves. ## jaxlib 0.4.24 diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index cf1b1c9b5eb3..082e99cdf794 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -16,6 +16,7 @@ from collections.abc import Sequence from functools import partial +import warnings import numpy as np import textwrap @@ -635,9 +636,15 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: def solve(a: ArrayLike, b: ArrayLike) -> Array: check_arraylike("jnp.linalg.solve", a, b) a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) - # TODO(jakevdp): this condition matches the broadcasting behavior in numpy < 2.0. - # For the array API specification, we would check only if b.ndim == 1. - if b.ndim == 1 or a.ndim == b.ndim + 1: + + if b.ndim == 1: + signature = "(m,m),(m)->(m)" + elif a.ndim == b.ndim + 1: + # Deprecation warning added 2024-02-06 + warnings.warn("Batched 1D solves with b.ndim > 1 are deprecated, " + "and in the future will be treated as a batched 2D solve. " + "Use solve(a, b[..., None])[..., 0] to avoid this warning.", + category=DeprecationWarning, stacklevel=2) signature = "(m,m),(m)->(m)" else: signature = "(m,m),(m,n)->(m,n)" diff --git a/tests/linalg_test.py b/tests/linalg_test.py index d62fea067cc7..116a18c986a9 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1035,6 +1035,7 @@ def testSolve(self, lhs_shape, rhs_shape, dtype): lhs_shape=[(2, 2), (2, 2, 2), (2, 2, 2, 2), (2, 2, 2, 2, 2)], rhs_shape=[(2,), (2, 2), (2, 2, 2), (2, 2, 2, 2)] ) + @jtu.ignore_warning(category=DeprecationWarning, message="Batched 1D solves") def testSolveBroadcasting(self, lhs_shape, rhs_shape): # Batched solve can involve some ambiguities; this test checks # that we match NumPy's convention in all cases. @@ -1196,6 +1197,7 @@ def test(x): self.assertAllClose(xc, grad_test_jc(xc)) @jtu.skip_on_flag("jax_skip_slow_tests", True) + @jtu.ignore_warning(category=DeprecationWarning, message="Batched 1D solves") def testIssue1151(self): rng = self.rng() A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32)