diff --git a/flax/core/__init__.py b/flax/core/__init__.py index bca72392df..f90775f731 100644 --- a/flax/core/__init__.py +++ b/flax/core/__init__.py @@ -48,9 +48,8 @@ from .tracers import ( check_trace_level as check_trace_level, current_trace as current_trace, - trace_level as trace_level, ) from flax.typing import ( Array as Array, -) \ No newline at end of file +) diff --git a/flax/core/scope.py b/flax/core/scope.py index e056d6ddb9..ea8a586b11 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -454,7 +454,7 @@ def __init__( self.flags = freeze({} if flags is None else flags) self._root = parent.root if parent else None - self.trace_level = tracers.trace_level(tracers.current_trace()) + self.trace_level = tracers.current_trace() self.rng_counters = {key: 0 for key in self.rngs} self.reservations = collections.defaultdict(set) diff --git a/flax/core/tracers.py b/flax/core/tracers.py index 9d8472bdc7..fe2ff874c0 100644 --- a/flax/core/tracers.py +++ b/flax/core/tracers.py @@ -20,18 +20,17 @@ def current_trace(): - """Returns the innermost Jax tracer.""" - return jax.core.find_top_trace(()) - - -def trace_level(main): - """Returns the level of the trace of -infinity if it is None.""" - if main: - return main.level - return float('-inf') + """Returns the current JAX state tracer.""" + if jax.__version_info__ <= (0, 4, 33): + top = jax.core.find_top_trace(()) + if top: + return top.level + else: + return float('-inf') + return jax.core.get_opaque_trace_state(convention="flax") def check_trace_level(base_level): - level = trace_level(current_trace()) + level = current_trace() if level != base_level: raise errors.JaxTransformError() diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py index 3db066376b..cc78597395 100644 --- a/flax/nnx/tracers.py +++ b/flax/nnx/tracers.py @@ -15,14 +15,17 @@ # Taken from flax/core/tracer.py 🏴‍☠️ -from jax.core import MainTrace, thread_local_state +import jax +import jax.core from flax.nnx import reprlib -def current_jax_trace() -> MainTrace: - """Returns the innermost Jax tracer.""" - return thread_local_state.trace_state.trace_stack.dynamic +def current_jax_trace(): + """Returns the Jax tracing state.""" + if jax.__version_info__ <= (0, 4, 33): + return jax.core.thread_local_state.trace_state.trace_stack.dynamic + return jax.core.get_opaque_trace_state(convention="nnx") class TraceState(reprlib.Representable): @@ -36,7 +39,10 @@ def jax_trace(self): return self._jax_trace def is_valid(self) -> bool: - return self._jax_trace is current_jax_trace() + if jax.__version_info__ <= (0, 4, 33): + return self._jax_trace is current_jax_trace() + + return self._jax_trace == current_jax_trace() def __nnx_repr__(self): yield reprlib.Object(f'{type(self).__name__}') @@ -52,4 +58,7 @@ def __treescope_repr__(self, path, subtree_renderer): ) def __eq__(self, other): - return isinstance(other, TraceState) and self._jax_trace is other._jax_trace + if jax.__version_info__ <= (0, 4, 33): + return isinstance(other, TraceState) and self._jax_trace is other._jax_trace + + return isinstance(other, TraceState) and self._jax_trace == other._jax_trace