From 62dbabad79a0fd8931ea71cdb3ac71f39605a267 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 24 Nov 2024 15:24:33 +0100 Subject: [PATCH] Fix filter_vmap with out_axes!=0,1 producing the wrong axis order. Fixes https://github.com/patrick-kidger/equinox/issues/900 --- equinox/_vmap_pmap.py | 6 +++--- pyproject.toml | 2 +- tests/test_pmap.py | 15 +++++++++++++++ tests/test_vmap.py | 13 +++++++++++++ 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/equinox/_vmap_pmap.py b/equinox/_vmap_pmap.py index 2f548c9a..1b68afc6 100644 --- a/equinox/_vmap_pmap.py +++ b/equinox/_vmap_pmap.py @@ -125,8 +125,8 @@ def _bind(axis): return jtu.tree_map(_bind, out_axes) -def _swapaxes(array, axis): - return jnp.swapaxes(array, 0, axis) +def _moveaxis(array, axis): + return jnp.moveaxis(array, 0, axis) def _named_in_axes(fun, in_axes, args): @@ -230,7 +230,7 @@ def _fun_wrapper(_dynamic_args): nonvmapd = combine(nonvmapd_arr, nonvmapd_static) assert jtu.tree_structure(vmapd) == jtu.tree_structure(out_axes) - vmapd = jtu.tree_map(_swapaxes, vmapd, out_axes) + vmapd = jtu.tree_map(_moveaxis, vmapd, out_axes) return combine(vmapd, nonvmapd) diff --git a/pyproject.toml b/pyproject.toml index 81def3ed..7b4447c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "equinox" -version = "0.11.8" +version = "0.11.9" description = "Elegant easy-to-use neural networks in JAX." readme = "README.md" requires-python =">=3.9" diff --git a/tests/test_pmap.py b/tests/test_pmap.py index 8f61bf8a..0635d91c 100644 --- a/tests/test_pmap.py +++ b/tests/test_pmap.py @@ -287,3 +287,18 @@ def g(y): assert b.shape == (3, 1) filter_pmap(f)(jnp.arange(3).reshape(1, 3, 1)) + + +# https://github.com/patrick-kidger/equinox/issues/900 +# Unlike the vmap case we only test nonnegative integers, as pmap does not support +# negative indexing for `in_axes` or `out_axes`. +@pytest.mark.parametrize("out_axes", (0, 1, 2)) +def test_out_axes_with_at_least_three_dimensions(out_axes): + def foo(x): + return x * 2 + + x = jnp.arange(24).reshape((1, 2, 3, 4)) + y = jax.pmap(foo, out_axes=out_axes)(x) + z = filter_pmap(foo, out_axes=out_axes)(x) + assert y.shape == z.shape + assert (y == z).all() diff --git a/tests/test_vmap.py b/tests/test_vmap.py index e42e45e5..ba196be3 100644 --- a/tests/test_vmap.py +++ b/tests/test_vmap.py @@ -175,3 +175,16 @@ def g(y): assert b.shape == (3, 1) eqx.filter_vmap(f)(jnp.arange(6).reshape(2, 3, 1)) + + +# https://github.com/patrick-kidger/equinox/issues/900 +@pytest.mark.parametrize("out_axes", (0, 1, 2, -1, -2, -3)) +def test_out_axes_with_at_least_three_dimensions(out_axes): + def foo(x): + return x * 2 + + x = jnp.arange(24).reshape((2, 3, 4)) + y = jax.vmap(foo, out_axes=out_axes)(x) + z = eqx.filter_vmap(foo, out_axes=out_axes)(x) + assert y.shape == z.shape + assert (y == z).all()