diff --git a/flax/experimental/nnx/nnx/tracers.py b/flax/experimental/nnx/nnx/tracers.py index f64e8660e3..88266823da 100644 --- a/flax/experimental/nnx/nnx/tracers.py +++ b/flax/experimental/nnx/nnx/tracers.py @@ -14,9 +14,6 @@ # Taken from flax/core/tracer.py 🏴‍☠️ -import contextlib -import dataclasses -import threading import typing as tp import jax @@ -44,70 +41,19 @@ def current_jax_trace() -> MainTrace: return get_top_trace(()) -def get_all_traces(pytree: tp.Union[tp.Any, Tracer]) -> tp.Set[MainTrace]: - """Returns True if all tracers have the same main trace.""" - if isinstance(pytree, Tracer): - return {pytree._trace.main} - else: - return { - trace._trace.main - for trace in jax.tree_util.tree_leaves(pytree) - if isinstance(trace, Tracer) - } - - -def trace_level(main): - """Returns the level of the trace of -infinity if it is None.""" - if main: - return main.level - return float('-inf') - - -@dataclasses.dataclass -class TraceContext(threading.local): - nnx_trace_stack: tp.List[MainTrace] = dataclasses.field( - default_factory=lambda: [current_jax_trace()] - ) - - -TRACE_CONTEXT = TraceContext() - - -@contextlib.contextmanager -def nnx_trace(trace: MainTrace): - TRACE_CONTEXT.nnx_trace_stack.append(trace) - try: - yield - finally: - TRACE_CONTEXT.nnx_trace_stack.pop() - - -def current_nnx_trace() -> MainTrace: - return TRACE_CONTEXT.nnx_trace_stack[-1] - - class TraceState(reprlib.Representable): - __slots__ = ['_jax_trace', '_nnx_trace'] + __slots__ = ['_jax_trace'] def __init__(self): self._jax_trace = current_jax_trace() - self._nnx_trace = current_nnx_trace() @property def jax_trace(self): return self._jax_trace - @property - def nnx_trace(self): - return self._nnx_trace - def is_valid(self) -> bool: - return ( - self._jax_trace is current_jax_trace() - and self._nnx_trace is current_nnx_trace() - ) + return self._jax_trace is current_jax_trace() def __nnx_repr__(self): yield reprlib.Object(f'{type(self).__name__}') yield reprlib.Attr('jax_trace', self._jax_trace) - yield reprlib.Attr('nnx_trace', self._nnx_trace) diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py index 7292eacf42..aedbb2adad 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/experimental/nnx/nnx/transforms.py @@ -42,7 +42,6 @@ filterlib, rnglib, spmd, - tracers, variables, ) from flax.experimental.nnx.nnx.module import GraphDef, Module, ModuleMeta @@ -196,12 +195,10 @@ def jitted_fn( if isinstance(states, State): states = (states,) - nnx_trace = tracers.get_top_trace((args, kwargs)) - with tracers.nnx_trace(nnx_trace): - if 'rngs' in kwargs: - kwargs['rngs'] = rnglib.Rngs(kwargs['rngs']) - module = graphdef.merge(*states) - out = f(module, *args, **kwargs) + if 'rngs' in kwargs: + kwargs['rngs'] = rnglib.Rngs(kwargs['rngs']) + module = graphdef.merge(*states) + out = f(module, *args, **kwargs) updates = module.split() out = (updates, out) @@ -479,9 +476,8 @@ def grad_apply(options: GradOptions, f, module: Module, *args, **kwargs): def grad_fn(diff: State): nonlocal graphdef - with tracers.nnx_trace(tracers.get_top_trace(diff)): - module = graphdef.merge(diff, nondiff) - out = f(module, *args, **kwargs) + module = graphdef.merge(diff, nondiff) + out = f(module, *args, **kwargs) updates, graphdef = module.split() if options.has_aux: @@ -847,41 +843,41 @@ def scan_apply( # transpose axes state scan_states = tuple( - jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), axes_state) - for axes_state, axis in zip(scan_states, options.variable_axes.values()) + jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), axes_state) + for axes_state, axis in zip(scan_states, options.variable_axes.values()) ) # transpose axes arg scan_args = jax.tree_util.tree_map( - lambda axis, node: jax.tree_util.tree_map( - lambda x: jnp.moveaxis(x, axis, 0), node - ) - if axis is not None - else None, - options.in_args_axes, - args, - is_leaf=lambda x: x is None, + lambda axis, node: jax.tree_util.tree_map( + lambda x: jnp.moveaxis(x, axis, 0), node + ) + if axis is not None + else None, + options.in_args_axes, + args, + is_leaf=lambda x: x is None, ) broadcast_args = jax.tree_util.tree_map( - lambda axis, node: None if axis is not None else node, - options.in_args_axes, - args, - is_leaf=lambda x: x is None, + lambda axis, node: None if axis is not None else node, + options.in_args_axes, + args, + is_leaf=lambda x: x is None, ) scan_kwargs = jax.tree_util.tree_map( - lambda axis, node: jax.tree_util.tree_map( - lambda x: jnp.moveaxis(x, axis, 0), node - ) - if axis is not None - else None, - options.in_kwargs_axes, - kwargs, - is_leaf=lambda x: x is None, + lambda axis, node: jax.tree_util.tree_map( + lambda x: jnp.moveaxis(x, axis, 0), node + ) + if axis is not None + else None, + options.in_kwargs_axes, + kwargs, + is_leaf=lambda x: x is None, ) broadcast_kwargs = jax.tree_util.tree_map( - lambda axis, node: None if axis is not None else node, - options.in_kwargs_axes, - kwargs, - is_leaf=lambda x: x is None, + lambda axis, node: None if axis is not None else node, + options.in_kwargs_axes, + kwargs, + is_leaf=lambda x: x is None, ) # infer length @@ -938,18 +934,18 @@ def scan_fn( # merge args and kwargs args = jax.tree_util.tree_map( - lambda axis, scan, broadcast: scan if axis is not None else broadcast, - options.in_args_axes, - scan_args, - broadcast_args, - is_leaf=lambda x: x is None, + lambda axis, scan, broadcast: scan if axis is not None else broadcast, + options.in_args_axes, + scan_args, + broadcast_args, + is_leaf=lambda x: x is None, ) kwargs = jax.tree_util.tree_map( - lambda axis, scan, broadcast: scan if axis is not None else broadcast, - options.in_kwargs_axes, - scan_kwargs, - broadcast_kwargs, - is_leaf=lambda x: x is None, + lambda axis, scan, broadcast: scan if axis is not None else broadcast, + options.in_kwargs_axes, + scan_kwargs, + broadcast_kwargs, + is_leaf=lambda x: x is None, ) # merge rng state @@ -1018,16 +1014,16 @@ def scan_fn( # transpose axes state scan_states = tuple( - jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), axes_state) - for axes_state, axis in zip(scan_states, options.variable_axes.values()) + jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), axes_state) + for axes_state, axis in zip(scan_states, options.variable_axes.values()) ) # transpose axes arg scan_out = jax.tree_util.tree_map( - lambda axis, node: jax.tree_util.tree_map( - lambda x: jnp.moveaxis(x, 0, axis), node - ), - options.out_axes, - scan_out, + lambda axis, node: jax.tree_util.tree_map( + lambda x: jnp.moveaxis(x, 0, axis), node + ), + options.out_axes, + scan_out, ) # slice new carry state carry_state_new = jax.tree_util.tree_map(lambda x: x[0], carry_state_new) @@ -1483,20 +1479,20 @@ def vmap_apply( # infer length axis_sizes: tp.Set[int] = set() args_sizes = jax.tree_util.tree_map( - lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node) - if axis is not None - else None, - options.in_args_axes, - args, - is_leaf=lambda x: x is None, + lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node) + if axis is not None + else None, + options.in_args_axes, + args, + is_leaf=lambda x: x is None, ) kwargs_sizes = jax.tree_util.tree_map( - lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node) - if axis is not None - else None, - options.in_kwargs_axes, - kwargs, - is_leaf=lambda x: x is None, + lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node) + if axis is not None + else None, + options.in_kwargs_axes, + kwargs, + is_leaf=lambda x: x is None, ) axis_sizes.update(jax.tree_util.tree_leaves(args_sizes)) axis_sizes.update(jax.tree_util.tree_leaves(kwargs_sizes)) diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index ad38cbbcad..c3c24de280 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -88,22 +88,6 @@ def f(_): f(1.0) - def test_trace_level_error_on_nnx_grad(self): - # error occurs because nnx updates its nnx_trace - # in nnx.grad. - m = nnx.Dict(a=nnx.Param(1.0)) - - @nnx.grad - def f(_): - with pytest.raises( - nnx.TraceContextError, - match='Cannot mutate Module from different trace level', - ): - m.a = 2.0 - return 1.0 - - f(m) - def test_call(self): class Foo(nnx.Module): def __init__(self, c: float, *, rngs: nnx.Rngs):