Skip to content

Commit

Permalink
Merge pull request #19700 from jakevdp:fix-cholesky-upper
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605071135
  • Loading branch information
jax authors committed Feb 7, 2024
2 parents a9a3865 + 57f27e7 commit ff65926
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
5 changes: 2 additions & 3 deletions jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
check_arraylike("jnp.linalg.cholesky", a)
a, = promote_dtypes_inexact(jnp.asarray(a))
if upper:
a = jax.numpy.matrix_transpose(a).conj()
return lax_linalg.cholesky(a)
L = lax_linalg.cholesky(a)
return L.mT.conj() if upper else L

@overload
def svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[True],
Expand Down
5 changes: 3 additions & 2 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,12 @@ def np_fun(x, upper=upper):
# Upper argument added in NumPy 2.0.0
if jtu.numpy_version() >= (2, 0, 0):
return np.linalg.cholesky(x, upper=upper)
result = np.linalg.cholesky(x)
if upper:
axes = list(range(x.ndim))
axes[-1], axes[-2] = axes[-2], axes[-1]
x = np.transpose(x, axes).conj()
return np.linalg.cholesky(x)
return np.transpose(result, axes).conj()
return result

self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=1e-3)
Expand Down

0 comments on commit ff65926

Please sign in to comment.