Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 9, 2023
1 parent 8528e4b commit ab71a72
Show file tree
Hide file tree
Showing 14 changed files with 143 additions and 194 deletions.
2 changes: 1 addition & 1 deletion flax/experimental/nnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ y = model(jnp.ones((8, 12))) # call methods directly
assert model.count == 1
```

In this example `nnx.Rngs(0)` create a `PRNGKey` for `params` with seed `0`, this is used by `rngs.<rng-name>()` inside `__init__` to generate a random key to initialize the parameters.
In this example `nnx.Rngs(0)` create a `random.key` for `params` with seed `0`, this is used by `rngs.<rng-name>()` inside `__init__` to generate a random key to initialize the parameters.

### Training with the Functional API

Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from .nnx.pytreelib import Pytree as Pytree
from .nnx.pytreelib import TreeNode as TreeNode
from .nnx.rnglib import Rngs as Rngs
from .nnx.rnglib import merge_rngs as merge_rngs
from .nnx.rnglib import RngStream as RngStream
from .nnx.spmd import PARTITION_NAME as PARTITION_NAME
from .nnx.spmd import get_partition_spec as get_partition_spec
from .nnx.spmd import with_partitioning as with_partitioning
Expand Down
8 changes: 4 additions & 4 deletions flax/experimental/nnx/docs/tiny_nnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"A = tp.TypeVar(\"A\")\n",
"M = tp.TypeVar(\"M\", bound=\"Module\")\n",
"Sharding = tp.Tuple[tp.Optional[str], ...]\n",
"KeyArray = random.KeyArray\n",
"Array = random.Array\n",
"\n",
"\n",
"class Variable(tp.Generic[A]):\n",
Expand Down Expand Up @@ -393,9 +393,9 @@
}
],
"source": [
"module = ScanMLP(hidden_size=10, n_layers=5, rngs=Rngs(random.PRNGKey(0)))\n",
"x = jax.random.normal(random.PRNGKey(0), (2, 10))\n",
"y = module(x, train=True, rngs=Rngs(random.PRNGKey(1)))\n",
"module = ScanMLP(hidden_size=10, n_layers=5, rngs=Rngs(random.key(0)))\n",
"x = jax.random.normal(random.key(0), (2, 10))\n",
"y = module(x, train=True, rngs=Rngs(random.key(1)))\n",
"\n",
"state, moduledef = module.split()\n",
"print(\"state =\", jax.tree_map(jnp.shape, state))\n",
Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/examples/05_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def sample(state: nnx.TrainState[VAE], z: jax.Array) -> jax.Array:


# %%
key = jax.random.PRNGKey(0)
key = jax.random.key(0)

for epoch in range(epochs):
losses = []
Expand Down
29 changes: 12 additions & 17 deletions flax/experimental/nnx/examples/06_scan_over_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,42 +41,37 @@ class ScanMLP(nnx.Module):

def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs):
self.n_layers = n_layers
# split Rngs and split the `params` key
keys, rngsdef = rngs.split()
params_key = jax.random.split(keys['params'], n_layers)
# fork Rngs, split keys into `n_layers`
keys = rngs.fork(n_layers)

def create_block(params_key):
# merge back Rngs using the sliced `params` key
rngs = rngsdef.merge({'params': params_key})
def create_block(keys):
# create Block instance and return its split
return Block(dim, rngs=rngs).split()
return Block(dim, rngs=nnx.Rngs(keys)).split()

# call vmap over create_block, passing the split `params` key
# and immediately merge to get a Block instance
self.layers = nnx.merge(jax.vmap(create_block)(params_key))
self.layers = nnx.merge(jax.vmap(create_block)(keys))

def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array:
# split Rngs and split the `dropout` key
keys, rngsdef = rngs.split()
dropout_key = jax.random.split(keys['dropout'], self.n_layers)
# fork Rngs, split keys into `n_layers`
keys = rngs.fork(self.n_layers)
# split Module to get params
params, moduledef = self.layers.split(nnx.Param)

def scan_fn(
x: jax.Array, inputs: Tuple[nnx.State, jax.Array]
x: jax.Array, inputs: Tuple[nnx.State, dict[str, nnx.RngStream]]
) -> Tuple[jax.Array, nnx.State]:
params, dropout_key = inputs
params, keys = inputs
# merge back Module and Rngs
rngs = rngsdef.merge({'dropout': dropout_key})
module = moduledef.merge(params)
# forward pass
x = module(x, rngs=rngs)
x = module(x, rngs=nnx.Rngs(keys))
# split state and return
params, _ = module.split(nnx.Param)
return x, params

# call scan passing x as the carry, and params + dropout_key as the input
x, params = jax.lax.scan(scan_fn, x, (params, dropout_key))
# call scan passing x as the carry, and params + keys as the input
x, params = jax.lax.scan(scan_fn, x, (params, keys))
# update layers state and return
self.layers.update(params)
return x
Expand Down
4 changes: 2 additions & 2 deletions flax/experimental/nnx/examples/07_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,10 @@ def __call__(self, cfg: Config, x, *, rngs: nnx.Rngs):
assert isinstance(self.layers, DecoderBlock)

state, moduledef = self.layers.split()
rngs, rngsdef = rngs.split()
rngs, rngsdef = rngs.fork()
dropout_key = jax.random.split(rngs['dropout'], cfg.layers)

def scan_fn(x, s: tp.Tuple[jax.random.KeyArray, nnx.State]):
def scan_fn(x, s: tp.Tuple[jax.Array, nnx.State]):
dropout_key, state = s
rngs = rngsdef.merge({'dropout': dropout_key})
y, (state, _) = moduledef.apply(state)(cfg, x, rngs=rngs)
Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/examples/10_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def forward(state: nnx.TrainState[MLP], inputs: jax.Array) -> jax.Array:


# %%
key = jax.random.PRNGKey(0)
key = jax.random.key(0)

for epoch in range(epochs):
for step in range(steps_per_epoch):
Expand Down
7 changes: 3 additions & 4 deletions flax/experimental/nnx/nnx/nn/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,13 @@

Shape = tp.Sequence[int]
DTypeLikeInexact = tp.Any
KeyArray = jax.random.KeyArray
Array = jax.Array


class Initializer(tp.Protocol):
@staticmethod
def __call__(
key: KeyArray, shape: Shape, dtype: DTypeLikeInexact = jnp.float_
key: Array, shape: Shape, dtype: DTypeLikeInexact = jnp.float_
) -> Array:
...

Expand All @@ -53,7 +52,7 @@ def zeros() -> Initializer:
>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import zeros_init
>>> zeros_initializer = zeros_init()
>>> zeros_initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
"""
Expand All @@ -66,7 +65,7 @@ def ones() -> Initializer:
>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import ones_init
>>> ones_initializer = ones_init()
>>> ones_initializer(jax.random.PRNGKey(42), (3, 2), jnp.float32)
>>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32)
Expand Down
12 changes: 6 additions & 6 deletions flax/experimental/nnx/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from flax.experimental.nnx.nnx.nn import dtypes, initializers

Array = jax.Array
PRNGKey = tp.Any
KeyArray = jax.Array
Shape = tp.Tuple[int, ...]
Dtype = tp.Any # this could be a real type?
PrecisionLike = tp.Union[
Expand Down Expand Up @@ -101,10 +101,10 @@ def __init__(
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
kernel_init: tp.Callable[
[PRNGKey, Shape, Dtype], Array
[KeyArray, Shape, Dtype], Array
] = default_kernel_init,
bias_init: tp.Callable[
[PRNGKey, Shape, Dtype], Array
[KeyArray, Shape, Dtype], Array
] = initializers.zeros(),
dot_general: DotGeneralT = lax.dot_general,
rngs: rnglib.Rngs,
Expand Down Expand Up @@ -210,10 +210,10 @@ def __init__(
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
kernel_init: tp.Callable[
[PRNGKey, Shape, Dtype], Array
[KeyArray, Shape, Dtype], Array
] = default_kernel_init,
bias_init: tp.Callable[
[PRNGKey, Shape, Dtype], Array
[KeyArray, Shape, Dtype], Array
] = initializers.zeros(),
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated,
rngs: rnglib.Rngs,
Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
embedding_init: tp.Callable[
[PRNGKey, Shape, Dtype], Array
[KeyArray, Shape, Dtype], Array
] = default_embed_init,
rngs: rnglib.Rngs,
):
Expand Down
10 changes: 5 additions & 5 deletions flax/experimental/nnx/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from flax.experimental.nnx.nnx.module import Module, first_from
from flax.experimental.nnx.nnx.nn import dtypes, initializers

PRNGKey = jax.Array
KeyArray = jax.Array
Array = jax.Array
Shape = tp.Tuple[int, ...]
Dtype = tp.Any # this could be a real type?
Expand Down Expand Up @@ -202,10 +202,10 @@ def __init__(
use_bias: bool = True,
use_scale: bool = True,
bias_init: tp.Callable[
[PRNGKey, Shape, Dtype], Array
[KeyArray, Shape, Dtype], Array
] = initializers.zeros(),
scale_init: tp.Callable[
[PRNGKey, Shape, Dtype], Array
[KeyArray, Shape, Dtype], Array
] = initializers.ones(),
axis_name: tp.Optional[str] = None,
axis_index_groups: tp.Any = None,
Expand Down Expand Up @@ -333,10 +333,10 @@ def __init__(
use_bias: bool = True,
use_scale: bool = True,
bias_init: tp.Callable[
[PRNGKey, Shape, Dtype], Array
[KeyArray, Shape, Dtype], Array
] = initializers.zeros(),
scale_init: tp.Callable[
[PRNGKey, Shape, Dtype], Array
[KeyArray, Shape, Dtype], Array
] = initializers.ones(),
reduction_axes: Axes = -1,
feature_axes: Axes = -1,
Expand Down
39 changes: 19 additions & 20 deletions flax/experimental/nnx/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def fork(self, pattern: Pattern) -> 'RngStream':
num_splits = int(np.prod([x for x in pattern if x is not None]))
axis_size = tuple(x if x is not None else 1 for x in pattern)
# reshape key
key = jax.random.split(key, num_splits).reshape(*axis_size, -1)
key = jax.random.split(key, num_splits).reshape(*axis_size)
count_path = [0]
return RngStream(key, count_path)

Expand All @@ -77,8 +77,8 @@ def copy(self) -> 'RngStream':

jax.tree_util.register_pytree_node(
RngStream,
lambda rng: (rng.key, tuple(rng.counts)),
lambda counts, key: RngStream(key, list(counts)),
lambda rng: ((rng.key,), tuple(rng.counts)),
lambda counts, nodes: RngStream(nodes[0], list(counts)),
)

RngValue = tp.Union[int, jax.Array, RngStream]
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(

self._rngs = {
name: (
RngStream(jax.random.PRNGKey(value), [0])
RngStream(jax.random.key(value), [0])
if isinstance(value, int)
else RngStream(value, [0])
if isinstance(value, jax.Array)
Expand Down Expand Up @@ -176,7 +176,7 @@ def fork(

def fork(
self,
__default: Pattern | dict[filterlib.Filter, Pattern] | Missing = MISSING,
_default: Pattern | dict[filterlib.Filter, Pattern] | Missing = MISSING,
**patterns: Pattern,
) -> dict[str, RngStream] | tuple[dict[str, RngStream], dict[str, RngStream]]:
if not self.is_valid():
Expand All @@ -185,15 +185,15 @@ def fork(
)

filter_patterns: list[tuple[filterlib.Filter, Pattern]]
if isinstance(__default, dict):
if isinstance(_default, dict):
# merge default and patterns
filter_patterns = [
*__default.items(),
*_default.items(),
*patterns.items(),
(..., None), # broadcast all remaining
]
else:
default = None if isinstance(__default, Missing) else __default
default = None if isinstance(_default, Missing) else _default
filter_patterns = [
*patterns.items(),
(..., default), # split all remaining with default
Expand All @@ -204,25 +204,24 @@ def fork(
for filter_, pattern in filter_patterns
]

rngs = self._rngs.copy()
splits: dict[str, RngStream] = {}
broadcasts: dict[str, RngStream] = {}

for name, stream in rngs.items():
for name, stream in self._rngs.items():
for predicate, pattern in predicate_pattern:
if predicate(name, stream):
rngs[name] = stream.fork(pattern)
fork = stream.fork(pattern)
if pattern is None:
broadcasts[name] = fork
else:
splits[name] = fork
break
else:
raise RuntimeError(
f'Strea {name!r} did not match any predicate, this is a bug.'
)

if isinstance(__default, dict) or patterns:
split_rngs, broadcast_rngs = {}, {}
for name, stream in rngs.items():
if patterns[name] is None:
broadcast_rngs[name] = stream
else:
split_rngs[name] = stream
return split_rngs, broadcast_rngs
if isinstance(_default, dict) or patterns:
return splits, broadcasts
else:
return rngs
return {**splits, **broadcasts}
Loading

0 comments on commit ab71a72

Please sign in to comment.