Skip to content

Commit

Permalink
Add logging to track deprecated codepaths.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652597860
  • Loading branch information
cpgaffney1 authored and Flax Authors committed Aug 12, 2024
1 parent 8e0228e commit ec9f3de
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions flax/training/orbax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ def is_multi_device_array(value: Any) -> bool:


def save_args_from_target(target: Any) -> Any:
return jax.tree_util.tree_map(
lambda x: ocp.SaveArgs(aggregate=not is_multi_device_array(x)),
target,
)
return jax.tree_util.tree_map(lambda _: ocp.SaveArgs(), target)


def maybe_construct_transformations(
Expand Down

0 comments on commit ec9f3de

Please sign in to comment.