From c16e1c2e18fbdfe2422f24428ff995827320dc33 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 15 Sep 2023 03:00:59 +0000 Subject: [PATCH 1/6] Add `rsqrt` to numpy backend --- keras_core/backend/numpy/math.py | 5 +++++ keras_core/ops/math_test.py | 4 ---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/keras_core/backend/numpy/math.py b/keras_core/backend/numpy/math.py index fb51f44e2..23746fd7a 100644 --- a/keras_core/backend/numpy/math.py +++ b/keras_core/backend/numpy/math.py @@ -3,6 +3,7 @@ from keras_core.backend import standardize_dtype from keras_core.backend.jax.math import fft as jax_fft from keras_core.backend.jax.math import fft2 as jax_fft2 +from keras_core.backend.jax.math import rsqrt as jax_rsqrt from keras_core.backend.numpy.core import convert_to_tensor from keras_core.utils.module_utils import scipy @@ -298,3 +299,7 @@ def istft( else: end = expected_output_len return x[..., start:end] + + +def rsqrt(x): + return np.array(jax_rsqrt(x)) diff --git a/keras_core/ops/math_test.py b/keras_core/ops/math_test.py index dddb67103..9fdac5106 100644 --- a/keras_core/ops/math_test.py +++ b/keras_core/ops/math_test.py @@ -831,10 +831,6 @@ def test_istft( ref = ref[..., truncated_len:-truncated_len] self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5) - @pytest.mark.skipif( - backend.backend() == "numpy", - reason="Numpy does not support rsqrt.", - ) def test_rsqrt(self): x = np.array([[1, 4, 9], [16, 25, 36]], dtype="float32") self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x)) From c1cbbf03c387bdd767caf96bba383c56cc52d3e8 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 15 Sep 2023 03:01:39 +0000 Subject: [PATCH 2/6] Improve normalization --- .../normalization/batch_normalization.py | 31 ++++++++------ .../normalization/group_normalization.py | 31 +++++--------- .../normalization/layer_normalization.py | 41 +++++++++---------- .../normalization/unit_normalization.py | 2 +- 4 files changed, 50 insertions(+), 55 deletions(-) diff --git a/keras_core/layers/normalization/batch_normalization.py b/keras_core/layers/normalization/batch_normalization.py index 0812f6843..652db9e04 100644 --- a/keras_core/layers/normalization/batch_normalization.py +++ b/keras_core/layers/normalization/batch_normalization.py @@ -201,21 +201,21 @@ def call(self, inputs, training=None, mask=None): mean, variance = ops.moments( inputs, axes=self._reduction_axes, keepdims=True ) - outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon) - mean = ops.squeeze(mean, self._reduction_axes) - variance = ops.squeeze(variance, self._reduction_axes) moving_mean = ops.cast(self.moving_mean, inputs.dtype) moving_variance = ops.cast(self.moving_variance, inputs.dtype) self.moving_mean.assign( ops.cast( - moving_mean * self.momentum + mean * (1.0 - self.momentum), + moving_mean * self.momentum + + ops.squeeze(mean, self._reduction_axes) + * (1.0 - self.momentum), inputs.dtype, ) ) self.moving_variance.assign( ops.cast( moving_variance * self.momentum - + variance * (1.0 - self.momentum), + + ops.squeeze(variance, self._reduction_axes) + * (1.0 - self.momentum), inputs.dtype, ) ) @@ -224,17 +224,24 @@ def call(self, inputs, training=None, mask=None): moving_variance = ops.cast(self.moving_variance, inputs.dtype) moving_mean = ops.reshape(moving_mean, broadcast_shape) moving_variance = ops.reshape(moving_variance, broadcast_shape) - outputs = (inputs - moving_mean) / ops.sqrt( - moving_variance + self.epsilon - ) + mean = moving_mean + variance = moving_variance + + inv = ops.rsqrt(variance + self.epsilon) + res = -mean * inv + if self.scale: gamma = ops.reshape(self.gamma, broadcast_shape) - gamma = ops.cast(gamma, outputs.dtype) - outputs = outputs * gamma + gamma = ops.cast(gamma, inputs.dtype) + inv = inv * gamma if self.center: beta = ops.reshape(self.beta, broadcast_shape) - beta = ops.cast(beta, outputs.dtype) - outputs = outputs + beta + beta = ops.cast(beta, inputs.dtype) + res = res + beta + + # Note: Folding BatchNormalization depends on the precise order of ops + # that are generated by the expression below + outputs = inputs * inv + res return ops.cast(outputs, input_dtype) def get_config(self): diff --git a/keras_core/layers/normalization/group_normalization.py b/keras_core/layers/normalization/group_normalization.py index 94b56b05f..8dd42ce52 100644 --- a/keras_core/layers/normalization/group_normalization.py +++ b/keras_core/layers/normalization/group_normalization.py @@ -171,37 +171,26 @@ def _apply_normalization(self, reshaped_inputs, input_shape): axis = -2 if self.axis == -1 else self.axis - 1 group_reduction_axes.pop(axis) + broadcast_shape = self._create_broadcast_shape(input_shape) mean, variance = ops.moments( reshaped_inputs, axes=group_reduction_axes, keepdims=True ) - gamma, beta = self._get_reshaped_weights(input_shape) # Compute the batch normalization. - inv = 1 / ops.sqrt(variance + self.epsilon) - - if gamma is not None: - inv = ops.multiply(inv, gamma) + inv = ops.rsqrt(variance + self.epsilon) + res = -mean * inv - if beta is not None: - x = beta - ops.multiply(mean, inv) - else: - x = -ops.multiply(mean, inv) - - normalized_inputs = reshaped_inputs * ops.cast( - inv, reshaped_inputs.dtype - ) + ops.cast(x, reshaped_inputs.dtype) - normalized_inputs = ops.cast(normalized_inputs, reshaped_inputs.dtype) - return normalized_inputs - - def _get_reshaped_weights(self, input_shape): - broadcast_shape = self._create_broadcast_shape(input_shape) - gamma = None - beta = None if self.scale: gamma = ops.reshape(self.gamma, broadcast_shape) + gamma = ops.cast(gamma, reshaped_inputs.dtype) + inv = inv * gamma if self.center: beta = ops.reshape(self.beta, broadcast_shape) - return gamma, beta + beta = ops.cast(beta, reshaped_inputs.dtype) + res = res + beta + + normalized_inputs = reshaped_inputs * inv + res + return normalized_inputs def _create_broadcast_shape(self, input_shape): broadcast_shape = [1] * len(input_shape) diff --git a/keras_core/layers/normalization/layer_normalization.py b/keras_core/layers/normalization/layer_normalization.py index 7d8381c27..01e6c0d5d 100644 --- a/keras_core/layers/normalization/layer_normalization.py +++ b/keras_core/layers/normalization/layer_normalization.py @@ -206,33 +206,32 @@ def _broadcast(v): if self.rms_scaling: # Calculate outputs with only variance and gamma if rms scaling # is enabled - # Calculate the variance along last axis (layer activations). + # Calculate the variance along self.axis (layer activations). variance = ops.var(inputs, axis=self.axis, keepdims=True) - inv = 1 / ops.sqrt(variance + self.epsilon) - outputs = inputs * ops.cast(inv, inputs.dtype) * self.gamma + inv = ops.rsqrt(variance + self.epsilon) + + gamma = _broadcast(self.gamma) + gamma = ops.cast(gamma, inputs.dtype) + + outputs = inputs * inv * gamma else: - # Calculate the mean & variance along last axis (layer activations). + # Calculate the mean & variance along self.axis (layer activations). mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True) - inv = 1 / ops.sqrt(variance + self.epsilon) - scale, offset = _broadcast(self.gamma), _broadcast(self.beta) - if scale is not None: - scale = ops.cast(scale, inputs.dtype) - inv = inv * scale - x = -mean * inv - if offset is not None: - offset = ops.cast(offset, inputs.dtype) - x = offset + x - - outputs = inputs * ops.cast(inv, inputs.dtype) + ops.cast( - x, inputs.dtype - ) + inv = ops.rsqrt(variance + self.epsilon) + res = -mean * inv - outputs = ops.cast(outputs, input_dtype) + if self.gamma is not None: + gamma = _broadcast(self.gamma) + gamma = ops.cast(gamma, inputs.dtype) + inv = inv * gamma + if self.beta is not None: + beta = _broadcast(self.beta) + beta = ops.cast(beta, inputs.dtype) + res = res + beta - # If some components of the shape got lost due to adjustments, fix that. - outputs = ops.reshape(outputs, ops.shape(inputs)) + outputs = inputs * inv + res - return outputs + return ops.cast(outputs, input_dtype) def compute_output_shape(self, input_shape): return input_shape diff --git a/keras_core/layers/normalization/unit_normalization.py b/keras_core/layers/normalization/unit_normalization.py index 7b44290fc..33553ada0 100644 --- a/keras_core/layers/normalization/unit_normalization.py +++ b/keras_core/layers/normalization/unit_normalization.py @@ -45,7 +45,7 @@ def call(self, inputs): x = ops.cast(inputs, self.compute_dtype) square_sum = ops.sum(ops.square(x), axis=self.axis, keepdims=True) - x_inv_norm = 1 / ops.sqrt(ops.maximum(square_sum, 1e-12)) + x_inv_norm = ops.rsqrt(ops.maximum(square_sum, 1e-12)) return ops.multiply(x, x_inv_norm) def compute_output_shape(self, input_shape): From bf0fca86aac289e9a6a2012122a2c40e88968f4b Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 15 Sep 2023 04:31:15 +0000 Subject: [PATCH 3/6] Fix order bug --- keras_core/layers/normalization/batch_normalization.py | 4 ++-- keras_core/layers/normalization/group_normalization.py | 4 ++-- keras_core/layers/normalization/layer_normalization.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/keras_core/layers/normalization/batch_normalization.py b/keras_core/layers/normalization/batch_normalization.py index 652db9e04..5bb3a4eb7 100644 --- a/keras_core/layers/normalization/batch_normalization.py +++ b/keras_core/layers/normalization/batch_normalization.py @@ -228,12 +228,12 @@ def call(self, inputs, training=None, mask=None): variance = moving_variance inv = ops.rsqrt(variance + self.epsilon) - res = -mean * inv - if self.scale: gamma = ops.reshape(self.gamma, broadcast_shape) gamma = ops.cast(gamma, inputs.dtype) inv = inv * gamma + + res = -mean * inv if self.center: beta = ops.reshape(self.beta, broadcast_shape) beta = ops.cast(beta, inputs.dtype) diff --git a/keras_core/layers/normalization/group_normalization.py b/keras_core/layers/normalization/group_normalization.py index 8dd42ce52..7931f3fb9 100644 --- a/keras_core/layers/normalization/group_normalization.py +++ b/keras_core/layers/normalization/group_normalization.py @@ -178,12 +178,12 @@ def _apply_normalization(self, reshaped_inputs, input_shape): # Compute the batch normalization. inv = ops.rsqrt(variance + self.epsilon) - res = -mean * inv - if self.scale: gamma = ops.reshape(self.gamma, broadcast_shape) gamma = ops.cast(gamma, reshaped_inputs.dtype) inv = inv * gamma + + res = -mean * inv if self.center: beta = ops.reshape(self.beta, broadcast_shape) beta = ops.cast(beta, reshaped_inputs.dtype) diff --git a/keras_core/layers/normalization/layer_normalization.py b/keras_core/layers/normalization/layer_normalization.py index 01e6c0d5d..a4876f69b 100644 --- a/keras_core/layers/normalization/layer_normalization.py +++ b/keras_core/layers/normalization/layer_normalization.py @@ -218,12 +218,12 @@ def _broadcast(v): # Calculate the mean & variance along self.axis (layer activations). mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True) inv = ops.rsqrt(variance + self.epsilon) - res = -mean * inv - if self.gamma is not None: gamma = _broadcast(self.gamma) gamma = ops.cast(gamma, inputs.dtype) inv = inv * gamma + + res = -mean * inv if self.beta is not None: beta = _broadcast(self.beta) beta = ops.cast(beta, inputs.dtype) From f134fdf612681d60eade1a2b2a8f0725bb26b7b7 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 15 Sep 2023 05:17:48 +0000 Subject: [PATCH 4/6] Update LayerNormalization --- .../layers/normalization/layer_normalization.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/keras_core/layers/normalization/layer_normalization.py b/keras_core/layers/normalization/layer_normalization.py index a4876f69b..f4cb231d4 100644 --- a/keras_core/layers/normalization/layer_normalization.py +++ b/keras_core/layers/normalization/layer_normalization.py @@ -210,22 +210,19 @@ def _broadcast(v): variance = ops.var(inputs, axis=self.axis, keepdims=True) inv = ops.rsqrt(variance + self.epsilon) - gamma = _broadcast(self.gamma) - gamma = ops.cast(gamma, inputs.dtype) - - outputs = inputs * inv * gamma + outputs = inputs * inv * ops.cast(self.gamma, inputs.dtype) else: # Calculate the mean & variance along self.axis (layer activations). mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True) + gamma, beta = _broadcast(self.gamma), _broadcast(self.beta) + inv = ops.rsqrt(variance + self.epsilon) - if self.gamma is not None: - gamma = _broadcast(self.gamma) + if gamma is not None: gamma = ops.cast(gamma, inputs.dtype) inv = inv * gamma res = -mean * inv - if self.beta is not None: - beta = _broadcast(self.beta) + if beta is not None: beta = ops.cast(beta, inputs.dtype) res = res + beta From 254ba1073d4254b6f643b8675fd3dadc4e991e02 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Fri, 15 Sep 2023 06:08:00 +0000 Subject: [PATCH 5/6] Improve unit test coverage --- .../normalization/group_normalization_test.py | 45 +++++++++++++++++++ .../normalization/layer_normalization_test.py | 10 +++++ .../spectral_normalization_test.py | 26 +++++++++++ .../normalization/unit_normalization_test.py | 10 +++++ 4 files changed, 91 insertions(+) diff --git a/keras_core/layers/normalization/group_normalization_test.py b/keras_core/layers/normalization/group_normalization_test.py index 59a65edd9..1780f62a7 100644 --- a/keras_core/layers/normalization/group_normalization_test.py +++ b/keras_core/layers/normalization/group_normalization_test.py @@ -41,6 +41,51 @@ def test_groupnorm(self): supports_masking=True, ) + def test_undefined_dim_error(self): + inputs = layers.Input(shape=(2, 2, 2, None)) + layer = layers.GroupNormalization() + with self.assertRaisesRegex( + ValueError, + ( + "input tensor should have a defined dimension but the layer " + "received an input with shape" + ), + ): + _ = layer(inputs) + + def test_groups_bigger_than_dim_error(self): + inputs = np.ones(shape=(2, 2, 2, 4)) + layer = layers.GroupNormalization(groups=5) + with self.assertRaisesRegex( + ValueError, + "cannot be more than the number of channels", + ): + _ = layer(inputs) + + def test_groups_not_a_multiple_of_dim_error(self): + inputs = np.ones(shape=(2, 2, 2, 4)) + layer = layers.GroupNormalization(groups=3) + with self.assertRaisesRegex( + ValueError, + "must be a multiple of the number of channels", + ): + _ = layer(inputs) + + def test_groups_instance_norm(self): + # GroupNormalization with groups=-1 will become InstanceNormalization + instance_norm_layer_1 = layers.GroupNormalization( + groups=-1, axis=-1, scale=False, center=False + ) + instance_norm_layer_2 = layers.GroupNormalization( + groups=4, axis=-1, scale=False, center=False + ) + inputs = np.array([[[-1.0, 1.0, 0, 2.0], [1.0, 3.0, -4, -2.0]]]) + + outputs_1 = instance_norm_layer_1(inputs) + outputs_2 = instance_norm_layer_2(inputs) + + self.assertAllClose(outputs_1, outputs_2) + def test_correctness_instance_norm(self): instance_norm_layer = layers.GroupNormalization( groups=4, axis=-1, scale=False, center=False diff --git a/keras_core/layers/normalization/layer_normalization_test.py b/keras_core/layers/normalization/layer_normalization_test.py index c5d66f89f..94b039db0 100644 --- a/keras_core/layers/normalization/layer_normalization_test.py +++ b/keras_core/layers/normalization/layer_normalization_test.py @@ -83,6 +83,16 @@ def test_ln_basics(self): supports_masking=True, ) + def test_invalid_axis(self): + with self.assertRaisesRegex( + TypeError, + ( + "Expected an int or a list/tuple of ints for the argument " + "'axis'" + ), + ): + layers.LayerNormalization(axis={"axis": -1}) + def test_correctness(self): layer = layers.LayerNormalization(dtype="float32") layer.build(input_shape=(2, 2, 2)) diff --git a/keras_core/layers/normalization/spectral_normalization_test.py b/keras_core/layers/normalization/spectral_normalization_test.py index ae923e379..488beae24 100644 --- a/keras_core/layers/normalization/spectral_normalization_test.py +++ b/keras_core/layers/normalization/spectral_normalization_test.py @@ -20,6 +20,32 @@ def test_basic_spectralnorm(self): expected_num_losses=0, supports_masking=False, ) + self.run_layer_test( + layers.SpectralNormalization, + init_kwargs={"layer": layers.Embedding(10, 4)}, + input_data=np.random.randint(10, size=(10,)), + expected_output_shape=(10, 4), + expected_num_trainable_weights=1, + expected_num_non_trainable_weights=1, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=False, + run_training_check=False, + ) + + def test_invalid_power_iterations(self): + with self.assertRaisesRegex( + ValueError, "`power_iterations` should be greater than zero." + ): + layers.SpectralNormalization(layers.Dense(2), power_iterations=0) + + def test_invalid_layer(self): + layer = layers.SpectralNormalization(layers.ReLU()) + inputs = np.ones(shape=(4, 2)) + with self.assertRaisesRegex( + ValueError, "object has no attribute 'kernel' nor 'embeddings'" + ): + layer(inputs) def test_apply_layer(self): images = np.ones((1, 2, 2, 1)) diff --git a/keras_core/layers/normalization/unit_normalization_test.py b/keras_core/layers/normalization/unit_normalization_test.py index 8a3e6b027..94235e855 100644 --- a/keras_core/layers/normalization/unit_normalization_test.py +++ b/keras_core/layers/normalization/unit_normalization_test.py @@ -29,6 +29,16 @@ def test_un_basics(self): supports_masking=True, ) + def test_invalid_axis(self): + with self.assertRaisesRegex( + TypeError, + ( + "Invalid value for `axis` argument: expected an int or a " + "list/tuple of ints." + ), + ): + layers.UnitNormalization(axis={"axis": -1}) + def test_correctness(self): layer = layers.UnitNormalization(axis=-1) inputs = np.random.normal(size=(2, 3)) From 10e4a031cf3336d7c7c6612b6cce82c1be2c1f7e Mon Sep 17 00:00:00 2001 From: chiuhongyu <20734616+james77777778@users.noreply.github.com> Date: Sat, 16 Sep 2023 14:21:39 +0800 Subject: [PATCH 6/6] Use np native --- keras_core/backend/numpy/math.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_core/backend/numpy/math.py b/keras_core/backend/numpy/math.py index 23746fd7a..2e1c2cfcd 100644 --- a/keras_core/backend/numpy/math.py +++ b/keras_core/backend/numpy/math.py @@ -3,7 +3,6 @@ from keras_core.backend import standardize_dtype from keras_core.backend.jax.math import fft as jax_fft from keras_core.backend.jax.math import fft2 as jax_fft2 -from keras_core.backend.jax.math import rsqrt as jax_rsqrt from keras_core.backend.numpy.core import convert_to_tensor from keras_core.utils.module_utils import scipy @@ -302,4 +301,4 @@ def istft( def rsqrt(x): - return np.array(jax_rsqrt(x)) + return 1.0 / np.sqrt(x)