diff --git a/flax/linen/activation.py b/flax/linen/activation.py index b5d22cc460..f019435173 100644 --- a/flax/linen/activation.py +++ b/flax/linen/activation.py @@ -62,7 +62,14 @@ class PReLU(Module): it needs to be initialized before being called. Example usage:: - x = nn.PReLU()(x) + >>> import flax.linen as nn + + >>> class MLP(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... x = nn.Dense(2)(x) + ... x = nn.PReLU()(x) # initialized + ... return x Attributes: param_dtype: the dtype passed to parameter initializers (default: float32). diff --git a/flax/linen/attention.py b/flax/linen/attention.py index d6ca8c166d..5b0afa71a8 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -203,6 +203,54 @@ def dot_product_attention( class MultiHeadDotProductAttention(Module): """Multi-head dot-product attention. + Example usage:: + + >>> import flax.linen as nn + >>> import jax + + >>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16) + >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) + >>> shape = (4, 3, 2, 5) + >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) + >>> variables = layer.init(jax.random.key(0), q) + + >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) + >>> out = layer.apply(variables, q) + >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) + >>> out = layer.apply(variables, q, k) + >>> # different inputs for inputs_q, inputs_k and inputs_v + >>> out = layer.apply(variables, q, k, v) + + >>> attention_kwargs = dict( + ... num_heads=8, + ... qkv_features=16, + ... kernel_init=nn.initializers.ones, + ... bias_init=nn.initializers.zeros, + ... dropout_rate=0.5, + ... deterministic=False, + ... ) + >>> class Module(nn.Module): + ... attention_kwargs: dict + ... + ... @nn.compact + ... def __call__(self, x, dropout_rng=None): + ... out1 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) + ... out2 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) + ... return out1, out2 + >>> module = Module(attention_kwargs) + >>> variables = module.init({'params': key1, 'dropout': key2}, q) + + >>> # out1 and out2 are different. + >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) + >>> # out3 and out4 are different. + >>> # out1 and out3 are different. out2 and out4 are different. + >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) + >>> # out1 and out2 are the same. + >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) + >>> # out1 and out2 are the same as out3 and out4. + >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` + >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5) + Attributes: num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. @@ -481,7 +529,14 @@ def __call__( class SelfAttention(MultiHeadDotProductAttention): - """Self-attention special case of multi-head dot-product attention.""" + """Self-attention special case of multi-head dot-product attention. + + Example usage:: + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + >>> layer = nn.SelfAttention(num_heads=8, qkv_features=16) + >>> params = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5))) + """ @compact def __call__( # type: ignore diff --git a/flax/linen/combinators.py b/flax/linen/combinators.py index 2161379d06..48cecec9d6 100644 --- a/flax/linen/combinators.py +++ b/flax/linen/combinators.py @@ -32,15 +32,15 @@ class Sequential(Module): the next module and returns the output of the final module. Example usage:: + >>> import flax.linen as nn - class Foo(nn.Module): - - @nn.compact - def __call__(self, x): - return nn.Sequential([nn.Dense(4), - nn.relu, - nn.Dense(2), - nn.log_softmax])(x) + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... return nn.Sequential([nn.Dense(4), + ... nn.relu, + ... nn.Dense(2), + ... nn.log_softmax])(x) This combinator supports also layers that return multiple outputs if returned as a tuple or a dictionary. If the output of a layer is a ``tuple`` it will be @@ -49,25 +49,26 @@ def __call__(self, x): Example usage:: - class CrossAttentionBlock(nn.Module): - num_heads: int = 2 - qkv_features: int = 16 - - @nn.compact - def __call__(self, query, key_value): - output = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, qkv_features=self.qkv_features)(query, - key_value) - output = nn.Dense(self.qkv_features)(output) - return dict(query=output, key_value=key_value) # also works for tuples - - class CrossAttentionNetwork(nn.Module): - num_layers: Sequence[int] - - @nn.compact - def __call__(self, x): - return nn.Sequential([CrossAttentionBlock() for _ in - range(self.num_layers)])(query, key_value) + >>> class CrossAttentionBlock(nn.Module): + ... num_heads: int = 2 + ... qkv_features: int = 16 + ... + ... @nn.compact + ... def __call__(self, query, key_value): + ... output = nn.MultiHeadDotProductAttention( + ... num_heads=self.num_heads, qkv_features=self.qkv_features)(query, + ... key_value) + ... output = nn.Dense(self.qkv_features)(output) + ... return dict(query=output, key_value=key_value) # also works for tuples + + >>> from typing import Sequence + >>> class CrossAttentionNetwork(nn.Module): + ... num_layers: Sequence[int] + ... + ... @nn.compact + ... def __call__(self, x): + ... return nn.Sequential([CrossAttentionBlock() for _ in + ... range(self.num_layers)])(query, key_value) """ layers: Sequence[Callable[..., Any]] diff --git a/flax/linen/linear.py b/flax/linen/linear.py index 77d3e65f87..fde959d8c0 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -71,6 +71,24 @@ def _canonicalize_tuple(x: Union[Sequence[int], int]) -> Tuple[int, ...]: class DenseGeneral(Module): """A linear transformation with flexible axes. + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> # equivalent to `nn.Dense(features=4)` + >>> layer = nn.DenseGeneral(features=4) + >>> # output features (4, 5) + >>> layer = nn.DenseGeneral(features=(4, 5)) + >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) + >>> jax.tree_map(jnp.shape, params) + {'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}} + >>> # apply transformation on the the second and last axes + >>> layer = nn.DenseGeneral(features=(4, 5), axis=(1, -1)) + >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7))) + >>> jax.tree_map(jnp.shape, params) + {'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}} + Attributes: features: int or tuple with number of output features. axis: int or tuple with axes to apply the transformation on. For instance, @@ -198,6 +216,16 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): class Dense(Module): """A linear transformation applied over the last dimension of the input. + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> layer = nn.Dense(features=4) + >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) + >>> jax.tree_map(jnp.shape, params) + {'params': {'bias': (4,), 'kernel': (3, 4)}} + Attributes: features: the number of output features. use_bias: whether to add a bias to the output (default: True). @@ -563,6 +591,30 @@ def maybe_broadcast( class Conv(_Conv): """Convolution Module wrapping `lax.conv_general_dilated`. + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> # valid padding + >>> layer = nn.Conv(features=4, kernel_size=(3,), padding='VALID') + >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) + >>> jax.tree_map(jnp.shape, variables) + {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} + >>> out.shape + (1, 6, 4) + >>> # circular padding with stride 2 + >>> layer = nn.Conv(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR') + >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) + >>> jax.tree_map(jnp.shape, variables) + {'params': {'bias': (4,), 'kernel': (3, 3, 3, 4)}} + >>> out.shape + (1, 4, 4) + >>> # apply lower triangle mask + >>> mask = jnp.tril(jnp.ones((3, 3, 4))) + >>> layer = nn.Conv(features=4, kernel_size=(3,), mask=mask, padding='VALID') + >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3))) + Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. @@ -604,6 +656,30 @@ def shared_weights(self) -> bool: class ConvLocal(_Conv): """Local convolution Module wrapping `lax.conv_general_dilated_local`. + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> # valid padding + >>> layer = nn.ConvLocal(features=4, kernel_size=(3,), padding='VALID') + >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) + >>> jax.tree_map(jnp.shape, variables) + {'params': {'bias': (6, 4), 'kernel': (6, 9, 4)}} + >>> out.shape + (1, 6, 4) + >>> # circular padding with stride 2 + >>> layer = nn.ConvLocal(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR') + >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) + >>> jax.tree_map(jnp.shape, variables) + {'params': {'bias': (1, 4, 4), 'kernel': (1, 4, 27, 4)}} + >>> out.shape + (1, 4, 4) + >>> # apply lower triangle mask + >>> mask = jnp.tril(jnp.ones((6, 9, 4))) + >>> layer = nn.ConvLocal(features=4, kernel_size=(3,), mask=mask, padding='VALID') + >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3))) + Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. @@ -645,6 +721,30 @@ def shared_weights(self) -> bool: class ConvTranspose(Module): """Convolution Module wrapping lax.conv_transpose. + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> # valid padding + >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), padding='VALID') + >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) + >>> jax.tree_map(jnp.shape, variables) + {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} + >>> out.shape + (1, 10, 4) + >>> # circular padding with stride 2 + >>> layer = nn.ConvTranspose(features=4, kernel_size=(6, 6), strides=(2, 2), padding='CIRCULAR', transpose_kernel=True) + >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 15, 15, 3))) + >>> jax.tree_map(jnp.shape, variables) + {'params': {'bias': (4,), 'kernel': (6, 6, 4, 3)}} + >>> out.shape + (1, 30, 30, 4) + >>> # apply lower triangle mask + >>> mask = jnp.tril(jnp.ones((3, 3, 4))) + >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), mask=mask, padding='VALID') + >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3))) + Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. For 1D convolution, @@ -840,6 +940,20 @@ class Embed(Module): A parameterized function from integers [0, n) to d-dimensional vectors. + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> layer = nn.Embed(num_embeddings=4, features=3) + >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 5), dtype=int)) + >>> jax.tree_map(jnp.shape, variables) + {'params': {'embedding': (4, 3)}} + >>> layer.apply(variables, jnp.ones((5,), dtype=int)).shape + (5, 3) + >>> layer.apply(variables, jnp.ones((5, 6), dtype=int)).shape + (5, 6, 3) + Attributes: num_embeddings: number of embeddings. features: number of feature dimensions for each embedding.