Skip to content

Commit

Permalink
Stackless yashful
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681582933
  • Loading branch information
dougalm authored and Flax Authors committed Oct 3, 2024
1 parent 67e2917 commit 641bda9
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions flax/core/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def current_trace():
return jax.core.get_opaque_trace_state(convention="flax")

def check_trace_level(base_level):
level = current_trace()
if level != base_level:
raise errors.JaxTransformError()
pass
# level = current_trace()
# if level != base_level:
# raise errors.JaxTransformError()

0 comments on commit 641bda9

Please sign in to comment.