Skip to content

Commit

Permalink
updated haiku guide
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed May 7, 2024
1 parent 3d98eda commit 59e037f
Show file tree
Hide file tree
Showing 16 changed files with 311 additions and 221 deletions.
143 changes: 112 additions & 31 deletions docs/_ext/codediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
Use directive as follows:
.. codediff::
:title_left: <LEFT_CODE_BLOCK_TITLE>
:title_right: <RIGHT_CODE_BLOCK_TITLE>
:title: <LEFT_CODE_BLOCK_TITLE>, <RIGHT_CODE_BLOCK_TITLE>
<CODE_BLOCK_LEFT>
---
<CODE_BLOCK_RIGHT>
In order to highlight a line of code, append "#!" to it.
"""
from typing import List, Tuple

from typing import List, Optional, Tuple

import sphinx
from docutils import nodes
Expand All @@ -40,29 +40,103 @@
class CodeDiffParser:
def parse(
self,
lines,
title_left='Base',
title_right='Diff',
code_sep='---',
sync=MISSING,
lines: List[str],
title: str,
groups: Optional[List[str]] = None,
skip_test: Optional[str] = None,
code_sep: str = '---',
sync: object = MISSING,
):
sync = sync is not MISSING
"""Parse the code diff block and format it so that it
renders in different tabs and is tested by doctest.
For example:
.. testcode:: tab0, tab2, tab3
<CODE_BLOCK_A>
.. codediff::
:title: Tab 0, Tab 1, Tab 2, Tab 3
:groups: tab0, tab1, tab2, tab3
:skip_test: tab1, tab3
<CODE_BLOCK_B0>
---
<CODE_BLOCK_B1>
---
if code_sep not in lines:
<CODE_BLOCK_B2>
---
<CODE_BLOCK_B3>
For group tab0: <CODE_BLOCK_A> and <CODE_BLOCK_B0> are executed.
For group tab1: Nothing is executed.
For group tab2: <CODE_BLOCK_A> and <CODE_BLOCK_B2> are executed.
For group tab3: <CODE_BLOCK_A> is executed.
Arguments:
lines: a string list, where each element is a single string code line
title: a single string that contains the titles of each tab (they should
be separated by commas)
groups: a single string that contains the group of each tab (they should
be separated by commas). Code snippets that are part of the same group
will be executed together. If groups=None, then the group names will
default to the tab title names.
skip_test: a single string denoting which group(s) to skip testing (they
should be separated by commas). This is useful for legacy code snippets
that no longer run correctly anymore. If skip_test=None, then no tests
are skipped.
code_sep: the separator character(s) used to denote a separate code block
for a new tab. The default code separator is '---'.
sync: an option for Sphinx directives, that will sync all tabs together.
This means that if the user clicks to switch to another tab, all tabs
will switch to the new tab.
"""
titles = [t.strip() for t in title.split(',')]
num_tabs = len(titles)

sync = sync is not MISSING
# skip legacy code snippets in upgrade guides
if skip_test is not None:
skip_tests = set([index.strip() for index in skip_test.split(',')])
else:
skip_tests = set()

code_blocks = '\n'.join(lines)
if code_blocks.count(code_sep) != num_tabs - 1:
raise ValueError(
f'Expected {num_tabs-1} code separator(s) for {num_tabs} tab(s), but got {code_blocks.count(code_sep)} code separator(s) instead.'
)
code_blocks = [
code_block.split('\n')
for code_block in code_blocks.split(code_sep + '\n')
] # list[code_tab_list1[string_line1, ...], ...]

# by default, put each code snippet in a different group denoted by an index number, to be executed separately
if groups is not None:
groups = [group_name.strip() for group_name in groups.split(',')]
else:
groups = titles
if len(groups) != num_tabs:
raise ValueError(
'Code separator not found! Code snippets should be '
f'separated by {code_sep}.'
f'Expected {num_tabs} group assignment(s) for {num_tabs} tab(s), but got {len(groups)} group assignment(s) instead.'
)
idx = lines.index(code_sep)
code_left = self._code_block(lines[0:idx])
test_code = lines[idx + 1 :]
code_right = self._code_block(test_code)

output = self._tabs(
(title_left, code_left), (title_right, code_right), sync=sync
)
tabs = []
test_codes = []
for i, code_block in enumerate(code_blocks):
if groups[i] not in skip_tests:
test_codes.append((code_block, groups[i]))
tabs.append((titles[i], self._code_block(code_block)))
output = self._tabs(*tabs, sync=sync)

return output, test_code
return output, test_codes

def _code_block(self, lines):
"""Creates a codeblock."""
Expand Down Expand Up @@ -99,36 +173,43 @@ def _tabs(self, *contents: Tuple[str, List[str]], sync):
class CodeDiffDirective(SphinxDirective):
has_content = True
option_spec = {
'title_left': directives.unchanged,
'title_right': directives.unchanged,
'title': directives.unchanged,
'groups': directives.unchanged,
'skip_test': directives.unchanged,
'code_sep': directives.unchanged,
'sync': directives.flag,
}

def run(self):
table_code, test_code = CodeDiffParser().parse(
table_code, test_codes = CodeDiffParser().parse(
list(self.content), **self.options
)

# Create a test node as a comment node so it won't show up in the docs.
# We add attribute "testnodetype" so it is be picked up by the doctest
# builder. This functionality is not officially documented but can be found
# in the source code:
# https://github.com/sphinx-doc/sphinx/blob/3.x/sphinx/ext/doctest.py
# https://github.com/sphinx-doc/sphinx/blob/master/sphinx/ext/doctest.py
# (search for 'testnodetype').
test_code = '\n'.join(test_code)
test_node = nodes.comment(test_code, test_code, testnodetype='testcode')
# Set the source info so the error message is correct when testing.
self.set_source_info(test_node)
test_node['options'] = {}
test_node['language'] = 'python3'
test_nodes = []
for test_code, group in test_codes:
test_node = nodes.comment(
'\n'.join(test_code),
'\n'.join(test_code),
testnodetype='testcode',
groups=[group],
)
self.set_source_info(test_node)
test_node['options'] = {}
test_node['language'] = 'python3'
test_nodes.append(test_node)

# The table node is the side-by-side diff view that will be shown on RTD.
table_node = nodes.paragraph()
self.content = ViewList(table_code, self.content.parent)
self.state.nested_parse(self.content, self.content_offset, table_node)

return [table_node, test_node]
return [table_node] + test_nodes


def setup(app):
Expand Down
96 changes: 74 additions & 22 deletions docs/_ext/codediff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

"""Tests for codediff Sphinx extension."""

from absl.testing import absltest
from absl.testing import parameterized
from codediff import CodeDiffParser


class CodeDiffTest(absltest.TestCase):
class CodeDiffTest(parameterized.TestCase):
def test_parse(self):
input_text = r"""@jax.jit #!
def get_initial_params(key): #!
Expand All @@ -33,36 +33,88 @@ def get_initial_params(key):
initial_params = CNN().init(key, init_val)['params']
return initial_params"""

expected_table = r"""+----------------------------------------------------------+----------------------------------------------------------+
| Single device | Ensembling on multiple devices |
+----------------------------------------------------------+----------------------------------------------------------+
| .. code-block:: python | .. code-block:: python |
| :emphasize-lines: 1,2 | :emphasize-lines: 1 |
| | |
| @jax.jit | @jax.pmap |
| def get_initial_params(key): | def get_initial_params(key): |
| init_val = jnp.ones((1, 28, 28, 1), jnp.float32) | init_val = jnp.ones((1, 28, 28, 1), jnp.float32) |
| initial_params = CNN().init(key, init_val)['params'] | initial_params = CNN().init(key, init_val)['params'] |
| extra_line | return initial_params |
| return initial_params | |
+----------------------------------------------------------+----------------------------------------------------------+"""
expected_table = """.. tab-set::\n \n .. tab-item:: Single device\n \n .. code-block:: python\n :emphasize-lines: 1,2\n \n @jax.jit\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n extra_line\n return initial_params\n \n .. tab-item:: Ensembling on multiple devices\n \n .. code-block:: python\n :emphasize-lines: 1\n \n @jax.pmap\n def get_initial_params(key):\n init_val = jnp.ones((1, 28, 28, 1), jnp.float32)\n initial_params = CNN().init(key, init_val)['params']\n return initial_params"""

expected_testcode = r"""@jax.pmap #!
expected_testcodes = [
r"""@jax.jit #!
def get_initial_params(key): #!
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = CNN().init(key, init_val)['params']
extra_line
return initial_params
""",
r"""@jax.pmap #!
def get_initial_params(key):
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = CNN().init(key, init_val)['params']
return initial_params"""
return initial_params""",
]

title_left = 'Single device'
title_right = 'Ensembling on multiple devices'

actual_table, actual_testcode = CodeDiffParser().parse(
actual_table, actual_testcodes = CodeDiffParser().parse(
lines=input_text.split('\n'),
title_left=title_left,
title_right=title_right,
title=f'{title_left}, {title_right}',
)

actual_table = '\n'.join(actual_table)
actual_testcode = '\n'.join(actual_testcode)
actual_testcodes = ['\n'.join(testcode) for testcode, _ in actual_testcodes]

self.assertEqual(expected_table, actual_table)
self.assertEqual(expected_testcode, actual_testcode)
self.assertEqual(expected_testcodes[0], actual_testcodes[0])
self.assertEqual(expected_testcodes[1], actual_testcodes[1])

@parameterized.parameters(
{
'input_text': r"""x = 1
---
x = 2
""",
'title': 'Tab 0, Tab1, Tab2',
'groups': None,
'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 1 code separator\\(s\\) instead.',
},
{
'input_text': r"""x = 1
---
x = 2
---
x = 3
---
x = 4
""",
'title': 'Tab 0, Tab1, Tab2',
'groups': None,
'error_msg': 'Expected 2 code separator\\(s\\) for 3 tab\\(s\\), but got 3 code separator\\(s\\) instead.',
},
{
'input_text': r"""x = 1
---
x = 2
---
x = 3
""",
'title': 'Tab 0, Tab1, Tab2',
'groups': 'tab0, tab2',
'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 2 group assignment\\(s\\) instead.',
},
{
'input_text': r"""x = 1
---
x = 2
---
x = 3
""",
'title': 'Tab 0, Tab1, Tab2',
'groups': 'tab0, tab1, tab2, tab3',
'error_msg': 'Expected 3 group assignment\\(s\\) for 3 tab\\(s\\), but got 4 group assignment\\(s\\) instead.',
},
)
def test_parse_errors(self, input_text, title, groups, error_msg):
with self.assertRaisesRegex(ValueError, error_msg):
_, _ = CodeDiffParser().parse(
lines=input_text.split('\n'),
title=title,
groups=groups,
)
10 changes: 5 additions & 5 deletions docs/experimental/nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ seamlessly switch between them or use them together. We will be focusing on the

First, let's set up imports and generate some dummy data:

.. testcode::
.. testcode:: NNX, JAX

from flax.experimental import nnx
import jax
Expand Down Expand Up @@ -38,8 +38,8 @@ whereas the function signature of JAX-transformed functions can only accept the
the transformed function.

.. codediff::
:title_left: NNX transforms
:title_right: JAX transforms
:title: NNX transforms, JAX transforms
:groups: NNX, JAX
:sync:

@nnx.jit
Expand Down Expand Up @@ -83,8 +83,8 @@ NNX and JAX transformations can be mixed together, so long as the JAX-transforme
pure and has valid argument types that are recognized by JAX.

.. codediff::
:title_left: Using ``nnx.jit`` with ``jax.grad``
:title_right: Using ``jax.jit`` with ``nnx.grad``
:title: Using ``nnx.jit`` with ``jax.grad``, Using ``jax.jit`` with ``nnx.grad``
:groups: NNX, JAX
:sync:

@nnx.jit
Expand Down
Loading

0 comments on commit 59e037f

Please sign in to comment.