Skip to content

Commit

Permalink
Merge pull request #3733 from google:nnx-remove-flags-context
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615535549
  • Loading branch information
Flax Authors committed Mar 13, 2024
2 parents 3da44bc + 2206c40 commit 0280160
Show file tree
Hide file tree
Showing 14 changed files with 162 additions and 129 deletions.
1 change: 0 additions & 1 deletion flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from .nnx.errors import TraceContextError as TraceContextError
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.flaglib import flags as flags
from .nnx.graph_utils import GraphDef as GraphDef
from .nnx.helpers import Dict as Dict
from .nnx.helpers import Sequence as Sequence
Expand Down
16 changes: 11 additions & 5 deletions flax/experimental/nnx/examples/lm1b/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from jax import lax

from flax.experimental import nnx
from flax.experimental.nnx.nnx import flaglib
from flax.experimental.nnx.examples.lm1b.configs import default

Shape = tuple[int, ...]
Expand Down Expand Up @@ -126,9 +125,11 @@ def __init__(
self,
config: TransformerConfig,
*,
decode: bool = False,
rngs: nnx.Rngs,
):
self.config = config
self.decode = decode
self.pos_emb_shape = (1, config.max_len, config.emb_dim)

if config.posemb_init is not None:
Expand Down Expand Up @@ -168,7 +169,7 @@ def __call__(self, inputs: jax.Array, inputs_positions=None):
pos_embedding = self.pos_embedding.value

# We use a cache position index for tracking decoding position.
if flaglib.flags.get('decode', False):
if self.decode:
_, _, df = pos_embedding.shape
# equivalent to pos_embedding[:, i:i+1] but traceable
pos_embedding = lax.dynamic_slice(
Expand Down Expand Up @@ -336,9 +337,11 @@ def __init__(
config: TransformerConfig,
shared_embedding: nnx.Embed | None = None,
*,
decode: bool = False,
rngs: nnx.Rngs,
):
self.config = config
self.decode = decode
self.shared_embedding = shared_embedding

# Target Embedding
Expand Down Expand Up @@ -413,7 +416,7 @@ def __call__(
assert inputs.ndim == 2 # (batch, len)

y = inputs.astype('int32')
if not flaglib.flags.get('decode', False):
if not self.decode:
y = shift_inputs(y, segment_ids=inputs_segmentation)
y = self.output_embed(y)
y = self.posembed_output(y, inputs_positions=inputs_positions)
Expand Down Expand Up @@ -450,8 +453,11 @@ class TransformerLM(nnx.Module):
config: TransformerConfig dataclass containing hyperparameters.
"""

def __init__(self, config: TransformerConfig, *, rngs: nnx.Rngs):
def __init__(
self, config: TransformerConfig, *, decode: bool = False, rngs: nnx.Rngs
):
self.config = config
self.decode = decode
self.decoder = Decoder(config=config, shared_embedding=None, rngs=rngs)

def __call__(
Expand All @@ -475,7 +481,7 @@ def __call__(
config = self.config

# Make padding attention masks.
if flaglib.flags.get('decode', False):
if self.decode:
# for fast autoregressive decoding we use no decoder mask
decoder_mask = None
else:
Expand Down
8 changes: 4 additions & 4 deletions flax/experimental/nnx/examples/lm1b/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def test_forward_eval(self):
self.transfer_params(config, params_nnx, params_linen)
model_nnx.update(params_nnx)

with nnx.flags(deterministic=True, decode=False):
output_nnx = model_nnx(sample_inputs)
model_nnx.set_attributes(deterministic=True, decode=False)
output_nnx = model_nnx(sample_inputs)

output_linen: jax.Array = model_linen.apply(
{'params': params_linen}, sample_inputs
Expand Down Expand Up @@ -263,13 +263,13 @@ def test_forward_decode(self):
self.transfer_params(config, params_nnx, params_linen)
self.transfer_cache(config, cache_nnx, cache_linen)
model_nnx.update(params_nnx, cache_nnx)
model_nnx.set_attributes(deterministic=True, decode=True)

outputs_nnx = []
outputs_linen = []

for inputs in ar_decode_inputs:
with nnx.flags(deterministic=True, decode=True):
output_nnx = model_nnx(inputs)
output_nnx = model_nnx(inputs)
outputs_nnx.append(output_nnx)

output_linen: jax.Array
Expand Down
24 changes: 12 additions & 12 deletions flax/experimental/nnx/examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ def train_step(
def loss_fn(params):
"""loss function used for training."""
module = state.graphdef.merge(params)
with nnx.flags(deterministic=False, decode=False):
logits = module(
inputs,
inputs_positions=inputs_positions,
inputs_segmentation=inputs_segmentation,
rngs=nnx.Rngs(dropout=dropout_rng),
)
module.set_attributes(deterministic=False, decode=False)
logits = module(
inputs,
inputs_positions=inputs_positions,
inputs_segmentation=inputs_segmentation,
rngs=nnx.Rngs(dropout=dropout_rng),
)

loss, weight_sum = compute_weighted_cross_entropy(
logits, inputs, weights, label_smoothing
Expand Down Expand Up @@ -229,8 +229,8 @@ def eval_step(
inputs = batch['inputs']
weights = jnp.where(inputs > 0, 1.0, 0.0)
module = static.merge(params)
with nnx.flags(deterministic=True, decode=False):
logits = module(inputs)
module.set_attributes(deterministic=True, decode=False)
logits = module(inputs)

return compute_metrics(logits, inputs, weights, label_smoothing)

Expand Down Expand Up @@ -261,8 +261,8 @@ def tokens_ids_to_logits(flat_ids, cache: nnx.State):
"""Token slice to logits from decoder model."""
# --> [batch * beam, 1, vocab]
module = static.merge(params, cache)
with nnx.flags(deterministic=True, decode=True):
logits = module(flat_ids)
module.set_attributes(deterministic=True, decode=True)
logits = module(flat_ids)
cache = module.extract(nnx.Cache)
# Remove singleton sequence-length dimension:
# [batch, 1, vocab] --> [batch, vocab]
Expand Down Expand Up @@ -538,7 +538,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
predict_step,
in_axes=(
0,
jax.tree_map(lambda x: None, state.params),
jax.tree_util.tree_map(lambda x: None, state.params),
0,
None,
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def scan_fn(
model = ScanMLP(10, n_layers=5, rngs=nnx.Rngs(0))

x = jnp.ones((3, 10))
with nnx.flags(deterministic=False):
y = model(x, rngs=nnx.Rngs(dropout=1))
model.set_attributes(deterministic=False)
y = model(x, rngs=nnx.Rngs(dropout=1))

print(jax.tree_map(jnp.shape, model.get_state()))
print(y.shape)
79 changes: 0 additions & 79 deletions flax/experimental/nnx/nnx/flaglib.py

This file was deleted.

59 changes: 59 additions & 0 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,65 @@ def modules(self) -> tp.Iterator[tuple[Path, Module]]:
if isinstance(value, Module):
yield path, value

def set_attributes(
self,
*filters: filterlib.Filter,
raise_if_not_found: bool = True,
**attributes: tp.Any,
) -> None:
"""Sets the attributes of nested Modules including the current Module.
If the attribute is not found in the Module, it is ignored.
Example::
>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.5, deterministic=False)
... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
``Filter``s can be used to set the attributes of specific Modules::
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True, use_running_average=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Args:
*filters: Filters to select the Modules to set the attributes of.
raise_if_not_found: If True (default), raises a ValueError if at least one attribute
instance is not found in one of the selected Modules.
**attributes: The attributes to set.
"""
remaining_attributes = set(attributes.keys())
if not filters:
filters = (True,)
predicates = tuple(map(filterlib.to_predicate, filters))
for path, module in self.modules():
for predicate in predicates:
if predicate(path, module):
for name, value in attributes.items():
if hasattr(module, name):
if name in remaining_attributes:
remaining_attributes.remove(name)
setattr(module, name, value)
break

if remaining_attributes and raise_if_not_found:
raise ValueError(
f'Could not find at least one instance of the following attributes: {remaining_attributes}'
)

def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__()

Expand Down
3 changes: 0 additions & 3 deletions flax/experimental/nnx/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from flax.experimental import nnx
from flax.experimental.nnx.nnx import rnglib
from flax.experimental.nnx.nnx import flaglib
from flax.experimental.nnx.nnx.module import Module, first_from
from flax.experimental.nnx.nnx.nn import initializers
from flax.experimental.nnx.nnx.nn.dtypes import promote_dtype
Expand Down Expand Up @@ -510,7 +509,6 @@ def __call__(
decode = first_from(
decode,
self.decode,
flaglib.flags.get('decode'),
error_msg="""No `decode` argument was provided to MultiHeadAttention
as either a __call__ argument, class attribute, or nnx.flag.""",
)
Expand Down Expand Up @@ -557,7 +555,6 @@ def __call__(
deterministic = first_from(
deterministic,
self.deterministic,
flaglib.flags.get('deterministic'),
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
as either a __call__ argument, class attribute, or nnx.flag.""",
)
Expand Down
3 changes: 1 addition & 2 deletions flax/experimental/nnx/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from jax import lax

from flax.experimental import nnx
from flax.experimental.nnx.nnx import flaglib, rnglib
from flax.experimental.nnx.nnx import rnglib
from flax.experimental.nnx.nnx.module import Module, first_from
from flax.experimental.nnx.nnx.nn import dtypes, initializers
from flax.typing import (
Expand Down Expand Up @@ -283,7 +283,6 @@ def __call__(
use_running_average = first_from(
use_running_average,
self.use_running_average,
flaglib.flags.get('use_running_average'),
error_msg="""No `use_running_average` argument was provided to BatchNorm
as either a __call__ argument, class attribute, or nnx.flag.""",
)
Expand Down
3 changes: 1 addition & 2 deletions flax/experimental/nnx/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax.numpy as jnp
from jax import lax, random

from flax.experimental.nnx.nnx import flaglib, rnglib
from flax.experimental.nnx.nnx import rnglib
from flax.experimental.nnx.nnx.module import Module, first_from
import dataclasses

Expand Down Expand Up @@ -61,7 +61,6 @@ def __call__(
deterministic = first_from(
deterministic,
self.deterministic,
flaglib.flags.get('deterministic'),
error_msg="""No `deterministic` argument was provided to Dropout
as either a __call__ argument, class attribute, or nnx.flag.""",
)
Expand Down
7 changes: 3 additions & 4 deletions flax/experimental/nnx/tests/nn/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __call__(self, x, sow_weights=False):
),
rng,
)
module.set_attributes(decode=False)

with nnx.flags(decode=False):
_ = module(x, True)
_ = module(x, True)
intermediates = module.pop(nnx.Intermediate)
assert intermediates['attention_layers/0/attention_weights'].raw_value[
0
Expand All @@ -77,8 +77,7 @@ def __call__(self, x, sow_weights=False):
0
].shape == (4, 8, 6, 6)

with nnx.flags(decode=False):
_ = module(x)
_ = module(x)
intermediates = module.pop(nnx.Intermediate)
assert not intermediates # empty

Expand Down
Loading

0 comments on commit 0280160

Please sign in to comment.