From ab71a7237d63d308516e6660726b5790835ae230 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 9 Nov 2023 00:49:42 +0000 Subject: [PATCH] fix tests --- flax/experimental/nnx/README.md | 2 +- flax/experimental/nnx/__init__.py | 2 +- flax/experimental/nnx/docs/tiny_nnx.ipynb | 8 +- flax/experimental/nnx/examples/05_vae.py | 2 +- .../nnx/examples/06_scan_over_layers.py | 29 ++- .../nnx/examples/07_transformer.py | 4 +- .../nnx/examples/10_quantization.py | 2 +- flax/experimental/nnx/nnx/nn/initializers.py | 7 +- flax/experimental/nnx/nnx/nn/linear.py | 12 +- flax/experimental/nnx/nnx/nn/normalization.py | 10 +- flax/experimental/nnx/nnx/rnglib.py | 39 ++-- flax/experimental/nnx/nnx/transforms.py | 189 ++++++------------ flax/experimental/nnx/tests/test_rngs.py | 27 ++- pyproject.toml | 4 + 14 files changed, 143 insertions(+), 194 deletions(-) diff --git a/flax/experimental/nnx/README.md b/flax/experimental/nnx/README.md index ca2a5dca4b..5cc033c3aa 100644 --- a/flax/experimental/nnx/README.md +++ b/flax/experimental/nnx/README.md @@ -58,7 +58,7 @@ y = model(jnp.ones((8, 12))) # call methods directly assert model.count == 1 ``` -In this example `nnx.Rngs(0)` create a `PRNGKey` for `params` with seed `0`, this is used by `rngs.()` inside `__init__` to generate a random key to initialize the parameters. +In this example `nnx.Rngs(0)` create a `random.key` for `params` with seed `0`, this is used by `rngs.()` inside `__init__` to generate a random key to initialize the parameters. ### Training with the Functional API diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index 1d2799e776..adfad6b761 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -69,7 +69,7 @@ from .nnx.pytreelib import Pytree as Pytree from .nnx.pytreelib import TreeNode as TreeNode from .nnx.rnglib import Rngs as Rngs -from .nnx.rnglib import merge_rngs as merge_rngs +from .nnx.rnglib import RngStream as RngStream from .nnx.spmd import PARTITION_NAME as PARTITION_NAME from .nnx.spmd import get_partition_spec as get_partition_spec from .nnx.spmd import with_partitioning as with_partitioning diff --git a/flax/experimental/nnx/docs/tiny_nnx.ipynb b/flax/experimental/nnx/docs/tiny_nnx.ipynb index 91294555da..05d35fd34a 100644 --- a/flax/experimental/nnx/docs/tiny_nnx.ipynb +++ b/flax/experimental/nnx/docs/tiny_nnx.ipynb @@ -29,7 +29,7 @@ "A = tp.TypeVar(\"A\")\n", "M = tp.TypeVar(\"M\", bound=\"Module\")\n", "Sharding = tp.Tuple[tp.Optional[str], ...]\n", - "KeyArray = random.KeyArray\n", + "Array = random.Array\n", "\n", "\n", "class Variable(tp.Generic[A]):\n", @@ -393,9 +393,9 @@ } ], "source": [ - "module = ScanMLP(hidden_size=10, n_layers=5, rngs=Rngs(random.PRNGKey(0)))\n", - "x = jax.random.normal(random.PRNGKey(0), (2, 10))\n", - "y = module(x, train=True, rngs=Rngs(random.PRNGKey(1)))\n", + "module = ScanMLP(hidden_size=10, n_layers=5, rngs=Rngs(random.key(0)))\n", + "x = jax.random.normal(random.key(0), (2, 10))\n", + "y = module(x, train=True, rngs=Rngs(random.key(1)))\n", "\n", "state, moduledef = module.split()\n", "print(\"state =\", jax.tree_map(jnp.shape, state))\n", diff --git a/flax/experimental/nnx/examples/05_vae.py b/flax/experimental/nnx/examples/05_vae.py index 04bbea8d48..719bab9ef4 100644 --- a/flax/experimental/nnx/examples/05_vae.py +++ b/flax/experimental/nnx/examples/05_vae.py @@ -166,7 +166,7 @@ def sample(state: nnx.TrainState[VAE], z: jax.Array) -> jax.Array: # %% -key = jax.random.PRNGKey(0) +key = jax.random.key(0) for epoch in range(epochs): losses = [] diff --git a/flax/experimental/nnx/examples/06_scan_over_layers.py b/flax/experimental/nnx/examples/06_scan_over_layers.py index eca6dc2f69..24dcfdb22c 100644 --- a/flax/experimental/nnx/examples/06_scan_over_layers.py +++ b/flax/experimental/nnx/examples/06_scan_over_layers.py @@ -41,42 +41,37 @@ class ScanMLP(nnx.Module): def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs): self.n_layers = n_layers - # split Rngs and split the `params` key - keys, rngsdef = rngs.split() - params_key = jax.random.split(keys['params'], n_layers) + # fork Rngs, split keys into `n_layers` + keys = rngs.fork(n_layers) - def create_block(params_key): - # merge back Rngs using the sliced `params` key - rngs = rngsdef.merge({'params': params_key}) + def create_block(keys): # create Block instance and return its split - return Block(dim, rngs=rngs).split() + return Block(dim, rngs=nnx.Rngs(keys)).split() # call vmap over create_block, passing the split `params` key # and immediately merge to get a Block instance - self.layers = nnx.merge(jax.vmap(create_block)(params_key)) + self.layers = nnx.merge(jax.vmap(create_block)(keys)) def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: - # split Rngs and split the `dropout` key - keys, rngsdef = rngs.split() - dropout_key = jax.random.split(keys['dropout'], self.n_layers) + # fork Rngs, split keys into `n_layers` + keys = rngs.fork(self.n_layers) # split Module to get params params, moduledef = self.layers.split(nnx.Param) def scan_fn( - x: jax.Array, inputs: Tuple[nnx.State, jax.Array] + x: jax.Array, inputs: Tuple[nnx.State, dict[str, nnx.RngStream]] ) -> Tuple[jax.Array, nnx.State]: - params, dropout_key = inputs + params, keys = inputs # merge back Module and Rngs - rngs = rngsdef.merge({'dropout': dropout_key}) module = moduledef.merge(params) # forward pass - x = module(x, rngs=rngs) + x = module(x, rngs=nnx.Rngs(keys)) # split state and return params, _ = module.split(nnx.Param) return x, params - # call scan passing x as the carry, and params + dropout_key as the input - x, params = jax.lax.scan(scan_fn, x, (params, dropout_key)) + # call scan passing x as the carry, and params + keys as the input + x, params = jax.lax.scan(scan_fn, x, (params, keys)) # update layers state and return self.layers.update(params) return x diff --git a/flax/experimental/nnx/examples/07_transformer.py b/flax/experimental/nnx/examples/07_transformer.py index 70c6306b9f..d0352e32dd 100644 --- a/flax/experimental/nnx/examples/07_transformer.py +++ b/flax/experimental/nnx/examples/07_transformer.py @@ -390,10 +390,10 @@ def __call__(self, cfg: Config, x, *, rngs: nnx.Rngs): assert isinstance(self.layers, DecoderBlock) state, moduledef = self.layers.split() - rngs, rngsdef = rngs.split() + rngs, rngsdef = rngs.fork() dropout_key = jax.random.split(rngs['dropout'], cfg.layers) - def scan_fn(x, s: tp.Tuple[jax.random.KeyArray, nnx.State]): + def scan_fn(x, s: tp.Tuple[jax.Array, nnx.State]): dropout_key, state = s rngs = rngsdef.merge({'dropout': dropout_key}) y, (state, _) = moduledef.apply(state)(cfg, x, rngs=rngs) diff --git a/flax/experimental/nnx/examples/10_quantization.py b/flax/experimental/nnx/examples/10_quantization.py index 33c855f407..0ac4ac7de2 100644 --- a/flax/experimental/nnx/examples/10_quantization.py +++ b/flax/experimental/nnx/examples/10_quantization.py @@ -155,7 +155,7 @@ def forward(state: nnx.TrainState[MLP], inputs: jax.Array) -> jax.Array: # %% -key = jax.random.PRNGKey(0) +key = jax.random.key(0) for epoch in range(epochs): for step in range(steps_per_epoch): diff --git a/flax/experimental/nnx/nnx/nn/initializers.py b/flax/experimental/nnx/nnx/nn/initializers.py index 638b01b556..0e989c80ca 100644 --- a/flax/experimental/nnx/nnx/nn/initializers.py +++ b/flax/experimental/nnx/nnx/nn/initializers.py @@ -35,14 +35,13 @@ Shape = tp.Sequence[int] DTypeLikeInexact = tp.Any -KeyArray = jax.random.KeyArray Array = jax.Array class Initializer(tp.Protocol): @staticmethod def __call__( - key: KeyArray, shape: Shape, dtype: DTypeLikeInexact = jnp.float_ + key: Array, shape: Shape, dtype: DTypeLikeInexact = jnp.float_ ) -> Array: ... @@ -53,7 +52,7 @@ def zeros() -> Initializer: >>> import jax, jax.numpy as jnp >>> from flax.linen.initializers import zeros_init >>> zeros_initializer = zeros_init() - >>> zeros_initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + >>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32) """ @@ -66,7 +65,7 @@ def ones() -> Initializer: >>> import jax, jax.numpy as jnp >>> from flax.linen.initializers import ones_init >>> ones_initializer = ones_init() - >>> ones_initializer(jax.random.PRNGKey(42), (3, 2), jnp.float32) + >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32) diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/experimental/nnx/nnx/nn/linear.py index 3f4e5a8d5a..1bedc53605 100644 --- a/flax/experimental/nnx/nnx/nn/linear.py +++ b/flax/experimental/nnx/nnx/nn/linear.py @@ -25,7 +25,7 @@ from flax.experimental.nnx.nnx.nn import dtypes, initializers Array = jax.Array -PRNGKey = tp.Any +KeyArray = jax.Array Shape = tp.Tuple[int, ...] Dtype = tp.Any # this could be a real type? PrecisionLike = tp.Union[ @@ -101,10 +101,10 @@ def __init__( param_dtype: Dtype = jnp.float32, precision: PrecisionLike = None, kernel_init: tp.Callable[ - [PRNGKey, Shape, Dtype], Array + [KeyArray, Shape, Dtype], Array ] = default_kernel_init, bias_init: tp.Callable[ - [PRNGKey, Shape, Dtype], Array + [KeyArray, Shape, Dtype], Array ] = initializers.zeros(), dot_general: DotGeneralT = lax.dot_general, rngs: rnglib.Rngs, @@ -210,10 +210,10 @@ def __init__( param_dtype: Dtype = jnp.float32, precision: PrecisionLike = None, kernel_init: tp.Callable[ - [PRNGKey, Shape, Dtype], Array + [KeyArray, Shape, Dtype], Array ] = default_kernel_init, bias_init: tp.Callable[ - [PRNGKey, Shape, Dtype], Array + [KeyArray, Shape, Dtype], Array ] = initializers.zeros(), conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated, rngs: rnglib.Rngs, @@ -394,7 +394,7 @@ def __init__( dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, embedding_init: tp.Callable[ - [PRNGKey, Shape, Dtype], Array + [KeyArray, Shape, Dtype], Array ] = default_embed_init, rngs: rnglib.Rngs, ): diff --git a/flax/experimental/nnx/nnx/nn/normalization.py b/flax/experimental/nnx/nnx/nn/normalization.py index 6c5c5a37e7..1df89eeb9f 100644 --- a/flax/experimental/nnx/nnx/nn/normalization.py +++ b/flax/experimental/nnx/nnx/nn/normalization.py @@ -23,7 +23,7 @@ from flax.experimental.nnx.nnx.module import Module, first_from from flax.experimental.nnx.nnx.nn import dtypes, initializers -PRNGKey = jax.Array +KeyArray = jax.Array Array = jax.Array Shape = tp.Tuple[int, ...] Dtype = tp.Any # this could be a real type? @@ -202,10 +202,10 @@ def __init__( use_bias: bool = True, use_scale: bool = True, bias_init: tp.Callable[ - [PRNGKey, Shape, Dtype], Array + [KeyArray, Shape, Dtype], Array ] = initializers.zeros(), scale_init: tp.Callable[ - [PRNGKey, Shape, Dtype], Array + [KeyArray, Shape, Dtype], Array ] = initializers.ones(), axis_name: tp.Optional[str] = None, axis_index_groups: tp.Any = None, @@ -333,10 +333,10 @@ def __init__( use_bias: bool = True, use_scale: bool = True, bias_init: tp.Callable[ - [PRNGKey, Shape, Dtype], Array + [KeyArray, Shape, Dtype], Array ] = initializers.zeros(), scale_init: tp.Callable[ - [PRNGKey, Shape, Dtype], Array + [KeyArray, Shape, Dtype], Array ] = initializers.ones(), reduction_axes: Axes = -1, feature_axes: Axes = -1, diff --git a/flax/experimental/nnx/nnx/rnglib.py b/flax/experimental/nnx/nnx/rnglib.py index 78f9b9e831..e12ca1a60c 100644 --- a/flax/experimental/nnx/nnx/rnglib.py +++ b/flax/experimental/nnx/nnx/rnglib.py @@ -67,7 +67,7 @@ def fork(self, pattern: Pattern) -> 'RngStream': num_splits = int(np.prod([x for x in pattern if x is not None])) axis_size = tuple(x if x is not None else 1 for x in pattern) # reshape key - key = jax.random.split(key, num_splits).reshape(*axis_size, -1) + key = jax.random.split(key, num_splits).reshape(*axis_size) count_path = [0] return RngStream(key, count_path) @@ -77,8 +77,8 @@ def copy(self) -> 'RngStream': jax.tree_util.register_pytree_node( RngStream, - lambda rng: (rng.key, tuple(rng.counts)), - lambda counts, key: RngStream(key, list(counts)), + lambda rng: ((rng.key,), tuple(rng.counts)), + lambda counts, nodes: RngStream(nodes[0], list(counts)), ) RngValue = tp.Union[int, jax.Array, RngStream] @@ -106,7 +106,7 @@ def __init__( self._rngs = { name: ( - RngStream(jax.random.PRNGKey(value), [0]) + RngStream(jax.random.key(value), [0]) if isinstance(value, int) else RngStream(value, [0]) if isinstance(value, jax.Array) @@ -176,7 +176,7 @@ def fork( def fork( self, - __default: Pattern | dict[filterlib.Filter, Pattern] | Missing = MISSING, + _default: Pattern | dict[filterlib.Filter, Pattern] | Missing = MISSING, **patterns: Pattern, ) -> dict[str, RngStream] | tuple[dict[str, RngStream], dict[str, RngStream]]: if not self.is_valid(): @@ -185,15 +185,15 @@ def fork( ) filter_patterns: list[tuple[filterlib.Filter, Pattern]] - if isinstance(__default, dict): + if isinstance(_default, dict): # merge default and patterns filter_patterns = [ - *__default.items(), + *_default.items(), *patterns.items(), (..., None), # broadcast all remaining ] else: - default = None if isinstance(__default, Missing) else __default + default = None if isinstance(_default, Missing) else _default filter_patterns = [ *patterns.items(), (..., default), # split all remaining with default @@ -204,25 +204,24 @@ def fork( for filter_, pattern in filter_patterns ] - rngs = self._rngs.copy() + splits: dict[str, RngStream] = {} + broadcasts: dict[str, RngStream] = {} - for name, stream in rngs.items(): + for name, stream in self._rngs.items(): for predicate, pattern in predicate_pattern: if predicate(name, stream): - rngs[name] = stream.fork(pattern) + fork = stream.fork(pattern) + if pattern is None: + broadcasts[name] = fork + else: + splits[name] = fork break else: raise RuntimeError( f'Strea {name!r} did not match any predicate, this is a bug.' ) - if isinstance(__default, dict) or patterns: - split_rngs, broadcast_rngs = {}, {} - for name, stream in rngs.items(): - if patterns[name] is None: - broadcast_rngs[name] = stream - else: - split_rngs[name] = stream - return split_rngs, broadcast_rngs + if isinstance(_default, dict) or patterns: + return splits, broadcasts else: - return rngs + return {**splits, **broadcasts} diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py index 08874968ad..db86214d4f 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/experimental/nnx/nnx/transforms.py @@ -113,7 +113,7 @@ class JITOptions: inline: bool abstracted_axes: tp.Optional[tp.Any] - def get_kwargs(self) -> tp.Dict[str, tp.Any]: + def get_kwargs(self) -> dict[str, tp.Any]: kwargs = vars(self).copy() if kwargs['in_shardings'] is UNSPECIFIED: kwargs.pop('in_shardings') @@ -187,7 +187,7 @@ def jitted_fn( nnx_trace = tracers.get_top_trace((args, kwargs)) with tracers.nnx_trace(nnx_trace): if 'rngs' in kwargs: - kwargs['rngs'] = rnglib.merge_rngs(kwargs['rngs']) + kwargs['rngs'] = rnglib.Rngs(kwargs['rngs']) module = moduledef.merge(*states) out = f(module, *args, **kwargs) @@ -211,10 +211,11 @@ def jit_init( module = tp.cast(M, module) if 'rngs' in kwargs and isinstance(rngs := kwargs['rngs'], rnglib.Rngs): - kwargs['rngs'] = rngs.split() + kwargs['rngs'] = rngs.fork() state_and_def = module.split() - updates, _ = jitted_fn(state_and_def, *args, **kwargs) + out = jitted_fn(state_and_def, *args, **kwargs) + updates, _ = out module.update(updates) @@ -230,7 +231,7 @@ def jit_apply( module = tp.cast(M, module) if 'rngs' in kwargs and isinstance(rngs := kwargs['rngs'], rnglib.Rngs): - kwargs['rngs'] = rngs.split() + kwargs['rngs'] = rngs.fork() state_and_def = module.split() updates, out = jitted_fn(state_and_def, *args, **kwargs) @@ -255,7 +256,7 @@ def __init__( abstracted_axes: tp.Optional[tp.Any] = None, # submodule args module_init_args: tuple[tp.Any, ...], - module_init_kwargs: tp.Dict[str, tp.Any], + module_init_kwargs: dict[str, tp.Any], ): self.options = JITOptions( in_shardings=in_shardings, @@ -418,7 +419,7 @@ def __init__( return_value: bool = False, # submodule args module_init_args: tuple[tp.Any, ...], - module_init_kwargs: tp.Dict[str, tp.Any], + module_init_kwargs: dict[str, tp.Any], ): self.options = GradOptions( wrt=wrt, @@ -686,7 +687,7 @@ def __init__( scan_output: bool = True, # submodule args module_init_args: tuple[tp.Any, ...], - module_init_kwargs: tp.Dict[str, tp.Any], + module_init_kwargs: dict[str, tp.Any], ): self.module_constructor = module_constructor self.options = ScanOptions( @@ -747,7 +748,7 @@ def scan_init( options: ScanOptions, module_constructor: tp.Callable[..., M], module_init_args: tuple[tp.Any, ...], - module_init_kwargs: tp.Dict[str, tp.Any], + module_init_kwargs: dict[str, tp.Any], ) -> M: if options.variable_axes and options.length is None: raise ValueError('Cannot use variable_axes without specifying a length') @@ -759,42 +760,31 @@ def scan_init( if rngs is not None and not isinstance(rngs, rnglib.Rngs): raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - key_values = [] + split_keys = [] if rngs is not None: if not isinstance(rngs, rnglib.Rngs): raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - keys, rngsdef = rngs.split() - broadcast_predicate = filterlib.to_predicate(options.broadcast_rngs) - - key_axes = [] - key_names = tuple(keys.keys()) - - for name, key in keys.items(): - if broadcast_predicate(name, key): - key_axes.append(None) - else: - if options.length is None: - raise ValueError('Cannot split RNGs without specifying a length') - key = jax.random.split(key, options.length) - key_axes.append(0) - key_values.append(key) + split_keys, broadcast_keys = rngs.fork( + {filterlib.Not(options.broadcast_rngs): options.length} + ) + + if split_keys and options.length is None: + raise ValueError('Cannot split RNGs without specifying a length') + else: - key_names = None - rngsdef = None - key_axes = None + split_keys = None + broadcast_keys = None moduledef: tp.Optional[ModuleDef[M]] = None - def _init_state(*key_values): + def _init_state(split_keys, broadcast_keys): nonlocal moduledef - if rngsdef is not None: - assert key_names is not None - keys = dict(zip(key_names, key_values)) - rngs = rngsdef.merge(keys) - module_init_kwargs['rngs'] = rngs + if split_keys is not None: + assert broadcast_keys is not None + module_init_kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) module = module_constructor(*module_init_args, **module_init_kwargs) @@ -805,16 +795,16 @@ def _init_state(*key_values): return tuple(states) - if rngsdef is not None or options.variable_axes: + if split_keys is not None or options.variable_axes: init_out_axes = (*options.variable_axes.values(), None) _init_state = jax.vmap( _init_state, - in_axes=key_axes, + in_axes=(0, None), out_axes=init_out_axes, axis_size=options.length, ) - *axes_states, carry_state = _init_state(*key_values) + *axes_states, carry_state = _init_state(split_keys, broadcast_keys) moduledef = tp.cast(ModuleDef[M], moduledef) # add additional axis name to Variable.sharding @@ -905,26 +895,14 @@ def scan_apply( ) # split rng state - scan_keys: tp.Optional[tp.Dict[str, jax.Array]] - broadcast_keys: tp.Optional[tp.Dict[str, jax.Array]] if rngs is not None: if not isinstance(rngs, rnglib.Rngs): raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - - scan_keys = {} - broadcast_keys = {} - - keys, rngsdef = rngs.split() - broadcast_predicate = filterlib.to_predicate(options.broadcast_rngs) - - for name, key in keys.items(): - if broadcast_predicate(name, key): - broadcast_keys[name] = key - else: - scan_keys[name] = jax.random.split(key, length) + split_keys, broadcast_keys = rngs.fork( + {filterlib.Not(options.broadcast_rngs): length} + ) else: - rngsdef = None - scan_keys = None + split_keys = None broadcast_keys = None moduledef_out: tp.Optional[ModuleDef[Module]] = None @@ -932,7 +910,7 @@ def scan_apply( def scan_fn( carry: tuple[State, tp.Any], scan: tuple[ - tp.Optional[tp.Dict[str, jax.Array]], + dict[str, rnglib.RngStream] | None, tuple[State, ...], tuple[tp.Any, ...], dict[str, tp.Any], @@ -940,7 +918,7 @@ def scan_fn( ): nonlocal moduledef_out carry_state, carry_arg = carry - scan_keys, scan_states, scan_args, scan_kwargs = scan + split_keys, scan_states, scan_args, scan_kwargs = scan # merge args and kwargs args = jax.tree_map( @@ -959,10 +937,9 @@ def scan_fn( ) # merge rng state - if rngsdef is not None: - assert scan_keys is not None and broadcast_keys is not None - rngs = rngsdef.merge({**scan_keys, **broadcast_keys}) - kwargs['rngs'] = rngs + if split_keys is not None: + assert broadcast_keys is not None + kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) # remove metadata axis name from Variable.sharding if spmd.PARTITION_NAME in options.scan_metadata: @@ -1009,7 +986,7 @@ def scan_fn( return full_carry_out, full_scan_out carry = (carry_state, carry_arg) - scan = (scan_keys, scan_states, scan_args, scan_kwargs) + scan = (split_keys, scan_states, scan_args, scan_kwargs) full_carry_out, full_scan_out = jax.lax.scan( scan_fn, @@ -1169,7 +1146,7 @@ def __init__( policy: tp.Optional[tp.Callable[..., bool]] = None, # submodule args module_init_args: tuple[tp.Any, ...], - module_init_kwargs: tp.Dict[str, tp.Any], + module_init_kwargs: dict[str, tp.Any], ): self.options = RematOptions( prevent_cse=prevent_cse, @@ -1218,22 +1195,16 @@ def remat_apply( _check_args(args) state, moduledef = module.split() - - if rngs is not None: - keys, rngsdef = rngs.split() - else: - keys = None - rngsdef = None + keys = rngs.fork() if rngs is not None else None def _remat_fn( state: State, - keys: tp.Optional[tp.Dict[str, jax.Array]], + keys: tp.Optional[dict[str, jax.Array]], *args, ) -> tuple[tuple[State, ModuleDef[Module]], tp.Any]: kwargs = {} if keys is not None: - assert rngsdef is not None - kwargs['rngs'] = rngsdef.merge(keys) + kwargs['rngs'] = rnglib.Rngs(keys) module = moduledef.merge(state) out = f(module, *args, **kwargs) @@ -1360,7 +1331,7 @@ def __init__( vmap_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), # submodule args module_init_args: tuple[tp.Any, ...], - module_init_kwargs: tp.Dict[str, tp.Any], + module_init_kwargs: dict[str, tp.Any], ): self.module_constructor = module_constructor self.options = VmapOptions( @@ -1413,7 +1384,7 @@ def vmap_init( options: VmapOptions, module_constructor: tp.Callable[..., M], module_init_args: tuple[tp.Any, ...], - module_init_kwargs: tp.Dict[str, tp.Any], + module_init_kwargs: dict[str, tp.Any], ) -> M: if options.variable_axes and options.axis_size is None: raise ValueError('Cannot use variable_axes without specifying a length') @@ -1425,42 +1396,26 @@ def vmap_init( if rngs is not None and not isinstance(rngs, rnglib.Rngs): raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - key_values = [] - if rngs is not None: if not isinstance(rngs, rnglib.Rngs): raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - - keys, rngsdef = rngs.split() - broadcast_predicate = filterlib.to_predicate(options.broadcast_rngs) - - key_axes = [] - key_names = tuple(keys.keys()) - - for name, key in keys.items(): - if broadcast_predicate(name, key): - key_axes.append(None) - else: - if options.axis_size is None: - raise ValueError('Cannot split RNGs without specifying a length') - key = jax.random.split(key, options.axis_size) - key_axes.append(0) - key_values.append(key) + split_keys, broadcast_keys = rngs.fork( + {filterlib.Not(options.broadcast_rngs): options.axis_size} + ) + if split_keys and options.axis_size is None: + raise ValueError('Cannot split RNGs without specifying a length') else: - key_names = None - rngsdef = None - key_axes = None + split_keys = None + broadcast_keys = None moduledef: tp.Optional[ModuleDef[M]] = None - def _init_state(*key_values): + def _init_state(split_keys, broadcast_keys): nonlocal moduledef - if rngsdef is not None: - assert key_names is not None - keys = dict(zip(key_names, key_values)) - rngs = rngsdef.merge(keys) - module_init_kwargs['rngs'] = rngs + if split_keys is not None: + assert broadcast_keys is not None + module_init_kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) module = module_constructor(*module_init_args, **module_init_kwargs) @@ -1471,16 +1426,16 @@ def _init_state(*key_values): return tuple(states) - if rngsdef is not None or options.variable_axes: + if split_keys is not None or options.variable_axes: init_out_axes = (*options.variable_axes.values(), None) _init_state = jax.vmap( _init_state, - in_axes=key_axes, + in_axes=(0, None), out_axes=init_out_axes, axis_size=options.axis_size, ) - *axes_states, carry_state = _init_state(*key_values) + *axes_states, carry_state = _init_state(split_keys, broadcast_keys) moduledef = tp.cast(ModuleDef[M], moduledef) # add additional axis name to Variable.sharding @@ -1549,26 +1504,15 @@ def vmap_apply( ) # split rng state - vectorized_keys: tp.Optional[tp.Dict[str, jax.Array]] - broadcast_keys: tp.Optional[tp.Dict[str, jax.Array]] if rngs is not None: if not isinstance(rngs, rnglib.Rngs): raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - vectorized_keys = {} - broadcast_keys = {} - - keys, rngsdef = rngs.split() - broadcast_predicate = filterlib.to_predicate(options.broadcast_rngs) - - for name, key in keys.items(): - if broadcast_predicate(name, key): - broadcast_keys[name] = key - else: - vectorized_keys[name] = jax.random.split(key, axis_size) + split_keys, broadcast_keys = rngs.fork( + {filterlib.Not(options.broadcast_rngs): axis_size} + ) else: - rngsdef = None - vectorized_keys = None + split_keys = None broadcast_keys = None moduledef_out: tp.Optional[ModuleDef[Module]] = None @@ -1588,7 +1532,7 @@ def vmap_apply( spmd_axis_name=options.spmd_axis_name, ) def vmap_fn( - vectorized_keys: tp.Optional[tp.Dict[str, jax.Array]], + split_keys: dict[str, rnglib.RngStream] | None, vectorized_states: list[State], args: tuple[tp.Any, ...], kwargs: dict[str, tp.Any], @@ -1596,10 +1540,9 @@ def vmap_fn( nonlocal moduledef_out # merge rng state - if rngsdef is not None: - assert vectorized_keys is not None and broadcast_keys is not None - rngs = rngsdef.merge({**vectorized_keys, **broadcast_keys}) - kwargs['rngs'] = rngs + if split_keys is not None: + assert broadcast_keys is not None + kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) # remove metadata axis name from Variable.sharding if spmd.PARTITION_NAME in options.vmap_metadata: @@ -1632,7 +1575,7 @@ def vmap_fn( return broadcast_state_out, vectorized_states_out, output broadcast_state, vectorized_states, output = vmap_fn( - vectorized_keys, vectorized_states, args, kwargs + split_keys, vectorized_states, args, kwargs ) assert moduledef_out is not None diff --git a/flax/experimental/nnx/tests/test_rngs.py b/flax/experimental/nnx/tests/test_rngs.py index 44a15a5e9c..adede26d52 100644 --- a/flax/experimental/nnx/tests/test_rngs.py +++ b/flax/experimental/nnx/tests/test_rngs.py @@ -127,6 +127,7 @@ def test_partition_merge(self): def test_fork_broadcast(self): rngs = nnx.Rngs(params=0, dropout=1) + jax.random.key keys = rngs.fork() # all broadcast @@ -139,33 +140,41 @@ def test_fork_split(self): rngs = nnx.Rngs(params=0, dropout=1) keys = rngs.fork(4) # split all - assert keys['params'].key.shape == (4, 2) - assert keys['dropout'].key.shape == (4, 2) + assert keys['params'].key.shape == (4,) + assert keys['dropout'].key.shape == (4,) def test_fork_split_and_broadcast(self): rngs = nnx.Rngs(params=0, dropout=1) splits, broadcasts = rngs.fork(params=4, dropout=None) - assert splits['params'].key.shape == (4, 2) - assert broadcasts['dropout'].key.shape == (2,) + assert splits['params'].key.shape == (4,) + assert broadcasts['dropout'].key.shape == () def test_fork_filters(self): rngs = nnx.Rngs(params=0, dropout=1) splits, broadcasts = rngs.fork({'params': 4}) - assert splits['params'].key.shape == (4, 2) - assert broadcasts['dropout'].key.shape == (2,) + assert splits['params'].key.shape == (4,) + assert broadcasts['dropout'].key.shape == () def test_fork_multidimensional_split(self): rngs = nnx.Rngs(params=0, dropout=1) keys = rngs.fork((4, None, 3)) # split all - assert keys['params'].key.shape == (4, 1, 3, 2) - assert keys['dropout'].key.shape == (4, 1, 3, 2) + assert keys['params'].key.shape == (4, 1, 3) + assert keys['dropout'].key.shape == (4, 1, 3) def test_fork_multidimensional_split_mixed(self): rngs = nnx.Rngs(params=0, dropout=1) splits, broadcasts = rngs.fork(params=(4, None, 3)) # split all - assert splits['params'].key.shape == (4, 1, 3, 2) + assert splits['params'].key.shape == (4, 1, 3) assert broadcasts['dropout'].key.shape == () + + def test_rng_stream_pytree(self): + rngs = nnx.Rngs(params=0, dropout=1) + stream = rngs.fork()['params'] + + stream2 = jax.tree_map(lambda x: x, stream) + + assert stream.key is stream2.key diff --git a/pyproject.toml b/pyproject.toml index f15cf5369a..34d4263e52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,6 +138,10 @@ filterwarnings = [ "ignore:.*jax.config.define_bool_state is deprecated.:DeprecationWarning", # pytest-cov uses a deprecated feature of pytest-xdist. (2023-11-06) "ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning", + # DeprecationWarning: jax.random.KeyArray is deprecated. + "ignore:.*jax.random.KeyArray is deprecated.*:DeprecationWarning", + # DeprecationWarning: jax.core.Shape is deprecated. + "ignore:.*jax.core.Shape is deprecated.*:DeprecationWarning", ] [tool.coverage.report]