Skip to content

Commit

Permalink
Merge pull request #3724 from google:nnx-simplify-trace-state
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611220805
  • Loading branch information
Flax Authors committed Feb 28, 2024
2 parents f5c48fe + 4b7ba09 commit d1f219f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 137 deletions.
58 changes: 2 additions & 56 deletions flax/experimental/nnx/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@

# Taken from flax/core/tracer.py 🏴‍☠️

import contextlib
import dataclasses
import threading
import typing as tp

import jax
Expand Down Expand Up @@ -44,70 +41,19 @@ def current_jax_trace() -> MainTrace:
return get_top_trace(())


def get_all_traces(pytree: tp.Union[tp.Any, Tracer]) -> tp.Set[MainTrace]:
"""Returns True if all tracers have the same main trace."""
if isinstance(pytree, Tracer):
return {pytree._trace.main}
else:
return {
trace._trace.main
for trace in jax.tree_util.tree_leaves(pytree)
if isinstance(trace, Tracer)
}


def trace_level(main):
"""Returns the level of the trace of -infinity if it is None."""
if main:
return main.level
return float('-inf')


@dataclasses.dataclass
class TraceContext(threading.local):
nnx_trace_stack: tp.List[MainTrace] = dataclasses.field(
default_factory=lambda: [current_jax_trace()]
)


TRACE_CONTEXT = TraceContext()


@contextlib.contextmanager
def nnx_trace(trace: MainTrace):
TRACE_CONTEXT.nnx_trace_stack.append(trace)
try:
yield
finally:
TRACE_CONTEXT.nnx_trace_stack.pop()


def current_nnx_trace() -> MainTrace:
return TRACE_CONTEXT.nnx_trace_stack[-1]


class TraceState(reprlib.Representable):
__slots__ = ['_jax_trace', '_nnx_trace']
__slots__ = ['_jax_trace']

def __init__(self):
self._jax_trace = current_jax_trace()
self._nnx_trace = current_nnx_trace()

@property
def jax_trace(self):
return self._jax_trace

@property
def nnx_trace(self):
return self._nnx_trace

def is_valid(self) -> bool:
return (
self._jax_trace is current_jax_trace()
and self._nnx_trace is current_nnx_trace()
)
return self._jax_trace is current_jax_trace()

def __nnx_repr__(self):
yield reprlib.Object(f'{type(self).__name__}')
yield reprlib.Attr('jax_trace', self._jax_trace)
yield reprlib.Attr('nnx_trace', self._nnx_trace)
126 changes: 61 additions & 65 deletions flax/experimental/nnx/nnx/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
filterlib,
rnglib,
spmd,
tracers,
variables,
)
from flax.experimental.nnx.nnx.module import GraphDef, Module, ModuleMeta
Expand Down Expand Up @@ -196,12 +195,10 @@ def jitted_fn(
if isinstance(states, State):
states = (states,)

nnx_trace = tracers.get_top_trace((args, kwargs))
with tracers.nnx_trace(nnx_trace):
if 'rngs' in kwargs:
kwargs['rngs'] = rnglib.Rngs(kwargs['rngs'])
module = graphdef.merge(*states)
out = f(module, *args, **kwargs)
if 'rngs' in kwargs:
kwargs['rngs'] = rnglib.Rngs(kwargs['rngs'])
module = graphdef.merge(*states)
out = f(module, *args, **kwargs)

updates = module.split()
out = (updates, out)
Expand Down Expand Up @@ -479,9 +476,8 @@ def grad_apply(options: GradOptions, f, module: Module, *args, **kwargs):
def grad_fn(diff: State):
nonlocal graphdef

with tracers.nnx_trace(tracers.get_top_trace(diff)):
module = graphdef.merge(diff, nondiff)
out = f(module, *args, **kwargs)
module = graphdef.merge(diff, nondiff)
out = f(module, *args, **kwargs)

updates, graphdef = module.split()
if options.has_aux:
Expand Down Expand Up @@ -847,41 +843,41 @@ def scan_apply(

# transpose axes state
scan_states = tuple(
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), axes_state)
for axes_state, axis in zip(scan_states, options.variable_axes.values())
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), axes_state)
for axes_state, axis in zip(scan_states, options.variable_axes.values())
)
# transpose axes arg
scan_args = jax.tree_util.tree_map(
lambda axis, node: jax.tree_util.tree_map(
lambda x: jnp.moveaxis(x, axis, 0), node
)
if axis is not None
else None,
options.in_args_axes,
args,
is_leaf=lambda x: x is None,
lambda axis, node: jax.tree_util.tree_map(
lambda x: jnp.moveaxis(x, axis, 0), node
)
if axis is not None
else None,
options.in_args_axes,
args,
is_leaf=lambda x: x is None,
)
broadcast_args = jax.tree_util.tree_map(
lambda axis, node: None if axis is not None else node,
options.in_args_axes,
args,
is_leaf=lambda x: x is None,
lambda axis, node: None if axis is not None else node,
options.in_args_axes,
args,
is_leaf=lambda x: x is None,
)
scan_kwargs = jax.tree_util.tree_map(
lambda axis, node: jax.tree_util.tree_map(
lambda x: jnp.moveaxis(x, axis, 0), node
)
if axis is not None
else None,
options.in_kwargs_axes,
kwargs,
is_leaf=lambda x: x is None,
lambda axis, node: jax.tree_util.tree_map(
lambda x: jnp.moveaxis(x, axis, 0), node
)
if axis is not None
else None,
options.in_kwargs_axes,
kwargs,
is_leaf=lambda x: x is None,
)
broadcast_kwargs = jax.tree_util.tree_map(
lambda axis, node: None if axis is not None else node,
options.in_kwargs_axes,
kwargs,
is_leaf=lambda x: x is None,
lambda axis, node: None if axis is not None else node,
options.in_kwargs_axes,
kwargs,
is_leaf=lambda x: x is None,
)

# infer length
Expand Down Expand Up @@ -938,18 +934,18 @@ def scan_fn(

# merge args and kwargs
args = jax.tree_util.tree_map(
lambda axis, scan, broadcast: scan if axis is not None else broadcast,
options.in_args_axes,
scan_args,
broadcast_args,
is_leaf=lambda x: x is None,
lambda axis, scan, broadcast: scan if axis is not None else broadcast,
options.in_args_axes,
scan_args,
broadcast_args,
is_leaf=lambda x: x is None,
)
kwargs = jax.tree_util.tree_map(
lambda axis, scan, broadcast: scan if axis is not None else broadcast,
options.in_kwargs_axes,
scan_kwargs,
broadcast_kwargs,
is_leaf=lambda x: x is None,
lambda axis, scan, broadcast: scan if axis is not None else broadcast,
options.in_kwargs_axes,
scan_kwargs,
broadcast_kwargs,
is_leaf=lambda x: x is None,
)

# merge rng state
Expand Down Expand Up @@ -1018,16 +1014,16 @@ def scan_fn(

# transpose axes state
scan_states = tuple(
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), axes_state)
for axes_state, axis in zip(scan_states, options.variable_axes.values())
jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), axes_state)
for axes_state, axis in zip(scan_states, options.variable_axes.values())
)
# transpose axes arg
scan_out = jax.tree_util.tree_map(
lambda axis, node: jax.tree_util.tree_map(
lambda x: jnp.moveaxis(x, 0, axis), node
),
options.out_axes,
scan_out,
lambda axis, node: jax.tree_util.tree_map(
lambda x: jnp.moveaxis(x, 0, axis), node
),
options.out_axes,
scan_out,
)
# slice new carry state
carry_state_new = jax.tree_util.tree_map(lambda x: x[0], carry_state_new)
Expand Down Expand Up @@ -1483,20 +1479,20 @@ def vmap_apply(
# infer length
axis_sizes: tp.Set[int] = set()
args_sizes = jax.tree_util.tree_map(
lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node)
if axis is not None
else None,
options.in_args_axes,
args,
is_leaf=lambda x: x is None,
lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node)
if axis is not None
else None,
options.in_args_axes,
args,
is_leaf=lambda x: x is None,
)
kwargs_sizes = jax.tree_util.tree_map(
lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node)
if axis is not None
else None,
options.in_kwargs_axes,
kwargs,
is_leaf=lambda x: x is None,
lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node)
if axis is not None
else None,
options.in_kwargs_axes,
kwargs,
is_leaf=lambda x: x is None,
)
axis_sizes.update(jax.tree_util.tree_leaves(args_sizes))
axis_sizes.update(jax.tree_util.tree_leaves(kwargs_sizes))
Expand Down
16 changes: 0 additions & 16 deletions flax/experimental/nnx/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,6 @@ def f(_):

f(1.0)

def test_trace_level_error_on_nnx_grad(self):
# error occurs because nnx updates its nnx_trace
# in nnx.grad.
m = nnx.Dict(a=nnx.Param(1.0))

@nnx.grad
def f(_):
with pytest.raises(
nnx.TraceContextError,
match='Cannot mutate Module from different trace level',
):
m.a = 2.0
return 1.0

f(m)

def test_call(self):
class Foo(nnx.Module):
def __init__(self, c: float, *, rngs: nnx.Rngs):
Expand Down

0 comments on commit d1f219f

Please sign in to comment.