diff --git a/docs/experimental/nnx/index.rst b/docs/experimental/nnx/index.rst index 886542a41e..05ef0b027c 100644 --- a/docs/experimental/nnx/index.rst +++ b/docs/experimental/nnx/index.rst @@ -98,7 +98,7 @@ Basic usage self.din, self.dout = din, dout def __call__(self, x: jax.Array): - return x @ self.w + self.b + return x @ self.w.value + self.b.value rngs = nnx.Rngs(0) # explicit RNG handling model = Linear(din=2, dout=3, rngs=rngs) # initialize the model diff --git a/docs/experimental/nnx/nnx_basics.ipynb b/docs/experimental/nnx/nnx_basics.ipynb index 16dd9420d7..8276e5f7dc 100644 --- a/docs/experimental/nnx/nnx_basics.ipynb +++ b/docs/experimental/nnx/nnx_basics.ipynb @@ -50,7 +50,7 @@ " self.din, self.dout = din, dout\n", "\n", " def __call__(self, x: jax.Array):\n", - " return x @ self.w + self.b" + " return x @ self.w.value + self.b.value" ] }, { @@ -80,9 +80,9 @@ " din=2,\n", " dout=3\n", ")\n", - "model.w = Array([[0.19007349, 0.31424356, 0.3686391 ],\n", + "model.w.value = Array([[0.19007349, 0.31424356, 0.3686391 ],\n", " [0.7862853 , 0.03352201, 0.50682676]], dtype=float32)\n", - "model.b = Array([0., 0., 0.], dtype=float32)\n" + "model.b.value = Array([0., 0., 0.], dtype=float32)\n" ] } ], @@ -90,8 +90,8 @@ "model = Linear(din=2, dout=3, rngs=nnx.Rngs(params=0))\n", "\n", "print(f'{model = }')\n", - "print(f'{model.w = }')\n", - "print(f'{model.b = }')" + "print(f'{model.w.value = }')\n", + "print(f'{model.b.value = }')" ] }, { @@ -99,7 +99,7 @@ "metadata": {}, "source": [ "This is very handy for debugging as it allows accessing the entire structure or\n", - "modify it. Similarly, computation can ran directly." + "modify it. Similarly, computation can be ran directly." ] }, { @@ -145,15 +145,15 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "counter.count = 0\n", - "counter.count = 1\n" + "counter.count.value = 0\n", + "counter.count.value = 1\n" ] } ], @@ -163,12 +163,12 @@ " self.count = nnx.Variable(0)\n", "\n", " def __call__(self):\n", - " self.count += 1\n", + " self.count.value += 1\n", "\n", "counter = Counter()\n", - "print(f'{counter.count = }')\n", + "print(f'{counter.count.value = }')\n", "counter()\n", - "print(f'{counter.count = }')" + "print(f'{counter.count.value = }')" ] }, { @@ -199,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -215,9 +215,9 @@ " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", - " kernel_init=.init at 0x1307e2a60>,\n", - " bias_init=,\n", - " dot_general=\n", + " kernel_init=.init at 0x169773f70>,\n", + " bias_init=,\n", + " dot_general=\n", " ),\n", " bn=BatchNorm(\n", " num_features=2,\n", @@ -257,22 +257,22 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "model.blocks[1].linear.kernel = Array([[0.992858 , 0.9711272],\n", + "model.blocks[1].linear.kernel.value = Array([[0.992858 , 0.9711272],\n", " [1.4061186, 0.4704619]], dtype=float32)\n", - "model.blocks[0].bn.scale = Array([1., 1.], dtype=float32)\n" + "model.blocks[0].bn.scale.value = Array([1., 1.], dtype=float32)\n" ] } ], "source": [ - "print(f'{model.blocks[1].linear.kernel = }')\n", - "print(f'{model.blocks[0].bn.scale = }')" + "print(f'{model.blocks[1].linear.kernel.value = }')\n", + "print(f'{model.blocks[0].bn.scale.value = }')" ] }, { @@ -316,7 +316,7 @@ "model.blocks[2] = awesome_layer\n", "\n", "# Variable sharing (weight tying)\n", - "model.blocks[-1].linear.variables.kernel = model.blocks[0].linear.variables.kernel\n", + "model.blocks[-1].linear.kernel = model.blocks[0].linear.kernel\n", "\n", "model(jnp.ones((1, 2)))" ] @@ -353,8 +353,8 @@ " self.count = Count(0)\n", "\n", " def __call__(self, x: jax.Array):\n", - " self.count += 1\n", - " return x @ self.w + self.b\n", + " self.count.value += 1\n", + " return x @ self.w.value + self.b.value\n", " \n", "model = StatefulLinear(din=2, dout=3, rngs=nnx.Rngs(0))" ] @@ -382,10 +382,16 @@ "output_type": "stream", "text": [ "state = State({\n", - " 'w': Array([[0.19007349, 0.31424356, 0.3686391 ],\n", - " [0.7862853 , 0.03352201, 0.50682676]], dtype=float32),\n", - " 'b': Array([0., 0., 0.], dtype=float32),\n", - " 'count': 0\n", + " 'w': Param(\n", + " raw_value=Array([[0.19007349, 0.31424356, 0.3686391 ],\n", + " [0.7862853 , 0.03352201, 0.50682676]], dtype=float32)\n", + " ),\n", + " 'b': Param(\n", + " raw_value=Array([0., 0., 0.], dtype=float32)\n", + " ),\n", + " 'count': Count(\n", + " raw_value=0\n", + " )\n", "})\n", "\n", "static = GraphDef(\n", @@ -431,8 +437,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "model.count = 0\n", - "model.count = Array(1, dtype=int32, weak_type=True)\n" + "model.count = Count(\n", + " raw_value=0\n", + ")\n", + "model.count.value = Array(1, dtype=int32, weak_type=True)\n" ] } ], @@ -456,7 +464,7 @@ "# 5. Update the state of the original Module\n", "model.update(state)\n", "\n", - "print(f'{model.count = }')" + "print(f'{model.count.value = }')" ] }, { @@ -503,13 +511,19 @@ "output_type": "stream", "text": [ "params = State({\n", - " 'w': Array([[0.19007349, 0.31424356, 0.3686391 ],\n", - " [0.7862853 , 0.03352201, 0.50682676]], dtype=float32),\n", - " 'b': Array([0., 0., 0.], dtype=float32)\n", + " 'w': Param(\n", + " raw_value=Array([[0.19007349, 0.31424356, 0.3686391 ],\n", + " [0.7862853 , 0.03352201, 0.50682676]], dtype=float32)\n", + " ),\n", + " 'b': Param(\n", + " raw_value=Array([0., 0., 0.], dtype=float32)\n", + " )\n", "})\n", "\n", "counts = State({\n", - " 'count': Array(1, dtype=int32, weak_type=True)\n", + " 'count': Count(\n", + " raw_value=Array(1, dtype=int32, weak_type=True)\n", + " )\n", "})\n" ] } diff --git a/docs/experimental/nnx/nnx_basics.md b/docs/experimental/nnx/nnx_basics.md index d76fd43510..311679903c 100644 --- a/docs/experimental/nnx/nnx_basics.md +++ b/docs/experimental/nnx/nnx_basics.md @@ -38,7 +38,7 @@ class Linear(nnx.Module): self.din, self.dout = din, dout def __call__(self, x: jax.Array): - return x @ self.w + self.b + return x @ self.w.value + self.b.value ``` As shown above dynamic state is stored in `nnx.Variable`s such as `nnx.Param`, @@ -54,12 +54,12 @@ for inspection. model = Linear(din=2, dout=3, rngs=nnx.Rngs(params=0)) print(f'{model = }') -print(f'{model.w = }') -print(f'{model.b = }') +print(f'{model.w.value = }') +print(f'{model.b.value = }') ``` This is very handy for debugging as it allows accessing the entire structure or -modify it. Similarly, computation can ran directly. +modify it. Similarly, computation can be ran directly. ```{code-cell} ipython3 x = jnp.ones((1, 2)) @@ -84,12 +84,12 @@ class Counter(nnx.Module): self.count = nnx.Variable(0) def __call__(self): - self.count += 1 + self.count.value += 1 counter = Counter() -print(f'{counter.count = }') +print(f'{counter.count.value = }') counter() -print(f'{counter.count = }') +print(f'{counter.count.value = }') ``` **This looks too easy, what is the catch?** @@ -136,8 +136,8 @@ One of the benefits of NNX is that nested Modules as easy to inspect and static analyzers can help you while doing so. ```{code-cell} ipython3 -print(f'{model.blocks[1].linear.kernel = }') -print(f'{model.blocks[0].bn.scale = }') +print(f'{model.blocks[1].linear.kernel.value = }') +print(f'{model.blocks[0].bn.scale.value = }') ``` #### Model Surgery @@ -160,7 +160,7 @@ def awesome_layer(x): return x model.blocks[2] = awesome_layer # Variable sharing (weight tying) -model.blocks[-1].linear.variables.kernel = model.blocks[0].linear.variables.kernel +model.blocks[-1].linear.kernel = model.blocks[0].linear.kernel model(jnp.ones((1, 2))) ``` @@ -187,8 +187,8 @@ class StatefulLinear(nnx.Module): self.count = Count(0) def __call__(self, x: jax.Array): - self.count += 1 - return x @ self.w + self.b + self.count.value += 1 + return x @ self.w.value + self.b.value model = StatefulLinear(din=2, dout=3, rngs=nnx.Rngs(0)) ``` @@ -236,7 +236,7 @@ y, state = forward(static, state, x=jnp.ones((1, 2))) # 5. Update the state of the original Module model.update(state) -print(f'{model.count = }') +print(f'{model.count.value = }') ``` The key insight of this pattern is that using mutable references is diff --git a/flax/experimental/nnx/README.md b/flax/experimental/nnx/README.md index 5b3b9f3193..0d4b177f85 100644 --- a/flax/experimental/nnx/README.md +++ b/flax/experimental/nnx/README.md @@ -2,7 +2,7 @@ # NNX -_**N**eural **N**etworks for JA**X**_ - | [docs](https://flax.readthedocs.io/en/latest/experimental/nnx/index.html) | +_**N**eural **N**etworks for JA**X**_ NNX is a JAX-based neural network library designed for simplicity and power. Its modular approach follows standard Python conventions, making it both intuitive and compatible with the broader JAX ecosystem. @@ -14,10 +14,12 @@ NNX is a JAX-based neural network library designed for simplicity and power. Its * **User-friendly**: NNX prioritizes simplicity for common use cases, building upon lessons learned from Linen to provide a streamlined experience. -> [!NOTE] -> NNX is currently in an experimental state and is subject to change. Linen is still the - recommended option for large-scale projects. Feedback and contributions are welcome! - +#### Table of Contents +* [Installation](#installation) +* [Getting Started](#getting-started) +* [Examples](#examples) +* [FAQs](#faqs) +* [User Guide](#user-guide) ## Installation @@ -26,104 +28,67 @@ To get started with `nnx`, install Flax from GitHub: pip install git+https://github.com/google/flax.git ``` -## What does NNX look like? - -We provide three examples using the NNX API: a simple multi-layer perceptron, a CNN and an auto-encoder. +## Getting Started -To learn more about the `Module` abstraction, check out our [NNX Basics](https://flax.readthedocs.io/en/latest/experimental/nnx/nnx_basics.html#) guide. +The following example guides you through creating a basic `Linear` model with NNX and executing a forward pass. It also demonstrate how handle mutable state by showing how to keep track of the number of times the model has been called. ```python +from flax.experimental import nnx import jax import jax.numpy as jnp -from flax.experimental import nnx - +class Count(nnx.Variable): pass # typed Variable collections -class MLP(nnx.Module): - def __init__(self, features: list[int], *, rngs: nnx.Rngs): - self.layers = [ - nnx.Linear(din, dout, rngs=rngs) - for din, dout in zip(features[:-1], features[1:]) - ] +class Linear(nnx.Module): + def __init__(self, din, dout, *, rngs: nnx.Rngs): # explicit RNG management + key = rngs() + # put dynamic state in Variable types + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = Count(0) + # other types as treated as static + self.din = din + self.dout = dout - def __call__(self, x: jax.Array) -> jax.Array: - for layer in self.layers[:-1]: - x = nnx.relu(layer(x)) - x = self.layers[-1](x) - return x + def __call__(self, x): + self.count += 1 # inplace stateful updates + return x @ self.w + self.b +model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # no special `init` method +x = jnp.ones((8, 12)) +y = model(x) # call methods directly -model = MLP([784, 64, 32, 10], rngs=nnx.Rngs(0)) -y = model(jnp.ones((1, 784))) +assert model.count == 1 ``` -```python -class CNN(nnx.Module): - def __init__(self, *, rngs: nnx.Rngs): - self.conv1 = nnx.Conv(1, 64, kernel_size=(3, 3), rngs=rngs) - self.conv2 = nnx.Conv(64, 32, kernel_size=(3, 3), rngs=rngs) - self.linear1 = nnx.Linear(7 * 7 * 32, 256, rngs=rngs) - self.linear2 = nnx.Linear(256, 10, rngs=rngs) +In this example `nnx.Rngs(0)` create a `random.key` for `params` with seed `0`, this is used by `rngs.()` inside `__init__` to generate a random key to initialize the parameters. - def __call__(self, x): - x = nnx.relu(self.conv1(x)) - x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nnx.relu(self.conv2(x)) - x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = nnx.relu(self.linear1(x)) - logits = self.linear2(x) - return logits - - -model = CNN(rngs=nnx.Rngs(0)) -x = jnp.ones((1, 28, 28, 1)) # (N, H, W, C) format -logits = model(x) -``` - -```python -class AutoEncoder(nnx.Module): - def __init__( - self, - input_features: int, - encoder_features: list[int], - decoder_features: list[int], - *, - rngs: nnx.Rngs, - ): - self.encoder = MLP([input_features, *encoder_features], rngs=rngs) - self.decoder = MLP([*decoder_features, input_features], rngs=rngs) - - def __call__(self, x): - return self.decode(self.encode(x)) +### Interacting with JAX - def encode(self, x): - return self.encoder(x) +While NNX Modules inherently follow reference semantics, they can be easily converted into a pure functional representation that can be used with JAX transformations and other value-based, functional code. - def decode(self, z): - return nnx.sigmoid(self.decoder(z)) +NNX has two very simple APIs to interact with JAX: `split` and `merge`. +The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `ModuleDef` object that contains the static structure of the Module. -model = AutoEncoder( - input_features=784, - encoder_features=[64, 32], - decoder_features=[32, 64], - rngs=nnx.Rngs(0), -) -x = jnp.ones((1, 784)) -z = model.encode(x) -y = model.decode(z) +```python +state, static = model.split() +``` +``` +state = State({ + 'b': Array(..., dtype=float32), + 'count': Array(1, dtype=int32), + 'w': Array(..., dtype=float32) +}) ``` -### Interacting with JAX - -To interact with JAX NNX provides the [Functional API](https://flax.readthedocs.io/en/latest/experimental/nnx/nnx_basics.html#the-functional-api) which consists of 3 simple methods: `split`, `merge`, and `update`. Using these methods any Module can be lifted to be used in JAX transformations. Here is a simple jitted `forward` function as an example: +The `ModuleDef.merge` method allows you to take a `ModuleDef` and one or more `State` objects and merge them back into a `Module` object. -```pythonthon -state, static = model.split() +Using `split` and `merge` in conjunction allows you to carry your Module in and out of any JAX transformation. Here is a simple jitted `forward` function as an example: +```python @jax.jit -def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array): +def forward(static: nnx.ModuleDef, state: nnx.State, x: jax.Array): model = static.merge(state) y = model(x) state, _ = model.split() @@ -131,18 +96,407 @@ def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array): x = jnp.ones((2, 4)) y, state = forward(static, state, x) +``` +``` +state["count"] = Array(2, dtype=int32) +``` + +For simple use cases, you can use `nnx.jit` which is a lifted transform that automatically splits, merges, and updates the outside Module for you: + +```python +state, static = model.split() -model.update(state) +@nnx.jit +def forward(model: Linear, x: jax.Array): + return model(x) + +y = forward(model, x=jnp.ones((2, 4))) + +assert model.count == 3 # state automatically updated! ``` -### Examples +#### Training Example + +Using `split` and `merge` (the [Functional API](#functional-api)) is the recommended way to use NNX as it provides tight control over the state, allows you to use regular JAX transformations, and it minimizes overhead. In this example we will create a simple training step that implements Stochastic Gradient Descent (SGD): + +```python +params, counts, static = model.split(nnx.Param, Count) + +@jax.jit +def train_step(params, counts, x, y): + def loss_fn(params): + model = static.merge(params, counts) + y_pred = model(x) + counts = model.extract(Count) # get updated Counts + loss = jax.numpy.mean((y_pred - y) ** 2) + return loss, counts + + # compute gradient + grads, counts = jax.grad(loss_fn, has_aux=True)(params) + # SGD update + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) + + return params, counts + +# execute the training step +params, counts = train_step(params, counts, x, y) +model = static.merge(params, counts) +assert model.count == 4 +``` +Here `...` is a `Filter` (much like `nnx.Param`) that matches any node type, see the [Filters](#filters) section for more information. + +#### Training with Lifted Transforms + +[Lifted Transforms](#lifted-transforms) provide a convenient way interact with NNX Modules. In this example, we use the `nnx.jit` and `nnx.grad` lifted transforms to define the training step. The model is trained using Stochastic Gradient Descent (SGD). Because lifted transforms automatically update the Module's state, `train_step` doesn't require a return statement. + +```python +@nnx.jit +def train_step(model, x, y): + def loss_fn(model): + y_pred = model(x) + return jax.numpy.mean((y_pred - y) ** 2) + + # compute gradient + grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) + # SGD update + params, *_ = model.split(nnx.Param, ...) + model.update( + jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) + ) + +# execute the training step +train_step(model, x, y) +assert model.count == 3 +``` + +**Note**: Using `nnx.jit` introduces some overhead when compared to using `jax.jit` directly. Use `nnx.jit` for simple prototypes, but for production code use `jax.jit` directly. + +## Examples + +* [Using the Functional API](https://github.com/cgarciae/nnx/blob/main/examples/01_functional_api.py): Shows how to train a simple model using the functional API. +* [Using Lifted Transforms](https://github.com/cgarciae/nnx/blob/main/examples/02_lifted_transforms.py): Shows how to train a simple model using lifted transforms. +* [Using TrainState](https://github.com/cgarciae/nnx/blob/main/examples/03_train_state.py): Shows how to train a simple model using the functional API with the help of `TrainState`. +* [Training a VAE](https://github.com/cgarciae/nnx/blob/main/examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset, uses the functional API, `TrainState`, and shows how to use capture intermediate values to retrieve `kl_loss`. +* [Scan over layers](https://github.com/cgarciae/nnx/blob/main/examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. +* [Creating a Transformer](https://github.com/cgarciae/nnx/blob/main/examples/07_transformer.py): Shows how to create a Transformer with an auto-regressive decoder that uses scan over layers and a kv-cache for fast inference. Credits to @levskaya. + +## FAQs + +### Status +NNX is still in early development so expect bugs and breaking changes. + +### How is it different from Flax? +NNX takes the best features that allow Flax to scale to large projects and integrates them into a much simpler Module system with pythonic semantics. + +One place in which NNX strongly deviates from Flax is that (currently) it avoids shape inference in favor of static initialization. It is not a technical limitation but rather a design choice. This design both simplifies the internal implementation and makes it easier to reason about the code for the user, at the cost of being more verbose at times. On the other hand, Pytorch users will feel right at home. + +## User Guide + +### Modules + +NNX Modules are normal python classes, they obey regular python semantics such as mutability and reference sharing, including reference cycles. They can contain 2 types of attributes: node attributes and static attributes. Node attributes include NNX `Variable`s (e.g. `nnx.Param`) and sub-Modules. All other types are treated as static attributes. + +```python +class Foo(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + # node attributes + self.param = nnx.Param(jnp.array(1)) + self.submodule = nnx.Linear(12, 3, rngs=rngs) + self.container = [4, nnx.Linear(5, 6, rngs=rngs), 7] + # static attributes + self.int = 8 + self.float = 9.0 + self.str = 'hello' + + def __call__(self, x): + return self.submodule(x + self.param) + + def some_method(self, x): + return x + 1 + +model = Foo(rngs=nnx.Rngs(0)) +``` +As shown above, python container types such as `list`, `tuple`, and `dict` are treated as node attributes, +this means you can naturally have e.g. `list`s or `dict`s of Modules. + +### Functional API + +NNX Modules are not pytrees so they cannot be passed to JAX transformations. In order to interact with JAX, a Module must be partitioned into a `State` and `GraphDef` objects. The `State` object is a flat dictionary-like pytree structure that contains all the deduplicated node attributes, and the `GraphDef` contains the static attributes and structural information needed to reconstruct the Module. + +```python +state, static = model.split() +``` +``` +State({ + 'param': Array(1, dtype=int32, weak_type=True), + 'submodule': { + 'kernel': Array(..., dtype=float32), + 'bias': Array(..., dtype=float32) + }, + 'container': { + '1': { + 'kernel': Array(..., dtype=float32), + 'bias': Array(..., dtype=float32) + } + } +}) +``` + +`State` and `GraphDef` are pytrees so they can be passed to JAX transformations. More over, `GraphDef` provides 2 very important methods: `merge` and `apply`. The `merge` method can be used to create a new `Module` from a `State` object: + +```python +model = static.merge(state) +``` +This can be use to e.g. recreate a module inside a JAX transformation. The `apply` provides a functional interface to the module, it can be used call any method or submodule and get the output and the updated state: + +```python +# run __call__ +y, (state, static) = static.apply(state)(x) +# run some_method +y, (state, static) = static.apply(state).some_method(x) +# run submodule +y, (state, static) = static.apply(state).submodule(x) +``` + +`apply` can call any nested method or submodule as long as it can be accessed via the `.` or `[]` operators. + +### Partitioning State +In NNX you can filter based on any node type, most commonly you will want to filter based on `nnx.Variable` subclasses such as `nnx.Param` or `nnx.BatchStat`. + +Here are various examples of how you can use the `split` method to split a module into multiple substates: + +```python +# split the module into the state with all the nodes and the static +state, static = model.split() +# verify that the state contains only params, else raise an error +params, static = model.split(nnx.Param) +# split the state into params and batch_stats, verify no nodes are left +params, batch_stats, static = model.split(nnx.Param, nnx.BatchStat) +# if there are any nodes left, use the `...` filter to capture them +params, batch_stats, rest, static = model.split(nnx.Param, nnx.BatchStat, ...) +# using `...` as the only filter is equivalent to not passing any filters +model.split(...) = model.split() +``` +`split` will make sure all nodes are match by atleast one filter, else it will raise an error. You can use the `...` filter which will any (remaining) nodes. For a more general filter you can pass a predicate function that can use both the path and value of the node: + +```python +(path: Tuple[str, ...], value: Any) -> bool +``` +To reconstruct the module from a set of substates, you can use `merge` as usual but passing the substates as additional arguments: -* [LM1B](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/lm1b): A language model trained on the 1 Billion Word Benchmark dataset. +```python +model = static.merge(params, batch_stats, rest) +``` + +The same is true for `apply`. + +```python +y, (state, static) = static.apply(params, batch_stats, rest)(x) +``` + + Note that `apply` will return a single `state` object, if you need to `split` the state you can use `State`'s own `split` method: + +```python +params, batch_stats, rest = state.split(nnx.Param, nnx.BatchStat, ...) +``` + +Alternatively, if you are just interested in a subset of partitions, you can use the `State.extract` method which will not raise an error if some nodes are not matched by any filter: + +```python +# only get params +params = state.extract(nnx.Param) +# get params and batch_stats +params, batch_stats = state.extract(nnx.Param, nnx.BatchStat) +``` + +### Filters + +Filters let you select subsets of nodes based on some criteria. These are use throughout the API in methods like `split`, `extract`, and `pop`. There are 4 types of filters: + +* `type`: matches all node instances of the given type. +* `...`: matches all nodes. +* `(path, any) -> bool`: a predicate function that takes a node path and value and returns a boolean. +* `Tuple[Filter, ...]`: a tuple of filters, matches all nodes that match any of the filters. + +NNX also provides the following custom filters: + +* `nnx.Not(filter)`: matches all nodes that do not match the given filter +* `nnx.All(*filters)`: matches nodes that match all filters + +Here is an example of how to use `Not`: +```python +non_params = module.extract(nnx.Not(nnx.Param)) +``` + + +### Capturing Intermediate Values +In NNX you can easily propagate intemediate values by simply assigning them to an attribute at runtime. For convenience, you should assign them to a `Variable` attribute with a `collection` name by using `nnx.var` so you can easily retrieve them later. + +Here is an example of how to create a `Linear` module that captures its output into a `Variable` attribute with the `intermediates` collection name: + +```python +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + key = rngs.params() + self.w = nnx.Param(jax.random.uniform(key, (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + y = x @ self.w + self.b + self.y = nnx.Intermediate(y) + return y + +model = Linear(12, 2, rngs=nnx.Rngs(0)) +``` +Since `y` is only created when the module is called, it is not available upon initialization. However, once you call the module `y` will be created. It is recommended that you use `pop` to retrieve temporary collections like `Intermediate`: + +```python +y = model(jnp.ones((8, 12))) +intermediates = model.pop(nnx.Intermediate) +``` +`pop` will return a `State` object with the nodes that match the given filter and remove them from the module's attributes. + +``` +intermediates +``` +``` +State({ + 'y: Intermediate(value=Array(...)) +}) +``` + +If you use the functional API to call the module instead, the `Intermediate` nodes will be present in the output `state`. To retrieve the `Intermediate` nodes and optionally separate them from the output `state` you can use `State.split`: + +```python +state, static = model.split() +y, (state, static) = static.apply(state)(jnp.ones((8, 12))) +# "pop" the intermediates from the state +intermediates, state = state.split(nnx.Intermediate, ...) +``` + +Alternatively, you can use `State.extract` to retrieve the `Intermediate` nodes without removing them from the `state`. + + +### Lifted Transforms + +NNX lifted transforms analogous versions of JAX transforms but they know how to work with Modules. They usually perform the following tasks: + +* Handle the Module's substates and Rngs's RNG streams according to the transform's semantics. +* Properly propagating state in and out of the transform, including updating the input Module's state with updates that happen inside the transform. + +Here's a diagram illustrating how lifted transformations work: + +![lifted-transforms](https://raw.githubusercontent.com/cgarciae/nnx/main/docs/images/stateful-transforms.png) + +Currently NNX provides the `jit`, `grad`, `scan`, and `remat`, lifted transforms. + +#### Manual Lifting + +In case you want to use JAX transforms directly you can always use the functional API +to manually lift your Modules. + +Here we will create an example of how to implement an MLP that uses "scan over layers" to efficiently process a sequence of inputs assuming that each layer has the same parameters and input/output dimensions. The first thing we need to do is create a `Block` module that represents a single layer, this block with just contain a `Linear` layer, a `Dropout` layer, and a `GELU` activation function: + +```python +class Block(nnx.Module): + def __init__(self, dim: int, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(dim, dim, rngs=rngs) + self.dropout = nnx.Dropout(0.5) + + def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array: + x = self.linear(x) + x = self.dropout(x, deterministic=not train, rngs=rngs) + x = jax.nn.gelu(x) + return x +``` + +Now we will define `ScanMLP`. During `__init__`, instead of creating a list of `Block`s, we will use `jax.vmap` to create a single `Block` whose parameters have an addtional `layer` axis. This will allow us to pass the parameters as inputs to scan so it will apply a layer at each step. + +```python +class ScanMLP(nnx.Module): + def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs): + params_key = jax.random.split(rngs.params(), n_layers) + self.n_layers = n_layers + state, static = jax.vmap( + lambda key: Block(dim, rngs=nnx.Rngs(params=key)).split() + )(params_key) + self.layers = static.merge(state) + +``` +Note that we split the `params` key into `n_layers` keys so each layer has different parameters. + +Now we will define `__call__`. Here we need to split the `dropout` key into `n_layers` keys so each layer has a different dropout mask, and `split` the layers to get their `params`. Both `params` and `dropout_key` will be passed as inputs, `x` will be the carry value. Inside the `scan_fn` we will merge the `params` back into a `Block` module and +apply it to the input `x`, passing the sliced `dropout_key` as part of the `Rngs`. + + +```python + def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array: + dropout_key = jax.random.split(rngs.dropout(), self.n_layers) + params, static = self.layers.split(nnx.Param) + + def scan_fn(x, inputs): + params, dropout_key = inputs + module = static.merge(params) + x = module(x, train=train, rngs=nnx.Rngs(dropout=dropout_key)) + return x, module.extract(nnx.Param) + + x, params = jax.lax.scan(scan_fn, x, (params, dropout_key)) + self.layers.update(params) + return x +``` +Finally we apply `jax.lax.scan`, update the `layers` state with the new `params`, and return the final `x` value. + +Here is a simple way to test our `ScanMLP`: + +```python +model = ScanMLP(10, n_layers=5, rngs=nnx.Rngs(0)) + +x = jnp.ones((3, 10)) +y = model(x, train=True, rngs=nnx.Rngs(dropout=1)) +``` + +For a more robust implementation with comments take a look at the [Scan over layers](https://github.com/cgarciae/nnx/blob/main/examples/06_scan_over_layers.py) example. + +### Case Studies +#### Shared State + +In NNX, you can create modules that share state between them. This is useful when designing complex neural network architectures, as it allows you to reuse certain layers and reduce the number of learnable parameters. + +Here's an example of creating a module with shared state: + +```python +class Block(nnx.Module): + def __init__(self, linear: nnx.Linear, *, rngs: nnx.Rngs): + self.linear = linear + self.bn = nnx.BatchNorm(2, rngs=rngs) + + def __call__(self, x, *, rngs: nnx.Rngs): + x = self.linear(x) + x = self.bn(x, rngs=rngs) + x = nnx.relu(x) + return x + +class Model(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + shared = nnx.Linear(2, 2, rngs=rngs) + self.block1 = Block(shared, rngs=rngs) + self.block2 = Block(shared, rngs=rngs) + + def __call__(self, x): + x = self.block1(x) + x = self.block2(x) + return x +``` + +In this example, the `Model` module contains two instances of the `Block` module. Each instance shares the same `nnx.Linear` module. To run the model, you can use the Rngs `flags` argument to set the `use_running_average` flag for all `BatchNorm` modules. + +Here's an example of computing the loss for a `Model` instance: + +```python +def loss_fn(model: Model, x: jax.Array, y: jax.Array): + with nnx.flags(use_running_average=True): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) +``` -#### Toy Examples -* [Using the Functional API](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. -* [Using Lifted Transforms](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py): Shows how to train a simple model using lifted transforms. -* [Using TrainState](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/03_train_state.py): Shows how to train a simple model using the functional API with the help of `TrainState`. -* [Training a VAE](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset, uses the functional API, `TrainState`, and shows how to use capture intermediate values to retrieve `kl_loss`. -* [Scan over layers](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. -* [Creating a Transformer](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/07_transformer.py): Shows how to create a Transformer with an auto-regressive decoder that uses scan over layers and a kv-cache for fast inference. Credits to @levskaya. \ No newline at end of file +It's important to note that the state for the shared `nnx.Linear` module will be kept in sync at all times on both `Block` instances, including during gradient updates. diff --git a/flax/experimental/nnx/docs/why.ipynb b/flax/experimental/nnx/docs/why.ipynb index f029ded8bd..34049e7b2c 100644 --- a/flax/experimental/nnx/docs/why.ipynb +++ b/flax/experimental/nnx/docs/why.ipynb @@ -93,7 +93,7 @@ " self.count = Count(jnp.zeros((), jnp.int32)) # typed Variable collections\n", "\n", " def __call__(self, x):\n", - " self.count += 1 # in-place stateful updates\n", + " self.count.value += 1 # in-place stateful updates\n", " return self.linear(x)\n", "\n", "\n", diff --git a/flax/experimental/nnx/docs/why.md b/flax/experimental/nnx/docs/why.md index c83c454418..0152b7380e 100644 --- a/flax/experimental/nnx/docs/why.md +++ b/flax/experimental/nnx/docs/why.md @@ -61,7 +61,7 @@ class CounterLinear(nnx.Module): self.count = Count(jnp.zeros((), jnp.int32)) # typed Variable collections def __call__(self, x): - self.count += 1 # in-place stateful updates + self.count.value += 1 # in-place stateful updates return self.linear(x) diff --git a/flax/experimental/nnx/examples/lm1b/models.py b/flax/experimental/nnx/examples/lm1b/models.py index 66b8af9975..0d262f6107 100644 --- a/flax/experimental/nnx/examples/lm1b/models.py +++ b/flax/experimental/nnx/examples/lm1b/models.py @@ -158,22 +158,23 @@ def __call__(self, inputs: jax.Array, inputs_positions=None): 'Number of dimensions should be 3, but it is: %d' % inputs.ndim ) length = inputs.shape[1] - pos_embedding = self.pos_embedding - if pos_embedding is None: + if self.pos_embedding is None: # Use a fixed (non-learned) sinusoidal position embedding. pos_embedding = sinusoidal_init(max_len=config.max_len)( None, self.pos_emb_shape ) + else: + pos_embedding = self.pos_embedding.value # We use a cache position index for tracking decoding position. if flaglib.flags.get('decode', False): _, _, df = pos_embedding.shape # equivalent to pos_embedding[:, i:i+1] but traceable pos_embedding = lax.dynamic_slice( - pos_embedding, jnp.array((0, self.cache_index, 0)), (1, 1, df) + pos_embedding, jnp.array((0, self.cache_index.value, 0)), (1, 1, df) ) - self.cache_index += 1 + self.cache_index.value += 1 else: pos_embedding = pos_embedding[:, :length, :] diff --git a/flax/experimental/nnx/examples/lm1b/models_test.py b/flax/experimental/nnx/examples/lm1b/models_test.py index 607996ac49..007c6da206 100644 --- a/flax/experimental/nnx/examples/lm1b/models_test.py +++ b/flax/experimental/nnx/examples/lm1b/models_test.py @@ -18,7 +18,6 @@ from pathlib import Path from typing import Any - # add project_root to import lm1b Linen model project_root = str(Path(__file__).parents[6]) sys.path.append(project_root) @@ -36,11 +35,11 @@ from flax import traverse_util from flax.experimental import nnx from flax.experimental.nnx.examples.lm1b.configs import default -from flax.experimental.nnx.examples.lm1b.utils import HasCache from flax.experimental.nnx.examples.lm1b.models import ( TransformerConfig, TransformerLM, ) +from flax.experimental.nnx.examples.lm1b.utils import HasCache jax.config.update('jax_disable_most_optimizations', True) @@ -91,10 +90,10 @@ def apply_rules(names: tuple[str, ...]): def copy_var(nnx_name, linen_name): assert ( - flat_params_nnx[nnx_name].value.shape + flat_params_nnx[nnx_name].raw_value.shape == flat_params_linen[linen_name].value.shape ) - flat_params_nnx[nnx_name].value = flat_params_linen[linen_name].value + flat_params_nnx[nnx_name].raw_value = flat_params_linen[linen_name].value assert flat_params_nnx[nnx_name].sharding == apply_rules( flat_params_linen[linen_name].names ) @@ -171,10 +170,10 @@ def transfer_cache( def copy_var(nnx_name, linen_name): assert ( - flat_cache_nnx[nnx_name].value.shape + flat_cache_nnx[nnx_name].raw_value.shape == flat_cache_linen[linen_name].shape ) - flat_cache_nnx[nnx_name].value = flat_cache_linen[linen_name] + flat_cache_nnx[nnx_name].raw_value = flat_cache_linen[linen_name] for idx in range(config.num_layers): copy_var( diff --git a/flax/experimental/nnx/examples/lm1b/train.py b/flax/experimental/nnx/examples/lm1b/train.py index 8937eb4439..13eb172982 100644 --- a/flax/experimental/nnx/examples/lm1b/train.py +++ b/flax/experimental/nnx/examples/lm1b/train.py @@ -534,27 +534,27 @@ def constructor(config: models.TransformerConfig, key: jax.Array): # Since the inputs and rngkey args for predict_step will be batched, # we must vmap them, otherwise the global arrays will be seen in each device jit_pred_step = jax.jit( - jax.vmap( - predict_step, - in_axes=( - 0, - jax.tree_util.tree_map(lambda x: None, state.params), - 0, - None, - None, - None, - None, - None, - None, - ), + jax.vmap( + predict_step, + in_axes=( + 0, + jax.tree_map(lambda x: None, state.params), + 0, + None, + None, + None, + None, + None, + None, ), - in_shardings=( - data_sharding, - state_sharding.params, - data_sharding, - ), # type: ignore - out_shardings=data_sharding, # type: ignore - static_argnums=tuple(range(3, 9)), + ), + in_shardings=( + data_sharding, + state_sharding.params, + data_sharding, + ), # type: ignore + out_shardings=data_sharding, # type: ignore + static_argnums=tuple(range(3, 9)), ) # Main Train Loop @@ -582,7 +582,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array): # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation('train', step_num=step): batch = next(train_iter) - batch = jax.tree_util.tree_map(lambda x: jnp.asarray(x), batch) + batch = jax.tree_map(lambda x: jnp.asarray(x), batch) state, metrics = jit_train_step( state, batch, learning_rate_fn, 0.0, dropout_rngs ) diff --git a/flax/experimental/nnx/examples/lm1b/utils.py b/flax/experimental/nnx/examples/lm1b/utils.py index c4299bad1d..1f7e8e03ea 100644 --- a/flax/experimental/nnx/examples/lm1b/utils.py +++ b/flax/experimental/nnx/examples/lm1b/utils.py @@ -161,7 +161,7 @@ def setup_initial_state( state = TrainState.create( apply_fn=static.apply, params=params, tx=tx, graphdef=static ) - state = jax.tree_util.tree_map(_to_array, state) + state = jax.tree_map(_to_array, state) state_spec = nnx.get_partition_spec(state) state = jax.lax.with_sharding_constraint(state, state_spec) diff --git a/flax/experimental/nnx/examples/toy_examples/00_demo.ipynb b/flax/experimental/nnx/examples/toy_examples/00_demo.ipynb index 7497ca7b99..b39fa835c5 100644 --- a/flax/experimental/nnx/examples/toy_examples/00_demo.ipynb +++ b/flax/experimental/nnx/examples/toy_examples/00_demo.ipynb @@ -35,7 +35,7 @@ " self.b = nnx.Param(jnp.zeros((dout,)))\n", "\n", " def __call__(self, x):\n", - " return x @ self.w + self.b\n", + " return x @ self.w.value + self.b.value\n", "\n", "\n", "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n", @@ -56,9 +56,13 @@ "output_type": "stream", "text": [ "State({\n", - " 'w': Array([[0.31696808, 0.55285215],\n", - " [0.31418085, 0.7399571 ]], dtype=float32),\n", - " 'b': Array([0., 0.], dtype=float32)\n", + " 'w': Param(\n", + " raw_value=Array([[0.31696808, 0.55285215],\n", + " [0.31418085, 0.7399571 ]], dtype=float32)\n", + " ),\n", + " 'b': Param(\n", + " raw_value=Array([0., 0.], dtype=float32)\n", + " )\n", "})\n", "GraphDef(\n", " type=Linear,\n", @@ -115,9 +119,13 @@ "text/plain": [ "State({\n", " 'linear': {\n", - " 'w': Array([[0.31696808, 0.55285215],\n", - " [0.31418085, 0.7399571 ]], dtype=float32),\n", - " 'b': Array([0., 0.], dtype=float32)\n", + " 'w': Param(\n", + " raw_value=Array([[0.31696808, 0.55285215],\n", + " [0.31418085, 0.7399571 ]], dtype=float32)\n", + " ),\n", + " 'b': Param(\n", + " raw_value=Array([0., 0.], dtype=float32)\n", + " )\n", " }\n", "})" ] @@ -169,7 +177,7 @@ " self.submodule = self\n", "\n", " def __call__(self, x):\n", - " return x @ self.submodule.w + self.submodule.b\n", + " return x @ self.submodule.w.value + self.submodule.b.value\n", "\n", "\n", "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n", @@ -190,9 +198,13 @@ "output_type": "stream", "text": [ "State({\n", - " 'w': Array([[0.31696808, 0.55285215],\n", - " [0.31418085, 0.7399571 ]], dtype=float32),\n", - " 'b': Array([0., 0.], dtype=float32)\n", + " 'w': Param(\n", + " raw_value=Array([[0.31696808, 0.55285215],\n", + " [0.31418085, 0.7399571 ]], dtype=float32)\n", + " ),\n", + " 'b': Param(\n", + " raw_value=Array([0., 0.], dtype=float32)\n", + " )\n", "})\n", "GraphDef(\n", " type=Linear,\n", @@ -293,7 +305,7 @@ " self.b = nnx.Param(jnp.zeros((dout,)))\n", "\n", " def __call__(self, x):\n", - " y = x @ self.w + self.b\n", + " y = x @ self.w.value + self.b.value\n", " self.y = nnx.Intermediate(y)\n", " return y\n", "\n", @@ -316,14 +328,22 @@ "output_type": "stream", "text": [ "State({\n", - " 'y': Array([[0.63114893, 1.2928092 ],\n", - " [0.63114893, 1.2928092 ]], dtype=float32)\n", + " 'y': Intermediate(\n", + " raw_value=Array([[0.63114893, 1.2928092 ],\n", + " [0.63114893, 1.2928092 ]], dtype=float32)\n", + " )\n", "})\n", "State({\n", - " 'w': Array([[0.31696808, 0.55285215],\n", - " [0.31418085, 0.7399571 ]], dtype=float32),\n", - " 'b': Array([0., 0.], dtype=float32),\n", - " 'y': Empty\n", + " 'w': Param(\n", + " raw_value=Array([[0.31696808, 0.55285215],\n", + " [0.31418085, 0.7399571 ]], dtype=float32)\n", + " ),\n", + " 'b': Param(\n", + " raw_value=Array([0., 0.], dtype=float32)\n", + " ),\n", + " 'y': Intermediate(\n", + " raw_value=Empty\n", + " )\n", "})\n" ] } @@ -477,7 +497,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.9.18" } }, "nbformat": 4, diff --git a/flax/experimental/nnx/examples/toy_examples/01_functional_api.py b/flax/experimental/nnx/examples/toy_examples/01_functional_api.py index 291a585bb6..879e32146a 100644 --- a/flax/experimental/nnx/examples/toy_examples/01_functional_api.py +++ b/flax/experimental/nnx/examples/toy_examples/01_functional_api.py @@ -36,7 +36,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): - return x @ self.w + self.b + return x @ self.w.value + self.b.value class Count(nnx.Variable[nnx.A]): @@ -50,7 +50,7 @@ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear2 = Linear(dhidden, dout, rngs=rngs) def __call__(self, x): - self.count += 1 + self.count.value += 1 x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) @@ -74,7 +74,7 @@ def loss_fn(params): grad, counts = jax.grad(loss_fn, has_aux=True)(params) # |-------- sgd ---------| - params = jax.tree_util.tree_map(lambda w, g: w - 0.1 * g, params, grad) + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grad) return params, counts @@ -99,7 +99,7 @@ def test_step(params: nnx.State, counts: nnx.State, batch): break model = modeldef.merge(params, counts) -print('times called:', model.count) +print('times called:', model.count.value) y_pred = model(X) diff --git a/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py b/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py index bc577e4b56..3b9a8c3525 100644 --- a/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py +++ b/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py @@ -36,7 +36,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): - return x @ self.w + self.b + return x @ self.w.value + self.b.value class Count(nnx.Variable): @@ -50,7 +50,7 @@ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear2 = Linear(dhidden, dout, rngs=rngs) def __call__(self, x): - self.count += 1 + self.count.value += 1 x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) @@ -72,9 +72,7 @@ def loss_fn(model: MLP): grad: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) # sdg update model.update( - jax.tree_util.tree_map( - lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grad - ) + jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grad) ) # no return!!! @@ -99,7 +97,7 @@ def test_step(model: MLP, batch): if step >= total_steps - 1: break -print('times called:', model.count) +print('times called:', model.count.value) y_pred = model(X) diff --git a/flax/experimental/nnx/examples/toy_examples/03_train_state.py b/flax/experimental/nnx/examples/toy_examples/03_train_state.py index 3177792b82..2f8eb030d5 100644 --- a/flax/experimental/nnx/examples/toy_examples/03_train_state.py +++ b/flax/experimental/nnx/examples/toy_examples/03_train_state.py @@ -37,7 +37,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): - return x @ self.w + self.b + return x @ self.w.value + self.b.value class Count(nnx.Variable[nnx.A]): @@ -51,7 +51,7 @@ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear2 = Linear(dhidden, dout, rngs=rngs) def __call__(self, x): - self.count = self.count + 1 + self.count.value += 1 x = self.linear1(x) x = jax.nn.relu(x) x = self.linear2(x) @@ -108,7 +108,7 @@ def test_step(state: nnx.TrainState[MLP], batch): break model = static.merge(state.params, state.counts) -print('times called:', model.count) +print('times called:', model.count.value) y_pred = model(X) diff --git a/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py b/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py index f982eb5934..e3c9d3f658 100644 --- a/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py +++ b/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py @@ -83,5 +83,5 @@ def scan_fn( with nnx.flags(deterministic=False): y = model(x, rngs=nnx.Rngs(dropout=1)) -print(jax.tree_util.tree_map(jnp.shape, model.get_state())) +print(jax.tree_map(jnp.shape, model.get_state())) print(y.shape) diff --git a/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py b/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py index 7254f85844..e3af70efc0 100644 --- a/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py +++ b/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import jax from flax.experimental import nnx @@ -51,10 +52,5 @@ def __call__(self, x): # split the parameters into trainable and non-trainable parameters trainable_params, non_trainable, static = model.split(is_trainable, ...) -print( - 'trainable_params =', - jax.tree_util.tree_map(jax.numpy.shape, trainable_params), -) -print( - 'non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable) -) +print('trainable_params =', jax.tree_map(jax.numpy.shape, trainable_params)) +print('non_trainable = ', jax.tree_map(jax.numpy.shape, non_trainable)) diff --git a/flax/experimental/nnx/examples/toy_examples/10_quantization.py b/flax/experimental/nnx/examples/toy_examples/10_quantization.py index 32707b2776..37ea869f88 100644 --- a/flax/experimental/nnx/examples/toy_examples/10_quantization.py +++ b/flax/experimental/nnx/examples/toy_examples/10_quantization.py @@ -166,7 +166,7 @@ def forward(state: nnx.TrainState[MLP], inputs: jax.Array) -> jax.Array: state, loss = train_step(state, x_batch, y_batch) metrics = eval_step(state, X_test, Y_test) - metrics = jax.tree_util.tree_map(lambda x: x.item(), metrics) + metrics = jax.tree_map(lambda x: x.item(), metrics) print(f'Epoch {epoch} - {metrics}') # %% @@ -212,19 +212,19 @@ def __init__(self, din: int, dout: int): def __call__(self, x: jax.Array) -> jax.Array: x = self.quantize(x, 8, jnp.uint8) - print(x.shape, self.qkernel.shape, self.qbias.shape) - x = jnp.dot(x, self.qkernel, preferred_element_type=jnp.uint16) - x = (x + self.qbias).astype(jnp.uint32) + print(x.shape, self.qkernel.value.shape, self.qbias.value.shape) + x = jnp.dot(x, self.qkernel.value, preferred_element_type=jnp.uint16) + x = (x + self.qbias.value).astype(jnp.uint32) x = self.dequantize(x) return x def quantize(self, x: jax.Array, b: int, dtype: jnp.dtype) -> jax.Array: return jnp.clip( - diff_round(x / self.scale) + self.zero_point, 0, 2**b - 1 + diff_round(x / self.scale.value) + self.zero_point.value, 0, 2**b - 1 ).astype(dtype) def dequantize(self, x: jax.Array) -> jax.Array: - return (x - self.zero_point) * self.scale + return (x - self.zero_point.value) * self.scale.value def optimize( self, @@ -238,7 +238,7 @@ def optimize( tx = optax.adam(1e-3) opt_state = tx.init(q_hparams) - print(jax.tree_util.tree_map(lambda x: x.shape, q_hparams)) + print(jax.tree_map(lambda x: x.shape, q_hparams)) @jax.jit def optimization_step( @@ -251,9 +251,13 @@ def optimization_step( def loss_fn(q_hparams: nnx.State): model = static.merge(q_hparams, rest) - model.qkernel = model.quantize(pretrained.kernel, 8, jnp.uint8) + model.qkernel.value = model.quantize( + pretrained.kernel.value, 8, jnp.uint8 + ) assert pretrained.bias is not None - model.qbias = model.quantize(pretrained.bias, 16, jnp.uint16) + model.qbias.value = model.quantize( + pretrained.bias.value, 16, jnp.uint16 + ) y_quant = model(x) y_unquant = pretrained(x) @@ -276,9 +280,9 @@ def loss_fn(q_hparams: nnx.State): self.update(q_hparams) - self.qkernel = self.quantize(pretrained.kernel, 8, jnp.uint8) - assert pretrained.bias is not None - self.qbias = self.quantize(pretrained.bias, 16, jnp.uint16) + self.qkernel.value = self.quantize(pretrained.kernel.value, 8, jnp.uint8) + assert pretrained.bias.value is not None + self.qbias.value = self.quantize(pretrained.bias.value, 16, jnp.uint16) def optimize2( diff --git a/flax/experimental/nnx/ideas/shape_inference.py b/flax/experimental/nnx/ideas/shape_inference.py index bff4df2717..0352f08ffa 100644 --- a/flax/experimental/nnx/ideas/shape_inference.py +++ b/flax/experimental/nnx/ideas/shape_inference.py @@ -150,12 +150,12 @@ def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array: # eager m1 = Linear(din=32, dout=10, rngs=nnx.Rngs(params=0)) y = m1(x=jnp.ones((1, 32))) -print(jax.tree_util.tree_map(jnp.shape, m1.get_state())) +print(jax.tree_map(jnp.shape, m1.get_state())) # lazy m2 = Linear(dout=10) y = m2.init(x=jnp.ones((1, 32)), rngs=nnx.Rngs(params=0)) -print(jax.tree_util.tree_map(jnp.shape, m2.get_state())) +print(jax.tree_map(jnp.shape, m2.get_state())) # usage y1 = m1(x=jnp.ones((1, 32))) @@ -199,7 +199,7 @@ def __call__(self, x: jax.Array, _, *, train: bool, rngs: nnx.Rngs): mlp = MLP(din=10, dout=10, rngs=nnx.Rngs(params=0)) y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, rngs=nnx.Rngs(dropout=1)) print(f'{y.shape=}') -print('state =', jax.tree_util.tree_map(jnp.shape, mlp.get_state())) +print('state =', jax.tree_map(jnp.shape, mlp.get_state())) print() # lazy @@ -207,4 +207,4 @@ def __call__(self, x: jax.Array, _, *, train: bool, rngs: nnx.Rngs): mlp.init(jnp.ones((1, 10)), None, train=False, rngs=nnx.Rngs(params=0)) y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, rngs=nnx.Rngs(dropout=1)) print(f'{y.shape=}') -print('state =', jax.tree_util.tree_map(jnp.shape, mlp.get_state())) +print('state =', jax.tree_map(jnp.shape, mlp.get_state())) diff --git a/flax/experimental/nnx/nnx/compatibility.py b/flax/experimental/nnx/nnx/compatibility.py index 973ec81cd8..e46664eac1 100644 --- a/flax/experimental/nnx/nnx/compatibility.py +++ b/flax/experimental/nnx/nnx/compatibility.py @@ -17,7 +17,6 @@ from typing import Any from flax import linen -from flax.experimental.nnx.nnx import helpers from flax.experimental.nnx.nnx import variables as variableslib from flax.experimental.nnx.nnx.module import GraphDef, Module from flax.experimental.nnx.nnx.rnglib import Rngs @@ -75,10 +74,10 @@ def __init__( variables = module.init(_rngs, *args, **kwargs) - self.states = helpers.Dict( - (collection, variableslib.variable_type(collection)(value)) + self.states = { + collection: variableslib.variable_type(collection)(value) for collection, value in variables.items() - ) + } def __call__( self, *args: Any, rngs: tp.Optional[Rngs] = None, **kwargs: Any @@ -87,7 +86,9 @@ def __call__( {name: stream.key for name, stream in rngs._rngs.items()} if rngs else {} ) - variables = {collection: value for collection, value in self.states.items()} + variables = { + collection: value.value for collection, value in self.states.items() + } out = self.module.apply(variables, *args, rngs=_rngs, **kwargs) if kwargs.get('mutable', False) != False: diff --git a/flax/experimental/nnx/nnx/graph_utils.py b/flax/experimental/nnx/nnx/graph_utils.py index 85cb415dc3..339cb94b28 100644 --- a/flax/experimental/nnx/nnx/graph_utils.py +++ b/flax/experimental/nnx/nnx/graph_utils.py @@ -20,7 +20,7 @@ import jax -from flax.experimental.nnx.nnx import filterlib, reprlib +from flax.experimental.nnx.nnx import filterlib, reprlib, tracers from flax.experimental.nnx.nnx.proxy_caller import ( ApplyCaller, CallableProxy, @@ -191,13 +191,15 @@ class VariableDef(reprlib.Representable): @classmethod def from_variable(cls, variable: Variable[tp.Any], index: int) -> VariableDef: metadata = vars(variable).copy() - del metadata['value'] + del metadata['raw_value'] + del metadata['_trace_state'] return cls(type(variable), index, metadata) def to_variable(self, value: Node) -> Variable[Node]: variables = object.__new__(self._type) - variables.value = value - vars(variables).update(self._metadata) + vars(variables).update( + self._metadata, raw_value=value, _trace_state=tracers.TraceState() + ) return variables def __init__( @@ -457,6 +459,7 @@ def _graph_unflatten( if graphdef.index in index_to_node: raise RuntimeError(f'GraphDef index {graphdef.index} already used.') + # TODO(cgarciae): why copy here? state = state.copy() node_impl = get_node_impl(graphdef.type) diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index 512da8ca27..aafcfb395e 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -63,64 +63,6 @@ def setup(self) -> None: SEEN_MODULES_REPR: tp.Optional[tp.Set[ids.UUID]] = None -class ModuleVariablesMapping( - tp.MutableMapping[str, Variable[tp.Any]], reprlib.Representable -): - __slots__ = ('_module',) - - def __init__(self, module: Module): - if tp.TYPE_CHECKING: - self._module = module - else: - object.__setattr__(self, '_module', module) - - def __getitem__(self, key: str | int) -> Variable[tp.Any]: - if isinstance(key, int): - key = str(key) - - module_vars = vars(self._module) - if key not in module_vars: - raise KeyError(f'Variable {key} not found') - value = module_vars[key] - - if not isinstance(value, Variable): - raise KeyError(f"Variable '{key}' is not found.") - - return value - - def __setitem__(self, name: str, value: Variable[tp.Any]) -> None: - vars(self._module)[name] = value - - def __getattr__(self, name: str) -> Variable[tp.Any]: - module_vars = vars(self._module) - if name not in module_vars: - raise AttributeError(f'Variable {name!r} not found') - value = module_vars[name] - if not isinstance(value, Variable): - raise AttributeError(f"Variable '{name}' is not found.") - return value - - def __setattr__(self, name: str, value: Variable[tp.Any]) -> None: - vars(self._module)[name] = value - - def __delitem__(self, name: str) -> None: - delattr(self._module, name) - - def __iter__(self) -> tp.Iterator[str]: - for name, value in vars(self._module).items(): - if isinstance(value, Variable): - yield name - - def __len__(self) -> int: - return sum(1 for _ in self) - - def __nnx_repr__(self): - yield reprlib.Object(type(self), start='{', end='}', value_sep=': ') - for name, value in vars(self._module).items(): - if isinstance(value, Variable): - yield reprlib.Attr(repr(name), value) - - class ModuleState(reprlib.Representable): __slots__ = ('_trace_state', '_id') @@ -195,12 +137,6 @@ class Module(reprlib.Representable, metaclass=ModuleMeta): if not tp.TYPE_CHECKING: - def __getattribute__(self, name: str) -> Any: - value = object.__getattribute__(self, name) - if isinstance(value, Variable): - return value.get_value() - return value - def __setattr__(self, name: str, value: Any) -> None: self._setattr(name, value) @@ -210,34 +146,15 @@ def _setattr(self, name: str, value: tp.Any) -> None: 'Cannot mutate Module from different trace level' ) - vars_dict = vars(self) - if name in vars_dict: - if isinstance(variable := vars_dict[name], Variable): - if isinstance(value, Variable): - if type(value) != type(variable): - raise ValueError( - f"Trying to assign a Variable of type '{type(value).__name__}' " - f"to the Module attribute '{name}' of a different type " - f"'{type(variable).__name__}'." - ) - variable.copy_from(value) - else: - variable.set_value(value) - else: - vars_dict[name] = value - else: - if isinstance(value, (jax.Array, np.ndarray, State)): - raise ValueError( - f"Trying to assign a '{type(value).__name__}' to the Module" - f" attribute '{name}'. This is not supported. Non-hashable " - 'objects are not valid static state in JAX. Please wrap ' - 'the value in a Variable type instead.' - ) - vars_dict[name] = value + if isinstance(value, (jax.Array, np.ndarray, State)): + raise ValueError( + f"Trying to assign a '{type(value).__name__}' to the Module" + f" attribute '{name}'. This is not supported. Non-hashable " + 'objects are not valid static state in JAX. Please wrap ' + 'the value in a Variable type instead.' + ) - @property - def variables(self) -> ModuleVariablesMapping: - return ModuleVariablesMapping(self) + object.__setattr__(self, name, value) def __deepcopy__(self: M, memo=None) -> M: state, graphdef = self.split() @@ -522,7 +439,7 @@ def sow( init_fn: tp.Callable[[], B] = tuple_init, # type: ignore ) -> None: if hasattr(self, name): - variable = vars(self)[name] + variable = getattr(self, name) if not isinstance(variable, variableslib.Variable): raise ValueError( f"Expected '{name}' to be a Variable, got {type(variable).__name__}" @@ -532,8 +449,7 @@ def sow( f"Expected '{name}' to be of type '{variable_type.__name__}', " f"got '{type(variable).__name__}'" ) - reduced_value = reduce_fn(variable.value, value) - setattr(self, name, reduced_value) + variable.raw_value = reduce_fn(variable.raw_value, value) else: reduced_value = reduce_fn(init_fn(), value) setattr(self, name, variable_type(reduced_value)) @@ -608,9 +524,15 @@ def _module_graph_get_key(module: Module, name: str) -> tp.Any: return vars(module)[name] -def _module_graph_set_key(module: M, name: str, value: tp.Any) -> M: - setattr(module, name, value) - return module +def _module_graph_set_key(module: Module, name: str, value: tp.Any): + if ( + hasattr(module, name) + and isinstance(variable := getattr(module, name), Variable) + and isinstance(value, Variable) + ): + variable.copy_from(value) + else: + setattr(module, name, value) def _module_graph_has_key(module: Module, name: str) -> bool: diff --git a/flax/experimental/nnx/nnx/nn/attention.py b/flax/experimental/nnx/nnx/nn/attention.py index e294e2af93..86f77044d7 100644 --- a/flax/experimental/nnx/nnx/nn/attention.py +++ b/flax/experimental/nnx/nnx/nn/attention.py @@ -521,7 +521,7 @@ def __call__( max_length, num_heads, depth_per_head, - ) = self.cached_key.shape + ) = self.cached_key.value.shape # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: @@ -531,14 +531,14 @@ def __call__( % (expected_shape, query.shape) ) # update key, value caches with our new 1d spatial slices - cur_index = self.cache_index + cur_index = self.cache_index.value zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype)) indices = (zero,) * len(batch_dims) + (cur_index, zero, zero) - key = lax.dynamic_update_slice(self.cached_key, key, indices) - value = lax.dynamic_update_slice(self.cached_value, value, indices) - self.cached_key = key - self.cached_value = value - self.cache_index += 1 + key = lax.dynamic_update_slice(self.cached_key.value, key, indices) + value = lax.dynamic_update_slice(self.cached_value.value, value, indices) + self.cached_key.value = key + self.cached_value.value = value + self.cache_index.value += 1 # causal mask for cached decoder self-attention: # our single query position should only attend to those key # positions that have already been generated and cached, diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/experimental/nnx/nnx/nn/linear.py index c081154910..7fa582ca48 100644 --- a/flax/experimental/nnx/nnx/nn/linear.py +++ b/flax/experimental/nnx/nnx/nn/linear.py @@ -115,12 +115,12 @@ class LinearGeneral(Module): >>> # output features (4, 5) >>> layer = nn.LinearGeneral(features=(4, 5)) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) - >>> jax.tree_util.tree_map(jnp.shape, params) + >>> jax.tree_map(jnp.shape, params) {'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}} >>> # apply transformation on the the second and last axes >>> layer = nn.LinearGeneral(features=(4, 5), axis=(1, -1)) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7))) - >>> jax.tree_util.tree_map(jnp.shape, params) + >>> jax.tree_map(jnp.shape, params) {'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}} Attributes: @@ -187,16 +187,16 @@ def __init__( n_in_features = len(self.in_features) n_out_features = len(self.out_features) - def kernel_init_wrap(rng, shape, dtype) -> jax.Array: + def kernel_init_wrap(rng, shape, dtype): flat_shape = ( np.prod(shape[:n_batch_axis]) * np.prod(shape[n_batch_axis : n_in_features + n_batch_axis]), np.prod(shape[-n_out_features:]), ) - flat_shape = jax.tree_util.tree_map(int, flat_shape) + flat_shape = jax.tree_map(int, flat_shape) kernel = self.kernel_init(rng, flat_shape, dtype) if isinstance(kernel, variables.VariableMetadata): - kernel.value = jnp.reshape(kernel.value, shape) + kernel.raw_value = jnp.reshape(kernel.raw_value, shape) else: kernel = jnp.reshape(kernel, shape) @@ -214,11 +214,11 @@ def kernel_init_wrap(rng, shape, dtype) -> jax.Array: if self.use_bias: - def bias_init_wrap(rng, shape, dtype) -> jax.Array: + def bias_init_wrap(rng, shape, dtype): flat_shape = (int(np.prod(shape)),) bias = self.bias_init(rng, flat_shape, dtype) if isinstance(bias, variables.VariableMetadata): - bias.value = jnp.reshape(bias.value, shape) + bias.raw_value = jnp.reshape(bias.raw_value, shape) else: bias = jnp.reshape(bias, shape) return bias @@ -252,8 +252,8 @@ def __call__(self, inputs: Array) -> Array: for ax in range(inputs.ndim) if ax not in axis ) - kernel = self.kernel - bias = self.bias + kernel = self.kernel.value + bias = self.bias.value batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) @@ -339,8 +339,8 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ - kernel = self.kernel - bias = self.bias + kernel = self.kernel.value + bias = self.bias.value inputs, kernel, bias = dtypes.promote_dtype( inputs, kernel, bias, dtype=self.dtype @@ -529,12 +529,12 @@ def maybe_broadcast( # One shared convolutional kernel for all pixels in the output. assert self.in_features % self.feature_group_count == 0 - kernel = self.kernel + kernel = self.kernel.value if self.mask_fn is not None: kernel = self.mask_fn(kernel) - bias = self.bias + bias = self.bias.value inputs, kernel, bias = dtypes.promote_dtype( inputs, kernel, bias, dtype=self.dtype @@ -603,7 +603,7 @@ def __init__( self.num_embeddings = num_embeddings self.features = features - self.dtype = dtype or self.embedding.dtype + self.dtype = dtype or self.embedding.value.dtype self.param_dtype = param_dtype self.embedding_init = embedding_init @@ -623,7 +623,7 @@ def __call__(self, inputs: Array) -> Array: # Use take because fancy indexing numpy arrays with JAX indices does not # work correctly. (embedding,) = dtypes.promote_dtype( - self.embedding, dtype=self.dtype, inexact=False + self.embedding.value, dtype=self.dtype, inexact=False ) if self.num_embeddings == 1: return jnp.where( @@ -648,6 +648,6 @@ def attend(self, query: Array) -> Array: in NLP models. """ query, embedding = dtypes.promote_dtype( - query, self.embedding, dtype=self.dtype + query, self.embedding.value, dtype=self.dtype ) return jnp.dot(query, embedding.T) diff --git a/flax/experimental/nnx/nnx/nn/normalization.py b/flax/experimental/nnx/nnx/nn/normalization.py index 8826bc8347..1d0faad048 100644 --- a/flax/experimental/nnx/nnx/nn/normalization.py +++ b/flax/experimental/nnx/nnx/nn/normalization.py @@ -291,7 +291,7 @@ def __call__( reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes) if use_running_average: - mean, var = self.mean, self.var + mean, var = self.mean.value, self.var.value else: mean, var = _compute_stats( x, @@ -303,15 +303,19 @@ def __call__( mask=mask, ) - self.mean = self.momentum * self.mean + (1 - self.momentum) * mean - self.var = self.momentum * self.var + (1 - self.momentum) * var + self.mean.value = ( + self.momentum * self.mean.value + (1 - self.momentum) * mean + ) + self.var.value = ( + self.momentum * self.var.value + (1 - self.momentum) * var + ) return _normalize( x, mean, var, - self.scale, - self.bias, + self.scale.value, + self.bias.value, reduction_axes, feature_axes, self.dtype, @@ -421,8 +425,8 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale, - self.bias, + self.scale.value, + self.bias.value, self.reduction_axes, self.feature_axes, self.dtype, @@ -522,7 +526,7 @@ def __call__(self, x, mask: tp.Optional[jax.Array] = None): x, mean, var, - self.scale, + self.scale.value, None, self.reduction_axes, self.feature_axes, diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/experimental/nnx/nnx/nn/stochastic.py index 23c3fb0b2d..64cf057f54 100644 --- a/flax/experimental/nnx/nnx/nn/stochastic.py +++ b/flax/experimental/nnx/nnx/nn/stochastic.py @@ -17,12 +17,12 @@ import jax.numpy as jnp from jax import lax, random -from flax.experimental.nnx.nnx import dataclasses as nnx_dataclasses from flax.experimental.nnx.nnx import flaglib, rnglib from flax.experimental.nnx.nnx.module import Module, first_from +import dataclasses -@nnx_dataclasses.dataclass +@dataclasses.dataclass class Dropout(Module): """Create a dropout layer. diff --git a/flax/experimental/nnx/nnx/pytreelib.py b/flax/experimental/nnx/nnx/pytreelib.py index a4c78e2ecd..c459340d73 100644 --- a/flax/experimental/nnx/nnx/pytreelib.py +++ b/flax/experimental/nnx/nnx/pytreelib.py @@ -94,7 +94,7 @@ class Pytree(reprlib.Representable, metaclass=PytreeMeta): def __getattribute__(self, name: str) -> tp.Any: value = object.__getattribute__(self, name) if isinstance(value, variables.Variable): - return value.value + return value.raw_value return value def __setattr__(self, name: str, value: tp.Any) -> None: @@ -125,7 +125,7 @@ def _setattr(self: P, name: str, value: tp.Any): ) vars_dict[name] = value else: - variable.set_value(value) + variable.value = value else: vars_dict[name] = value else: diff --git a/flax/experimental/nnx/nnx/spmd.py b/flax/experimental/nnx/nnx/spmd.py index 0feb6de5d3..5f14762623 100644 --- a/flax/experimental/nnx/nnx/spmd.py +++ b/flax/experimental/nnx/nnx/spmd.py @@ -29,7 +29,6 @@ Sharding, ) - A = tp.TypeVar('A') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) PARTITION_NAME = 'partition_name' @@ -57,8 +56,8 @@ def _add_axis(x: tp.Any): x.add_axis(axis_name, index) return x - return jax.tree_util.tree_map( - _add_axis, state, is_leaf=lambda x: isinstance(x, variables.Variable) + return jax.tree_map( + _add_axis, state, is_leaf=lambda x: isinstance(x, variables.Variable) ) @@ -76,8 +75,8 @@ def _remove_axis(x: tp.Any): x.remove_axis(axis_name, index) return x - return jax.tree_util.tree_map( - _remove_axis, state, is_leaf=lambda x: isinstance(x, variables.Variable) + return jax.tree_map( + _remove_axis, state, is_leaf=lambda x: isinstance(x, variables.Variable) ) @@ -102,25 +101,23 @@ def _maybe_replicate(x): def f(x): if isinstance(x, variables.Variable): if isinstance(x, HasSharding) and x.sharding: - return x.replace(value=PartitionSpec(*x.sharding)) + return x.replace(raw_value=PartitionSpec(*x.sharding)) else: - return x.replace(value=_maybe_replicate(x.value)) + return x.replace(raw_value=_maybe_replicate(x.raw_value)) return _maybe_replicate(x) - return jax.tree_util.tree_map( - f, - tree, - is_leaf=lambda x: isinstance(x, variables.Variable) - and not isinstance(x, TreeNode), + return jax.tree_map( + f, + tree, + is_leaf=lambda x: isinstance(x, variables.Variable) + and not isinstance(x, TreeNode), ) def get_named_sharding(tree: A, mesh: jax.sharding.Mesh) -> A: spec = get_partition_spec(tree) - sharding = jax.tree_util.tree_map( - lambda p: jax.sharding.NamedSharding(mesh, p), spec - ) + sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), spec) return sharding @@ -192,7 +189,7 @@ def sharding_hook( if _global_mesh_defined() or ( isinstance(node, Partitioned) and node.mesh is not None ): - spec = get_partition_spec(node).value + spec = get_partition_spec(node).raw_value return with_sharding_constraint(value, spec, mesh=node.mesh) return value diff --git a/flax/experimental/nnx/nnx/state.py b/flax/experimental/nnx/nnx/state.py index d204ae2a2f..4340daf98b 100644 --- a/flax/experimental/nnx/nnx/state.py +++ b/flax/experimental/nnx/nnx/state.py @@ -35,64 +35,12 @@ from flax import traverse_util from flax.experimental.nnx.nnx import filterlib, reprlib from flax.experimental.nnx.nnx.variables import Variable -from flax.typing import Leaf, Path +from flax.typing import Path A = tp.TypeVar('A') Key = str -FlatState = dict[Path, Variable[Leaf]] - - -class StateVariablesMapping( - tp.MutableMapping[str, Variable[tp.Any]], reprlib.Representable -): - __slots__ = ('_mapping',) - - def __init__(self, mapping: dict[Key, tp.Any]): - if tp.TYPE_CHECKING: - self._mapping = mapping - else: - object.__setattr__(self, '_mapping', mapping) - - def __getitem__(self, key: str | int) -> Variable[tp.Any]: - if isinstance(key, int): - key = str(key) - - value = self._mapping[key] - - if not isinstance(value, Variable): - raise KeyError(f"Variable '{key}' not found.") - - return value - - def __setitem__(self, name: str, value: Variable[tp.Any]) -> None: - self._mapping[name] = value - - def __getattr__(self, name: str) -> Variable[tp.Any]: - value = self._mapping[name] - if not isinstance(value, Variable): - raise AttributeError(f"Variable '{name}' not found.") - return value - - def __setattr__(self, name: str, value: Variable[tp.Any]) -> None: - self._mapping[name] = value - - def __delitem__(self, name: str) -> None: - del self._mapping[name] - - def __iter__(self) -> tp.Iterator[str]: - for name, value in self._mapping.items(): - if isinstance(value, Variable): - yield name - - def __len__(self) -> int: - return sum(1 for _ in self) - - def __nnx_repr__(self): - yield reprlib.Object(type(self), start='{', end='}', value_sep=': ') - for name, value in vars(self._mapping).items(): - if isinstance(value, Variable): - yield reprlib.Attr(repr(name), value) +FlatState = dict[Path, Variable[Variable]] class NestedStateRepr(reprlib.Representable): @@ -126,36 +74,33 @@ def __init__( def raw_mapping(self) -> dict[Key, dict[str, tp.Any] | tp.Any]: return self._mapping - @property - def variables(self) -> StateVariablesMapping: - return StateVariablesMapping(self._mapping) - - def __getitem__(self, key: Key | int) -> Leaf | State: + def __getitem__(self, key: Key | int) -> Variable | State: if isinstance(key, int): key = str(key) value = self._mapping[key] if isinstance(value, Variable): - return value.value + return value return State(value) - def __getattr__(self, key: Key) -> Leaf | State: + def __getattr__(self, key: Key) -> Variable | State: if '_mapping' not in vars(self) or key not in self._mapping: raise AttributeError(f'No attribute {key} in State') return self[key] - def __setitem__(self, key: Key | int, value: Leaf | State) -> None: + def __setitem__(self, key: Key | int, value: Variable | State) -> None: if isinstance(key, int): key = str(key) + + if not isinstance(value, (Variable, State)): + raise ValueError( + f'Trying to set key {key} to a value' + f' that is not a Variable or State, got: {value}.' + ) if isinstance(value, State): self._mapping[key] = value._mapping else: - if not isinstance(self._mapping[key], Variable): - raise ValueError( - f'Trying to set key {key} to a leaf value ' - f'but current value is not a Variable: {self._mapping[key]}.' - ) - self._mapping[key].value = value + self._mapping[key] = value __setattr__ = __setitem__ @@ -176,7 +121,7 @@ def __nnx_repr__(self): v = NestedStateRepr(v) yield reprlib.Attr(repr(k), v) - def flat_state(self) -> dict[Key, Variable[Leaf]]: + def flat_state(self) -> dict[Key, Variable[Variable]]: return traverse_util.flatten_dict(self._mapping, sep='/') # type: ignore @classmethod @@ -289,7 +234,7 @@ def _state_flatten_with_keys(x: State): def _state_unflatten( static: tp.Tuple[Path, ...] | None, - leaves: tp.Tuple[Leaf, ...] | tuple[dict[str, Leaf]], + leaves: tp.Tuple[Variable, ...] | tuple[dict[str, Variable]], ): return State(zip(static, leaves)) if static else State(leaves[0]) diff --git a/flax/experimental/nnx/nnx/tracers.py b/flax/experimental/nnx/nnx/tracers.py index 88266823da..1e8688f4eb 100644 --- a/flax/experimental/nnx/nnx/tracers.py +++ b/flax/experimental/nnx/nnx/tracers.py @@ -57,3 +57,6 @@ def is_valid(self) -> bool: def __nnx_repr__(self): yield reprlib.Object(f'{type(self).__name__}') yield reprlib.Attr('jax_trace', self._jax_trace) + + def __eq__(self, other): + return isinstance(other, TraceState) and self._jax_trace is other._jax_trace diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/experimental/nnx/nnx/variables.py index 1cb7256348..6d94fd6516 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/experimental/nnx/nnx/variables.py @@ -36,8 +36,7 @@ import jax import jax.tree_util as jtu -from flax.experimental.nnx.nnx import reprlib -from flax.typing import Sharding +from flax.experimental.nnx.nnx import reprlib, tracers A = tp.TypeVar('A') B = tp.TypeVar('B') @@ -76,43 +75,23 @@ def __hash__(self): @dataclasses.dataclass class VariableMetadata(tp.Generic[A]): - value: A - set_value_hooks: tuple[SetValueHook[A], ...] - get_value_hooks: tuple[GetValueHook[A], ...] - create_value_hooks: tuple[CreateValueHook[A], ...] - add_axis_hooks: tuple[AddAxisHook['Variable[A]'], ...] - remove_axis_hooks: tuple[RemoveAxisHook['Variable[A]'], ...] - metadata: tp.Mapping[str, tp.Any] - - -class VariableMetaclass(ABCMeta): - def __call__(self, value: A, **metadata: tp.Any) -> A: - if isinstance(value, Variable): - container = value - value = container.value - else: - container = None - - obj = super().__call__(value, **metadata) - - if container is not None and not container.is_equivalent(obj): - raise ValueError( - f"input value of type '{type(container).__name__}' is not compatible " - f"with return type '{type(obj).__name__}'" - ) - - return obj + raw_value: A + set_value_hooks: tuple[SetValueHook[A], ...] = () + get_value_hooks: tuple[GetValueHook[A], ...] = () + create_value_hooks: tuple[CreateValueHook[A], ...] = () + add_axis_hooks: tuple[AddAxisHook['Variable[A]'], ...] = () + remove_axis_hooks: tuple[RemoveAxisHook['Variable[A]'], ...] = () + metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict) -class Variable( - tp.Generic[A], reprlib.Representable, metaclass=VariableMetaclass -): - value: A +class Variable(tp.Generic[A], reprlib.Representable): + raw_value: A set_value_hooks: tuple[SetValueHook[A], ...] get_value_hooks: tuple[GetValueHook[A], ...] create_value_hooks: tuple[CreateValueHook[A], ...] add_axis_hooks: tuple[AddAxisHook['Variable[A]'], ...] remove_axis_hooks: tuple[RemoveAxisHook['Variable[A]'], ...] + _trace_state: tracers.TraceState def __init__( self, @@ -135,6 +114,7 @@ def __init__( ] = (), **metadata: tp.Any, ): + vars(self)['_trace_state'] = tracers.TraceState() if set_value_hooks: if callable(set_value_hooks): set_value_hooks = (set_value_hooks,) @@ -198,7 +178,7 @@ def __init__( remove_axis_hooks = value.remove_axis_hooks metadata.update(value_metadata) - value = tp.cast(A, value.value) + value = tp.cast(A, value.raw_value) if hasattr(self, 'on_get_value'): on_get_value = getattr(type(self), 'on_get_value') @@ -225,7 +205,7 @@ def __init__( if on_remove_axis not in remove_axis_hooks: remove_axis_hooks = (on_remove_axis, *remove_axis_hooks) - self.value = value + self.raw_value = value self.get_value_hooks = get_value_hooks self.set_value_hooks = set_value_hooks self.create_value_hooks = create_value_hooks @@ -234,23 +214,27 @@ def __init__( vars(self).update(metadata) # run create_value hooks - self.value = self.create_value(self.value) - - @property - def is_empty(self) -> bool: - return self.value is EMPTY + self.raw_value = self.create_value(self.raw_value) if tp.TYPE_CHECKING: def __getattr__(self, name: str) -> tp.Any: ... + else: + def __setattr__(self, name: str, value: Any) -> None: + return self._setattr(name, value) - def get_value(self) -> A: - value = self.value - if self.get_value_hooks: - for hook in self.get_value_hooks: - value = hook(self, value) - return value + def _setattr(self, name: str, value: tp.Any): + if not self._trace_state.is_valid(): + raise ValueError( + 'Cannot mutate Variable from a different trace level' + ) + + object.__setattr__(self, name, value) + + @property + def is_empty(self) -> bool: + return self.raw_value is EMPTY def copy_from(self, other: 'Variable[A]') -> None: if not self.is_equivalent(other): @@ -264,11 +248,24 @@ def copy_from(self, other: 'Variable[A]') -> None: vars_dict.clear() vars_dict.update(vars(other)) - def set_value(self, value: A): + @property + def value(self) -> A: + value = self.raw_value + if self.get_value_hooks: + for hook in self.get_value_hooks: + value = hook(self, value) + return value + + @value.setter + def value(self, value: A): + if isinstance(value, Variable): + raise ValueError( + 'Cannot set value to a Variable, ' 'use `copy_from` method instead' + ) if self.set_value_hooks: for hook in self.set_value_hooks: value = hook(self, value) - self.value = value + self.raw_value = value def create_value(self, value: A): for hook in self.create_value_hooks: @@ -284,8 +281,6 @@ def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex): hook(self, axis_name, axis_index) def __eq__(self, other: object) -> bool: - if not isinstance(other, Variable): - return False return type(self) is type(other) and vars(other) == vars(self) @tp.overload @@ -298,9 +293,11 @@ def replace(self, **kwargs) -> 'Variable[A]': def replace(self, **kwargs) -> 'Variable[tp.Any]': # return `value` if it is a Variable - if 'value' in kwargs and isinstance(value := kwargs['value'], Variable): + if 'raw_value' in kwargs and isinstance( + value := kwargs['raw_value'], Variable + ): # remove value from kwargs - kwargs.pop('value') + kwargs.pop('raw_value') if not self.is_equivalent(value): raise ValueError( 'Cannot replace value from incompatible container, ' @@ -322,20 +319,22 @@ def replace(self, **kwargs) -> 'Variable[tp.Any]': return obj def as_empty(self: V) -> V: - return self.replace(value=EMPTY) + return self.replace(raw_value=EMPTY) def is_equivalent(self, other: tp.Any) -> bool: return type(self) is type(other) def copy(self: 'Variable[A]') -> 'Variable[A]': obj = object.__new__(type(self)) - vars(obj).update(vars(self)) + attributes = vars(self).copy() + attributes['_trace_state'] = tracers.TraceState() + vars(obj).update(attributes) return obj def __nnx_repr__(self): yield reprlib.Object(type=type(self)) for name, value in vars(self).items(): - if name.endswith('_hooks'): + if name.endswith('_hooks') or name == "_trace_state": continue yield reprlib.Attr(name, repr(value)) @@ -364,15 +363,18 @@ def on_create_value(self, value: A) -> A: def on_add_axis(self: V, axis_name: AxisName, axis_index: AxisIndex) -> V: raise NotImplementedError - def on_remove_axis(self: V, axis_name: AxisName, axis_index: AxisIndex) -> V: + def on_remove_axis( + self: V, axis_name: AxisName, axis_index: AxisIndex + ) -> V: raise NotImplementedError def _variable_flatten(x: Variable[tp.Any], *, with_keys: bool): attributes = vars(x).copy() - value = attributes.pop('value') + del attributes['_trace_state'] + value = attributes.pop('raw_value') if with_keys: - node = (jtu.GetAttrKey('value'), value) + node = (jtu.GetAttrKey('raw_value'), value) else: node = value @@ -386,8 +388,7 @@ def _variable_unflatten( cls: type[Variable[A]], ) -> Variable[A]: variable = object.__new__(cls) - variable.value = children[0] - vars(variable).update(metadata) + vars(variable).update(metadata, _trace_state=tracers.TraceState(), raw_value=children[0]) return variable @@ -422,7 +423,7 @@ def __init__(self, value: jax.Array, *, tag: str, **metadata: tp.Any): super().__init__(value, tag=tag, **metadata) def on_get_value(self, value: jax.Array): - self.value, value = jax.random.split(value) + self.raw_value, value = jax.random.split(value) return value diff --git a/flax/experimental/nnx/scripts/run-all-examples.bash b/flax/experimental/nnx/scripts/run-all-examples.bash index e7f1f33081..523fa3cf49 100644 --- a/flax/experimental/nnx/scripts/run-all-examples.bash +++ b/flax/experimental/nnx/scripts/run-all-examples.bash @@ -4,7 +4,7 @@ cd ../../.. source .venv/bin/activate cd flax/experimental/nnx -for f in $(find examples -name "*.py" -maxdepth 1); do +for f in $(find examples/toy_examples -name "*.py" -maxdepth 1); do echo -e "\n---------------------------------" echo "$f" echo "---------------------------------" diff --git a/flax/experimental/nnx/tests/nn/test_attention.py b/flax/experimental/nnx/tests/nn/test_attention.py index 230854bddb..776abf46f6 100644 --- a/flax/experimental/nnx/tests/nn/test_attention.py +++ b/flax/experimental/nnx/tests/nn/test_attention.py @@ -69,19 +69,13 @@ def __call__(self, x, sow_weights=False): with nnx.flags(decode=False): _ = module(x, True) intermediates = module.pop(nnx.Intermediate) - assert intermediates['attention_layers/0/attention_weights'][0].shape == ( - 4, - 8, - 6, - 6, - ) + assert intermediates['attention_layers/0/attention_weights'].raw_value[ + 0 + ].shape == (4, 8, 6, 6) assert 'attention_layers/1/attention_weights' not in intermediates - assert intermediates['attention_layers/2/attention_weights'][0].shape == ( - 4, - 8, - 6, - 6, - ) + assert intermediates['attention_layers/2/attention_weights'].raw_value[ + 0 + ].shape == (4, 8, 6, 6) with nnx.flags(decode=False): _ = module(x) @@ -99,8 +93,8 @@ def test_autoregressive_decode_with_x64(self): rngs=nnx.Rngs(0), ) module.init_cache(x.shape, dtype=x.dtype) - assert module.cached_key.shape == (1, 4, 2, 2) - assert module.cached_value.shape == (1, 4, 2, 2) + assert module.cached_key.value.shape == (1, 4, 2, 2) + assert module.cached_value.value.shape == (1, 4, 2, 2) y1 = module(x[:, :1, :]) y2 = module(x[:, 1:2, :]) @@ -164,13 +158,11 @@ def test_nnx_attention_equivalence( variables = model.init(key, x) for qkvo in ('query', 'key', 'value', 'out'): - setattr( - getattr(model_nnx, qkvo), 'kernel', variables['params'][qkvo]['kernel'] - ) + getattr(model_nnx, qkvo).kernel.value = variables['params'][qkvo][ + 'kernel' + ] if use_bias: - setattr( - getattr(model_nnx, qkvo), 'bias', variables['params'][qkvo]['bias'] - ) + getattr(model_nnx, qkvo).bias.value = variables['params'][qkvo]['bias'] if decode: model_nnx.init_cache(x.shape, dtype=dtype) diff --git a/flax/experimental/nnx/tests/nn/test_conv.py b/flax/experimental/nnx/tests/nn/test_conv.py index d5f05d5a9e..db5a3a4c84 100644 --- a/flax/experimental/nnx/tests/nn/test_conv.py +++ b/flax/experimental/nnx/tests/nn/test_conv.py @@ -23,25 +23,20 @@ from flax import linen from flax.experimental import nnx -from flax.typing import ( - PaddingLike, - Dtype, - PrecisionLike -) +from flax.typing import PaddingLike, Dtype, PrecisionLike class TestConvLinenConsistency(parameterized.TestCase): - @parameterized.product( - strides = [None, (2, 3)], - padding = ['VALID', (4, 2)], - input_dilation = [(2, 3)], - kernel_dilation = [(2, 3)], - feature_group_count = [3], - use_bias = [True, False], - dtype = [jnp.float32], - param_dtype = [jnp.float16], - precision = [Precision.HIGHEST], + strides=[None, (2, 3)], + padding=['VALID', (4, 2)], + input_dilation=[(2, 3)], + kernel_dilation=[(2, 3)], + feature_group_count=[3], + use_bias=[True, False], + dtype=[jnp.float32], + param_dtype=[jnp.float16], + precision=[Precision.HIGHEST], ) def test_nnx_linen_equivalence( self, @@ -96,9 +91,9 @@ def test_nnx_linen_equivalence( precision=precision, ) variables = model.init(key, x) - model_nnx.kernel = variables['params']['kernel'] + model_nnx.kernel.value = variables['params']['kernel'] if use_bias: - model_nnx.bias = variables['params']['bias'] + model_nnx.bias.value = variables['params']['bias'] out_nnx = model_nnx(x) out = model.apply(variables, x) diff --git a/flax/experimental/nnx/tests/nn/test_embed.py b/flax/experimental/nnx/tests/nn/test_embed.py index 9eca0047bf..92962fc734 100644 --- a/flax/experimental/nnx/tests/nn/test_embed.py +++ b/flax/experimental/nnx/tests/nn/test_embed.py @@ -55,7 +55,7 @@ def test_nnx_linen_equivalence( NUM_EMBEDDINGS, IN_FEATURES, dtype=dtype, param_dtype=param_dtype ) variables = model.init(key, x) - model_nnx.embedding = variables['params']['embedding'] + model_nnx.embedding.value = variables['params']['embedding'] out_nnx = model_nnx(x) out = model.apply(variables, x) diff --git a/flax/experimental/nnx/tests/nn/test_linear.py b/flax/experimental/nnx/tests/nn/test_linear.py index e6092186b8..2267a9df53 100644 --- a/flax/experimental/nnx/tests/nn/test_linear.py +++ b/flax/experimental/nnx/tests/nn/test_linear.py @@ -31,27 +31,26 @@ def test_basic(self): y = module(jnp.ones((1, 2))) assert y.shape == (1, 3) - assert module.kernel.shape == (2, 3) - assert module.bias is not None - assert module.bias.shape == (3,) + assert module.kernel.value.shape == (2, 3) + assert module.bias.value is not None + assert module.bias.value.shape == (3,) def test_basic_multi_features(self): module = nnx.LinearGeneral(2, (3, 4), rngs=nnx.Rngs(0)) y = module(jnp.ones((1, 2))) assert y.shape == (1, 3, 4) - assert module.kernel.shape == (2, 3, 4) - assert module.bias is not None - assert module.bias.shape == (3, 4) + assert module.kernel.value.shape == (2, 3, 4) + assert module.bias.value is not None + assert module.bias.value.shape == (3, 4) class TestLinenConsistency(parameterized.TestCase): - @parameterized.product( - use_bias = [True, False], - dtype = [jnp.float32, jnp.float16], - param_dtype = [jnp.float32, jnp.float16], - precision = [Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST], + use_bias=[True, False], + dtype=[jnp.float32, jnp.float16], + param_dtype=[jnp.float32, jnp.float16], + precision=[Precision.DEFAULT, Precision.HIGH, Precision.HIGHEST], ) def test_nnx_linen_equivalence( self, @@ -83,9 +82,9 @@ def test_nnx_linen_equivalence( precision=precision, ) variables = model.init(key, x) - model_nnx.kernel = variables['params']['kernel'] + model_nnx.kernel.value = variables['params']['kernel'] if use_bias: - model_nnx.bias = variables['params']['bias'] + model_nnx.bias.value = variables['params']['bias'] out_nnx = model_nnx(x) out = model.apply(variables, x) diff --git a/flax/experimental/nnx/tests/nn/test_normalization.py b/flax/experimental/nnx/tests/nn/test_normalization.py index 83dd654b1a..854c367ae3 100644 --- a/flax/experimental/nnx/tests/nn/test_normalization.py +++ b/flax/experimental/nnx/tests/nn/test_normalization.py @@ -95,8 +95,8 @@ def __call__(self, x, *, mask=None): use_fast_variance=use_fast_variance, rngs=rngs, ) - nnx_model.linear.kernel = variables['params']['linear']['kernel'] - nnx_model.linear.bias = variables['params']['linear']['bias'] + nnx_model.linear.kernel.value = variables['params']['linear']['kernel'] + nnx_model.linear.bias.value = variables['params']['linear']['bias'] nnx_out = nnx_model(x, mask=mask) assert_array_equal(linen_out, nnx_out) @@ -167,8 +167,8 @@ def __call__(self, x, *, mask=None): use_fast_variance=use_fast_variance, rngs=rngs, ) - nnx_model.linear.kernel = variables['params']['linear']['kernel'] - nnx_model.linear.bias = variables['params']['linear']['bias'] + nnx_model.linear.kernel.value = variables['params']['linear']['kernel'] + nnx_model.linear.bias.value = variables['params']['linear']['bias'] nnx_out = nnx_model(x, mask=mask) assert_array_equal(linen_out, nnx_out) @@ -239,8 +239,8 @@ def __call__(self, x, *, mask=None): use_fast_variance=use_fast_variance, rngs=rngs, ) - nnx_model.linear.kernel = variables['params']['linear']['kernel'] - nnx_model.linear.bias = variables['params']['linear']['bias'] + nnx_model.linear.kernel.value = variables['params']['linear']['kernel'] + nnx_model.linear.bias.value = variables['params']['linear']['bias'] nnx_out = nnx_model(x, mask=mask) assert_array_equal(linen_out, nnx_out) diff --git a/flax/experimental/nnx/tests/test_containers.py b/flax/experimental/nnx/tests/test_containers.py index 617c9ca31c..582d661ab8 100644 --- a/flax/experimental/nnx/tests/test_containers.py +++ b/flax/experimental/nnx/tests/test_containers.py @@ -12,58 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from flax.experimental import nnx class TestContainers: - def test_node_idenpotence(self): - x = nnx.Variable(1) - x = nnx.Variable(x) - - assert isinstance(x, nnx.Variable) - - def test_variable_idenpotence(self): - x = nnx.Variable(1) - x = nnx.Variable(x) - - assert isinstance(x, nnx.Variable) - assert x.value == 1 - - def test_variable_cannot_change_collection(self): - x = nnx.Param(1) - - with pytest.raises(ValueError, match='is not compatible with return type'): - x = nnx.BatchStat(x) - - def test_container_cannot_change_type(self): - x = nnx.Variable(1) - - with pytest.raises(ValueError, match='is not compatible with return type'): - x = nnx.Param(x) - - x = nnx.Param(2) - - with pytest.raises(ValueError, match='is not compatible with return type'): - x = nnx.Variable(x) - def test_unbox(self): - x: nnx.Param[int] = nnx.Param( + x = nnx.Param( 1, get_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2], # type: ignore ) - assert x.get_value() == 4 + assert x.value == 4 def test_box(self): x: nnx.Param[int] = nnx.Param( 1, # type: ignore set_value_hooks=[lambda c, x: x + 1, lambda c, x: x * 2], # type: ignore ) - x.set_value(5) + x.value = 5 - assert x.value == 12 + assert x.raw_value == 12 def test_module_unbox(self): class Foo(nnx.Module): @@ -74,8 +43,8 @@ def __init__(self) -> None: module = Foo() - assert module.x == 4 - assert vars(module)['x'].value == 1 + assert module.x.value == 4 + assert vars(module)['x'].raw_value == 1 def test_module_box(self): class Foo(nnx.Module): @@ -85,7 +54,7 @@ def __init__(self) -> None: ) module = Foo() - module.x = 5 + module.x.value = 5 - assert module.x == 12 - assert vars(module)['x'].value == 12 + assert module.x.value == 12 + assert vars(module)['x'].raw_value == 12 diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/experimental/nnx/tests/test_graph_utils.py index e75d75e7da..8ee03eb0ca 100644 --- a/flax/experimental/nnx/tests/test_graph_utils.py +++ b/flax/experimental/nnx/tests/test_graph_utils.py @@ -25,8 +25,8 @@ def test_flatten(self): state, static = nnx.graph_utils.graph_flatten(g) - state['0']['b'] = 2 - state['3'] = 4 + state['0']['b'].raw_value = 2 + state['3'].raw_value = 4 def test_unflatten(self): a = {'a': 1, 'b': nnx.Param(2)} @@ -45,8 +45,8 @@ def test_unflatten_empty(self): g = static.merge(nnx.State({})) assert g[0] is g[2] - assert g[0]['b'].value is nnx.EMPTY - assert g[3].value is nnx.EMPTY + assert g[0]['b'].raw_value is nnx.EMPTY + assert g[3].raw_value is nnx.EMPTY def test_update_dynamic(self): a = {'a': 1, 'b': nnx.Param(2)} @@ -54,11 +54,11 @@ def test_update_dynamic(self): state, static = nnx.graph_utils.graph_flatten(g) - state['0']['b'] = 3 + state['0']['b'].raw_value = 3 nnx.graph_utils.graph_update_dynamic(g, state) - assert g[0]['b'].value == 3 - assert g[2]['b'].value == 3 + assert g[0]['b'].raw_value == 3 + assert g[2]['b'].raw_value == 3 def test_update_static(self): a = {'a': 1, 'b': nnx.Param(2)} @@ -110,12 +110,12 @@ def test_module_list(self): state, static = nnx.graph_utils.graph_flatten(ls) - assert state['0']['kernel'].shape == (2, 2) - assert state['0']['bias'].shape == (2,) - assert state['1']['scale'].shape == (2,) - assert state['1']['bias'].shape == (2,) - assert state['1']['mean'].shape == (2,) - assert state['1']['var'].shape == (2,) + assert state['0']['kernel'].raw_value.shape == (2, 2) + assert state['0']['bias'].raw_value.shape == (2,) + assert state['1']['scale'].raw_value.shape == (2,) + assert state['1']['bias'].raw_value.shape == (2,) + assert state['1']['mean'].raw_value.shape == (2,) + assert state['1']['var'].raw_value.shape == (2,) def test_shared_variables(self): v = nnx.Param(1) @@ -136,7 +136,7 @@ def __init__(self, *, rngs: nnx.Rngs) -> None: self.baz = nnx.Linear(2, 2, rngs=rngs) # tie the weights - self.baz.variables.kernel = self.bar.variables.kernel + self.baz.kernel = self.bar.kernel node = Foo(rngs=nnx.Rngs(0)) state, static = nnx.graph_utils.graph_flatten(node) @@ -145,7 +145,7 @@ def __init__(self, *, rngs: nnx.Rngs) -> None: node2 = static.merge(state) - assert node2.bar.variables.kernel is node2.baz.variables.kernel + assert node2.bar.kernel is node2.baz.kernel def test_tied_weights_example(self): class LinearTranspose(nnx.Module): @@ -155,7 +155,7 @@ def __init__(self, dout: int, din: int, *, rngs: nnx.Rngs) -> None: ) def __call__(self, x): - return x @ self.kernel.T + return x @ self.kernel.value.T class Encoder(nnx.Module): def __init__(self, *, rngs: nnx.Rngs) -> None: @@ -164,7 +164,7 @@ def __init__(self, *, rngs: nnx.Rngs) -> None: self.linear_out = LinearTranspose(10, 2, rngs=rngs) # tie the weights - self.linear_out.variables.kernel = self.embed.variables.embedding + self.linear_out.kernel = self.embed.embedding def __call__(self, x): x = self.embed(x) @@ -189,17 +189,17 @@ def __init__(self): m = Foo() state, static = m.split() - assert isinstance(m.variables.a, nnx.Param) - assert isinstance(state.variables.a, nnx.Param) - assert m.variables.a is not state.variables.a - assert m.a == state.a + assert isinstance(m.a, nnx.Param) + assert isinstance(state.a, nnx.Param) + assert m.a is not state.a + assert m.a.value == state.a.raw_value m2 = static.merge(state) - assert isinstance(m2.variables.a, nnx.Param) - assert isinstance(state.variables.a, nnx.Param) - assert m2.variables.a is not state.variables.a - assert m2.a == state.a + assert isinstance(m2.a, nnx.Param) + assert isinstance(state.a, nnx.Param) + assert m2.a is not state.a + assert m2.a.value == state.a.raw_value def test_shared_state_variables_not_shared_with_graph(self): class Foo(nnx.Module): @@ -211,22 +211,22 @@ def __init__(self): m = Foo() state, static = m.split() - assert isinstance(m.variables.a, nnx.Param) - assert isinstance(m.variables.b, nnx.Param) - assert isinstance(state.variables.a, nnx.Param) + assert isinstance(m.a, nnx.Param) + assert isinstance(m.b, nnx.Param) + assert isinstance(state.a, nnx.Param) assert 'b' not in state - assert m.variables.a is not state.variables.a - assert m.variables.b is not state.variables.a - assert m.a == state.a - assert m.b == state.a + assert m.a is not state.a + assert m.b is not state.a + assert m.a.value == state.a.raw_value + assert m.b.value == state.a.raw_value m2 = static.merge(state) - assert isinstance(m2.variables.a, nnx.Param) - assert isinstance(m2.variables.b, nnx.Param) - assert isinstance(state.variables.a, nnx.Param) - assert m2.variables.a is not state.variables.a - assert m2.variables.b is not state.variables.a - assert m2.a == state.a - assert m2.b == state.a + assert isinstance(m2.a, nnx.Param) + assert isinstance(m2.b, nnx.Param) + assert isinstance(state.a, nnx.Param) + assert m2.a is not state.a + assert m2.b is not state.a + assert m2.a.value == state.a.raw_value + assert m2.b.value == state.a.raw_value assert m2.a is m2.b diff --git a/flax/experimental/nnx/tests/test_helpers.py b/flax/experimental/nnx/tests/test_helpers.py index 4f83533546..3b5bae2028 100644 --- a/flax/experimental/nnx/tests/test_helpers.py +++ b/flax/experimental/nnx/tests/test_helpers.py @@ -68,6 +68,6 @@ def __call__(self, x: jax.Array, train: bool) -> jax.Array: assert y.shape == (1, 4) # fake gradient - grads = jax.tree_util.tree_map(jnp.ones_like, state.params) + grads = jax.tree_map(jnp.ones_like, state.params) # test apply_gradients state = state.apply_gradients(grads) diff --git a/flax/experimental/nnx/tests/test_integration.py b/flax/experimental/nnx/tests/test_integration.py index ddf014401d..8a534608c8 100644 --- a/flax/experimental/nnx/tests/test_integration.py +++ b/flax/experimental/nnx/tests/test_integration.py @@ -56,9 +56,7 @@ def loss_fn(model: Model): grads = loss_fn(model) model.update( - jax.tree_util.tree_map( - lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads - ) + jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads) ) model = Model(rngs=nnx.Rngs(0)) @@ -107,9 +105,7 @@ def loss_fn(model: Model): grads = loss_fn(model) model.update( - jax.tree_util.tree_map( - lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads - ) + jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads) ) return model.split() @@ -143,14 +139,14 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.count = State(0) def __call__(self, x): - self.count += 1 - return x @ self.w + self.b[None] + self.count.value += 1 + return x @ self.w.value + self.b.value[None] model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # forward pass x = jnp.ones((8, 12)) y = model(x) - assert model.count == 1 + assert model.count.value == 1 @nnx.jit def train_step(model, x, y): @@ -162,14 +158,12 @@ def loss_fn(model): grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) # SGD update model.update( - jax.tree_util.tree_map( - lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads - ) + jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads) ) # execute the training step train_step(model, x, y) - assert model.count == 2 + assert model.count.value == 2 def test_functional_example(self): class Count(nnx.Variable[A]): @@ -183,14 +177,14 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.count = Count(0) def __call__(self, x): - self.count += 1 - return x @ self.w + self.b[None] + self.count.value += 1 + return x @ self.w.value + self.b.value[None] model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # forward pass x = jnp.ones((8, 12)) y = model(x) - assert model.count == 1 + assert model.count.value == 1 params, counts, graphdef = model.split(nnx.Param, Count) @@ -204,14 +198,14 @@ def loss_fn(params): # compute gradient grads, counts = jax.grad(loss_fn, has_aux=True)(params) # SGD update - params = jax.tree_util.tree_map(lambda w, g: w - 0.1 * g, params, grads) + params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grads) return params, counts # execute the training step params, counts = train_step(params, counts, x, y) model = graphdef.merge(params, counts) - assert model.count == 2 + assert model.count.value == 2 def test_intermediates_example(self): class Linear(nnx.Module): @@ -221,7 +215,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): - y = x @ self.w + self.b[None] + y = x @ self.w.value + self.b.value[None] self.y = nnx.Intermediate(y) return y @@ -241,7 +235,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): - y = x @ self.w + self.b[None] + y = x @ self.w.value + self.b.value[None] self.y = nnx.Intermediate(y) return y diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index c3c24de280..a97e6fe5fc 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -53,14 +53,14 @@ def test_tree_map(self): state, static = m.split() - state = jax.tree_util.tree_map(lambda x: x + 1, state) + state = jax.tree_map(lambda x: x + 1, state) def test_split_2(self): m = nnx.Dict(a=nnx.Param(1)) empty, some, static = m.split(None, ...) - some = jax.tree_util.tree_map(lambda x: x + 1, some) + some = jax.tree_map(lambda x: x + 1, some) def test_split_merge(self): m = nnx.Dict(a=nnx.Param(1)) @@ -97,7 +97,7 @@ def __init__(self, c: float, *, rngs: nnx.Rngs): def __call__(self, x, *, rngs: nnx.Rngs): key = rngs.e() - return self.w * x + jax.random.normal(key, ()) + self.c + return self.w.value * x + jax.random.normal(key, ()) + self.c foo = Foo(c=1.0, rngs=nnx.Rngs(0)) @@ -159,16 +159,16 @@ def test_cross_barrier(self): m = nnx.Dict(a=nnx.Param(1)) @jax.jit - def g(state: nnx.State, graphdef: nnx.GraphDef[nnx.Dict[int]]): + def g(state: nnx.State, graphdef: nnx.GraphDef[nnx.Dict[nnx.Param[int]]]): m = graphdef.merge(state) - m.a += 1 + m.a.value += 1 return m.split() state, graphdef = g(*m.split()) m2 = graphdef.merge(state) assert m2 is not m - assert m.a == 1 - assert m2.a == 2 + assert m.a.value == 1 + assert m2.a.value == 2 def test_no_rejit(self): n = 0 @@ -179,15 +179,15 @@ def g(state_and_def): nonlocal n n += 1 m = nnx.merge(state_and_def) - m.a += 1 + m.a.value += 1 return m.split() m2 = nnx.merge(g(m.split())) assert n == 1 assert m2 is not m - assert m.a == 1 - assert m2.a == 2 + assert m.a.value == 1 + assert m2.a.value == 2 g(m.split()) assert n == 1 @@ -262,15 +262,15 @@ def __call__(self, x): assert y1 == 3 assert y2 == 11 - assert m.y == (3, 11) + assert m.y.value == (3, 11) intermediates = m.pop(nnx.Intermediate) - assert isinstance(intermediates.variables.y, nnx.Intermediate) - assert intermediates['y'] == (3, 11) + assert isinstance(intermediates.y, nnx.Intermediate) + assert intermediates['y'].raw_value == (3, 11) assert hasattr(m, 'y') - assert m.y is nnx.EMPTY + assert m.y.value is nnx.EMPTY def test_sow_existing_non_variable_field(self): class Foo(nnx.Module): @@ -441,8 +441,8 @@ def add_submodule(self): def test_create_abstract(self): linear = nnx.Linear.create_abstract(2, 3, rngs=nnx.Rngs(0)) - assert linear.kernel == jax.ShapeDtypeStruct((2, 3), jnp.float32) - assert linear.bias == jax.ShapeDtypeStruct((3,), jnp.float32) + assert linear.kernel.value == jax.ShapeDtypeStruct((2, 3), jnp.float32) + assert linear.bias.value == jax.ShapeDtypeStruct((3,), jnp.float32) def test_partial_init(self): linear = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) @@ -454,9 +454,9 @@ def test_partial_init(self): 2, 3, bias_init=nnx.initializers.ones_init(), rngs=nnx.Rngs(1) ) - np.testing.assert_allclose(linear.kernel, linear2.kernel) - np.testing.assert_allclose(linear.bias, 0) - np.testing.assert_allclose(linear2.bias, 1) + np.testing.assert_allclose(linear.kernel.value, linear2.kernel.value) + np.testing.assert_allclose(linear.bias.value, 0) + np.testing.assert_allclose(linear2.bias.value, 1) def test_deepcopy(self): class Foo(nnx.Module): @@ -485,9 +485,9 @@ def __init__(self): m = Foo() - m = jax.tree_util.tree_map(lambda x: x + 1, m) + m = jax.tree_map(lambda x: x + 1, m) - assert m.node == 2 + assert m.node.value == 2 assert m.static == 1 @@ -514,20 +514,10 @@ class Foo(nnx.Module): state, graphdef = m.split() assert len(state) == 4 - assert state.variables.b == nnx.TreeNode(2) - assert state.variables.c == nnx.Param(3) - assert state.variables.d == nnx.BatchStat(4) - assert state.variables.f == nnx.Variable(6) - - def test_no_override(self): - @nnx.dataclass - class Foo(nnx.Module): - a: int = nnx.treenode_field() - - with pytest.raises(ValueError, match='is not compatible with return type'): - _m = Foo(a=nnx.Param(1)) - - _m = Foo(a=nnx.TreeNode(1)) + assert state.b == nnx.TreeNode(2) + assert state.c == nnx.Param(3) + assert state.d == nnx.BatchStat(4) + assert state.f == nnx.Variable(6) def test_context_none_after_init(self): @dataclasses.dataclass @@ -575,7 +565,7 @@ def __init__(self, c: float, *, rngs: nnx.Rngs): def __call__(self, x, *, rngs: nnx.Rngs): key = rngs.e() - return self.w * x + jax.random.normal(key, ()) + self.c + return self.w.value * x + jax.random.normal(key, ()) + self.c rngs = nnx.Rngs(0) foo = Foo(c=1.0, rngs=rngs) @@ -583,7 +573,7 @@ def __call__(self, x, *, rngs: nnx.Rngs): states, graphdef = foo.split() assert isinstance(states, nnx.State) - assert isinstance(states.variables.w, nnx.Param) + assert isinstance(states.w, nnx.Param) # assert isinstance(states["c"], jax.Array) y, _updates = graphdef.apply(states)(x=2.0, rngs=nnx.Rngs(e=1)) @@ -600,7 +590,7 @@ def __init__(self, c: float, *, rngs: nnx.Rngs): def __call__(self, x, *, rngs: nnx.Rngs): key = rngs.e() - return self.w * x + jax.random.normal(key, ()) + self.c + return self.w.value * x + jax.random.normal(key, ()) + self.c.value foo = Foo(c=1.0, rngs=nnx.Rngs(0)) @@ -608,8 +598,8 @@ def __call__(self, x, *, rngs: nnx.Rngs): assert isinstance(graphdef, nnx.GraphDef) assert isinstance(state, nnx.State) - assert isinstance(state.variables.w, nnx.Param) - assert isinstance(state.variables.c, nnx.Variable) + assert isinstance(state.w, nnx.Param) + assert isinstance(state.c, nnx.Variable) y, (state, graphdef) = graphdef.apply(state)(x=2.0, rngs=nnx.Rngs(e=1)) diff --git a/flax/experimental/nnx/tests/test_partitioning.py b/flax/experimental/nnx/tests/test_partitioning.py index 7b5a095647..14a218e721 100644 --- a/flax/experimental/nnx/tests/test_partitioning.py +++ b/flax/experimental/nnx/tests/test_partitioning.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import jax import pytest @@ -32,17 +33,17 @@ def test_partition(self): assert len(rest) == 1 # check params - assert params['a']['0'] == m.a[0] - assert params['b'] == m.b + assert params['a']['0'].raw_value == m.a[0].value + assert params['b'].raw_value == m.b.value # check rest - assert rest['a']['1'] == m.a[1] + assert rest['a']['1'].raw_value == m.a[1].value m2 = graphdef.merge(params, rest) - assert m2.a[0] == m.a[0] - assert m2.a[1] == m.a[1] - assert m2.b == m.b + assert m2.a[0].value == m.a[0].value + assert m2.a[1].value == m.a[1].value + assert m2.b.value == m.b.value assert m2.c == 100 def test_complete_partitioning(self): @@ -93,13 +94,13 @@ def test_update_from(self): ) state = m.split()[0] - state = jax.tree_util.tree_map(lambda x: x * 2, state) + state = jax.tree_map(lambda x: x * 2, state) m.update(state) - assert m.a[0] == 2 - assert m.a[1] == 6 - assert m.b == 4 + assert m.a[0].value == 2 + assert m.a[1].value == 6 + assert m.b.value == 4 assert m.c == 100 def test_update_from_with_array_leaf(self): @@ -110,14 +111,14 @@ def test_update_from_with_array_leaf(self): ) state, graphdef = m.split() - state = jax.tree_util.tree_map(lambda x: x * 2, state) + state = jax.tree_map(lambda x: x * 2, state) m.update(state) - assert m.a[0] == 2 - assert m.a[1] == 6 - assert m.b == 4 - assert m.c == 200 + assert m.a[0].value == 2 + assert m.a[1].value == 6 + assert m.b.value == 4 + assert m.c.value == 200 def test_grad_example(self): m = nnx.Dict( @@ -134,9 +135,9 @@ def loss(params): grads = jax.grad(loss)(params) m.update(grads) - assert m.a[0] == 2.0 - assert m.a[1] == -10 - assert m.b == 2.0 + assert m.a[0].value == 2.0 + assert m.a[1].value == -10 + assert m.b.value == 2.0 assert m.c == 100 def test_get_paritition(self): @@ -151,8 +152,8 @@ def test_get_paritition(self): assert vars(m.a)['0'] is not vars(m)['b'] state = m.extract(nnx.Variable) - assert state['a']['0'] == m.a[0] - assert state['a']['1'] == m.a[1] - assert state['b'] == m.b - assert state.variables.b is not state.a.variables[0] + assert state['a']['0'].raw_value == m.a[0].value + assert state['a']['1'].raw_value == m.a[1].value + assert state['b'].raw_value == m.b.value + assert state.b is not state.a[0] assert len(state.flat_state()) == 3 diff --git a/flax/experimental/nnx/tests/test_pytree.py b/flax/experimental/nnx/tests/test_pytree.py index 0ac1126620..ad89dcb42e 100644 --- a/flax/experimental/nnx/tests/test_pytree.py +++ b/flax/experimental/nnx/tests/test_pytree.py @@ -33,7 +33,7 @@ def __init__(self, y) -> None: leaves = jax.tree_util.tree_leaves(pytree) assert leaves == [3] - pytree = jax.tree_util.tree_map(lambda x: x * 2, pytree) + pytree = jax.tree_map(lambda x: x * 2, pytree) assert pytree.x == 2 assert pytree.y == 6 @@ -57,7 +57,7 @@ class Foo(nnx.Pytree): leaves = jax.tree_util.tree_leaves(pytree) assert leaves == [3] - pytree = jax.tree_util.tree_map(lambda x: x * 2, pytree) + pytree = jax.tree_map(lambda x: x * 2, pytree) assert pytree.x == 2 assert pytree.y == 6 @@ -150,8 +150,8 @@ class Foo(nnx.Pytree): path_values, treedef = jax.tree_util.tree_flatten_with_path(foo) path_values = [(list(map(str, path)), value) for path, value in path_values] - assert path_values[0] == (['.x', '.value'], 3) - assert path_values[1] == (['.z', '.value', '.a', '.value'], 1) + assert path_values[0] == (['.x', '.raw_value'], 3) + assert path_values[1] == (['.z', '.raw_value', '.a', '.raw_value'], 1) def test_replace_unknown_fields_error(self): class Foo(nnx.Pytree): @@ -184,7 +184,7 @@ def __new__(cls, a): pytree = A(a=1) - pytree = jax.tree_util.tree_map(lambda x: x * 2, pytree) + pytree = jax.tree_map(lambda x: x * 2, pytree) def test_deterministic_order(self): class A(nnx.Pytree): @@ -217,7 +217,7 @@ def __init__(self, y) -> None: leaves = jax.tree_util.tree_leaves(pytree) assert leaves == [3] - pytree = jax.tree_util.tree_map(lambda x: x * 2, pytree) + pytree = jax.tree_map(lambda x: x * 2, pytree) assert pytree.x == 2 assert pytree.y == 6 @@ -251,7 +251,7 @@ class Foo(nnx.Pytree, mutable=True): leaves = jax.tree_util.tree_leaves(pytree) assert leaves == [3] - pytree = jax.tree_util.tree_map(lambda x: x * 2, pytree) + pytree = jax.tree_map(lambda x: x * 2, pytree) assert pytree.x == 2 assert pytree.y == 6 diff --git a/flax/experimental/nnx/tests/test_spmd.py b/flax/experimental/nnx/tests/test_spmd.py index 124af14112..743fad62ca 100644 --- a/flax/experimental/nnx/tests/test_spmd.py +++ b/flax/experimental/nnx/tests/test_spmd.py @@ -70,6 +70,10 @@ def __call__(self, x): ) state_spec = nnx.get_partition_spec(state) - assert state_spec.params['w'] == PartitionSpec('row', 'col') - assert state_spec.opt_state[0].mu['w'] == PartitionSpec('row', 'col') - assert state_spec.opt_state[0].nu['w'] == PartitionSpec('row', 'col') + assert state_spec.params['w'].raw_value == PartitionSpec('row', 'col') + assert state_spec.opt_state[0].mu['w'].raw_value == PartitionSpec( + 'row', 'col' + ) + assert state_spec.opt_state[0].nu['w'].raw_value == PartitionSpec( + 'row', 'col' + ) diff --git a/flax/experimental/nnx/tests/test_state.py b/flax/experimental/nnx/tests/test_state.py index 20052aaa3f..3a9d73475c 100644 --- a/flax/experimental/nnx/tests/test_state.py +++ b/flax/experimental/nnx/tests/test_state.py @@ -21,34 +21,34 @@ class StateTest(TestCase): def test_create_state(self): state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) - assert state['a'] == 1 - assert state['b']['c'] == 2 + assert state['a'].raw_value == 1 + assert state['b']['c'].raw_value == 2 def test_get_attr(self): state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) - assert state.a == 1 - assert state.b.c == 2 + assert state.a.raw_value == 1 + assert state.b.c.raw_value == 2 def test_set_attr(self): state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) - state.a = 3 - state.b.c = 4 + state.a.raw_value = 3 + state.b.c.raw_value = 4 - assert state['a'] == 3 - assert state['b']['c'] == 4 + assert state['a'].raw_value == 3 + assert state['b']['c'].raw_value == 4 def test_set_attr_variables(self): state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) - state.a = 3 - state.b.c = 4 + state.a.raw_value = 3 + state.b.c.raw_value = 4 - assert isinstance(state.variables.a, nnx.Param) - assert state.variables.a.value == 3 - assert isinstance(state.b.variables.c, nnx.Param) - assert state.b.variables.c.value == 4 + assert isinstance(state.a, nnx.Param) + assert state.a.raw_value == 3 + assert isinstance(state.b.c, nnx.Param) + assert state.b.c.raw_value == 4 def test_integer_access(self): class Foo(nnx.Module): @@ -58,7 +58,7 @@ def __init__(self, *, rngs: nnx.Rngs): module = Foo(rngs=nnx.Rngs(0)) state = module.get_state() - assert module.layers[0].kernel.shape == (1, 2) - assert state.layers[0].kernel.shape == (1, 2) - assert module.layers[1].kernel.shape == (2, 3) - assert state.layers[1].kernel.shape == (2, 3) + assert module.layers[0].kernel.value.shape == (1, 2) + assert state.layers[0].kernel.raw_value.shape == (1, 2) + assert module.layers[1].kernel.value.shape == (2, 3) + assert state.layers[1].kernel.raw_value.shape == (2, 3) diff --git a/flax/experimental/nnx/tests/test_transforms.py b/flax/experimental/nnx/tests/test_transforms.py index c71b861bae..f441adf16f 100644 --- a/flax/experimental/nnx/tests/test_transforms.py +++ b/flax/experimental/nnx/tests/test_transforms.py @@ -52,12 +52,12 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): m = Foo(2, 3, rngs=nnx.Rngs(0)) assert n == 1 - assert m.w.shape == (2, 3) + assert m.w.value.shape == (2, 3) assert m.din == 2 assert m.dout == 3 assert isinstance(m.din, int) assert isinstance(m.dout, int) - assert isinstance(m.w, jax.Array) + assert isinstance(m.w.value, jax.Array) m = Foo(2, 3, rngs=nnx.Rngs(0)) assert n == 1 @@ -76,15 +76,15 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): def __call__(self, x: jax.Array) -> jax.Array: nonlocal n n += 1 - return jnp.dot(x, self.w) + return jnp.dot(x, self.w.value) m = Foo(2, 3, rngs=nnx.Rngs(0)) - assert m.w.shape == (2, 3) + assert m.w.value.shape == (2, 3) assert m.din == 2 assert m.dout == 3 assert isinstance(m.din, int) assert isinstance(m.dout, int) - assert isinstance(m.w, jax.Array) + assert isinstance(m.w.value, jax.Array) y = m(jnp.ones((1, 2))) assert y.shape == (1, 3) @@ -106,7 +106,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): def __call__(self, x: jax.Array) -> jax.Array: nonlocal n n += 1 - return jnp.dot(x, self.w) + return jnp.dot(x, self.w.value) m = nnx.JIT(Foo)(2, 3, rngs=nnx.Rngs(0)) @@ -132,24 +132,24 @@ def test_grad(self): @nnx.grad def f(m: nnx.Dict): # sum all params - return m['a'][0] + m['a'][1] + m['b'] + return m['a'][0].value + m['a'][1].value + m['b'].value grads = f(m) - assert m.a.variables['0'] is m.variables.b + assert m.a[0] is m.b assert isinstance(grads, nnx.State) - assert grads['a']['0'] == 2.0 - assert isinstance(grads.a.variables['0'], nnx.Variable) - assert grads['a']['1'] == 1.0 - assert isinstance(grads.a.variables['1'], nnx.Variable) + assert grads['a']['0'].raw_value == 2.0 + assert isinstance(grads.a['0'], nnx.Variable) + assert grads['a']['1'].raw_value == 1.0 + assert isinstance(grads.a['1'], nnx.Variable) assert len(grads.flat_state()) == 2 m.update(grads) - assert m.a.variables['0'] is m.variables.b - assert m['a'][0] == 2.0 - assert m['a'][1] == 1.0 - assert m['b'] == 2.0 + assert m.a[0] is m.b + assert m['a'][0].value == 2.0 + assert m['a'][1].value == 1.0 + assert m['b'].value == 2.0 assert m['c'] == 7 assert m['d'] == 5.0 @@ -164,20 +164,20 @@ def test_grad_with_multiple_ref_types(self): @nnx.grad def f(m: nnx.Dict): # sum all params - return m.a[0] + m.a[1] + m.b + return m.a[0].value + m.a[1].value + m.b.value grads = f(m) assert isinstance(grads, nnx.State) - assert grads['a']['0'] == 1.0 - assert isinstance(grads.a.variables['0'], nnx.Param) + assert grads['a']['0'].raw_value == 1.0 + assert isinstance(grads.a['0'], nnx.Param) assert len(grads) == 2 m.update(grads) - assert m.a[0] == 1.0 - assert m.a[1] == 20.0 - assert m.b == 1.0 + assert m.a[0].value == 1.0 + assert m.a[1].value == 20.0 + assert m.b.value == 1.0 assert m.c == 7 assert m.d == 5.0 @@ -192,20 +192,20 @@ def test_grad_with_type_predicate(self): @partial(nnx.grad, wrt=nnx.BatchStat) def f(m: nnx.Dict): # sum all params - return m.a[0] + m.a[1] + m.b + return m.a[0].value + m.a[1].value + m.b.value grads = f(m) assert isinstance(grads, nnx.State) - assert grads['a']['1'] == 1.0 - assert isinstance(grads.a.variables['1'], nnx.BatchStat) + assert grads['a']['1'].raw_value == 1.0 + assert isinstance(grads.a['1'], nnx.BatchStat) assert len(grads) == 1 m.update(grads) - assert m.a[0] == 10.0 - assert m.a[1] == 1.0 - assert m.b == 10.0 + assert m.a[0].value == 10.0 + assert m.a[1].value == 1.0 + assert m.b.value == 10.0 assert m.c == 7 assert m.d == 5.0 @@ -230,9 +230,9 @@ def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: module = MLP(rngs=nnx.Rngs(0)) - assert module.scan_module.linear.kernel.shape == (5, 3, 3) - assert module.scan_module.linear.bias.shape == (5, 3) - assert module.scan_module.node.shape == (2,) + assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) + assert module.scan_module.linear.bias.value.shape == (5, 3) + assert module.scan_module.node.value.shape == (2,) x = jnp.ones((1, 3)) y, out = module(x) @@ -260,9 +260,9 @@ def __call__(self, x: jax.Array): module = MLP(rngs=nnx.Rngs(0)) - assert module.scan_module.linear.kernel.shape == (5, 3, 3) - assert module.scan_module.linear.bias.shape == (5, 3) - assert module.scan_module.node.shape == (2,) + assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) + assert module.scan_module.linear.bias.value.shape == (5, 3) + assert module.scan_module.node.value.shape == (2,) x = jnp.ones((1, 3)) y = module(x) @@ -289,9 +289,9 @@ def __call__(self, x: jax.Array): module = MLP(rngs=nnx.Rngs(0)) - assert module.scan_module.linear.kernel.shape == (5, 3, 3) - assert module.scan_module.linear.bias.shape == (5, 3) - assert module.scan_module.node.shape == (2,) + assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) + assert module.scan_module.linear.bias.value.shape == (5, 3) + assert module.scan_module.node.value.shape == (2,) x = jnp.ones((1, 3)) c, (y1, y2) = module(x) @@ -323,9 +323,9 @@ def __call__( module = MLP(rngs=nnx.Rngs(0)) - assert module.scan_module.linear.kernel.shape == (5, 3, 3) - assert module.scan_module.linear.bias.shape == (5, 3) - assert module.scan_module.node.shape == (2,) + assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) + assert module.scan_module.linear.bias.value.shape == (5, 3) + assert module.scan_module.node.value.shape == (2,) x = jnp.ones((1, 3)) a = jnp.ones((5, 1, 3)) @@ -359,9 +359,9 @@ def __call__( module = MLP(rngs=nnx.Rngs(0)) - assert module.scan_module.linear.kernel.shape == (5, 3, 3) - assert module.scan_module.linear.bias.shape == (5, 3) - assert module.scan_module.node.shape == (2,) + assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) + assert module.scan_module.linear.bias.value.shape == (5, 3) + assert module.scan_module.node.value.shape == (2,) x = jnp.ones((1, 3)) a = jnp.ones((5, 1, 3)) @@ -392,9 +392,9 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: module = MLP(rngs=nnx.Rngs(0)) - assert module.scan_module.linear.kernel.shape == (5, 3, 3) - assert module.scan_module.linear.bias.shape == (5, 3) - assert module.scan_module.node.shape == (2,) + assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) + assert module.scan_module.linear.bias.value.shape == (5, 3) + assert module.scan_module.node.value.shape == (2,) x = jnp.ones((1, 3)) with nnx.flags(deterministic=False, use_running_average=False): @@ -428,9 +428,9 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: module = MLP(rngs=nnx.Rngs(0)) - assert module.scan_module.linear.kernel.shape == (5, 3, 3) - assert module.scan_module.linear.bias.shape == (5, 3) - assert module.scan_module.node.shape == (2,) + assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) + assert module.scan_module.linear.bias.value.shape == (5, 3) + assert module.scan_module.node.value.shape == (2,) x = jnp.ones((1, 3)) with nnx.flags(deterministic=False, use_running_average=False): @@ -467,9 +467,9 @@ def __call__( module = Block(rngs=nnx.Rngs(0)) assert module.d == 3 - assert module.linear.kernel.shape == (5, 3, 3) - assert module.linear.bias.shape == (5, 3) - assert module.node.shape == (2,) + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + assert module.node.value.shape == (2,) x = jnp.ones((1, 3)) with nnx.flags(deterministic=False, use_running_average=False): @@ -500,10 +500,10 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: # test sharding layer axes is not present inside scan state = self.linear.get_state() - assert state.kernel.shape == (3, 3) - assert state.variables.kernel.sharding == ('din', 'dout') - assert state.bias.shape == (3,) - assert state.variables.bias.sharding == ('dout',) + assert state.kernel.raw_value.shape == (3, 3) + assert state.kernel.sharding == ('din', 'dout') + assert state.bias.raw_value.shape == (3,) + assert state.bias.sharding == ('dout',) return x, None @@ -518,14 +518,18 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: # test sharding layers axes is set state = m.get_state() - assert state.scan_module.linear.variables.kernel.value.shape == (5, 3, 3) - assert state.scan_module.linear.variables.kernel.sharding == ( + assert state.scan_module.linear.kernel.raw_value.shape == ( + 5, + 3, + 3, + ) + assert state.scan_module.linear.kernel.sharding == ( 'layers', 'din', 'dout', ) - assert state.scan_module.linear.variables.bias.value.shape == (5, 3) - assert state.scan_module.linear.variables.bias.sharding == ( + assert state.scan_module.linear.bias.raw_value.shape == (5, 3) + assert state.scan_module.linear.bias.sharding == ( 'layers', 'dout', ) @@ -535,14 +539,14 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: # test sharding axes is preserved state = m.get_state() - assert state.scan_module.linear.kernel.shape == (5, 3, 3) - assert state.scan_module.linear.variables.kernel.sharding == ( + assert state.scan_module.linear.kernel.raw_value.shape == (5, 3, 3) + assert state.scan_module.linear.kernel.sharding == ( 'layers', 'din', 'dout', ) - assert state.scan_module.linear.bias.shape == (5, 3) - assert state.scan_module.linear.variables.bias.sharding == ( + assert state.scan_module.linear.bias.raw_value.shape == (5, 3) + assert state.scan_module.linear.bias.sharding == ( 'layers', 'dout', ) @@ -634,8 +638,8 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: m = ScanRematLinear(rngs=nnx.Rngs(0)) - assert m.scan_module.remat_module.linear.kernel.shape == (5, 3, 3) - assert m.scan_module.remat_module.linear.bias.shape == (5, 3) + assert m.scan_module.remat_module.linear.kernel.value.shape == (5, 3, 3) + assert m.scan_module.remat_module.linear.bias.value.shape == (5, 3) y, _ = m(jnp.ones((1, 3)), None) assert y.shape == (1, 3) @@ -663,8 +667,8 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: m = ScanLinear(rngs=nnx.Rngs(0)) - assert m.linear.kernel.shape == (5, 3, 3) - assert m.linear.bias.shape == (5, 3) + assert m.linear.kernel.value.shape == (5, 3, 3) + assert m.linear.bias.value.shape == (5, 3) y, _ = m(jnp.ones((1, 3)), None) assert y.shape == (1, 3) @@ -686,11 +690,11 @@ def __call__(self, x: jax.Array) -> jax.Array: module = MLP(rngs=nnx.Rngs(0)) assert not jnp.allclose( - module.vmap_module.linear.kernel[0], - module.vmap_module.linear.kernel[1], + module.vmap_module.linear.kernel.value[0], + module.vmap_module.linear.kernel.value[1], ) - assert module.vmap_module.linear.kernel.shape == (5, 3, 3) - assert module.vmap_module.linear.bias.shape == (5, 3) + assert module.vmap_module.linear.kernel.value.shape == (5, 3, 3) + assert module.vmap_module.linear.bias.value.shape == (5, 3) x = jnp.ones((5, 1, 3)) y = module(x) diff --git a/flax/experimental/nnx/tests/test_variable.py b/flax/experimental/nnx/tests/test_variable.py index 3a44864fd9..bdc49aaf88 100644 --- a/flax/experimental/nnx/tests/test_variable.py +++ b/flax/experimental/nnx/tests/test_variable.py @@ -24,10 +24,10 @@ class TestVariable: def test_value(self): r1 = nnx.Variable(1) - assert r1.value == 1 + assert r1.raw_value == 1 - r2 = jax.tree_util.tree_map(lambda x: x + 1, r1) + r2 = jax.tree_map(lambda x: x + 1, r1) - assert r1.value == 1 - assert r2.value == 2 + assert r1.raw_value == 1 + assert r2.raw_value == 2 assert r1 is not r2