Skip to content

Commit

Permalink
Merge pull request #22653 from rajasekharporeddy:testbranch1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 655980588
  • Loading branch information
jax authors committed Jul 25, 2024
2 parents d54bf52 + 4ff69d3 commit 9ea79c6
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5072,13 +5072,17 @@ def tril_indices(n: int, k: int = 0, m: int | None = None) -> tuple[Array, Array
@util.implements(np.triu_indices_from)
def triu_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
arr_shape = shape(arr)
return triu_indices(arr_shape[-2], k=k, m=arr_shape[-1])
if len(arr_shape) != 2:
raise ValueError("Only 2-D inputs are accepted")
return triu_indices(arr_shape[0], k=k, m=arr_shape[1])


@util.implements(np.tril_indices_from)
def tril_indices_from(arr: ArrayLike, k: int = 0) -> tuple[Array, Array]:
arr_shape = shape(arr)
return tril_indices(arr_shape[-2], k=k, m=arr_shape[-1])
if len(arr_shape) != 2:
raise ValueError("Only 2-D inputs are accepted")
return tril_indices(arr_shape[0], k=k, m=arr_shape[1])


@util.implements(np.fill_diagonal, lax_description="""
Expand Down

0 comments on commit 9ea79c6

Please sign in to comment.