Skip to content

Commit

Permalink
Merge pull request #3744 from google:nnx-fix-demo
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616812307
  • Loading branch information
Flax Authors committed Mar 18, 2024
2 parents e1f672b + 585e844 commit e3e8cb4
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 83 deletions.
124 changes: 65 additions & 59 deletions flax/experimental/nnx/docs/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 7,
"id": "d99b73af",
"metadata": {
"outputId": "d8ef66d5-6866-4d5c-94c2-d22512bfe718"
Expand All @@ -43,20 +43,19 @@
"model = MLP(\n",
" blocks=[Block(\n",
" linear=Linear(\n",
" in_features=4,\n",
" out_features=4,\n",
" use_bias=True,\n",
" dtype=None,\n",
" param_dtype=<class 'jax.numpy.float32'>,\n",
" precision=None,\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7f8aa7e24670>,\n",
" bias_init=<function zeros at 0x7f8b4a8d55a0>,\n",
" dot_general=<function dot_general at 0x7f8b4aed8f70>\n",
" ),\n",
" in_features=4,\n",
" out_features=4,\n",
" use_bias=True,\n",
" dtype=None,\n",
" param_dtype=<class 'jax.numpy.float32'>,\n",
" precision=None,\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x28ae86dc0>,\n",
" bias_init=<function zeros at 0x122d39f70>,\n",
" dot_general=<function dot_general at 0x1218459d0>\n",
" ),\n",
" bn=BatchNorm(\n",
" num_features=4,\n",
" use_running_average=None,\n",
" \n",
" num_features=4,\n",
" \n",
"...\n"
]
}
Expand All @@ -65,15 +64,11 @@
"\n",
"class Block(nnx.Module):\n",
" def __init__(self, din, dout, *, rngs):\n",
" self.linear = nnx.Linear(din, dout, rngs=rngs,\n",
" kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal() , ('data', 'mp')))\n",
" self.linear = nnx.Linear(din, dout, rngs=rngs)\n",
" self.bn = nnx.BatchNorm(dout, rngs=rngs)\n",
"\n",
" def __call__(self, x, *, train: bool):\n",
" x = self.linear(x)\n",
" x = self.bn(x, use_running_average=not train)\n",
" x = nnx.relu(x)\n",
" return x\n",
" def __call__(self, x):\n",
" return nnx.relu(self.bn(self.linear(x)))\n",
"\n",
"\n",
"class MLP(nnx.Module):\n",
Expand All @@ -83,17 +78,18 @@
" ]\n",
" self.count = Count(0) # stateful variables are defined as attributes\n",
"\n",
" def __call__(self, x, *, train: bool):\n",
" self.count += 1 # in-place stateful updates\n",
" def __call__(self, x):\n",
" self.count.value += 1 # in-place stateful updates\n",
" for block in self.blocks:\n",
" x = block(x, train=train)\n",
" x = block(x)\n",
" return x\n",
"\n",
"class Count(nnx.Variable): # custom Variable types define the \"collections\"\n",
" pass\n",
"\n",
"model = MLP(5, 4, rngs=nnx.Rngs(0)) # no special `init` method\n",
"y = model(jnp.ones((2, 4)), train=False) # call methods directly\n",
"model.set_attributes(deterministic=False, use_running_average=False) # set flags\n",
"y = model(jnp.ones((2, 4))) # call methods directly\n",
"\n",
"print(f'{model = }'[:500] + '\\n...')"
]
Expand All @@ -108,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 9,
"id": "6f278ec4",
"metadata": {
"outputId": "10a46b0f-2993-4677-c26d-36a4ddf33449"
Expand All @@ -118,11 +114,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"model.count = 1\n",
"model.blocks[0].linear.kernel = Array([[ 0.4541134 , -0.5264871 , -0.36505395, -0.57566494],\n",
" [ 0.3880299 , 0.56555384, 0.48706698, 0.22677685],\n",
" [-0.9015692 , 0.24465257, -0.58447087, 0.18421973],\n",
" [-0.06992681, -0.64693826, 0.20232539, 1.1200054 ]], dtype=float32)\n"
"model.count = Count(\n",
" raw_value=1\n",
")\n",
"model.blocks[0].linear.kernel = Param(\n",
" raw_value=Array([[-0.80345297, -0.34071913, -0.9408296 , 0.01005968],\n",
" [ 0.26146442, 1.1247735 , 0.54563737, -0.374164 ],\n",
" [ 1.0281805 , -0.6798804 , -0.1488401 , 0.05694951],\n",
" [-0.44308168, -0.60587114, 0.434087 , -0.40541083]], dtype=float32)\n",
")\n"
]
}
],
Expand All @@ -142,7 +142,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 10,
"id": "96f61108",
"metadata": {
"outputId": "e6f86be8-3537-4c48-f471-316ee0fb6c45"
Expand All @@ -160,12 +160,12 @@
"# Module sharing\n",
"model.blocks[1] = model.blocks[3]\n",
"# Weight tying\n",
"model.blocks[0].linear.variables.kernel = model.blocks[-1].linear.variables.kernel\n",
"model.blocks[0].linear.kernel = model.blocks[-1].linear.kernel\n",
"# Monkey patching\n",
"def my_optimized_layer(x, *, train: bool): return x\n",
"def my_optimized_layer(x): return x\n",
"model.blocks[2] = my_optimized_layer\n",
"\n",
"y = model(jnp.ones((2, 4)), train=False) # still works\n",
"y = model(jnp.ones((2, 4))) # still works\n",
"print(f'{y.shape = }')"
]
},
Expand All @@ -179,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 11,
"id": "c166dcc7",
"metadata": {
"outputId": "9a3f378b-739e-4f45-9968-574651200ede"
Expand All @@ -193,14 +193,15 @@
" 'blocks': {\n",
" '0': {\n",
" 'linear': {\n",
" 'kernel': Array([[-0.33674937, 1.0543901 , -0.524824 , 0.16665861],\n",
" [ 0.6607222 , 0.07498633, -0.165967 , -0.36928803],\n",
" [-0.7086948 , -0.5809104 , 0.2939486 , -0.6660238 ],\n",
" [-0.13412867, 0.09832543, 0.77024055, -0.2405255 ]], dtype=float32),\n",
" 'bias': Array([0., 0., 0., 0.], dtype=float32)\n",
" },\n",
" 'bn': {\n",
" 'mean': Array([0., 0., 0., 0.], dtype=float32),\n",
" 'kernel': Param(\n",
" raw_value=Array([[-0.33095378, 0.67149884, 0.33700302, 0.30972847],\n",
" [ 0.8662822 , -0.11225506, -1.0820619 , -0.9906892 ],\n",
" [ 0.88298297, -0.2143851 , 0.48143268, 0.6474548 ],\n",
" [-0.7710582 , 0.3372276 , 0.15487202, 0.6219269 ]], dtype=float32)\n",
" ),\n",
" 'bias': Param(\n",
" raw_value=Array([0., 0., 0., 0.], dtype=float32)\n",
" \n",
"...\n",
"\n",
"static = GraphDef(\n",
Expand All @@ -227,13 +228,13 @@
"# state is a dictionary-like JAX pytree\n",
"print(f'{state = }'[:500] + '\\n...')\n",
"\n",
"# static is also a JAX pytree, but containing no data, just metadata\n",
"# static is also a JAX pytree, but just metadata\n",
"print(f'\\n{static = }'[:300] + '\\n...')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"id": "9f03e3af",
"metadata": {
"outputId": "0007d357-152a-449e-bcb9-b1b5a91d2d8d"
Expand All @@ -244,7 +245,7 @@
"output_type": "stream",
"text": [
"y.shape = (2, 4)\n",
"model.count = Array(3, dtype=int32, weak_type=True)\n"
"model.count.value = Array(3, dtype=int32, weak_type=True)\n"
]
}
],
Expand All @@ -254,7 +255,7 @@
"@jax.jit\n",
"def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array):\n",
" model = static.merge(state)\n",
" y = model(x, train=True)\n",
" y = model(x)\n",
" state, _ = model.split()\n",
" return y, state\n",
"\n",
Expand All @@ -264,7 +265,7 @@
"model.update(state)\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{model.count = }')"
"print(f'{model.count.value = }')"
]
},
{
Expand Down Expand Up @@ -303,7 +304,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 14,
"id": "2461bfe8",
"metadata": {},
"outputs": [
Expand All @@ -312,40 +313,45 @@
"output_type": "stream",
"text": [
"y.shape = (2, 4)\n",
"parent.model.count = Array(5, dtype=int32, weak_type=True)\n"
"parent.model.count.value = Array(4, dtype=int32, weak_type=True)\n"
]
}
],
"source": [
"class Parent(nnx.Module):\n",
"\n",
" def __init__(self, model: MLP):\n",
" self.model = model\n",
"\n",
" def __call__(self, x, *, train: bool):\n",
"\n",
" def __call__(self, x):\n",
" params, batch_stats, counts, static = self.model.split(nnx.Param, nnx.BatchStat, Count)\n",
"\n",
" @jax.jit\n",
" def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n",
" model = static.merge(params, batch_stats, counts)\n",
" y = model(x, train=True)\n",
" y = model(x)\n",
" params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n",
" return y, params, batch_stats, counts\n",
"\n",
" y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x)\n",
"\n",
" self.model.update(params, batch_stats, counts)\n",
"\n",
" return y\n",
"\n",
"parent = Parent(model)\n",
"\n",
"y = parent(jnp.ones((2, 4)), train=False)\n",
"y = parent(jnp.ones((2, 4)))\n",
"\n",
"print(f'{y.shape = }')\n",
"print(f'{parent.model.count = }')"
"print(f'{parent.model.count.value = }')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e340bcb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -362,7 +368,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit e3e8cb4

Please sign in to comment.