diff --git a/flax/experimental/nnx/docs/demo.ipynb b/flax/experimental/nnx/docs/demo.ipynb index 6cec3f8677..d476ca1d6d 100644 --- a/flax/experimental/nnx/docs/demo.ipynb +++ b/flax/experimental/nnx/docs/demo.ipynb @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "id": "d99b73af", "metadata": { "outputId": "d8ef66d5-6866-4d5c-94c2-d22512bfe718" @@ -43,20 +43,19 @@ "model = MLP(\n", " blocks=[Block(\n", " linear=Linear(\n", - " in_features=4,\n", - " out_features=4,\n", - " use_bias=True,\n", - " dtype=None,\n", - " param_dtype=,\n", - " precision=None,\n", - " kernel_init=.init at 0x7f8aa7e24670>,\n", - " bias_init=,\n", - " dot_general=\n", - " ),\n", + " in_features=4,\n", + " out_features=4,\n", + " use_bias=True,\n", + " dtype=None,\n", + " param_dtype=,\n", + " precision=None,\n", + " kernel_init=.init at 0x28ae86dc0>,\n", + " bias_init=,\n", + " dot_general=\n", + " ),\n", " bn=BatchNorm(\n", - " num_features=4,\n", - " use_running_average=None,\n", - " \n", + " num_features=4,\n", + " \n", "...\n" ] } @@ -65,15 +64,11 @@ "\n", "class Block(nnx.Module):\n", " def __init__(self, din, dout, *, rngs):\n", - " self.linear = nnx.Linear(din, dout, rngs=rngs,\n", - " kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal() , ('data', 'mp')))\n", + " self.linear = nnx.Linear(din, dout, rngs=rngs)\n", " self.bn = nnx.BatchNorm(dout, rngs=rngs)\n", "\n", - " def __call__(self, x, *, train: bool):\n", - " x = self.linear(x)\n", - " x = self.bn(x, use_running_average=not train)\n", - " x = nnx.relu(x)\n", - " return x\n", + " def __call__(self, x):\n", + " return nnx.relu(self.bn(self.linear(x)))\n", "\n", "\n", "class MLP(nnx.Module):\n", @@ -83,17 +78,18 @@ " ]\n", " self.count = Count(0) # stateful variables are defined as attributes\n", "\n", - " def __call__(self, x, *, train: bool):\n", - " self.count += 1 # in-place stateful updates\n", + " def __call__(self, x):\n", + " self.count.value += 1 # in-place stateful updates\n", " for block in self.blocks:\n", - " x = block(x, train=train)\n", + " x = block(x)\n", " return x\n", "\n", "class Count(nnx.Variable): # custom Variable types define the \"collections\"\n", " pass\n", "\n", "model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method\n", - "y = model(jnp.ones((2, 4)), train=False) # call methods directly\n", + "model.set_attributes(deterministic=False, use_running_average=False) # set flags\n", + "y = model(jnp.ones((2, 4))) # call methods directly\n", "\n", "print(f'{model = }'[:500] + '\\n...')" ] @@ -108,7 +104,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 9, "id": "6f278ec4", "metadata": { "outputId": "10a46b0f-2993-4677-c26d-36a4ddf33449" @@ -118,11 +114,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "model.count = 1\n", - "model.blocks[0].linear.kernel = Array([[ 0.4541134 , -0.5264871 , -0.36505395, -0.57566494],\n", - " [ 0.3880299 , 0.56555384, 0.48706698, 0.22677685],\n", - " [-0.9015692 , 0.24465257, -0.58447087, 0.18421973],\n", - " [-0.06992681, -0.64693826, 0.20232539, 1.1200054 ]], dtype=float32)\n" + "model.count = Count(\n", + " raw_value=1\n", + ")\n", + "model.blocks[0].linear.kernel = Param(\n", + " raw_value=Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n", + " [ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ],\n", + " [ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951],\n", + " [-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)\n", + ")\n" ] } ], @@ -142,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 10, "id": "96f61108", "metadata": { "outputId": "e6f86be8-3537-4c48-f471-316ee0fb6c45" @@ -160,12 +160,12 @@ "# Module sharing\n", "model.blocks[1] = model.blocks[3]\n", "# Weight tying\n", - "model.blocks[0].linear.variables.kernel = model.blocks[-1].linear.variables.kernel\n", + "model.blocks[0].linear.kernel = model.blocks[-1].linear.kernel\n", "# Monkey patching\n", - "def my_optimized_layer(x, *, train: bool): return x\n", + "def my_optimized_layer(x): return x\n", "model.blocks[2] = my_optimized_layer\n", "\n", - "y = model(jnp.ones((2, 4)), train=False) # still works\n", + "y = model(jnp.ones((2, 4))) # still works\n", "print(f'{y.shape = }')" ] }, @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 11, "id": "c166dcc7", "metadata": { "outputId": "9a3f378b-739e-4f45-9968-574651200ede" @@ -193,14 +193,15 @@ " 'blocks': {\n", " '0': {\n", " 'linear': {\n", - " 'kernel': Array([[-0.33674937, 1.0543901 , -0.524824 , 0.16665861],\n", - " [ 0.6607222 , 0.07498633, -0.165967 , -0.36928803],\n", - " [-0.7086948 , -0.5809104 , 0.2939486 , -0.6660238 ],\n", - " [-0.13412867, 0.09832543, 0.77024055, -0.2405255 ]], dtype=float32),\n", - " 'bias': Array([0., 0., 0., 0.], dtype=float32)\n", - " },\n", - " 'bn': {\n", - " 'mean': Array([0., 0., 0., 0.], dtype=float32),\n", + " 'kernel': Param(\n", + " raw_value=Array([[-0.33095378, 0.67149884, 0.33700302, 0.30972847],\n", + " [ 0.8662822 , -0.11225506, -1.0820619 , -0.9906892 ],\n", + " [ 0.88298297, -0.2143851 , 0.48143268, 0.6474548 ],\n", + " [-0.7710582 , 0.3372276 , 0.15487202, 0.6219269 ]], dtype=float32)\n", + " ),\n", + " 'bias': Param(\n", + " raw_value=Array([0., 0., 0., 0.], dtype=float32)\n", + " \n", "...\n", "\n", "static = GraphDef(\n", @@ -227,13 +228,13 @@ "# state is a dictionary-like JAX pytree\n", "print(f'{state = }'[:500] + '\\n...')\n", "\n", - "# static is also a JAX pytree, but containing no data, just metadata\n", + "# static is also a JAX pytree, but just metadata\n", "print(f'\\n{static = }'[:300] + '\\n...')" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 12, "id": "9f03e3af", "metadata": { "outputId": "0007d357-152a-449e-bcb9-b1b5a91d2d8d" @@ -244,7 +245,7 @@ "output_type": "stream", "text": [ "y.shape = (2, 4)\n", - "model.count = Array(3, dtype=int32, weak_type=True)\n" + "model.count.value = Array(3, dtype=int32, weak_type=True)\n" ] } ], @@ -254,7 +255,7 @@ "@jax.jit\n", "def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", " model = static.merge(state)\n", - " y = model(x, train=True)\n", + " y = model(x)\n", " state, _ = model.split()\n", " return y, state\n", "\n", @@ -264,7 +265,7 @@ "model.update(state)\n", "\n", "print(f'{y.shape = }')\n", - "print(f'{model.count = }')" + "print(f'{model.count.value = }')" ] }, { @@ -303,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 14, "id": "2461bfe8", "metadata": {}, "outputs": [ @@ -312,40 +313,45 @@ "output_type": "stream", "text": [ "y.shape = (2, 4)\n", - "parent.model.count = Array(5, dtype=int32, weak_type=True)\n" + "parent.model.count.value = Array(4, dtype=int32, weak_type=True)\n" ] } ], "source": [ "class Parent(nnx.Module):\n", - "\n", " def __init__(self, model: MLP):\n", " self.model = model\n", "\n", - " def __call__(self, x, *, train: bool):\n", - "\n", + " def __call__(self, x):\n", " params, batch_stats, counts, static = self.model.split(nnx.Param, nnx.BatchStat, Count)\n", "\n", " @jax.jit\n", " def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n", " model = static.merge(params, batch_stats, counts)\n", - " y = model(x, train=True)\n", + " y = model(x)\n", " params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n", " return y, params, batch_stats, counts\n", "\n", " y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x)\n", "\n", " self.model.update(params, batch_stats, counts)\n", - "\n", " return y\n", "\n", "parent = Parent(model)\n", "\n", - "y = parent(jnp.ones((2, 4)), train=False)\n", + "y = parent(jnp.ones((2, 4)))\n", "\n", "print(f'{y.shape = }')\n", - "print(f'{parent.model.count = }')" + "print(f'{parent.model.count.value = }')" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e340bcb", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -362,7 +368,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.9.18" } }, "nbformat": 4, diff --git a/flax/experimental/nnx/docs/demo.md b/flax/experimental/nnx/docs/demo.md index 40d760f13b..fbd8f7564c 100644 --- a/flax/experimental/nnx/docs/demo.md +++ b/flax/experimental/nnx/docs/demo.md @@ -24,15 +24,11 @@ from flax.experimental import nnx class Block(nnx.Module): def __init__(self, din, dout, *, rngs): - self.linear = nnx.Linear(din, dout, rngs=rngs, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal() , ('data', 'mp'))) + self.linear = nnx.Linear(din, dout, rngs=rngs) self.bn = nnx.BatchNorm(dout, rngs=rngs) - def __call__(self, x, *, train: bool): - x = self.linear(x) - x = self.bn(x, use_running_average=not train) - x = nnx.relu(x) - return x + def __call__(self, x): + return nnx.relu(self.bn(self.linear(x))) class MLP(nnx.Module): @@ -42,17 +38,18 @@ class MLP(nnx.Module): ] self.count = Count(0) # stateful variables are defined as attributes - def __call__(self, x, *, train: bool): - self.count += 1 # in-place stateful updates + def __call__(self, x): + self.count.value += 1 # in-place stateful updates for block in self.blocks: - x = block(x, train=train) + x = block(x) return x class Count(nnx.Variable): # custom Variable types define the "collections" pass model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method -y = model(jnp.ones((2, 4)), train=False) # call methods directly +model.set_attributes(deterministic=False, use_running_average=False) # set flags +y = model(jnp.ones((2, 4))) # call methods directly print(f'{model = }'[:500] + '\n...') ``` @@ -75,12 +72,12 @@ print(f'{model.blocks[0].linear.kernel = }') # Module sharing model.blocks[1] = model.blocks[3] # Weight tying -model.blocks[0].linear.variables.kernel = model.blocks[-1].linear.variables.kernel +model.blocks[0].linear.kernel = model.blocks[-1].linear.kernel # Monkey patching -def my_optimized_layer(x, *, train: bool): return x +def my_optimized_layer(x): return x model.blocks[2] = my_optimized_layer -y = model(jnp.ones((2, 4)), train=False) # still works +y = model(jnp.ones((2, 4))) # still works print(f'{y.shape = }') ``` @@ -94,7 +91,7 @@ state, static = model.split() # state is a dictionary-like JAX pytree print(f'{state = }'[:500] + '\n...') -# static is also a JAX pytree, but containing no data, just metadata +# static is also a JAX pytree, but just metadata print(f'\n{static = }'[:300] + '\n...') ``` @@ -106,7 +103,7 @@ state, static = model.split() @jax.jit def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array): model = static.merge(state) - y = model(x, train=True) + y = model(x) state, _ = model.split() return y, state @@ -116,7 +113,7 @@ y, state = forward(static,state, x) model.update(state) print(f'{y.shape = }') -print(f'{model.count = }') +print(f'{model.count.value = }') ``` ```{code-cell} ipython3 @@ -140,31 +137,32 @@ print(f'{model.count = }') ```{code-cell} ipython3 class Parent(nnx.Module): - def __init__(self, model: MLP): self.model = model - def __call__(self, x, *, train: bool): - + def __call__(self, x): params, batch_stats, counts, static = self.model.split(nnx.Param, nnx.BatchStat, Count) @jax.jit def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array): model = static.merge(params, batch_stats, counts) - y = model(x, train=True) + y = model(x) params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count) return y, params, batch_stats, counts y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x) self.model.update(params, batch_stats, counts) - return y parent = Parent(model) -y = parent(jnp.ones((2, 4)), train=False) +y = parent(jnp.ones((2, 4))) print(f'{y.shape = }') -print(f'{parent.model.count = }') +print(f'{parent.model.count.value = }') +``` + +```{code-cell} ipython3 + ```