Skip to content

Commit

Permalink
Upgrade Flax NNX Haiku Linen migration doc
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Sep 16, 2024
1 parent 03e034d commit b69a7d7
Showing 1 changed file with 41 additions and 57 deletions.
98 changes: 41 additions & 57 deletions docs_nnx/guides/haiku_linen_vs_nnx.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

Migrating from Haiku/Linen to NNX
=================================
Migrate to Flax NNX from Haiku/Flax Linen
=========================================

This guide will showcase the differences between Haiku, Flax Linen and Flax NNX.
Both Haiku and Linen enforce a functional paradigm with stateless modules,
while NNX is a new, next-generation API that embraces the python language to
provide a more intuitive development experience.
This guide demonstrates the differences between the Flax NNX API (the next-generation Flax API),
Haiku, and Flax Linen. Both Haiku and Flax Linen enforce a functional paradigm with stateless
`Module`s, while Flax NNX embraces the Python language to provide a more intuitive development
experience.

.. testsetup:: Haiku, Linen, NNX

Expand All @@ -21,18 +21,19 @@ provide a more intuitive development experience.
# TODO: make sure code lines are not too long
# TODO: make sure all code diffs are aligned

Basic Example
-----------------
A basic example
---------------

To create custom Modules you subclass from a ``Module`` base class in
both Haiku and Flax. Modules can be defined inline in Haiku and Flax
Linen (using the ``@nn.compact`` decorator), whereas modules can't be
defined inline in NNX and must be defined in ``__init__``.
To create custom `Module`s, in both Haiku and Flax you subclass from the ``Module``
base class. Note that:

Linen requires a ``deterministic`` argument to control whether or
not dropout is used. NNX also uses a ``deterministic`` argument
but the value can be set later using ``.eval()`` and ``.train()`` methods
that will be shown in a later code snippet.
- In Haiku and Flax Linen, `Module`s can be defined inline using the ``@nn.compact`` decorator).
- In Flax NNX, `Module`s can't be defined inline in NNX but instead must be defined in ``__init__``.

In addition:

- Flax Linen requires a ``deterministic`` argument to control whether or not dropout is used.
- Flax NNX also uses a ``deterministic`` argument but the value can be set later using ``.eval()`` and ``.train()`` methods that will be shown in a later code snippet.

.. codediff::
:title: Haiku, Linen, NNX
Expand Down Expand Up @@ -114,11 +115,8 @@ that will be shown in a later code snippet.
x = self.linear(x)
return x

Since modules are defined inline in Haiku and Linen, the parameters
are lazily initialized, by inferring the shape of a sample input. In Flax
NNX, the module is stateful and is initialized eagerly. This means that the
input shape must be explicitly passed during module instantiation since there
is no shape inference in NNX.
- In Haiku and Flax Linen, since `Module`s are defined inline, the parameters are lazily initialized by inferring the shape of a sample input.
- In Flax NNX, the `Module` is stateful and is initialized eagerly. This means that the input shape must be explicitly passed during module instantiation since there is no shape inference in NNX.

.. codediff::
:title: Haiku, Linen, NNX
Expand Down Expand Up @@ -146,11 +144,11 @@ is no shape inference in NNX.
To get the model parameters in both Haiku and Linen, you use the ``init`` method
with a ``random.key`` plus some inputs to run the model.

In NNX, the model parameters are automatically initialized when the user
In Flax NNX, the model parameters are automatically initialized when the user
instantiates the model because the input shapes are already explicitly passed at
instantiation time.

Since NNX is eager and the module is bound upon instantiation, the user can access
Since Flax NNX is eager and the module is bound upon instantiation, the user can access
the parameters (and other fields defined in ``__init__`` via dot-access). On the other
hand, Haiku and Linen use lazy initialization and so the parameters can only be accessed
once the module is initialized with a sample input and both frameworks do not support
Expand Down Expand Up @@ -193,16 +191,9 @@ dot-access of their attributes.
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)

Let's take a look at the parameter structure. In Haiku and Linen, we can
simply inspect the ``params`` object returned from ``.init()``.

To see the parameter structure in NNX, the user can call ``nnx.split`` to
generate ``Graphdef`` and ``State`` objects. The ``Graphdef`` is a static pytree
denoting the structure of the model (for example usages, see
`NNX Basics <https://flax.readthedocs.io/en/latest/experimental/nnx/nnx_basics.html>`__).
``State`` objects contains all the module variables (i.e. any class that sub-classes
``nnx.Variable``). If we filter for ``nnx.Param``, we will generate a ``State`` object
of all the learnable module parameters.
Let's take a look at the parameter structure:
- In Haiku and Flax Linen, you would simply inspect the ``params`` object returned from ``.init()``.
- In Flax NNX, to view the parameter structure, you can call ``nnx.split`` to generate ``Graphdef`` and ``State`` objects. The ``Graphdef`` is a static pytree denoting the structure of the model (for example usages, check the `NNX basics guide <https://flax.readthedocs.io/en/latest/experimental/nnx/nnx_basics.html>`__). ``State`` objects contains all the module variables (i.e. any class that subclasses ``nnx.Variable``). If we filter for ``nnx.Param``, we will generate a ``State`` object of all the learnable module parameters.

.. tab-set::

Expand Down Expand Up @@ -271,22 +262,13 @@ of all the learnable module parameters.
}
})
During training in Haiku and Linen, you pass the parameters structure to the
``apply`` method to run the forward pass. To use dropout, we must pass in
``training=True`` and provide a ``key`` to ``apply`` in order to generate the
random dropout masks. To use dropout in NNX, we first call ``model.train()``,
which will set the dropout layer's ``deterministic`` attribute to ``False``
(conversely, calling ``model.eval()`` would set ``deterministic`` to ``True``).
Since the stateful NNX module already contains both the parameters and RNG key
(used for dropout), we simply need to call the module to run the forward pass. We
use ``nnx.split`` to extract the learnable parameters (all learnable parameters
subclass the NNX class ``nnx.Param``) and then apply the gradients and statefully
update the model using ``nnx.update``.

To compile ``train_step``, we decorate the function using ``@jax.jit`` for Haiku
and Linen, and ``@nnx.jit`` for NNX. Similar to ``@jax.jit``, ``@nnx.jit`` also
compiles functions, with the additional feature of allowing the user to compile
functions that take in NNX modules as arguments.
During training:
- in Haiku and Flax Linen, you pass the parameters structure to the ``apply`` method to run the forward pass. To use dropout, we must pass in ``training=True`` and provide a ``key`` to ``apply`` to generate the random dropout masks. To use dropout in NNX, we first call ``model.train()``, which will set the dropout layer's ``deterministic`` attribute to ``False`` (conversely, calling ``model.eval()`` would set ``deterministic`` to ``True``).
- In Flax NNX, since the stateful NNX `Module` already contains both the parameters and RNG key (used for dropout), you simply need to call the module to run the forward pass. You use ``nnx.split`` to extract the learnable parameters (all learnable parameters subclass the NNX class ``nnx.Param``) and then apply the gradients and statefully update the model using ``nnx.update``.

To compile the ``train_step``:
- Haiku and Flax Linen, you decorate the function using ``@jax.jit``.
- In Flax Linen, you decorate it with ``@nnx.jit``. Similar to ``@jax.jit``, ``@nnx.jit`` also compiles functions, with the additional feature of allowing the user to compile functions that take in NNX modules as arguments.

.. codediff::
:title: Haiku, Linen, NNX
Expand Down Expand Up @@ -379,7 +361,7 @@ One thing to note is that ``GraphDef.apply`` will take in ``State``'s as argumen
return a callable function. This function can be called on the inputs to output the
model's logits, as well as updated ``GraphDef`` and ``State`` objects. This isn't needed
for our current example with dropout, but in the next section, you will see that using
these updated objects are relevant with layers like batch norm. Notice we also use
these updated objects are relevant with layers like batch normalization. Notice the use of
``@jax.jit`` since we aren't passing in NNX modules into ``train_step``.

.. codediff::
Expand Down Expand Up @@ -1402,13 +1384,13 @@ carry to get the forward pass output.
y, _ = self.blocks(x, None)
return y

Notice how in Flax we pass ``None`` as the second argument to ``ScanBlock`` and ignore
its second output. These represent the inputs/outputs per-step but they are ``None``
because in this case we don't have any.
Notice how in Flax, you pass ``None`` as the second argument to ``ScanBlock`` and ignore
its second output. These normally represent the inputs/outputs per-step but here they are
``None`` because in this case you don't have any.

Initializing each model is the same as in previous examples. In this case,
we will be specifying that we want to use ``5`` layers each with ``64`` features.
As before, we also pass in the input shape for NNX.
Initializing each model is the same as in previous examples. In this example,
you will specify that you want to use ``5`` layers each with ``64`` features.
As before, you also pass in the input shape for NNX.

.. codediff::
:title: Haiku, Linen, NNX
Expand Down Expand Up @@ -1594,4 +1576,6 @@ be set and accessed as normal using regular Python class semantics.

model = FooModule(rngs=nnx.Rngs(0))

_, params, counter = nnx.split(model, nnx.Param, Counter)
_, params, counter = nnx.split(model, nnx.Param, Counter)


0 comments on commit b69a7d7

Please sign in to comment.