Skip to content

Commit

Permalink
Merge pull request #19270 from jakevdp:permute-dims-sig
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 596992011
  • Loading branch information
jax authors committed Jan 9, 2024
2 parents 2356e57 + 707657e commit e76f514
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,9 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:


@util._wraps(getattr(np, "permute_dims", None))
def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array:
util.check_arraylike("permute_dims", x)
return lax.transpose(x, axes)
def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array:
util.check_arraylike("permute_dims", a)
return lax.transpose(a, axes)


@util._wraps(getattr(np, 'matrix_transpose', None))
Expand Down

0 comments on commit e76f514

Please sign in to comment.