Skip to content

Commit

Permalink
added example docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Nov 3, 2023
1 parent c008753 commit 2bf8981
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 29 deletions.
9 changes: 8 additions & 1 deletion flax/linen/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ class PReLU(Module):
it needs to be initialized before being called.
Example usage::
x = nn.PReLU()(x)
>>> import flax.linen as nn
>>> class MLP(nn.Module):
... @nn.compact
... def __call__(self, x):
... x = nn.Dense(2)(x)
... x = nn.PReLU()(x) # initialized
... return x
Attributes:
param_dtype: the dtype passed to parameter initializers (default: float32).
Expand Down
57 changes: 56 additions & 1 deletion flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,54 @@ def dot_product_attention(
class MultiHeadDotProductAttention(Module):
"""Multi-head dot-product attention.
Example usage::
>>> import flax.linen as nn
>>> import jax
>>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> attention_kwargs = dict(
... num_heads=8,
... qkv_features=16,
... kernel_init=nn.initializers.ones,
... bias_init=nn.initializers.zeros,
... dropout_rate=0.5,
... deterministic=False,
... )
>>> class Module(nn.Module):
... attention_kwargs: dict
...
... @nn.compact
... def __call__(self, x, dropout_rng=None):
... out1 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... out2 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
... return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)
>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
Expand Down Expand Up @@ -481,7 +529,14 @@ def __call__(


class SelfAttention(MultiHeadDotProductAttention):
"""Self-attention special case of multi-head dot-product attention."""
"""Self-attention special case of multi-head dot-product attention.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> layer = nn.SelfAttention(num_heads=8, qkv_features=16)
>>> params = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5)))
"""

@compact
def __call__( # type: ignore
Expand Down
55 changes: 28 additions & 27 deletions flax/linen/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ class Sequential(Module):
the next module and returns the output of the final module.
Example usage::
>>> import flax.linen as nn
class Foo(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Sequential([nn.Dense(4),
nn.relu,
nn.Dense(2),
nn.log_softmax])(x)
>>> class Foo(nn.Module):
... @nn.compact
... def __call__(self, x):
... return nn.Sequential([nn.Dense(4),
... nn.relu,
... nn.Dense(2),
... nn.log_softmax])(x)
This combinator supports also layers that return multiple outputs if returned
as a tuple or a dictionary. If the output of a layer is a ``tuple`` it will be
Expand All @@ -49,25 +49,26 @@ def __call__(self, x):
Example usage::
class CrossAttentionBlock(nn.Module):
num_heads: int = 2
qkv_features: int = 16
@nn.compact
def __call__(self, query, key_value):
output = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
key_value)
output = nn.Dense(self.qkv_features)(output)
return dict(query=output, key_value=key_value) # also works for tuples
class CrossAttentionNetwork(nn.Module):
num_layers: Sequence[int]
@nn.compact
def __call__(self, x):
return nn.Sequential([CrossAttentionBlock() for _ in
range(self.num_layers)])(query, key_value)
>>> class CrossAttentionBlock(nn.Module):
... num_heads: int = 2
... qkv_features: int = 16
...
... @nn.compact
... def __call__(self, query, key_value):
... output = nn.MultiHeadDotProductAttention(
... num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
... key_value)
... output = nn.Dense(self.qkv_features)(output)
... return dict(query=output, key_value=key_value) # also works for tuples
>>> from typing import Sequence
>>> class CrossAttentionNetwork(nn.Module):
... num_layers: Sequence[int]
...
... @nn.compact
... def __call__(self, x):
... return nn.Sequential([CrossAttentionBlock() for _ in
... range(self.num_layers)])(query, key_value)
"""

layers: Sequence[Callable[..., Any]]
Expand Down
114 changes: 114 additions & 0 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@ def _canonicalize_tuple(x: Union[Sequence[int], int]) -> Tuple[int, ...]:
class DenseGeneral(Module):
"""A linear transformation with flexible axes.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> # equivalent to `nn.Dense(features=4)`
>>> layer = nn.DenseGeneral(features=4)
>>> # output features (4, 5)
>>> layer = nn.DenseGeneral(features=(4, 5))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3)))
>>> jax.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}}
>>> # apply transformation on the the second and last axes
>>> layer = nn.DenseGeneral(features=(4, 5), axis=(1, -1))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7)))
>>> jax.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}}
Attributes:
features: int or tuple with number of output features.
axis: int or tuple with axes to apply the transformation on. For instance,
Expand Down Expand Up @@ -198,6 +216,16 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32):
class Dense(Module):
"""A linear transformation applied over the last dimension of the input.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> layer = nn.Dense(features=4)
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3)))
>>> jax.tree_map(jnp.shape, params)
{'params': {'bias': (4,), 'kernel': (3, 4)}}
Attributes:
features: the number of output features.
use_bias: whether to add a bias to the output (default: True).
Expand Down Expand Up @@ -563,6 +591,30 @@ def maybe_broadcast(
class Conv(_Conv):
"""Convolution Module wrapping `lax.conv_general_dilated`.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> # valid padding
>>> layer = nn.Conv(features=4, kernel_size=(3,), padding='VALID')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (3, 3, 4)}}
>>> out.shape
(1, 6, 4)
>>> # circular padding with stride 2
>>> layer = nn.Conv(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (3, 3, 3, 4)}}
>>> out.shape
(1, 4, 4)
>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nn.Conv(features=4, kernel_size=(3,), mask=mask, padding='VALID')
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
Attributes:
features: number of convolution filters.
kernel_size: shape of the convolutional kernel.
Expand Down Expand Up @@ -604,6 +656,30 @@ def shared_weights(self) -> bool:
class ConvLocal(_Conv):
"""Local convolution Module wrapping `lax.conv_general_dilated_local`.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> # valid padding
>>> layer = nn.ConvLocal(features=4, kernel_size=(3,), padding='VALID')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_map(jnp.shape, variables)
{'params': {'bias': (6, 4), 'kernel': (6, 9, 4)}}
>>> out.shape
(1, 6, 4)
>>> # circular padding with stride 2
>>> layer = nn.ConvLocal(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_map(jnp.shape, variables)
{'params': {'bias': (1, 4, 4), 'kernel': (1, 4, 27, 4)}}
>>> out.shape
(1, 4, 4)
>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((6, 9, 4)))
>>> layer = nn.ConvLocal(features=4, kernel_size=(3,), mask=mask, padding='VALID')
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
Attributes:
features: number of convolution filters.
kernel_size: shape of the convolutional kernel.
Expand Down Expand Up @@ -645,6 +721,30 @@ def shared_weights(self) -> bool:
class ConvTranspose(Module):
"""Convolution Module wrapping lax.conv_transpose.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> # valid padding
>>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), padding='VALID')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (3, 3, 4)}}
>>> out.shape
(1, 10, 4)
>>> # circular padding with stride 2
>>> layer = nn.ConvTranspose(features=4, kernel_size=(6, 6), strides=(2, 2), padding='CIRCULAR', transpose_kernel=True)
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 15, 15, 3)))
>>> jax.tree_map(jnp.shape, variables)
{'params': {'bias': (4,), 'kernel': (6, 6, 4, 3)}}
>>> out.shape
(1, 30, 30, 4)
>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), mask=mask, padding='VALID')
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
Attributes:
features: number of convolution filters.
kernel_size: shape of the convolutional kernel. For 1D convolution,
Expand Down Expand Up @@ -840,6 +940,20 @@ class Embed(Module):
A parameterized function from integers [0, n) to d-dimensional vectors.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> layer = nn.Embed(num_embeddings=4, features=3)
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 5), dtype=int))
>>> jax.tree_map(jnp.shape, variables)
{'params': {'embedding': (4, 3)}}
>>> layer.apply(variables, jnp.ones((5,), dtype=int)).shape
(5, 3)
>>> layer.apply(variables, jnp.ones((5, 6), dtype=int)).shape
(5, 6, 3)
Attributes:
num_embeddings: number of embeddings.
features: number of feature dimensions for each embedding.
Expand Down

0 comments on commit 2bf8981

Please sign in to comment.