Skip to content

Commit

Permalink
jnp.linalg.solve: deprecate batched 1D solves when b.ndim > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 6, 2024
1 parent 299b983 commit c107e96
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ Remember to align the itemized text with the first line of an item within a list
such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated
and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching.
* {func}`jax.lax.tie_in` is deprecated: it has been a no-op since JAX v0.2.0.
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
solves with `b.ndim > 1`. In the future these will be treated as batched 2D
solves.

## jaxlib 0.4.24

Expand Down
13 changes: 10 additions & 3 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Sequence
from functools import partial
import warnings

import numpy as np
import textwrap
Expand Down Expand Up @@ -635,9 +636,15 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
def solve(a: ArrayLike, b: ArrayLike) -> Array:
check_arraylike("jnp.linalg.solve", a, b)
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
# TODO(jakevdp): this condition matches the broadcasting behavior in numpy < 2.0.
# For the array API specification, we would check only if b.ndim == 1.
if b.ndim == 1 or a.ndim == b.ndim + 1:

if b.ndim == 1:
signature = "(m,m),(m)->(m)"
elif a.ndim == b.ndim + 1:
# Deprecation warning added 2024-02-06
warnings.warn("Batched 1D solves with b.ndim > 1 are deprecated, "
"and in the future will be treated as a batched 2D solve. "
"Use solve(a, b[..., None])[..., 0] to avoid this warning.",
category=DeprecationWarning, stacklevel=2)
signature = "(m,m),(m)->(m)"
else:
signature = "(m,m),(m,n)->(m,n)"
Expand Down
2 changes: 2 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,7 @@ def testSolve(self, lhs_shape, rhs_shape, dtype):
lhs_shape=[(2, 2), (2, 2, 2), (2, 2, 2, 2), (2, 2, 2, 2, 2)],
rhs_shape=[(2,), (2, 2), (2, 2, 2), (2, 2, 2, 2)]
)
@jtu.ignore_warning(category=DeprecationWarning, message="Batched 1D solves")
def testSolveBroadcasting(self, lhs_shape, rhs_shape):
# Batched solve can involve some ambiguities; this test checks
# that we match NumPy's convention in all cases.
Expand Down Expand Up @@ -1196,6 +1197,7 @@ def test(x):
self.assertAllClose(xc, grad_test_jc(xc))

@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.ignore_warning(category=DeprecationWarning, message="Batched 1D solves")
def testIssue1151(self):
rng = self.rng()
A = jnp.array(rng.randn(100, 3, 3), dtype=jnp.float32)
Expand Down

0 comments on commit c107e96

Please sign in to comment.