From 59e037f0ed2a8ddd12bced4efea5b21bc1d1fe27 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Thu, 11 Apr 2024 11:56:05 -0700 Subject: [PATCH] updated haiku guide --- docs/_ext/codediff.py | 143 ++++++++++++++---- docs/_ext/codediff_test.py | 96 +++++++++--- docs/experimental/nnx/transforms.rst | 10 +- .../haiku_migration_guide.rst | 62 +++----- .../linen_upgrade_guide.rst | 41 +++-- .../optax_update_guide.rst | 26 ++-- .../orbax_upgrade_guide.rst | 20 +-- .../regular_dict_upgrade_guide.rst | 14 +- .../rnncell_upgrade_guide.rst | 22 +-- .../flax_fundamentals/setup_or_nncompact.rst | 5 +- .../extracting_intermediates.rst | 14 +- docs/guides/parallel_training/ensembling.rst | 14 +- .../guides/training_techniques/batch_norm.rst | 27 ++-- docs/guides/training_techniques/dropout.rst | 22 +-- .../training_techniques/lr_schedule.rst | 15 +- tests/run_all_tests.sh | 1 + 16 files changed, 311 insertions(+), 221 deletions(-) diff --git a/docs/_ext/codediff.py b/docs/_ext/codediff.py index 0d2bcd27c1..a459137428 100644 --- a/docs/_ext/codediff.py +++ b/docs/_ext/codediff.py @@ -17,8 +17,7 @@ Use directive as follows: .. codediff:: - :title_left: - :title_right: + :title: , --- @@ -26,7 +25,8 @@ In order to highlight a line of code, append "#!" to it. """ -from typing import List, Tuple + +from typing import List, Optional, Tuple import sphinx from docutils import nodes @@ -40,29 +40,103 @@ class CodeDiffParser: def parse( self, - lines, - title_left='Base', - title_right='Diff', - code_sep='---', - sync=MISSING, + lines: List[str], + title: str, + groups: Optional[List[str]] = None, + skip_test: Optional[str] = None, + code_sep: str = '---', + sync: object = MISSING, ): - sync = sync is not MISSING + """Parse the code diff block and format it so that it + renders in different tabs and is tested by doctest. + + For example: + + .. testcode:: tab0, tab2, tab3 + + + + .. codediff:: + :title: Tab 0, Tab 1, Tab 2, Tab 3 + :groups: tab0, tab1, tab2, tab3 + :skip_test: tab1, tab3 + + + + --- + + + + --- - if code_sep not in lines: + + + --- + + + + For group tab0: and are executed. + For group tab1: Nothing is executed. + For group tab2: and are executed. + For group tab3: is executed. + + Arguments: + lines: a string list, where each element is a single string code line + title: a single string that contains the titles of each tab (they should + be separated by commas) + groups: a single string that contains the group of each tab (they should + be separated by commas). Code snippets that are part of the same group + will be executed together. If groups=None, then the group names will + default to the tab title names. + skip_test: a single string denoting which group(s) to skip testing (they + should be separated by commas). This is useful for legacy code snippets + that no longer run correctly anymore. If skip_test=None, then no tests + are skipped. + code_sep: the separator character(s) used to denote a separate code block + for a new tab. The default code separator is '---'. + sync: an option for Sphinx directives, that will sync all tabs together. + This means that if the user clicks to switch to another tab, all tabs + will switch to the new tab. + """ + titles = [t.strip() for t in title.split(',')] + num_tabs = len(titles) + + sync = sync is not MISSING + # skip legacy code snippets in upgrade guides + if skip_test is not None: + skip_tests = set([index.strip() for index in skip_test.split(',')]) + else: + skip_tests = set() + + code_blocks = '\n'.join(lines) + if code_blocks.count(code_sep) != num_tabs - 1: + raise ValueError( + f'Expected {num_tabs-1} code separator(s) for {num_tabs} tab(s), but got {code_blocks.count(code_sep)} code separator(s) instead.' + ) + code_blocks = [ + code_block.split('\n') + for code_block in code_blocks.split(code_sep + '\n') + ] # list[code_tab_list1[string_line1, ...], ...] + + # by default, put each code snippet in a different group denoted by an index number, to be executed separately + if groups is not None: + groups = [group_name.strip() for group_name in groups.split(',')] + else: + groups = titles + if len(groups) != num_tabs: raise ValueError( - 'Code separator not found! Code snippets should be ' - f'separated by {code_sep}.' + f'Expected {num_tabs} group assignment(s) for {num_tabs} tab(s), but got {len(groups)} group assignment(s) instead.' ) - idx = lines.index(code_sep) - code_left = self._code_block(lines[0:idx]) - test_code = lines[idx + 1 :] - code_right = self._code_block(test_code) - output = self._tabs( - (title_left, code_left), (title_right, code_right), sync=sync - ) + tabs = [] + test_codes = [] + for i, code_block in enumerate(code_blocks): + if groups[i] not in skip_tests: + test_codes.append((code_block, groups[i])) + tabs.append((titles[i], self._code_block(code_block))) + output = self._tabs(*tabs, sync=sync) - return output, test_code + return output, test_codes def _code_block(self, lines): """Creates a codeblock.""" @@ -99,14 +173,15 @@ def _tabs(self, *contents: Tuple[str, List[str]], sync): class CodeDiffDirective(SphinxDirective): has_content = True option_spec = { - 'title_left': directives.unchanged, - 'title_right': directives.unchanged, + 'title': directives.unchanged, + 'groups': directives.unchanged, + 'skip_test': directives.unchanged, 'code_sep': directives.unchanged, 'sync': directives.flag, } def run(self): - table_code, test_code = CodeDiffParser().parse( + table_code, test_codes = CodeDiffParser().parse( list(self.content), **self.options ) @@ -114,21 +189,27 @@ def run(self): # We add attribute "testnodetype" so it is be picked up by the doctest # builder. This functionality is not officially documented but can be found # in the source code: - # https://github.com/sphinx-doc/sphinx/blob/3.x/sphinx/ext/doctest.py + # https://github.com/sphinx-doc/sphinx/blob/master/sphinx/ext/doctest.py # (search for 'testnodetype'). - test_code = '\n'.join(test_code) - test_node = nodes.comment(test_code, test_code, testnodetype='testcode') - # Set the source info so the error message is correct when testing. - self.set_source_info(test_node) - test_node['options'] = {} - test_node['language'] = 'python3' + test_nodes = [] + for test_code, group in test_codes: + test_node = nodes.comment( + '\n'.join(test_code), + '\n'.join(test_code), + testnodetype='testcode', + groups=[group], + ) + self.set_source_info(test_node) + test_node['options'] = {} + test_node['language'] = 'python3' + test_nodes.append(test_node) # The table node is the side-by-side diff view that will be shown on RTD. table_node = nodes.paragraph() self.content = ViewList(table_code, self.content.parent) self.state.nested_parse(self.content, self.content_offset, table_node) - return [table_node, test_node] + return [table_node] + test_nodes def setup(app): diff --git a/docs/_ext/codediff_test.py b/docs/_ext/codediff_test.py index 30373e793a..83e15733c9 100644 --- a/docs/_ext/codediff_test.py +++ b/docs/_ext/codediff_test.py @@ -14,11 +14,11 @@ """Tests for codediff Sphinx extension.""" -from absl.testing import absltest +from absl.testing import parameterized from codediff import CodeDiffParser -class CodeDiffTest(absltest.TestCase): +class CodeDiffTest(parameterized.TestCase): def test_parse(self): input_text = r"""@jax.jit #! def get_initial_params(key): #! @@ -33,36 +33,88 @@ def get_initial_params(key): initial_params = CNN().init(key, init_val)['params'] return initial_params""" - expected_table = r"""+----------------------------------------------------------+----------------------------------------------------------+ -| Single device | Ensembling on multiple devices | -+----------------------------------------------------------+----------------------------------------------------------+ -| .. code-block:: python | .. code-block:: python | -| :emphasize-lines: 1,2 | :emphasize-lines: 1 | -| | | -| @jax.jit | @jax.pmap | -| def get_initial_params(key): | def get_initial_params(key): | -| init_val = jnp.ones((1, 28, 28, 1), jnp.float32) | init_val = jnp.ones((1, 28, 28, 1), jnp.float32) | -| initial_params = CNN().init(key, init_val)['params'] | initial_params = CNN().init(key, init_val)['params'] | -| extra_line | return initial_params | -| return initial_params | | -+----------------------------------------------------------+----------------------------------------------------------+""" + expected_table = """.. tab-set::\n \n .. tab-item:: Single device\n \n .. code-block:: python\n :emphasize-lines: 1,2\n \n @jax.jit\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n extra_line\n return initial_params\n \n .. tab-item:: Ensembling on multiple devices\n \n .. code-block:: python\n :emphasize-lines: 1\n \n @jax.pmap\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n return initial_params""" - expected_testcode = r"""@jax.pmap #! + expected_testcodes = [ + r"""@jax.jit #! +def get_initial_params(key): #! + init_val = jnp.ones((1, 28, 28, 1), jnp.float32) + initial_params = CNN().init(key, init_val)['params'] + extra_line + return initial_params +""", + r"""@jax.pmap #! def get_initial_params(key): init_val = jnp.ones((1, 28, 28, 1), jnp.float32) initial_params = CNN().init(key, init_val)['params'] - return initial_params""" + return initial_params""", + ] title_left = 'Single device' title_right = 'Ensembling on multiple devices' - actual_table, actual_testcode = CodeDiffParser().parse( + actual_table, actual_testcodes = CodeDiffParser().parse( lines=input_text.split('\n'), - title_left=title_left, - title_right=title_right, + title=f'{title_left}, {title_right}', ) + actual_table = '\n'.join(actual_table) - actual_testcode = '\n'.join(actual_testcode) + actual_testcodes = ['\n'.join(testcode) for testcode, _ in actual_testcodes] self.assertEqual(expected_table, actual_table) - self.assertEqual(expected_testcode, actual_testcode) + self.assertEqual(expected_testcodes[0], actual_testcodes[0]) + self.assertEqual(expected_testcodes[1], actual_testcodes[1]) + + @parameterized.parameters( + { + 'input_text': r"""x = 1 + --- + x = 2 +""", + 'title': 'Tab 0, Tab1, Tab2', + 'groups': None, + 'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 1 code separator\\(s\\) instead.', + }, + { + 'input_text': r"""x = 1 + --- + x = 2 + --- + x = 3 + --- + x = 4 +""", + 'title': 'Tab 0, Tab1, Tab2', + 'groups': None, + 'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 3 code separator\\(s\\) instead.', + }, + { + 'input_text': r"""x = 1 + --- + x = 2 + --- + x = 3 +""", + 'title': 'Tab 0, Tab1, Tab2', + 'groups': 'tab0, tab2', + 'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 2 group assignment\\(s\\) instead.', + }, + { + 'input_text': r"""x = 1 + --- + x = 2 + --- + x = 3 +""", + 'title': 'Tab 0, Tab1, Tab2', + 'groups': 'tab0, tab1, tab2, tab3', + 'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 4 group assignment\\(s\\) instead.', + }, + ) + def test_parse_errors(self, input_text, title, groups, error_msg): + with self.assertRaisesRegex(ValueError, error_msg): + _, _ = CodeDiffParser().parse( + lines=input_text.split('\n'), + title=title, + groups=groups, + ) diff --git a/docs/experimental/nnx/transforms.rst b/docs/experimental/nnx/transforms.rst index 1e6bee1c59..9f35afcc26 100644 --- a/docs/experimental/nnx/transforms.rst +++ b/docs/experimental/nnx/transforms.rst @@ -7,7 +7,7 @@ seamlessly switch between them or use them together. We will be focusing on the First, let's set up imports and generate some dummy data: -.. testcode:: +.. testcode:: NNX, JAX from flax.experimental import nnx import jax @@ -38,8 +38,8 @@ whereas the function signature of JAX-transformed functions can only accept the the transformed function. .. codediff:: - :title_left: NNX transforms - :title_right: JAX transforms + :title: NNX transforms, JAX transforms + :groups: NNX, JAX :sync: @nnx.jit @@ -83,8 +83,8 @@ NNX and JAX transformations can be mixed together, so long as the JAX-transforme pure and has valid argument types that are recognized by JAX. .. codediff:: - :title_left: Using ``nnx.jit`` with ``jax.grad`` - :title_right: Using ``jax.jit`` with ``nnx.grad`` + :title: Using ``nnx.jit`` with ``jax.grad``, Using ``jax.jit`` with ``nnx.grad`` + :groups: NNX, JAX :sync: @nnx.jit diff --git a/docs/guides/converting_and_upgrading/haiku_migration_guide.rst b/docs/guides/converting_and_upgrading/haiku_migration_guide.rst index e3eb1a032c..ce93f62eb4 100644 --- a/docs/guides/converting_and_upgrading/haiku_migration_guide.rst +++ b/docs/guides/converting_and_upgrading/haiku_migration_guide.rst @@ -5,7 +5,7 @@ Migrating from Haiku to Flax This guide will walk through the process of migrating Haiku models to Flax, and highlight the differences between the two libraries. -.. testsetup:: +.. testsetup:: Haiku, Flax import jax import jax.numpy as jnp @@ -25,8 +25,7 @@ whereas in Haiku ``name`` must be explicitly defined in the constructor signature and passed to the superclass constructor. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: import haiku as hk @@ -89,8 +88,7 @@ that calls your Module, ``transform`` will return an object with ``init`` and ``apply`` methods. In Flax, you simply instantiate your Module. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: def forward(x, training: bool): @@ -112,8 +110,7 @@ that Flax returns a mapping from collection names to nested array dictionaries, structure directly. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: sample_x = jax.numpy.ones((1, 784)) @@ -184,8 +181,7 @@ both cases we must provide a ``key`` to ``apply`` in order to generate the random dropout masks. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: def train_step(key, params, inputs, labels): @@ -218,7 +214,7 @@ the random dropout masks. return params -.. testcode:: +.. testcode:: Haiku, Flax :hide: train_step(random.key(0), params, sample_x, jnp.ones((1,), dtype=jnp.int32)) @@ -236,8 +232,7 @@ Now let's see how mutable state is handled in both libraries. We will take the same model as before, but now we will replace Dropout with BatchNorm. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: class Block(hk.Module): @@ -278,8 +273,7 @@ which changes the signature for ``init`` and ``apply`` to accept and return state. As before, in Flax you construct the Module directly. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: def forward(x, training: bool): @@ -304,8 +298,7 @@ of a Haiku model with an ``hk.BatchNorm`` layer. In Flax, we can set ``training=False`` as usual. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: sample_x = jax.numpy.ones((1, 784)) @@ -340,8 +333,7 @@ input dictionary, and get the ``updates`` variables dictionary as the second return value. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: def train_step(params, state, inputs, labels): @@ -375,7 +367,7 @@ return value. return params, batch_stats -.. testcode:: +.. testcode:: Flax :hide: train_step(params, batch_stats, sample_x, jnp.ones((1,), dtype=jnp.int32)) @@ -406,8 +398,7 @@ In Flax, we will define an ``encoder`` and a ``decoder`` Module ahead of time in ``setup``, and use them in the ``encode`` and ``decode`` respectively. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: class AutoEncoder(hk.Module): @@ -459,8 +450,7 @@ function passed to ``multi_transform`` defines how to initialize the module and different apply methods to generate. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: def forward(): @@ -485,8 +475,7 @@ To initialize the parameters of our model, ``init`` can be used to trigger the method. This will create all the necessary parameters for the model. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: params = model.init( @@ -543,8 +532,7 @@ This generates the following parameter structure. Finally, let's explore how we can employ the ``apply`` function to invoke the ``encode`` method: .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: encode, decode = model.apply @@ -593,8 +581,7 @@ method will be a function that takes the carry and input, and returns the new carry and output. In this case, the carry and the output are the same. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: class RNNCell(hk.Module): @@ -640,8 +627,7 @@ to create an instance of the lifted ``RNNCell`` and use it to create the ``carry the run the ``__call__`` method which will ``scan`` over the sequence. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: class RNN(hk.Module): @@ -680,8 +666,7 @@ according to the transform's semantics. Finally, let's quickly view how the ``RNN`` Module would be used in both Haiku and Flax. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: def forward(x): @@ -740,8 +725,7 @@ we are telling ``nn.scan`` create different parameters for each step and slice t .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: class Block(hk.Module): @@ -762,7 +746,7 @@ we are telling ``nn.scan`` create different parameters for each step and slice t self.num_layers = num_layers def __call__(self, x, training: bool): - @hk.experimental.layer_stack(self.num_layers) + @hk.experimental.layer_stack(self.num_layers) def stack_block(x): return Block(self.features)(x, training) @@ -803,8 +787,7 @@ Initializing each model is the same as in previous examples. In this case, we will be specifying that we want to use ``5`` layers each with ``64`` features. .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: def forward(x, training: bool): @@ -876,8 +859,7 @@ In Haiku, it is possible to write the entire model as a single function by using The Flax team recommends a more Module-centric approach that uses `__call__` to define the forward function. The corresponding accessor will be `nn.module.param` and `nn.module.variable` (go to `Handling State <#handling-state>`__ for an explanaion on collections). .. codediff:: - :title_left: Haiku - :title_right: Flax + :title: Haiku, Flax :sync: def forward(x): diff --git a/docs/guides/converting_and_upgrading/linen_upgrade_guide.rst b/docs/guides/converting_and_upgrading/linen_upgrade_guide.rst index 1b7eb1bd61..06bd677b3f 100644 --- a/docs/guides/converting_and_upgrading/linen_upgrade_guide.rst +++ b/docs/guides/converting_and_upgrading/linen_upgrade_guide.rst @@ -5,12 +5,13 @@ As of Flax v0.4.0, ``flax.nn`` no longer exists, and is replaced with the new Linen API at ``flax.linen``. If your codebase is still using the old API, you can use this upgrade guide to upgrade it to Linen. -.. testsetup:: +.. testsetup:: Linen from flax.training import train_state from jax import random import optax import jax + import flax.linen as nn from flax.linen import initializers from jax import lax @@ -29,8 +30,8 @@ Defining simple Flax Modules ---------------------------- .. codediff:: - :title_left: Old Flax - :title_right: Linen + :title: Old Flax, Linen + :skip_test: Old Flax :sync: from flax import nn @@ -51,8 +52,6 @@ Defining simple Flax Modules 'bias', (features,), bias_init) y = y + bias return y - - return new_state, metrics --- from flax import linen as nn # [1] #! @@ -95,8 +94,8 @@ Using Flax Modules inside other Modules --------------------------------------- .. codediff:: - :title_left: Old Flax - :title_right: Linen + :title: Old Flax, Linen + :skip_test: Old Flax :sync: class Encoder(nn.Module): @@ -127,8 +126,8 @@ Sharing submodules and defining multiple methods -------------------------------- .. codediff:: - :title_left: Old Flax - :title_right: Linen + :title: Old Flax, Linen + :skip_test: Old Flax :sync: class AutoEncoder(nn.Module): @@ -179,8 +178,8 @@ Sharing submodules and defining multiple methods --------------------------------------- .. codediff:: - :title_left: Old Flax - :title_right: Linen + :title: Old Flax, Linen + :skip_test: Old Flax :sync: # no import #! @@ -236,8 +235,8 @@ Top-level training code patterns -------------------------------- .. codediff:: - :title_left: Old Flax - :title_right: Linen + :title: Old Flax, Linen + :skip_test: Old Flax :sync: def create_model(key): @@ -306,12 +305,12 @@ Non-trainable variables ("state"): Use within Modules ----------------------------------------------------- .. codediff:: - :title_left: Old Flax - :title_right: Linen + :title: Old Flax, Linen + :skip_test: Old Flax :sync: class BatchNorm(nn.Module): - def apply(self, x, ...): + def apply(self, x): # [...] ra_mean = self.state( 'mean', (x.shape[-1], ), initializers.zeros_init()) @@ -338,8 +337,8 @@ Non-trainable variables ("state"): Top-level training code patterns ------------------------------------------------------------------- .. codediff:: - :title_left: Old Flax - :title_right: Linen + :title: Old Flax, Linen + :skip_test: Old Flax :sync: # initial params and state @@ -360,7 +359,7 @@ Non-trainable variables ("state"): Top-level training code patterns # reads immutable batch statistics during evaluation def eval_step(model, model_state, batch): - with nn.stateful(model_state, mutable=False): + with nn.stateful(model_state, mutable=False): logits = model(batch['image'], train=False) return compute_metrics(logits, batch['label']) --- @@ -422,8 +421,8 @@ Randomness ---------- .. codediff:: - :title_left: Old Flax - :title_right: Linen + :title: Old Flax, Linen + :skip_test: Old Flax :sync: def dropout(inputs, rate, deterministic=False): diff --git a/docs/guides/converting_and_upgrading/optax_update_guide.rst b/docs/guides/converting_and_upgrading/optax_update_guide.rst index e42a507a87..d531b781f9 100644 --- a/docs/guides/converting_and_upgrading/optax_update_guide.rst +++ b/docs/guides/converting_and_upgrading/optax_update_guide.rst @@ -10,7 +10,7 @@ towards :py:mod:`flax.optim` users to help them update their code to Optax. See also Optax's quick start documentation: https://optax.readthedocs.io/en/latest/getting_started.html -.. testsetup:: +.. testsetup:: default, flax.optim, optax import flax import jax @@ -42,8 +42,8 @@ optimizer state, parameters, and other associated data in a single dataclass (not used in code below). .. codediff:: - :title_left: flax.optim - :title_right: optax + :title: flax.optim, optax + :skip_test: flax.optim :sync: @jax.jit @@ -91,8 +91,8 @@ generic building blocks. .. _optax.chain(): https://optax.readthedocs.io/en/latest/api/combining_optimizers.html#optax.chain .. codediff:: - :title_left: Pre-defined alias - :title_right: Combining transformations + :title: Pre-defined alias, Combining transformations + :groups: default, default # Note that the aliases follow the convention to use positive # values for the learning rate by default. @@ -126,8 +126,8 @@ weight decay can be added as another "gradient transformation" .. _optax.add_decayed_weights(): https://optax.readthedocs.io/en/latest/api/transformations.html#optax.add_decayed_weights .. codediff:: - :title_left: flax.optim - :title_right: optax + :title: flax.optim, optax + :skip_test: flax.optim :sync: optimizer_def = flax.optim.Adam( @@ -158,8 +158,8 @@ becomes just another gradient transformation |optax.clip_by_global_norm()|_. .. _optax.clip_by_global_norm(): https://optax.readthedocs.io/en/latest/api/transformations.html#optax.clip_by_global_norm .. codediff:: - :title_left: flax.optim - :title_right: optax + :title: flax.optim, optax + :skip_test: flax.optim :sync: def train_step(optimizer, batch): @@ -202,8 +202,8 @@ learning rate schedule as a parameter for ``learning_rate``. .. _optax.inject_hyperparams(): https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.inject_hyperparams .. codediff:: - :title_left: flax.optim - :title_right: optax + :title: flax.optim, optax + :skip_test: flax.optim :sync: def train_step(step, optimizer, batch): @@ -244,8 +244,8 @@ that is not readily available outside the outer mask). .. _optax.multi_transform(): https://optax.readthedocs.io/en/latest/api/combining_optimizers.html#optax.multi_transform .. codediff:: - :title_left: flax.optim - :title_right: optax + :title: flax.optim, optax + :skip_test: flax.optim :sync: kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p) diff --git a/docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst b/docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst index 91f3cbb0f2..bf3c16f504 100644 --- a/docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst +++ b/docs/guides/converting_and_upgrading/orbax_upgrade_guide.rst @@ -17,7 +17,7 @@ You can click on "Open in Colab" above to run the code from this guide. Throughout the guide, you will be able to compare code examples with and without the Orbax code. -.. testsetup:: +.. testsetup:: orbax.checkpoint import flax from flax.training import checkpoints, orbax_utils @@ -41,7 +41,7 @@ Throughout the guide, you will be able to compare code examples with and without Setup ***** -.. testcode:: +.. testcode:: orbax.checkpoint # Create some dummy variables for this example. MAX_STEPS = 5 @@ -71,8 +71,8 @@ To upgrade your code: For example: .. codediff:: - :title_left: flax.checkpoints - :title_right: orbax.checkpoint + :title: flax.checkpoints, orbax.checkpoint + :skip_test: flax.checkpoints :sync: CKPT_DIR = '/tmp/orbax_upgrade/' @@ -119,8 +119,8 @@ To migrate to Orbax code, instead of using the ``overwrite`` argument in ``flax. For example: .. codediff:: - :title_left: flax.checkpoints - :title_right: orbax.checkpoint + :title: flax.checkpoints, orbax.checkpoint + :skip_test: flax.checkpoints :sync: PURE_CKPT_DIR = '/tmp/orbax_upgrade/pure' @@ -149,8 +149,8 @@ If you need to restore your checkpoints without a target pytree, pass ``item=Non For example: .. codediff:: - :title_left: flax.checkpoints - :title_right: orbax.checkpoint + :title: flax.checkpoints, orbax.checkpoint + :skip_test: flax.checkpoints :sync: NOTARGET_CKPT_DIR = '/tmp/orbax_upgrade/no_target' @@ -188,8 +188,8 @@ The ``orbax.checkpoint.PyTreeCheckpointHandler`` class, as the name suggests, ca For example: .. codediff:: - :title_left: flax.checkpoints - :title_right: orbax.checkpoint + :title: flax.checkpoints, orbax.checkpoint + :skip_test: flax.checkpoints :sync: ARR_CKPT_DIR = '/tmp/orbax_upgrade/singleton' diff --git a/docs/guides/converting_and_upgrading/regular_dict_upgrade_guide.rst b/docs/guides/converting_and_upgrading/regular_dict_upgrade_guide.rst index dda795f0de..dcb46b4692 100644 --- a/docs/guides/converting_and_upgrading/regular_dict_upgrade_guide.rst +++ b/docs/guides/converting_and_upgrading/regular_dict_upgrade_guide.rst @@ -24,7 +24,7 @@ To accommodate the regular dict change, replace usage of ``FrozenDict`` methods These utility functions mimic the behavior of their corresponding ``FrozenDict`` method, and can be called on either ``FrozenDicts`` or regular dicts. The following are the utility functions and example upgrade patterns: -.. testsetup:: +.. testsetup:: default, Only ``FrozenDict``, Both ``FrozenDict`` and regular dict import flax import flax.linen as nn @@ -40,8 +40,7 @@ The following are the utility functions and example upgrade patterns: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. codediff:: - :title_left: Only ``FrozenDict`` - :title_right: Both ``FrozenDict`` and regular dict + :title: Only ``FrozenDict``, Both ``FrozenDict`` and regular dict :sync: variables = variables.copy(add_or_replace={'other_variables': other_variables}) @@ -54,8 +53,7 @@ The following are the utility functions and example upgrade patterns: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. codediff:: - :title_left: Only ``FrozenDict`` - :title_right: Both ``FrozenDict`` and regular dict + :title: Only ``FrozenDict``, Both ``FrozenDict`` and regular dict :sync: state, params = variables.pop('params') @@ -68,8 +66,7 @@ The following are the utility functions and example upgrade patterns: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. codediff:: - :title_left: Only ``FrozenDict`` - :title_right: Both ``FrozenDict`` and regular dict + :title: Only ``FrozenDict``, Both ``FrozenDict`` and regular dict :sync: str_repr = variables.pretty_repr() @@ -82,8 +79,7 @@ The following are the utility functions and example upgrade patterns: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. codediff:: - :title_left: Only ``FrozenDict`` - :title_right: Both ``FrozenDict`` and regular dict + :title: Only ``FrozenDict``, Both ``FrozenDict`` and regular dict :sync: variables = variables.unfreeze() diff --git a/docs/guides/converting_and_upgrading/rnncell_upgrade_guide.rst b/docs/guides/converting_and_upgrading/rnncell_upgrade_guide.rst index 7629c3bb35..5282c8df45 100644 --- a/docs/guides/converting_and_upgrading/rnncell_upgrade_guide.rst +++ b/docs/guides/converting_and_upgrading/rnncell_upgrade_guide.rst @@ -11,7 +11,7 @@ This guide will walk you through these changes, demonstrating how to update your Basic Usage ----------- -.. testsetup:: +.. testsetup:: New import flax.linen as nn import jax.numpy as jnp @@ -21,7 +21,7 @@ Basic Usage Let's begin by defining some variables and a sample input that represents a batch of sequences: -.. testcode:: +.. testcode:: New batch_size = 32 seq_len = 10 @@ -34,8 +34,8 @@ First and foremost, it's important to note that all metadata, including the numb carry initializer, and so on, is now stored within the cell instance: .. codediff:: - :title_left: Legacy - :title_right: New + :title: Legacy, New + :skip_test: Legacy :sync: cell = nn.LSTMCell() @@ -49,8 +49,8 @@ the cell instance now contains all metadata, the ``initialize_carry`` method's signature only requires a PRNG key and a sample input: .. codediff:: - :title_left: Legacy - :title_right: New + :title: Legacy, New + :skip_test: Legacy :sync: carry = nn.LSTMCell.initialize_carry(jax.random.key(0), (batch_size,), out_features) @@ -62,7 +62,7 @@ signature only requires a PRNG key and a sample input: Here, ``x[:, 0].shape`` represents the input for the cell (without the time dimension). You can also just create the input shape directly when its more convenient: -.. testcode:: +.. testcode:: New carry = cell.initialize_carry(jax.random.key(0), (batch_size, in_features)) @@ -80,8 +80,8 @@ to make the minimal amount of changes to the code to get it working, albeit not in the most idiomatic way: .. codediff:: - :title_left: Legacy - :title_right: New + :title: Legacy, New + :skip_test: Legacy :sync: class SimpleLSTM(nn.Module): @@ -129,8 +129,8 @@ here will be that we will add a ``features`` attribute to the module and use it a ``nn.scan``-ed version of the cell in the ``setup`` method: .. codediff:: - :title_left: Legacy - :title_right: New + :title: Legacy, New + :skip_test: Legacy :sync: class SimpleLSTM(nn.Module): diff --git a/docs/guides/flax_fundamentals/setup_or_nncompact.rst b/docs/guides/flax_fundamentals/setup_or_nncompact.rst index 4cc2b2394c..4e296d485b 100644 --- a/docs/guides/flax_fundamentals/setup_or_nncompact.rst +++ b/docs/guides/flax_fundamentals/setup_or_nncompact.rst @@ -24,13 +24,12 @@ can be defined in two ways: Here is a short example of a module defined in both ways, with exactly the same functionality. -.. testsetup:: +.. testsetup:: Using ``setup``, Using ``nn.compact`` import flax.linen as nn .. codediff:: - :title_left: Using ``setup`` - :title_right: Using ``nn.compact`` + :title: Using ``setup``, Using ``nn.compact`` class MLP(nn.Module): def setup(self): diff --git a/docs/guides/model_inspection/extracting_intermediates.rst b/docs/guides/model_inspection/extracting_intermediates.rst index 34b98d1685..1597e3ebfd 100644 --- a/docs/guides/model_inspection/extracting_intermediates.rst +++ b/docs/guides/model_inspection/extracting_intermediates.rst @@ -4,7 +4,7 @@ Extracting intermediate values This guide will show you how to extract intermediate values from a module. Let's start with this simple CNN that uses :code:`nn.compact`. -.. testsetup:: +.. testsetup:: default, sow import flax import flax.linen as nn @@ -50,8 +50,8 @@ The CNN can be augmented with calls to ``sow`` to store intermediates as followi .. codediff:: - :title_left: Default CNN - :title_right: CNN using sow API + :title: Default CNN, CNN using sow API + :groups: default, sow class CNN(nn.Module): @nn.compact @@ -101,7 +101,7 @@ Note that, by default ``sow`` appends values every time it is called: * To override the default append behavior, specify ``init_fn`` and ``reduce_fn`` - see :meth:`Module.sow() `. -.. testcode:: +.. testcode:: sow class SowCNN2(nn.Module): @nn.compact @@ -229,8 +229,8 @@ Note that ``capture_intermediates`` will only apply to layers. You can use ``sel non-layer intermediates, but the filter function won't be applied to it. .. codediff:: - :title_left: Capturing all layer intermediates - :title_right: Using filter function and ``self.sow()`` + :title: Capturing all layer intermediates, Using filter function and ``self.sow()`` + :groups: default, sow class Model(nn.Module): @nn.compact @@ -285,7 +285,7 @@ To separate the intermediates extracted from ``self.sow`` from the intermediates we can either define a separate collection like ``self.sow('sow_intermediates', 'c', c)``, or manually filter out the intermediates after calling ``.apply()``. For example: -.. testcode:: +.. testcode:: sow flattened_dict = flax.traverse_util.flatten_dict(feats['intermediates'], sep='/') flattened_dict['c'] diff --git a/docs/guides/parallel_training/ensembling.rst b/docs/guides/parallel_training/ensembling.rst index 89b2dbcdfd..62d242c778 100644 --- a/docs/guides/parallel_training/ensembling.rst +++ b/docs/guides/parallel_training/ensembling.rst @@ -13,7 +13,7 @@ be described as: In this HOWTO we omit some of the code such as imports, the CNN module, and metrics computation, but they can be found in the `MNIST example`_. -.. testsetup:: +.. testsetup:: Single-model, Ensemble import functools from flax import jax_utils @@ -69,8 +69,7 @@ XLA (similar to |jax.jit()|_), but execute it in parallel on XLA devices (e.g., GPUs/TPUs). .. codediff:: - :title_left: Single-model - :title_right: Ensemble + :title: Single-model, Ensemble :sync: #! @@ -108,8 +107,7 @@ the average *across devices*. This also requires us to specify the ``axis_name`` to both |jax.pmap()|_ and |jax.lax.pmean()|_. .. codediff:: - :title_left: Single-model - :title_right: Ensemble + :title: Single-model, Ensemble :sync: @jax.jit #! @@ -156,8 +154,7 @@ functions from above, we mainly need to take care of duplicating the arguments for all devices where necessary, and de-duplicating the return values. .. codediff:: - :title_left: Single-model - :title_right: Ensemble + :title: Single-model, Ensemble :sync: def train_epoch(state, train_ds, batch_size, rng): @@ -218,8 +215,7 @@ smaller than the train dataset so we can do this for the entire dataset directly. .. codediff:: - :title_left: Single-model - :title_right: Ensemble + :title: Single-model, Ensemble :sync: train_ds, test_ds = get_datasets() diff --git a/docs/guides/training_techniques/batch_norm.rst b/docs/guides/training_techniques/batch_norm.rst index 65c95c2e3c..135b36fae6 100644 --- a/docs/guides/training_techniques/batch_norm.rst +++ b/docs/guides/training_techniques/batch_norm.rst @@ -10,7 +10,7 @@ of non-differentiable state that must be handled appropriately. Throughout the guide, you will be able to compare code examples with and without Flax ``BatchNorm``. -.. testsetup:: +.. testsetup:: No BatchNorm, With BatchNorm import flax.linen as nn import jax.numpy as jnp @@ -36,8 +36,7 @@ or ``tf.keras.Model`` by setting the `training `__ flag). .. codediff:: - :title_left: No BatchNorm - :title_right: With BatchNorm + :title: No BatchNorm, With BatchNorm :sync: class MLP(nn.Module): @@ -75,8 +74,7 @@ API documentation. The ``batch_stats`` collection must be extracted from the ``variables`` for later use. .. codediff:: - :title_left: No BatchNorm - :title_right: With BatchNorm + :title: No BatchNorm, With BatchNorm :sync: mlp = MLP() @@ -101,8 +99,7 @@ Flax ``BatchNorm`` adds a total of 4 variables: ``mean`` and ``var`` that live i collection. .. codediff:: - :title_left: No BatchNorm - :title_right: With BatchNorm + :title: No BatchNorm, With BatchNorm :sync: FrozenDict({ @@ -154,8 +151,7 @@ need to consider the following: The updated ``batch_stats`` must be extracted from here. .. codediff:: - :title_left: No BatchNorm - :title_right: With BatchNorm + :title: No BatchNorm, With BatchNorm :sync: y = mlp.apply( @@ -182,8 +178,7 @@ is handling the additional ``batch_stats`` state. To do this, you need to: * Pass the ``batch_stats`` values to the :meth:`train_state.TrainState.create ` method. .. codediff:: - :title_left: No BatchNorm - :title_right: With BatchNorm + :title: No BatchNorm, With BatchNorm :sync: from flax.training import train_state @@ -215,12 +210,11 @@ In addition, update your ``train_step`` function to reflect these changes: * The ``batch_stats`` from the ``TrainState`` must be updated. .. codediff:: - :title_left: No BatchNorm - :title_right: With BatchNorm + :title: No BatchNorm, With BatchNorm :sync: @jax.jit - def train_step(state: TrainState, batch): + def train_step(state: train_state.TrainState, batch): """Train for a single step.""" def loss_fn(params): logits = state.apply_fn( @@ -265,12 +259,11 @@ need to be propagated. Make sure you pass the ``batch_stats`` to ``flax.linen.ap and the ``train`` argument is set to ``False``: .. codediff:: - :title_left: No BatchNorm - :title_right: With BatchNorm + :title: No BatchNorm, With BatchNorm :sync: @jax.jit - def eval_step(state: TrainState, batch): + def eval_step(state: train_state.TrainState, batch): """Train for a single step.""" logits = state.apply_fn( {'params': params}, diff --git a/docs/guides/training_techniques/dropout.rst b/docs/guides/training_techniques/dropout.rst index 8a6420c997..f58dbf22ea 100644 --- a/docs/guides/training_techniques/dropout.rst +++ b/docs/guides/training_techniques/dropout.rst @@ -11,7 +11,7 @@ and visible units in a network. Throughout the guide, you will be able to compare code examples with and without Flax ``Dropout``. -.. testsetup:: +.. testsetup:: No Dropout, With Dropout import flax.linen as nn import jax.numpy as jnp @@ -37,8 +37,7 @@ Begin by splitting the PRNG key using into three keys, including one for Flax Linen ``Dropout``. .. codediff:: - :title_left: No Dropout - :title_right: With Dropout + :title: No Dropout, With Dropout :sync: root_key = jax.random.key(seed=0) @@ -98,8 +97,7 @@ check out the In short, :meth:`flax.linen.Module.make_rng` *guarantees full reproducibility*. .. codediff:: - :title_left: No Dropout - :title_right: With Dropout + :title: No Dropout, With Dropout :sync: class MyModel(nn.Module): @@ -137,8 +135,7 @@ and with ``Dropout`` is that the ``training`` (or ``train``) argument must be provided if you need dropout enabled. .. codediff:: - :title_left: No Dropout - :title_right: With Dropout + :title: No Dropout, With Dropout :sync: my_model = MyModel(num_neurons=3) @@ -163,8 +160,7 @@ When using :meth:`flax.linen.apply()` to run your model: to seed the ``'dropout'`` stream when you call :meth:`flax.linen.apply()`. .. codediff:: - :title_left: No Dropout - :title_right: With Dropout + :title: No Dropout, With Dropout :sync: # No need to pass the `training` and `rngs` flags. @@ -196,8 +192,7 @@ the training step function. Refer to the * Then, pass the ``key`` value—in this case, the ``dropout_key``—to the :meth:`train_state.TrainState.create` method. .. codediff:: - :title_left: No Dropout - :title_right: With Dropout + :title: No Dropout, With Dropout :sync: from flax.training import train_state @@ -236,12 +231,11 @@ the training step function. Refer to the as an extra parameter. .. codediff:: - :title_left: No Dropout - :title_right: With Dropout + :title: No Dropout, With Dropout :sync: @jax.jit - def train_step(state: TrainState, batch): + def train_step(state: train_state.TrainState, batch): def loss_fn(params): logits = state.apply_fn( diff --git a/docs/guides/training_techniques/lr_schedule.rst b/docs/guides/training_techniques/lr_schedule.rst index fa530c66ab..8c0f3a2c39 100644 --- a/docs/guides/training_techniques/lr_schedule.rst +++ b/docs/guides/training_techniques/lr_schedule.rst @@ -16,7 +16,7 @@ We will show you how to... * train a simple model using that schedule -.. testsetup:: +.. testsetup:: Default learning rate, Learning rate schedule import jax import jax.numpy as jnp @@ -75,7 +75,7 @@ We will show you how to... return metrics -.. testcode:: +.. testcode:: Default learning rate, Learning rate schedule def create_learning_rate_fn(config, base_learning_rate, steps_per_epoch): """Creates learning rate schedule.""" @@ -99,8 +99,7 @@ For example using this schedule on MNIST would require changing the ``train_step .. _Optax: https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules .. codediff:: - :title_left: Default learning rate - :title_right: Learning rate schedule + :title: Default learning rate, Learning rate schedule :sync: @jax.jit @@ -136,8 +135,7 @@ For example using this schedule on MNIST would require changing the ``train_step And the ``train_epoch`` function: .. codediff:: - :title_left: Default learning rate - :title_right: Learning rate schedule + :title: Default learning rate, Learning rate schedule :sync: def train_epoch(state, train_ds, batch_size, epoch, rng): @@ -193,8 +191,7 @@ And the ``create_train_state`` function: .. codediff:: - :title_left: Default learning rate - :title_right: Learning rate schedule + :title: Default learning rate, Learning rate schedule :sync: def create_train_state(rng, config): @@ -214,7 +211,7 @@ And the ``create_train_state`` function: apply_fn=cnn.apply, params=params, tx=tx) -.. testcleanup:: +.. testcleanup:: Learning rate schedule config = get_config() diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 9bc9cb853a..67a283fc42 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -113,6 +113,7 @@ if $RUN_PYTEST; then pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE # Run nnx tests pytest -n auto flax/experimental/nnx/tests $PYTEST_OPTS $PYTEST_IGNORE + pytest -n auto docs/_ext/codediff_test.py $PYTEST_OPTS $PYTEST_IGNORE # Per-example tests. #