Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] fix vmap #3995

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 63 additions & 61 deletions flax/nnx/nnx/transforms/parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,22 +159,23 @@ def _restore_vmap_keys(


def vmap_fn(
args: tuple[tp.Any, ...],
kwargs: dict[str, tp.Any],
graphdef: GraphDef[tuple[tp.Any, ...]],
split_keys: State,
split_counts: State,
broadcast_keys: State,
broadcast_counts: State,
vectorized_states: list[State],
broadcast_state: State,
transform_metadata: tp.Mapping[str, tp.Any],
state_axes: tp.Mapping[filterlib.Filter, int],
f: tp.Callable[..., tp.Any],
filters: tp.Tuple[filterlib.Filter, ...],
split_rngs: filterlib.Filter,
args: tuple[tp.Any, ...],
kwargs: dict[str, tp.Any],
graphdef: GraphDef[tuple[tp.Any, ...]],
split_keys: State,
split_counts: State,
broadcast_keys: State,
broadcast_counts: State,
vectorized_states: list[State],
broadcast_state: State,
transform_metadata: tp.Mapping[str, tp.Any],
state_axes_: list[tuple[filterlib.Filter, int]],
f: tp.Callable[..., tp.Any],
filters: tp.Tuple[filterlib.Filter, ...],
split_rngs: filterlib.Filter,
):
ctx = graph.current_update_context('vmap')
state_axes = dict(state_axes_)
# remove metadata axis name from Variable.sharding
if spmd.PARTITION_NAME in transform_metadata:
vectorized_states = [
Expand Down Expand Up @@ -230,6 +231,7 @@ def vmap_fn(
out,
)


def vmap(
f: F,
*,
Expand All @@ -247,33 +249,33 @@ def vmap(
vectorized_states_axes = list(state_axes.values())

vmapped_fn = jax.vmap(
vmap_fn,
in_axes=(
in_axes, # args_axes
in_axes_kwargs, # kwargs_axes
None, # graphdef_axes
0, # split_keys_axes
0, # split_counts_axes
None, # broadcast_keys_axes
None, # broadcast_counts_axes
vectorized_states_axes, # vectorized_states_axes
None, # broadcast_state_axes
None, # transform_metadata_axes
None, # states_axes_axes
None, # f_axes
None, # filters_axes
None, # split_rngs_axes
),
out_axes=(
None, # graphdef_out_axes
None, # broadcast_state_axes
vectorized_states_axes,
0, # keys_axes_out
out_axes, # out_axes
),
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
vmap_fn,
in_axes=(
in_axes, # args
in_axes_kwargs, # kwargs
None, # graphdef
0, # split_keys
0, # split_counts
None, # broadcast_keys
None, # broadcast_counts
vectorized_states_axes, # vectorized_states
None, # broadcast_state
None, # transform_metadata
None, # states_axes
None, # f
None, # vectorized_states_filters
None, # split_rngs
),
out_axes=(
None, # graphdef_out
None, # broadcast_state
vectorized_states_axes,
0, # keys_out
out_axes, # out_axes
),
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
)

@functools.wraps(f)
Expand Down Expand Up @@ -328,26 +330,26 @@ def vmap_wrapper(*args, **kwargs):
)

(
graphdef_out,
broadcast_state,
vectorized_states,
split_keys_out,
out,
graphdef_out,
broadcast_state,
vectorized_states,
split_keys_out,
out,
) = vmapped_fn(
args,
kwargs,
graphdef,
split_keys,
split_counts,
broadcast_keys,
broadcast_counts,
vectorized_states,
broadcast_state,
transform_metadata,
state_axes,
f,
filters,
split_rngs,
args,
kwargs,
graphdef,
split_keys,
split_counts,
broadcast_keys,
broadcast_counts,
vectorized_states,
broadcast_state,
transform_metadata,
list(state_axes.items()),
f,
filters,
split_rngs,
)

_, output_graph_nodes = ctx.merge(
Expand Down Expand Up @@ -824,4 +826,4 @@ def _submodule(self) -> M:
def _call(self, accessor: DelayedAccessor, *args, **kwargs):
return self.pmap_call(
self._submodule, *args, _nnx_vmap_accessor=accessor, **kwargs
)
)
5 changes: 2 additions & 3 deletions flax/nnx/tests/test_traversals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.

"""Tests for flax.experimental.nnx.traversal."""
import jax
from absl.testing import absltest

from flax.core import freeze
from flax.experimental.nnx import traversals
from flax.nnx import traversals
import jax

# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()
Expand Down
32 changes: 26 additions & 6 deletions flax/nnx/tests/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import partial
import typing as tp

from absl.testing import absltest
from absl.testing import parameterized
from flax import nnx
import jax
Expand All @@ -25,7 +26,7 @@
import pytest


class TestJIT:
class TestJIT(absltest.TestCase):
def test_jit(self):
m = nnx.Dict(a=nnx.Param(1))

Expand Down Expand Up @@ -483,7 +484,7 @@ def test_multiple_graph_nodes(self, loss_fn, argnums):
assert grads_m2.bias.value.shape == (3,)


class TestScan:
class TestScan(absltest.TestCase):
def test_basic(self):
class Block(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
Expand Down Expand Up @@ -946,7 +947,7 @@ def f(_, rngs: nnx.Rngs):
assert jnp.equal(dropout_keys[1], dropout_keys[2])


class TestRemat:
class TestRemat(absltest.TestCase):
def test_basic_remat(self):
RematLinear = nnx.Remat.constructor(nnx.Linear)

Expand Down Expand Up @@ -1031,7 +1032,7 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]:
assert y.shape == (1, 3)


class TestVmap:
class TestVmap(absltest.TestCase):
def test_basic(self):
class Block(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
Expand Down Expand Up @@ -1213,8 +1214,23 @@ def __call__(self, x: jax.Array) -> jax.Array:

assert module.vmap_module.graphdef == 'hello'

def test_state_axes(self):

class TestPmap:
class Foo(nnx.Module):

def __init__(self):
self.param = nnx.Param(jnp.arange(5))

foo = Foo()

@partial(nnx.vmap, state_axes={...: 0})
def f(foo: Foo):
assert foo.param.value.shape == ()

f(foo)


class TestPmap(absltest.TestCase):

def test_basic_single(self):
class Block(nnx.Module):
Expand Down Expand Up @@ -1367,7 +1383,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
assert y.shape == (1, 5, 3)


class TestCond:
class TestCond(absltest.TestCase):
def test_basic(self):
class TimeStep(tp.NamedTuple):
step: jax.Array
Expand Down Expand Up @@ -1407,3 +1423,7 @@ def reward_0(self: Foo):
foo.update()
assert foo.timestep.step == 4
assert foo.timestep.reward == 0.0


if __name__ == '__main__':
absltest.main()
Loading