From ac3fafbae2afb304884c4535063552b625e61a51 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Thu, 3 Oct 2024 23:22:57 +0000 Subject: [PATCH] Nitting and adding links --- docs_nnx/guides/flax_gspmd.ipynb | 12 ++++++------ docs_nnx/guides/flax_gspmd.md | 12 ++++++------ docs_nnx/nnx_basics.ipynb | 10 +++++----- docs_nnx/nnx_basics.md | 10 +++++----- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/docs_nnx/guides/flax_gspmd.ipynb b/docs_nnx/guides/flax_gspmd.ipynb index 5ea92a675d..168ae2bdb1 100644 --- a/docs_nnx/guides/flax_gspmd.ipynb +++ b/docs_nnx/guides/flax_gspmd.ipynb @@ -18,7 +18,7 @@ "\n", "Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s.\n", "\n", - "JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and `jax.jit` will automatically compile and run it on multiple devices.\n", + "JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices.\n", "\n", "To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information.\n", "\n", @@ -26,10 +26,10 @@ "\n", "If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials:\n", "\n", - "- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with `jax.jit`, semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map).\n", + "- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map).\n", "- [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html).\n", - "- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with `jax.jit` and `jax.lax.with_sharding_constraint`. Study it after the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html).\n", - "- [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html): Another more in-depth doc that follows the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html)." + "- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html). Study it after the [101](https://jax.readthedocs.io/en/latest/sharded-computation.html).\n", + "- [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html): Another more in-depth doc that follows the [101](https://jax.readthedocs.io/en/latest/sharded-computation.html)." ] }, { @@ -525,7 +525,7 @@ "\n", "Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. In the example below even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`.\n", "\n", - "> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happened at low level." + "> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low level." ] }, { @@ -850,7 +850,7 @@ "\n", " * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming.\n", "\n", - " * Shardings of intermediate *activation* values can only be done via `jax.lax.with_sharding_constraint` and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.\n", + " * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing.\n", "\n", "* **Logical naming**: Helpful if you want to experiment around and find the most optimal partition layout for your *model weights*." ] diff --git a/docs_nnx/guides/flax_gspmd.md b/docs_nnx/guides/flax_gspmd.md index 5a9c1f6a32..342774d945 100644 --- a/docs_nnx/guides/flax_gspmd.md +++ b/docs_nnx/guides/flax_gspmd.md @@ -18,7 +18,7 @@ This guide demonstrates how to scale up [Flax NNX `Module`s](https://flax-nnx.re Flax relies on [JAX](https://jax.readthedocs.io) for numeric computations and scaling the computations up across multiple devices, such as GPU and TPUs. At the core of scaling up is the [JAX just-in-time compiler `jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html). Throughout this guide, you will be using Flax’s own [`nnx.jit`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) which wraps around [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and works more conveniently with Flax NNX `Module`s. -JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and `jax.jit` will automatically compile and run it on multiple devices. +JAX compilation follows the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm. This means you write Python code as if it runs only on one device, and [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) will automatically compile and run it on multiple devices. To ensure the compilation performance, you often need to instruct JAX how your model's variables need to be sharded across devices. This is where Flax NNX's Sharding Metadata API - [`flax.nnx.spmd`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html) - comes in. It helps you annotate your model variables with this information. @@ -26,10 +26,10 @@ To ensure the compilation performance, you often need to instruct JAX how your m If you are new parallelization in JAX, you can learn more about its APIs for scaling up in the following tutorials: -- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with `jax.jit`, semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map). +- [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html): A 101 level tutorial covering the basics of automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), semi-automatic parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html), and manual sharding with [`shard_map`](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html#jax.experimental.shard_map.shard_map). - [JAX in multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html). -- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with `jax.jit` and `jax.lax.with_sharding_constraint`. Study it after the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html). -- [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html): Another more in-depth doc that follows the [101]([Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html). +- [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html): A more detailed tutorial about parallelization with [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) and [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html). Study it after the [101](https://jax.readthedocs.io/en/latest/sharded-computation.html). +- [Manual parallelism with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html): Another more in-depth doc that follows the [101](https://jax.readthedocs.io/en/latest/sharded-computation.html). +++ @@ -243,7 +243,7 @@ Now, from initialization or from checkpoint, you have a sharded model. To carry Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without JIT compilation. In the example below even without [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) on the output `y`, it was still sharded as `('data', None)`. -> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happened at low level. +> If you are interested in why: The second matmul of `DotReluDot.__call__` has two inputs of sharding `('data', 'model')` and `('model', None)`, in which both inputs' contraction axis are `model`. So a reduce-scatter matmul happened and will naturally shard the output as `('data', None)`. Check out the [JAX shard map collective guide](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#example-2-psum-scatter-the-result) and its examples if you want to learn mathematically how it happens at a low level. ```{code-cell} ipython3 # In data parallelism, the first dimension (batch) will be sharded on `data` axis. @@ -384,6 +384,6 @@ Choosing when to use a device or logical axis depends on how much you want to co * For a simpler model, this can save you a few extra lines of code of converting the logical naming back to the device naming. - * Shardings of intermediate *activation* values can only be done via `jax.lax.with_sharding_constraint` and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing. + * Shardings of intermediate *activation* values can only be done via [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) and device mesh axis. So if you want super fine-grained control over your model's sharding, directly using device mesh axis names everywhere might be less confusing. * **Logical naming**: Helpful if you want to experiment around and find the most optimal partition layout for your *model weights*. diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb index 693c74a919..e326f80585 100644 --- a/docs_nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -6,13 +6,13 @@ "source": [ "# Flax basics\n", "\n", - "Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.\n", + "Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.\n", "\n", "In this guide you will learn about:\n", "\n", "- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer.\n", " - Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass).\n", - " - Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm layers.\n", + " - Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layers.\n", " - Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers.\n", "- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management.\n", " - [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers.\n", @@ -64,7 +64,7 @@ "\n", "Let's begin by creating a `Linear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The following code shows that:\n", "\n", - "- Dynamic state is usually stored in [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s, and static state (all types not handled by NNX), such as integers or strings are stored directly.\n", + "- Dynamic state is usually stored in [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s, and static state (all types not handled by NNX), such as integers or strings are stored directly.\n", "- Attributes of type [`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html) and `numpy.ndarray` are also treated as dynamic states, although storing them inside [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s, such as `Param`, is preferred.\n", "- The [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object can be used to get new unique keys based on a root PRNG key passed to the constructor." ] @@ -429,7 +429,7 @@ "\n", "Below is an example of of `StatefulLinear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) that uses the Functional API. It contains:\n", "\n", - "- Some [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)sl and\n", + "- Some [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s; and\n", "- A custom `Count()` [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type, which is used to track the integer scalar state that increases on every forward pass." ] }, @@ -478,7 +478,7 @@ "\n", "A Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) can be decomposed into [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) using the [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) function:\n", "\n", - "- [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) is a `Mapping` from strings to `Variable`s or nested [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.\n", + "- [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) is a `Mapping` from strings to [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s or nested [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.\n", "- [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) contains all the static information needed to reconstruct a [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) graph, it is analogous to [JAX's `PyTreeDef`](https://jax.readthedocs.io/en/latest/pytrees.html#internal-pytree-handling)." ] }, diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md index c0d37cd98e..ba422da013 100644 --- a/docs_nnx/nnx_basics.md +++ b/docs_nnx/nnx_basics.md @@ -10,13 +10,13 @@ jupytext: # Flax basics -Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home. +Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home. In this guide you will learn about: - The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer. - Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass). - - Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm layers. + - Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layers. - Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers. - Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management. - [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers. @@ -51,7 +51,7 @@ The main difference between the Flax[`nnx.Module`](https://flax.readthedocs.io/e Let's begin by creating a `Linear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The following code shows that: -- Dynamic state is usually stored in [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s, and static state (all types not handled by NNX), such as integers or strings are stored directly. +- Dynamic state is usually stored in [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s, and static state (all types not handled by NNX), such as integers or strings are stored directly. - Attributes of type [`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html) and `numpy.ndarray` are also treated as dynamic states, although storing them inside [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s, such as `Param`, is preferred. - The [`nnx.Rngs`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.Rngs) object can be used to get new unique keys based on a root PRNG key passed to the constructor. @@ -250,7 +250,7 @@ The Flax NNX Functional API establishes a clear boundary between reference/objec Below is an example of of `StatefulLinear` [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) that uses the Functional API. It contains: -- Some [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)sl and +- Some [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s; and - A custom `Count()` [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type, which is used to track the integer scalar state that increases on every forward pass. ```{code-cell} ipython3 @@ -276,7 +276,7 @@ nnx.display(model) A Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) can be decomposed into [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) using the [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split) function: -- [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) is a `Mapping` from strings to `Variable`s or nested [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. +- [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) is a `Mapping` from strings to [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)s or nested [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s. - [`nnx.GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef) contains all the static information needed to reconstruct a [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) graph, it is analogous to [JAX's `PyTreeDef`](https://jax.readthedocs.io/en/latest/pytrees.html#internal-pytree-handling). ```{code-cell} ipython3