From c66a4df54cec25da0bf431cb4f34b77a948a7a86 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Mon, 18 Mar 2024 17:38:53 -0700 Subject: [PATCH] use note box in docstrings --- flax/cursor.py | 5 ++- flax/experimental/nnx/nnx/nn/attention.py | 15 ++++--- flax/linen/attention.py | 3 +- flax/linen/module.py | 9 ++-- flax/linen/normalization.py | 53 ++++++++++++----------- flax/linen/pooling.py | 8 ++-- flax/linen/stochastic.py | 9 ++-- flax/struct.py | 5 ++- 8 files changed, 60 insertions(+), 47 deletions(-) diff --git a/flax/cursor.py b/flax/cursor.py index e293c17696..b7792eb2c8 100644 --- a/flax/cursor.py +++ b/flax/cursor.py @@ -229,8 +229,9 @@ def build(self) -> A: """Create and return a copy of the original object with accumulated changes. This method is to be called after making changes to the Cursor object. - NOTE: The new object is built bottom-up, the changes will be first applied - to the leaf nodes, and then its parent, all the way up to the root. + .. note:: + The new object is built bottom-up, the changes will be first applied + to the leaf nodes, and then its parent, all the way up to the root. Example:: diff --git a/flax/experimental/nnx/nnx/nn/attention.py b/flax/experimental/nnx/nnx/nn/attention.py index a48c7df90a..aad0c3772a 100644 --- a/flax/experimental/nnx/nnx/nn/attention.py +++ b/flax/experimental/nnx/nnx/nn/attention.py @@ -153,15 +153,16 @@ def dot_product_attention( https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. - Note: query, key, value needn't have any batch dimensions. + .. note:: + ``query``, ``key``, ``value`` needn't have any batch dimensions. Args: - query: queries for calculating attention with shape of `[batch..., q_length, - num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of `[batch..., kv_length, - num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of `[batch..., kv_length, - num_heads, v_depth_per_head]`. + query: queries for calculating attention with shape of ``[batch..., q_length, + num_heads, qk_depth_per_head]``. + key: keys for calculating attention with shape of ``[batch..., kv_length, + num_heads, qk_depth_per_head]``. + value: values to be used in attention with shape of ``[batch..., kv_length, + num_heads, v_depth_per_head]``. bias: bias for the attention weights. This should be broadcastable to the shape `[batch..., num_heads, q_length, kv_length]`. This can be used for incorporating causal masks, padding masks, proximity bias, etc. diff --git a/flax/linen/attention.py b/flax/linen/attention.py index efcf2b78d8..4802536a0b 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -152,7 +152,8 @@ def dot_product_attention( https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. - Note: query, key, value needn't have any batch dimensions. + .. note:: + ``query``, ``key``, ``value`` needn't have any batch dimensions. Args: query: queries for calculating attention with shape of ``[batch..., q_length, diff --git a/flax/linen/module.py b/flax/linen/module.py index 15ff81cbe1..f532c28f3c 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -2653,10 +2653,11 @@ def perturb( intermediate gradients of ``value`` by running ``jax.grad`` on the perturbation argument. - Note: this is an experimental API and may be tweaked later for better - performance and usability. - At its current stage, it creates extra dummy variables that occupies extra - memory space. Use it only to debug gradients in training. + .. note:: + This is an experimental API and may be tweaked later for better + performance and usability. + At its current stage, it creates extra dummy variables that occupies extra + memory space. Use it only to debug gradients in training. Example:: diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 0eb8b69761..b6bd992fb2 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -310,12 +310,12 @@ def __call__( ): """Normalizes the input using batch statistics. - NOTE: - During initialization (when ``self.is_initializing()`` is ``True``) the running - average of the batch statistics will not be updated. Therefore, the inputs - fed during initialization don't need to match that of the actual input - distribution and the reduction axis (set with ``axis_name``) does not have - to exist. + .. note:: + During initialization (when ``self.is_initializing()`` is ``True``) the running + average of the batch statistics will not be updated. Therefore, the inputs + fed during initialization don't need to match that of the actual input + distribution and the reduction axis (set with ``axis_name``) does not have + to exist. Args: x: the input to be normalized. @@ -389,9 +389,10 @@ class LayerNorm(Module): i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1. - NOTE: This normalization operation is identical to InstanceNorm and GroupNorm; - the difference is simply which axes are reduced and the shape of the feature axes - (i.e. the shape of the learnable scale and bias parameters). + .. note:: + This normalization operation is identical to InstanceNorm and GroupNorm; + the difference is simply which axes are reduced and the shape of the feature + axes (i.e. the shape of the learnable scale and bias parameters). Example usage:: @@ -602,8 +603,9 @@ class GroupNorm(Module): The user should either specify the total number of channel groups or the number of channels per group. - NOTE: LayerNorm is a special case of GroupNorm where ``num_groups=1``, and - InstanceNorm is a special case of GroupNorm where ``group_size=1``. + .. note:: + LayerNorm is a special case of GroupNorm where ``num_groups=1``, and + InstanceNorm is a special case of GroupNorm where ``group_size=1``. Example usage:: @@ -778,9 +780,10 @@ class InstanceNorm(Module): within each channel within each example close to 0 and the activation standard deviation close to 1. - NOTE: This normalization operation is identical to LayerNorm and GroupNorm; the - difference is simply which axes are reduced and the shape of the feature axes - (i.e. the shape of the learnable scale and bias parameters). + .. note:: + This normalization operation is identical to LayerNorm and GroupNorm; the + difference is simply which axes are reduced and the shape of the feature axes + (i.e. the shape of the learnable scale and bias parameters). Example usage:: @@ -903,17 +906,17 @@ class SpectralNorm(Module): where each wrapped layer will have its params spectral normalized before computing its ``__call__`` output. - Usage Note: - The initialized variables dict will contain, in addition to a 'params' - collection, a separate 'batch_stats' collection that will contain a - ``u`` vector and ``sigma`` value, which are intermediate values used - when performing spectral normalization. During training, we pass in - ``update_stats=True`` and ``mutable=['batch_stats']`` so that ``u`` - and ``sigma`` are updated with the most recently computed values using - power iteration. This will help the power iteration method approximate - the true singular value more accurately over time. During eval, we pass - in ``update_stats=False`` to ensure we get deterministic behavior from - the model. For example:: + .. note:: + The initialized variables dict will contain, in addition to a 'params' + collection, a separate 'batch_stats' collection that will contain a + ``u`` vector and ``sigma`` value, which are intermediate values used + when performing spectral normalization. During training, we pass in + ``update_stats=True`` and ``mutable=['batch_stats']`` so that ``u`` + and ``sigma`` are updated with the most recently computed values using + power iteration. This will help the power iteration method approximate + the true singular value more accurately over time. During eval, we pass + in ``update_stats=False`` to ensure we get deterministic behavior from + the model. Example usage:: diff --git a/flax/linen/pooling.py b/flax/linen/pooling.py index 900fa6f93f..aac25a2cfb 100644 --- a/flax/linen/pooling.py +++ b/flax/linen/pooling.py @@ -23,9 +23,11 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): """Helper function to define pooling functions. Pooling functions are implemented using the ReduceWindow XLA op. - NOTE: Be aware that pooling is not generally differentiable. - That means providing a reduce_fn that is differentiable does not imply that - pool is differentiable. + + .. note:: + Be aware that pooling is not generally differentiable. + That means providing a reduce_fn that is differentiable does not imply that + pool is differentiable. Args: inputs: input data with dimensions (batch, window dims..., features). diff --git a/flax/linen/stochastic.py b/flax/linen/stochastic.py index 67be8f82d9..c245ebb2a9 100644 --- a/flax/linen/stochastic.py +++ b/flax/linen/stochastic.py @@ -26,9 +26,12 @@ class Dropout(Module): """Create a dropout layer. - Note: When using :meth:`Module.apply() `, make sure - to include an RNG seed named ``'dropout'``. Dropout isn't necessary for - variable initialization. Example usage:: + .. note:: + When using :meth:`Module.apply() `, make sure + to include an RNG seed named ``'dropout'``. Dropout isn't necessary for + variable initialization. + + Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp diff --git a/flax/struct.py b/flax/struct.py index 313a6ee664..7a8283a9d9 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -35,8 +35,9 @@ def field(pytree_node=True, **kwargs): def dataclass(clz: _T, **kwargs) -> _T: """Create a class which can be passed to functional transformations. - NOTE: Inherit from ``PyTreeNode`` instead to avoid type checking issues when - using PyType. + .. note:: + Inherit from ``PyTreeNode`` instead to avoid type checking issues when + using PyType. Jax transformations such as ``jax.jit`` and ``jax.grad`` require objects that are immutable and can be mapped over using the ``jax.tree_util`` methods.