Skip to content

Commit

Permalink
use note box in docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Mar 19, 2024
1 parent 79851d4 commit c66a4df
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 47 deletions.
5 changes: 3 additions & 2 deletions flax/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ def build(self) -> A:
"""Create and return a copy of the original object with accumulated changes.
This method is to be called after making changes to the Cursor object.
NOTE: The new object is built bottom-up, the changes will be first applied
to the leaf nodes, and then its parent, all the way up to the root.
.. note::
The new object is built bottom-up, the changes will be first applied
to the leaf nodes, and then its parent, all the way up to the root.
Example::
Expand Down
15 changes: 8 additions & 7 deletions flax/experimental/nnx/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,16 @@ def dot_product_attention(
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Note: query, key, value needn't have any batch dimensions.
.. note::
``query``, ``key``, ``value`` needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of `[batch..., q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch..., kv_length,
num_heads, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch..., kv_length,
num_heads, v_depth_per_head]`.
query: queries for calculating attention with shape of ``[batch..., q_length,
num_heads, qk_depth_per_head]``.
key: keys for calculating attention with shape of ``[batch..., kv_length,
num_heads, qk_depth_per_head]``.
value: values to be used in attention with shape of ``[batch..., kv_length,
num_heads, v_depth_per_head]``.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
Expand Down
3 changes: 2 additions & 1 deletion flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def dot_product_attention(
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Note: query, key, value needn't have any batch dimensions.
.. note::
``query``, ``key``, ``value`` needn't have any batch dimensions.
Args:
query: queries for calculating attention with shape of ``[batch..., q_length,
Expand Down
9 changes: 5 additions & 4 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,10 +2653,11 @@ def perturb(
intermediate gradients of ``value`` by running ``jax.grad`` on the perturbation
argument.
Note: this is an experimental API and may be tweaked later for better
performance and usability.
At its current stage, it creates extra dummy variables that occupies extra
memory space. Use it only to debug gradients in training.
.. note::
This is an experimental API and may be tweaked later for better
performance and usability.
At its current stage, it creates extra dummy variables that occupies extra
memory space. Use it only to debug gradients in training.
Example::
Expand Down
53 changes: 28 additions & 25 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,12 @@ def __call__(
):
"""Normalizes the input using batch statistics.
NOTE:
During initialization (when ``self.is_initializing()`` is ``True``) the running
average of the batch statistics will not be updated. Therefore, the inputs
fed during initialization don't need to match that of the actual input
distribution and the reduction axis (set with ``axis_name``) does not have
to exist.
.. note::
During initialization (when ``self.is_initializing()`` is ``True``) the running
average of the batch statistics will not be updated. Therefore, the inputs
fed during initialization don't need to match that of the actual input
distribution and the reduction axis (set with ``axis_name``) does not have
to exist.
Args:
x: the input to be normalized.
Expand Down Expand Up @@ -389,9 +389,10 @@ class LayerNorm(Module):
i.e. applies a transformation that maintains the mean activation within
each example close to 0 and the activation standard deviation close to 1.
NOTE: This normalization operation is identical to InstanceNorm and GroupNorm;
the difference is simply which axes are reduced and the shape of the feature axes
(i.e. the shape of the learnable scale and bias parameters).
.. note::
This normalization operation is identical to InstanceNorm and GroupNorm;
the difference is simply which axes are reduced and the shape of the feature
axes (i.e. the shape of the learnable scale and bias parameters).
Example usage::
Expand Down Expand Up @@ -602,8 +603,9 @@ class GroupNorm(Module):
The user should either specify the total number of channel groups or the
number of channels per group.
NOTE: LayerNorm is a special case of GroupNorm where ``num_groups=1``, and
InstanceNorm is a special case of GroupNorm where ``group_size=1``.
.. note::
LayerNorm is a special case of GroupNorm where ``num_groups=1``, and
InstanceNorm is a special case of GroupNorm where ``group_size=1``.
Example usage::
Expand Down Expand Up @@ -778,9 +780,10 @@ class InstanceNorm(Module):
within each channel within each example close to 0 and the activation standard
deviation close to 1.
NOTE: This normalization operation is identical to LayerNorm and GroupNorm; the
difference is simply which axes are reduced and the shape of the feature axes
(i.e. the shape of the learnable scale and bias parameters).
.. note::
This normalization operation is identical to LayerNorm and GroupNorm; the
difference is simply which axes are reduced and the shape of the feature axes
(i.e. the shape of the learnable scale and bias parameters).
Example usage::
Expand Down Expand Up @@ -903,17 +906,17 @@ class SpectralNorm(Module):
where each wrapped layer will have its params spectral normalized before
computing its ``__call__`` output.
Usage Note:
The initialized variables dict will contain, in addition to a 'params'
collection, a separate 'batch_stats' collection that will contain a
``u`` vector and ``sigma`` value, which are intermediate values used
when performing spectral normalization. During training, we pass in
``update_stats=True`` and ``mutable=['batch_stats']`` so that ``u``
and ``sigma`` are updated with the most recently computed values using
power iteration. This will help the power iteration method approximate
the true singular value more accurately over time. During eval, we pass
in ``update_stats=False`` to ensure we get deterministic behavior from
the model. For example::
.. note::
The initialized variables dict will contain, in addition to a 'params'
collection, a separate 'batch_stats' collection that will contain a
``u`` vector and ``sigma`` value, which are intermediate values used
when performing spectral normalization. During training, we pass in
``update_stats=True`` and ``mutable=['batch_stats']`` so that ``u``
and ``sigma`` are updated with the most recently computed values using
power iteration. This will help the power iteration method approximate
the true singular value more accurately over time. During eval, we pass
in ``update_stats=False`` to ensure we get deterministic behavior from
the model.
Example usage::
Expand Down
8 changes: 5 additions & 3 deletions flax/linen/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
"""Helper function to define pooling functions.
Pooling functions are implemented using the ReduceWindow XLA op.
NOTE: Be aware that pooling is not generally differentiable.
That means providing a reduce_fn that is differentiable does not imply that
pool is differentiable.
.. note::
Be aware that pooling is not generally differentiable.
That means providing a reduce_fn that is differentiable does not imply that
pool is differentiable.
Args:
inputs: input data with dimensions (batch, window dims..., features).
Expand Down
9 changes: 6 additions & 3 deletions flax/linen/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@
class Dropout(Module):
"""Create a dropout layer.
Note: When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
variable initialization. Example usage::
.. note::
When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
variable initialization.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
Expand Down
5 changes: 3 additions & 2 deletions flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def field(pytree_node=True, **kwargs):
def dataclass(clz: _T, **kwargs) -> _T:
"""Create a class which can be passed to functional transformations.
NOTE: Inherit from ``PyTreeNode`` instead to avoid type checking issues when
using PyType.
.. note::
Inherit from ``PyTreeNode`` instead to avoid type checking issues when
using PyType.
Jax transformations such as ``jax.jit`` and ``jax.grad`` require objects that are
immutable and can be mapped over using the ``jax.tree_util`` methods.
Expand Down

0 comments on commit c66a4df

Please sign in to comment.