diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3db6aa119930..49c06e1581b1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -543,9 +543,9 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: @util._wraps(getattr(np, "permute_dims", None)) -def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array: - util.check_arraylike("permute_dims", x) - return lax.transpose(x, axes) +def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: + util.check_arraylike("permute_dims", a) + return lax.transpose(a, axes) @util._wraps(getattr(np, 'matrix_transpose', None))