Skip to content

Commit

Permalink
[array api] fix linalg.solve and enable test
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 5, 2024
1 parent ed62f28 commit 3825738
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
11 changes: 4 additions & 7 deletions jax/experimental/array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,10 @@ def solve(x1, x2, /):
Returns the solution of a square system of linear equations with a unique solution.
"""
if x2.ndim == 1:
x2 = x2.reshape(*x1.shape[:-2], *x2.shape, 1)
return jax.numpy.linalg.solve(x1, x2)[..., 0]
if x2.ndim > x1.ndim:
x1 = x1.reshape(*x2.shape[:-2], *x1.shape)
elif x1.ndim > x2.ndim:
x2 = x2.reshape(*x1.shape[:-2], *x2.shape)
return jax.numpy.linalg.solve(x1, x2)
signature = "(m,m),(m)->(m)"
else:
signature = "(m,m),(m,n)->(m,n)"
return jax.numpy.vectorize(jax.numpy.linalg.solve, signature=signature)(x1, x2)


def svd(x, /, *, full_matrices=True):
Expand Down
2 changes: 0 additions & 2 deletions jax/experimental/array_api/skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ array_api_tests/test_special_cases.py::test_nan_propagation
array_api_tests/test_special_cases.py::test_unary
array_api_tests/test_array_object.py::test_setitem
array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_linalg.py::test_matrix_power
array_api_tests/test_linalg.py::test_solve

# fft test suite is buggy as of 83f0bcdc
array_api_tests/test_fft.py

0 comments on commit 3825738

Please sign in to comment.