diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index f5b24628bb..f2777f4911 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -434,11 +434,13 @@ class LayerNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x): + def __call__(self, x, mask=None): """Applies layer normalization on the input. Args: x: the inputs + mask: Binary array of shape broadcastable to `inputs` tensor, indicating + the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). @@ -450,6 +452,7 @@ def __call__(self, x): self.axis_name, self.axis_index_groups, use_fast_variance=self.use_fast_variance, + mask=mask, ) return _normalize( @@ -525,11 +528,13 @@ class RMSNorm(Module): axis_index_groups: Any = None @compact - def __call__(self, x): + def __call__(self, x, mask=None): """Applies layer normalization on the input. Args: x: the inputs + mask: Binary array of shape broadcastable to `inputs` tensor, indicating + the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). @@ -541,6 +546,7 @@ def __call__(self, x): self.axis_name, self.axis_index_groups, use_mean=False, + mask=mask, ) return _normalize( @@ -625,13 +631,15 @@ class GroupNorm(Module): use_fast_variance: bool = True @compact - def __call__(self, x): + def __call__(self, x, mask=None): """Applies group normalization to the input (arxiv.org/abs/1803.08494). Args: x: the input of shape N...C, where N is a batch dimension and C is a channels dimensions. `...` represents an arbitrary number of extra dimensions that are used to accumulate statistics over. + mask: Binary array of shape broadcastable to `inputs` tensor, indicating + the positions for which the mean and variance should be computed. Returns: Normalized inputs (the same shape as inputs). @@ -670,6 +678,9 @@ def __call__(self, x): group_size = x.shape[-1] // num_groups group_shape = x.shape[:-1] + (num_groups, group_size) + if mask is not None: + mask = mask.reshape(mask.shape[:-1] + (num_groups, group_size)) + mean, var = _compute_stats( x.reshape(group_shape), reduction_axes, @@ -677,6 +688,7 @@ def __call__(self, x): self.axis_name, self.axis_index_groups, use_fast_variance=self.use_fast_variance, + mask=mask, ) mean = jnp.repeat(mean, group_size, axis=-1) var = jnp.repeat(var, group_size, axis=-1) diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 07a582e37b..20962b5a0a 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -137,6 +137,59 @@ def test_pooling_no_batch_dims(self): class NormalizationTest(parameterized.TestCase): + def test_layer_norm_mask(self): + key = random.key(0) + keys = random.split(key) + x = random.normal(keys[0], (3, 4, 5)) + m = random.choice(keys[1], 2, x.shape).astype(bool) + m = m.at[..., :2].set(True) # guarantee at least 2 elements + x = jnp.where(m, x, jnp.nan) + + module = nn.LayerNorm() + y, w = module.init_with_output(key, x, m) + + z = y.mean(-1, where=m) + np.testing.assert_allclose(z, 0, atol=1e-4) + + z = y.var(-1, where=m) + np.testing.assert_allclose(z, 1, atol=1e-4) + + def test_rms_norm_mask(self): + key = random.key(0) + keys = random.split(key) + x = random.normal(keys[0], (3, 4, 5)) + m = random.choice(keys[1], 2, x.shape).astype(bool) + m = m.at[..., :1].set(True) # guarantee at least 1 element + x = jnp.where(m, x, jnp.nan) + + module = nn.RMSNorm() + y, w = module.init_with_output(key, x, m) + + z = np.square(y).mean(-1, where=m) + np.testing.assert_allclose(z, 1, atol=1e-4) + + def test_group_norm_mask(self): + key = random.key(0) + keys = random.split(key) + x = random.normal(keys[0], (13, 3, 5, 7 * 11)) + m = random.choice(keys[1], 2, x.shape).astype(bool) + m = m.at[..., :2].set(True) # guarantee at least 2 elements + x = jnp.where(m, x, jnp.nan) + + module = nn.GroupNorm(7, use_bias=False, use_scale=False) + y, w = module.init_with_output(key, x, m) + + yr = y.reshape((13, 3, 5, 7, 11)) + mr = m.reshape((13, 3, 5, 7, 11)) + + axes = list(range(1, x.ndim - 1)) + [-1] + + z = yr.mean(axes, where=mr) + np.testing.assert_allclose(z, 0, atol=1e-4) + + z = yr.var(axes, where=mr) + np.testing.assert_allclose(z, 1, atol=1e-4) + @parameterized.parameters({'test_mask': True}, {'test_mask': False}) def test_batch_norm(self, test_mask): rng = random.key(0)