Skip to content

Commit

Permalink
use unsafe jax.core function rather than jax.core internal data struc…
Browse files Browse the repository at this point in the history
…tures so

as to be slightly more robust to jax core changes

PiperOrigin-RevId: 684632943
  • Loading branch information
mattjj authored and Scenic Authors committed Oct 11, 2024
1 parent 8f58121 commit c5a27ea
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions scenic/projects/lang4video/trainer/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,7 @@ def _fun(*args_, **kwargs_):


def axis_name_exists(axis_name: Hashable) -> bool:
return any(frame.name == axis_name
for frame in jax.core.thread_local_state.trace_state.axis_env)
return axis_name in jax.core.unsafe_get_axis_names() # type: ignore


# The following are the same functions in `jax.example_libraries.optimizers` but
Expand Down

0 comments on commit c5a27ea

Please sign in to comment.