Skip to content

Commit

Permalink
Update Flax NNX Model Surgery
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Oct 9, 2024
1 parent 2ad9731 commit 8ca8131
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 60 deletions.
67 changes: 37 additions & 30 deletions docs_nnx/guides/surgery.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
"source": [
"# Model surgery\n",
"\n",
"> **Attention**: This page relates to the new Flax NNX API.\n",
"In this guide, you will learn how to perform model surgery in Flax NNX using several real-world scenarios:\n",
"\n",
"In this guide you will learn how to do model surgery with Flax NNX with several real-scenario use cases:\n",
"* __Pythonic `nnx.Module` manipulation__: Using Pythonic ways to manipulate sub-`Module`s given a model.\n",
"\n",
"* __Pythonic module manipulation__: Pythonic ways to manipulate sub-modules given a model.\n",
"* __Manipulation of an abstract model or state__: A key trick for playing with `flax.nnxModule`s and states without memory allocation.\n",
"\n",
"* __Manipulating an abstract model or state__: A key trick to play with Flax NNX modules and states without memory allocation.\n",
"\n",
"* __Checkpoint surgery: From a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code.\n",
"* __Checkpoint surgery from a raw state to model__: How to manipulate parameter states when they are incompatible with existing model code.\n",
"\n",
"* __Partial initialization__: How to initialize only a part of the model from scratch using a naive method or a memory-efficient method."
]
Expand Down Expand Up @@ -63,11 +61,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pythonic module manipulation\n",
"## Pythonic `nnx.Module` manipulation\n",
"\n",
"It is easier to perform model surgery when:\n",
"\n",
"Doing model surgery is easiest when you already have a fully fleshed-out model loaded with correct parameters, and you don't intend to change your model definition code.\n",
"1) You already have a fully fleshed-out model loaded with correct parameters; and\n",
"2) You don't intend to change your model definition code.\n",
"\n",
"You can perform a variety of Pythonic operations on its sub-modules, such as sub-module swapping, module sharing, variable sharing, and monkey-patching:"
"You can perform a variety of Pythonic operations on its sub-`Module`s, such as sub-`Module` swapping, `Module` sharing, variable sharing, and monkey-patching:"
]
},
{
Expand All @@ -80,25 +81,25 @@
"x = jax.random.normal(jax.random.key(42), (3, 4))\n",
"np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))\n",
"\n",
"# Sub-module swapping\n",
"# Sub-`Module` swapping.\n",
"original1, original2 = model.linear1, model.linear2\n",
"model.linear1, model.linear2 = model.linear2, model.linear1\n",
"np.testing.assert_allclose(model(x), original1(original2(x)))\n",
"\n",
"# Module sharing (tying all weights)\n",
"# `Module` sharing (tying all weights together).\n",
"model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
"model.linear2 = model.linear1\n",
"assert not hasattr(nnx.state(model), 'linear2')\n",
"np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))\n",
"\n",
"# Variable sharing (weight-tying)\n",
"# Variable sharing (weight-tying).\n",
"model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
"model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate\n",
"assert hasattr(nnx.state(model), 'linear2')\n",
"assert hasattr(nnx.state(model)['linear2'], 'bias')\n",
"assert not hasattr(nnx.state(model)['linear2'], 'kernel')\n",
"\n",
"# Monkey-patching\n",
"# Monkey-patching.\n",
"model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
"def awesome_layer(x): return x\n",
"model.linear2 = awesome_layer\n",
Expand All @@ -111,13 +112,14 @@
"source": [
"## Creating an abstract model or state without memory allocation\n",
"\n",
"For more complex model surgery, a key technique is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints.\n",
"To do more complex model surgery, the key technique you can use is creating and manipulating an abstract model or state without allocating any real parameter data. This makes trial iteration faster and removes any concern on memory constraints.\n",
"\n",
"To create an abstract model:\n",
"\n",
"To create an abstract model,\n",
"* Create a function that returns a valid Flax NNX model; and\n",
"* Run `nnx.eval_shape` (not `jax.eval_shape`) upon it.\n",
"\n",
"Now you can use `nnx.split` as usual to get its abstract state. Note that all the fields that should be `jax.Array` in a real model are now an abstract `jax.ShapeDtypeStruct` with only shape/dtype/sharding information."
"Now you can use `nnx.split` as usual to get its abstract state. Note that all fields that should be `jax.Array`s in a real model are now of an abstract `jax.ShapeDtypeStruct` type with only shape/dtype/sharding information."
]
},
{
Expand Down Expand Up @@ -164,7 +166,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"When you fill every `VariableState` leaf's `value`s with real jax arrays, the abstract model becomes equivalent to a real model."
"When you fill every `nnx.VariableState` pytree leaf's `value`s with real `jax.Array`s, the abstract model becomes equivalent to a real model."
]
},
{
Expand All @@ -188,9 +190,11 @@
"source": [
"## Checkpoint surgery\n",
"\n",
"With the abstract state technique in hand, you can do arbitrary manipulation on any checkpoint (or runtime parameter pytree) to make them fit with your given model code, then call `nnx.update` to merge them.\n",
"With the abstract state technique in hand, you can perform arbitrary manipulation on any checkpoint - or runtime parameter pytree - to make them fit with your given model code, and then call `nnx.update` to merge them.\n",
"\n",
"This can be helpful if you are trying to significantly change the model code - for example, when migrating from Flax Linen to Flax NNX - and old weights are no longer naturally compatible.\n",
"\n",
"This can be helpful when you are trying to change model code significantly (for example, when migrating from Flax Linen to Flax NNX), and old weights are no longer naturally compatible. Let's run a simple example here:"
"Let's run a simple example here:"
]
},
{
Expand All @@ -209,7 +213,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In this new model, the sub-modules are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure changed, it's impossible to load the old checkpoint with the new model state structure:"
"In this new model, the sub-`Module`s are renamed from `linear(1|2)` to `layer(1|2)`. Since the pytree structure has changed, it is impossible to load the old checkpoint with the new model state structure:"
]
},
{
Expand Down Expand Up @@ -247,7 +251,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"But you can load the parameter tree as a raw dictionary, make the renames, and generate a new state that is guaranteed to be compatible with your new model definition."
"However, you can load the parameter pytree as a raw dictionary, perform the renames, and generate a new state that is guaranteed to be compatible with your new model definition."
]
},
{
Expand Down Expand Up @@ -283,7 +287,7 @@
"source": [
"def process_raw_dict(raw_state_dict):\n",
" flattened = nnx.traversals.flatten_mapping(raw_state_dict)\n",
" # Cut off the '.value' postfix on every leaf path.\n",
" # Cut the '.value' postfix on every leaf path.\n",
" flattened = {(path[:-1] if path[-1] == 'value' else path): value\n",
" for path, value in flattened.items()}\n",
" return nnx.traversals.unflatten_mapping(flattened)\n",
Expand All @@ -309,7 +313,10 @@
"source": [
"## Partial initialization\n",
"\n",
"In some cases (such as with LoRA), you may want to randomly-initialize only *part of* your model parameters. This can be achieved through naive partial initialization or memory-efficient partial initialization."
"In some cases - such as with LoRA (Low-Rank Adaption) - you may want to randomly-initialize only *part of* your model parameters. This can be achieved through:\n",
"\n",
"- Naive partial initialization; or\n",
"- Memory-efficient partial initialization."
]
},
{
Expand All @@ -318,9 +325,9 @@
"source": [
"### Naive partial initialization\n",
"\n",
"You can simply initialize the whole model, then swap pre-trained parameters in. But this approach could allocate additional memory midway, if your modification requires re-creating module parameters that you will later discard. See this example below.\n",
"To do naive partial initialization, you can just initialize the whole model, then swap the pre-trained parameters in. However, this approach may allocate additional memory midway if your modification requires re-creating module parameters that you will later discard. Below is an example of this.\n",
"\n",
"> Note: You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single notebook cell multiple times (due to garbage-collecting old python variables), but restarting the kernel and running from scratch will always yield same output."
"> **Note:** You can use `jax.live_arrays()` to check all the arrays live in memory at any given time. This call can be messed up when you run a single Jupyter notebook cell multiple times (because of old garbage-collecting Python variables). However, restarting the Python kernel in the notebook and running the code from scratch will always yield the same output."
]
},
{
Expand All @@ -344,8 +351,8 @@
"\n",
"simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))\n",
"print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')\n",
"# On this line, extra kernel and bias is created inside the new LoRALinear!\n",
"# They are wasted since you are going to use the kernel and bias in `old_state` anyway.\n",
"# In this line, extra kernel and bias is created inside the new LoRALinear!\n",
"# They are wasted, because you are going to use the kernel and bias in `old_state` anyway.\n",
"simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))\n",
"print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'\n",
" ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')\n",
Expand All @@ -360,7 +367,7 @@
"source": [
"### Memory-efficient partial initialization\n",
"\n",
"Use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized:"
"To do memory-efficient partial initialization, use `nnx.jit`'s efficiently compiled code to make sure only the state parameters you need are initialized:"
]
},
{
Expand Down Expand Up @@ -391,10 +398,10 @@
" nnx.update(model, old_state)\n",
" return model\n",
"\n",
"print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')\n",
"print(f'Number of JAX Arrays in memory at start: {len(jax.live_arrays())}')\n",
"# Note that `old_state` will be deleted after this `partial_init` call.\n",
"good_model = partial_init(old_state, nnx.Rngs(42))\n",
"print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'\n",
"print(f'Number of JAX Arrays in memory at end: {len(jax.live_arrays())}'\n",
" ' (2 new created - lora_a and lora_b)')"
]
},
Expand Down
Loading

0 comments on commit 8ca8131

Please sign in to comment.