From 707657e5b7ccd0120143ea1ed027f588cc3268eb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 9 Jan 2024 09:56:19 -0800 Subject: [PATCH] Adjust permute_dims signature to match NumPy This really doesn't matter because it's a position-only argument, but this change satisfies our tests and is easier than making the tests smarter. --- jax/_src/numpy/lax_numpy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3db6aa119930..49c06e1581b1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -543,9 +543,9 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: @util._wraps(getattr(np, "permute_dims", None)) -def permute_dims(x: ArrayLike, /, axes: tuple[int, ...]) -> Array: - util.check_arraylike("permute_dims", x) - return lax.transpose(x, axes) +def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: + util.check_arraylike("permute_dims", a) + return lax.transpose(a, axes) @util._wraps(getattr(np, 'matrix_transpose', None))