From f92c261a61f6c5446944bcc4cf4eb9dec8da4f35 Mon Sep 17 00:00:00 2001 From: Vision Date: Thu, 25 Apr 2024 23:56:37 +0530 Subject: [PATCH 1/2] PSNR --- keras/api/_tf_keras/keras/ops/__init__.py | 2 +- keras/api/_tf_keras/keras/ops/nn/__init__.py | 1 + keras/api/ops/__init__.py | 2 +- keras/api/ops/nn/__init__.py | 1 + keras/src/backend/jax/nn.py | 13 ++++ keras/src/backend/numpy/nn.py | 13 ++++ keras/src/backend/tensorflow/nn.py | 15 ++++ keras/src/backend/torch/nn.py | 17 +++++ keras/src/ops/nn.py | 74 ++++++++++++++++++++ keras/src/ops/nn_test.py | 31 ++++++++ 10 files changed, 167 insertions(+), 2 deletions(-) diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 118c5a692ee..bfcf12392c0 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -72,6 +72,7 @@ from keras.src.ops.nn import multi_hot from keras.src.ops.nn import normalize from keras.src.ops.nn import one_hot +from keras.src.ops.nn import psnr from keras.src.ops.nn import relu from keras.src.ops.nn import relu6 from keras.src.ops.nn import selu @@ -196,7 +197,6 @@ from keras.src.ops.numpy import sin from keras.src.ops.numpy import sinh from keras.src.ops.numpy import size -from keras.src.ops.numpy import slogdet from keras.src.ops.numpy import sort from keras.src.ops.numpy import split from keras.src.ops.numpy import sqrt diff --git a/keras/api/_tf_keras/keras/ops/nn/__init__.py b/keras/api/_tf_keras/keras/ops/nn/__init__.py index 61efc22a570..8c7e3d921b3 100644 --- a/keras/api/_tf_keras/keras/ops/nn/__init__.py +++ b/keras/api/_tf_keras/keras/ops/nn/__init__.py @@ -26,6 +26,7 @@ from keras.src.ops.nn import multi_hot from keras.src.ops.nn import normalize from keras.src.ops.nn import one_hot +from keras.src.ops.nn import psnr from keras.src.ops.nn import relu from keras.src.ops.nn import relu6 from keras.src.ops.nn import selu diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 118c5a692ee..bfcf12392c0 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -72,6 +72,7 @@ from keras.src.ops.nn import multi_hot from keras.src.ops.nn import normalize from keras.src.ops.nn import one_hot +from keras.src.ops.nn import psnr from keras.src.ops.nn import relu from keras.src.ops.nn import relu6 from keras.src.ops.nn import selu @@ -196,7 +197,6 @@ from keras.src.ops.numpy import sin from keras.src.ops.numpy import sinh from keras.src.ops.numpy import size -from keras.src.ops.numpy import slogdet from keras.src.ops.numpy import sort from keras.src.ops.numpy import split from keras.src.ops.numpy import sqrt diff --git a/keras/api/ops/nn/__init__.py b/keras/api/ops/nn/__init__.py index 61efc22a570..8c7e3d921b3 100644 --- a/keras/api/ops/nn/__init__.py +++ b/keras/api/ops/nn/__init__.py @@ -26,6 +26,7 @@ from keras.src.ops.nn import multi_hot from keras.src.ops.nn import normalize from keras.src.ops.nn import one_hot +from keras.src.ops.nn import psnr from keras.src.ops.nn import relu from keras.src.ops.nn import relu6 from keras.src.ops.nn import selu diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 7fc623d831c..3cbc61126fc 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -926,3 +926,16 @@ def ctc_decode( f"Invalid strategy {strategy}. Supported values are " "'greedy' and 'beam_search'." ) + + +def psnr(x1, x2, max_val): + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + max_val = convert_to_tensor(max_val, dtype=x2.dtype) + mse = jnp.mean(jnp.square(x1 - x2)) + psnr = 20 * jnp.log10(max_val) - 10 * jnp.log10(mse) + return psnr diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index 2c27fad23b4..a2d89c323e6 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -967,3 +967,16 @@ def ctc_decode( f"Invalid strategy {strategy}. Supported values are " "'greedy' and 'beam_search'." ) + + +def psnr(x1, x2, max_val): + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + max_val = convert_to_tensor(max_val, dtype=x2.dtype) + mse = np.mean(np.square(x1 - x2)) + psnr = 20 * np.log10(max_val) - 10 * np.log10(mse) + return psnr diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index 2a53c25f1f9..e7317599af0 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -846,3 +846,18 @@ def ctc_decode( decoded_dense = tf.stack(decoded_dense, axis=0) decoded_dense = tf.cast(decoded_dense, "int32") return decoded_dense, scores + + +def psnr(x1, x2, max_val): + from keras.src.backend.tensorflow.numpy import log10 + + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + max_val = convert_to_tensor(max_val, dtype=x2.dtype) + mse = tf.reduce_mean(tf.square(x1 - x2)) + psnr = 20 * log10(max_val) - 10 * log10(mse) + return psnr diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index 4eef2c18977..62749bc163b 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -848,3 +848,20 @@ def ctc_decode( f"Invalid strategy {strategy}. Supported values are " "'greedy' and 'beam_search'." ) + + +def psnr(x1, x2, max_val): + if x1.shape != x2.shape: + raise ValueError( + f"Input shapes {x1.shape} and {x2.shape} must " + "match for PSNR calculation. " + ) + + x1, x2 = ( + convert_to_tensor(x1), + convert_to_tensor(x2), + ) + max_val = convert_to_tensor(max_val, dtype=x1.dtype) + mse = torch.mean((x1 - x2) ** 2) + psnr = 20 * torch.log10(max_val) - 10 * torch.log10(mse) + return psnr diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index b2b248154e3..84d1bbd077e 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -2042,3 +2042,77 @@ def _normalize(x, axis=-1, order=2): norm = backend.linalg.norm(x, ord=order, axis=axis, keepdims=True) denom = backend.numpy.maximum(norm, epsilon) return backend.numpy.divide(x, denom) + + +class PSNR(Operation): + def __init__( + self, + max_val, + ): + super().__init__() + self.max_val = max_val + + def call(self, x1, x2): + return backend.nn.psnr( + x1=x1, + x2=x2, + max_val=self.max_val, + ) + + def compute_output_spec(self, x1, x2): + if len(x1.shape) != len(x2.shape): + raise ValueError("Inputs must have the same rank") + + return KerasTensor(shape=()) + + +@keras_export( + [ + "keras.ops.psnr", + "keras.ops.nn.psnr", + ] +) +def psnr( + x1, + x2, + max_val, +): + """Peak Signal-to-Noise Ratio (PSNR) calculation. + + This function calculates the Peak Signal-to-Noise Ratio between two signals, + `x1` and `x2`. PSNR is a measure of the quality of a reconstructed signal. + The higher the PSNR, the closer the reconstructed signal is to the original + signal. + + Args: + x1: The first input signal. + x2: The second input signal. Must have the same shape as `x1`. + max_val: The maximum possible value in the signals. + + Returns: + float: The PSNR value between `x1` and `x2`. + + Examples: + >>> import numpy as np + >>> from keras import ops + >>> x = np.random.random((2, 4, 4, 3)) + >>> y = np.random.random((2, 4, 4, 3)) + >>> max_val = 1.0 + >>> psnr_value = ops.nn.psnr(x, y, max_val) + >>> psnr_value + 20.0 + """ + if any_symbolic_tensors( + ( + x1, + x2, + ) + ): + return PSNR( + max_val, + ).symbolic_call(x1, x2) + return backend.nn.psnr( + x1, + x2, + max_val, + ) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 5e52a6c263f..433bb0f46af 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -654,6 +654,12 @@ def test_normalize(self): x = KerasTensor([None, 2, 3]) self.assertEqual(knn.normalize(x).shape, (None, 2, 3)) + def test_psnr(self): + x1 = KerasTensor([None, 2, 3]) + x2 = KerasTensor([None, 5, 6]) + out = knn.psnr(x1, x2, max_val=224) + self.assertEqual(out.shape, ()) + class NNOpsStaticShapeTest(testing.TestCase): def test_relu(self): @@ -1114,6 +1120,12 @@ def test_normalize(self): x = KerasTensor([1, 2, 3]) self.assertEqual(knn.normalize(x).shape, (1, 2, 3)) + def test_psnr(self): + x1 = KerasTensor([1, 2, 3]) + x2 = KerasTensor([5, 6, 7]) + out = knn.psnr(x1, x2, max_val=224) + self.assertEqual(out.shape, ()) + class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): def test_relu(self): @@ -2032,6 +2044,25 @@ def test_normalize(self): ], ) + def test_psnr(self): + x1 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + x2 = np.array([[0.2, 0.2, 0.3], [0.4, 0.6, 0.6]]) + max_val = 1.0 + expected_psnr_1 = 20 * np.log10(max_val) - 10 * np.log10( + np.mean(np.square(x1 - x2)) + ) + psnr_1 = knn.psnr(x1, x2, max_val) + self.assertAlmostEqual(psnr_1, expected_psnr_1) + + x3 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + x4 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) + max_val = 1.0 + expected_psnr_2 = 20 * np.log10(max_val) - 10 * np.log10( + np.mean(np.square(x3 - x4)) + ) + psnr_2 = knn.psnr(x3, x4, max_val) + self.assertAlmostEqual(psnr_2, expected_psnr_2) + class NNOpsDtypeTest(testing.TestCase, parameterized.TestCase): """Test the dtype to verify that the behavior matches JAX.""" From 0d31ae1432d21ce3d0a9477f24e7f77d91911299 Mon Sep 17 00:00:00 2001 From: Vision Date: Fri, 26 Apr 2024 00:08:36 +0530 Subject: [PATCH 2/2] Fix --- keras/api/_tf_keras/keras/ops/__init__.py | 1 + keras/api/ops/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index bfcf12392c0..386730deb33 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -197,6 +197,7 @@ from keras.src.ops.numpy import sin from keras.src.ops.numpy import sinh from keras.src.ops.numpy import size +from keras.src.ops.numpy import slogdet from keras.src.ops.numpy import sort from keras.src.ops.numpy import split from keras.src.ops.numpy import sqrt diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index bfcf12392c0..386730deb33 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -197,6 +197,7 @@ from keras.src.ops.numpy import sin from keras.src.ops.numpy import sinh from keras.src.ops.numpy import size +from keras.src.ops.numpy import slogdet from keras.src.ops.numpy import sort from keras.src.ops.numpy import split from keras.src.ops.numpy import sqrt