Skip to content

Commit

Permalink
[nnx] simplify readme
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Feb 21, 2024
1 parent 736d133 commit 8f9e760
Showing 1 changed file with 85 additions and 23 deletions.
108 changes: 85 additions & 23 deletions flax/experimental/nnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,102 @@ To get started with `nnx`, install Flax from GitHub:
pip install git+https://github.com/google/flax.git
```

## Getting Started
## What does NNX look like?

The following example guides you through creating a basic `Linear` model with NNX and executing a forward pass. It also demonstrate how handle mutable state by showing how to keep track of the number of times the model has been called.
We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

```python
from flax.experimental import nnx
To learn more about the `Module` abstraction, check out our [docs](https://flax.readthedocs.io/), our [broad intro to the Module abstraction](https://github.com/google/flax/blob/main/docs/linen_intro.ipynb). For additional concrete demonstrations of best practices, refer to our
[guides](https://flax.readthedocs.io/en/latest/guides/index.html) and
[developer notes](https://flax.readthedocs.io/en/latest/developer_notes/index.html).

```py
import jax
import jax.numpy as jnp

class Count(nnx.Variable): pass # typed Variable collections
from flax.experimental import nnx

class Linear(nnx.Module):
def __init__(self, din, dout, *, rngs: nnx.Rngs): # explicit RNG management
key = rngs()
# put dynamic state in Variable types
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(0)
# other types as treated as static
self.din = din
self.dout = dout

def __call__(self, x):
self.count += 1 # inplace stateful updates
return x @ self.w + self.b
class MLP(nnx.Module):
def __init__(self, features: list[int], *, rngs: nnx.Rngs):
self.layers = [
nnx.Linear(din, dout, rngs=rngs)
for din, dout in zip(features[:-1], features[1:])
]

def __call__(self, x: jax.Array) -> jax.Array:
for layer in self.layers[:-1]:
x = nnx.relu(layer(x))
x = self.layers[-1](x)
return x

model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # no special `init` method
x = jnp.ones((8, 12))
y = model(x) # call methods directly

assert model.count == 1
model = MLP([784, 64, 32, 10], rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 784)))
```

In this example `nnx.Rngs(0)` create a `random.key` for `params` with seed `0`, this is used by `rngs.<rng-name>()` inside `__init__` to generate a random key to initialize the parameters.
```py
class CNN(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(
in_features=1, out_features=64, kernel_size=(3, 3), rngs=rngs
)
self.conv2 = nnx.Conv(
in_features=64, out_features=32, kernel_size=(3, 3), rngs=rngs
)
self.linear1 = nnx.Linear(
in_features=7 * 7 * 32, out_features=256, rngs=rngs
)
self.linear2 = nnx.Linear(in_features=256, out_features=10, rngs=rngs)

def __call__(self, x):
x = nnx.relu(self.conv1(x))
x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nnx.relu(self.conv2(x))
x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nnx.relu(self.linear1(x))
logits = self.linear2(x)
return logits


model = CNN(rngs=nnx.Rngs(0))
x = jnp.ones((1, 28, 28, 1)) # (N, H, W, C) format
logits = model(x)
```

```py
class AutoEncoder(nnx.Module):
def __init__(
self,
input_features: int,
encoder_features: list[int],
decoder_features: list[int],
*,
rngs: nnx.Rngs,
):
self.encoder = MLP([input_features, *encoder_features], rngs=rngs)
self.decoder = MLP([*decoder_features, input_features], rngs=rngs)

def __call__(self, x):
return self.decode(self.encode(x))

def encode(self, x):
return self.encoder(x)

def decode(self, z):
return nnx.sigmoid(self.decoder(z))


model = AutoEncoder(
input_features=784,
encoder_features=[64, 32],
decoder_features=[32, 64],
rngs=nnx.Rngs(0),
)
x = jnp.ones((1, 784))
z = model.encode(x)
y = model.decode(z)
```

### Interacting with JAX

Expand Down

0 comments on commit 8f9e760

Please sign in to comment.