Skip to content

Commit

Permalink
Fix tests after applying JAX key-reuse checker. See:
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp authored and Flax Authors committed Mar 13, 2024
1 parent 0280160 commit 5d128a5
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,30 +810,34 @@ def find_axis_size(axis, x):
raise ValueError('axis_size should be specified manually.')
else:
d_axis_size = axis_size
split_fn = lambda rng: random.split(rng, d_axis_size)
# random.clone is only available on Jax versions 0.4.26 or newer
# see: https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html
split_fn = lambda rng: random.split(
random.clone(rng) if hasattr(random, 'clone') else rng, d_axis_size
)

rng_groups = tuple(
tree_map_rngs(split_fn, rng_group) if split else rng_group
for rng_group, split in zip(rng_groups, rng_splits)
tree_map_rngs(split_fn, rng_group) if split else rng_group
for rng_group, split in zip(rng_groups, rng_splits)
)

new_variable_groups = []
for var_group, axis in zip(variable_groups, variable_in_axes):
if axis is not None:
new_variable_groups.append(
meta.remove_axis(var_group, axis, metadata_params)
meta.remove_axis(var_group, axis, metadata_params)
)
else:
new_variable_groups.append(var_group)
variable_groups = tuple(new_variable_groups)

@functools.partial(
jax.vmap,
in_axes=(variable_in_axes, rng_axes, in_axes),
out_axes=(out_axes, variable_out_axes),
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
jax.vmap,
in_axes=(variable_in_axes, rng_axes, in_axes),
out_axes=(out_axes, variable_out_axes),
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
)
@functools.wraps(fn)
def mapped(variable_groups, rng_groups, args):
Expand Down Expand Up @@ -969,27 +973,31 @@ def find_length(axis, x):
raise ValueError('length should be specified manually.')
else:
d_length = length
split_fn = lambda rng: random.split(rng, d_length)
# random.clone is only available on Jax versions 0.4.26 or newer
# see: https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html
split_fn = lambda rng: random.split(
random.clone(rng) if hasattr(random, 'clone') else rng, d_length
)

rng_groups = tuple(
tree_map_rngs(split_fn, rng_group) if split else rng_group
for rng_group, split in zip(rng_groups, rng_splits)
tree_map_rngs(split_fn, rng_group) if split else rng_group
for rng_group, split in zip(rng_groups, rng_splits)
)

@functools.partial(
axes_scan.scan,
in_axes=(variable_in_axes, rng_axes, in_axes),
out_axes=(out_axes, variable_out_axes),
length=length,
reverse=reverse,
unroll=unroll,
axes_scan.scan,
in_axes=(variable_in_axes, rng_axes, in_axes),
out_axes=(out_axes, variable_out_axes),
length=length,
reverse=reverse,
unroll=unroll,
)
def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args):
carry_vars, c = carry
variable_groups = (broadcast_vars, carry_vars) + scan_variable_groups
if data_transform is not None:
variable_groups, rng_groups = data_transform(
variable_groups, rng_groups
variable_groups, rng_groups
)
scope = scope_fn(variable_groups, rng_groups)
c, y = fn(scope, c, *args)
Expand Down

0 comments on commit 5d128a5

Please sign in to comment.