Skip to content

Commit

Permalink
updated haiku guide
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Apr 12, 2024
1 parent bf308ca commit b4da308
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 26 deletions.
38 changes: 28 additions & 10 deletions docs/_ext/codediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,45 @@ def parse(
self,
lines,
title_left='Base',
title_middle=None,
title_right='Diff',
code_sep='---',
sync=MISSING,
skip_left_test=MISSING,
):
sync = sync is not MISSING
# skip legacy code snippets in upgrade guides
skip_left_test=skip_left_test is not MISSING

if code_sep not in lines:
raise ValueError(
'Code separator not found! Code snippets should be '
f'separated by {code_sep}.'
)
idx = lines.index(code_sep)
code_left = self._code_block(lines[0:idx])
test_code = lines[idx + 1 :]
code_right = self._code_block(test_code)

output = self._tabs(
(title_left, code_left), (title_right, code_right), sync=sync
)
idxs = [i for i, x in enumerate(lines) if x == code_sep]
code_left = self._code_block(lines[0:idxs[0]])

test_codes = []
if not skip_left_test:
test_codes.append(lines[0:idxs[0]])
if title_middle is None:
assert len(idxs) == 1
test_codes.append(lines[idxs[0] + 1 :])
code_right = self._code_block(test_codes[-1])
output = self._tabs(
(title_left, code_left), (title_right, code_right), sync=sync
)
else:
assert len(idxs) == 2
test_codes.append(lines[idxs[0] + 1 :])
code_middle = self._code_block(test_codes[-1])
test_codes.append(lines[idxs[1] + 1 :])
code_right = self._code_block(test_codes[-1])
output = self._tabs(
(title_left, code_left), (title_middle, code_middle), (title_right, code_right), sync=sync
)

return output, test_code
return output, test_codes

def _code_block(self, lines):
"""Creates a codeblock."""
Expand Down Expand Up @@ -100,6 +118,7 @@ class CodeDiffDirective(SphinxDirective):
has_content = True
option_spec = {
'title_left': directives.unchanged,
'title_middle': directives.unchanged,
'title_right': directives.unchanged,
'code_sep': directives.unchanged,
'sync': directives.flag,
Expand All @@ -116,7 +135,6 @@ def run(self):
# in the source code:
# https://github.com/sphinx-doc/sphinx/blob/3.x/sphinx/ext/doctest.py
# (search for 'testnodetype').
test_code = '\n'.join(test_code)
test_node = nodes.comment(test_code, test_code, testnodetype='testcode')
# Set the source info so the error message is correct when testing.
self.set_source_info(test_node)
Expand Down
29 changes: 17 additions & 12 deletions docs/guides/converting_and_upgrading/haiku_migration_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ signature and passed to the superclass constructor.

.. codediff::
:title_left: Haiku
:title_middle: test1
:title_right: Flax
:sync:

Expand Down Expand Up @@ -55,6 +56,10 @@ signature and passed to the superclass constructor.

---

x = 1

---

import flax.linen as nn

class Block(nn.Module):
Expand Down Expand Up @@ -190,12 +195,12 @@ the random dropout masks.

def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params,
key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
logits = model.apply(
params,
key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
Expand All @@ -206,12 +211,12 @@ the random dropout masks.

def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
Expand Down
7 changes: 3 additions & 4 deletions docs/guides/converting_and_upgrading/linen_upgrade_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ can use this upgrade guide to upgrade it to Linen.
from jax import random
import optax
import jax
import flax.linen as nn
from flax.linen import initializers

from jax import lax
Expand Down Expand Up @@ -51,8 +52,6 @@ Defining simple Flax Modules
'bias', (features,), bias_init)
y = y + bias
return y

return new_state, metrics
---
from flax import linen as nn # [1] #!

Expand Down Expand Up @@ -311,7 +310,7 @@ Non-trainable variables ("state"): Use within Modules
:sync:

class BatchNorm(nn.Module):
def apply(self, x, ...):
def apply(self, x):
# [...]
ra_mean = self.state(
'mean', (x.shape[-1], ), initializers.zeros_init())
Expand Down Expand Up @@ -360,7 +359,7 @@ Non-trainable variables ("state"): Top-level training code patterns

# reads immutable batch statistics during evaluation
def eval_step(model, model_state, batch):
with nn.stateful(model_state, mutable=False):
with nn.stateful(model_state, mutable=False):
logits = model(batch['image'], train=False)
return compute_metrics(logits, batch['label'])
---
Expand Down

0 comments on commit b4da308

Please sign in to comment.