From 380b973881fd4683ab1344d12fcec6de829194f4 Mon Sep 17 00:00:00 2001 From: Alexander Hartl Date: Thu, 15 Aug 2024 11:17:49 +0200 Subject: [PATCH 1/3] Fix `ops.stop_gradient` for `KerasTensor` --- keras/src/ops/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras/src/ops/core.py b/keras/src/ops/core.py index 88814ea5b2a..33f9eaf6874 100644 --- a/keras/src/ops/core.py +++ b/keras/src/ops/core.py @@ -609,6 +609,8 @@ def stop_gradient(variable): ... ) >>> var = keras.ops.stop_gradient(var) """ + if any_symbolic_tensors((variable,)): + return StopGradient().symbolic_call(variable) return backend.core.stop_gradient(variable) From 69309a2e35d1a8a9e8c8ae0260c4f8b2d0c7f0ca Mon Sep 17 00:00:00 2001 From: Alexander Hartl Date: Thu, 15 Aug 2024 12:55:39 +0200 Subject: [PATCH 2/3] Added test for `stop_gradient` in functional model --- keras/src/ops/core_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 9610ba06b84..765aba856ac 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -583,6 +583,15 @@ def test_stop_gradient_return(self): y = ops.stop_gradient(x) self.assertAllClose(x, y) + def test_stop_gradient_functional(self): + a = layers.Input(shape=(2,)) + b = layers.Dense(4, kernel_initializer="ones", use_bias=False)(a) + c = layers.Dense(4, kernel_initializer="ones", use_bias=False)(b) + d = ops.stop_gradient(b) + c + model = models.Model(inputs=a, outputs=d) + output = model(ops.convert_to_tensor([[1.0, 2.0]])) + self.assertAllClose(output.numpy(), 15.0) + def test_shape(self): x = ops.ones((2, 3, 7, 1)) self.assertEqual(core.shape(x).__class__, tuple) From 69e63e8ca20e5d900f780897b45ee7f541276240 Mon Sep 17 00:00:00 2001 From: Alexander Hartl Date: Thu, 15 Aug 2024 13:10:23 +0200 Subject: [PATCH 3/3] Fixed tests when using numpy as backend --- keras/src/ops/core_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 765aba856ac..675a2ab357f 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -590,7 +590,7 @@ def test_stop_gradient_functional(self): d = ops.stop_gradient(b) + c model = models.Model(inputs=a, outputs=d) output = model(ops.convert_to_tensor([[1.0, 2.0]])) - self.assertAllClose(output.numpy(), 15.0) + self.assertAllClose(ops.convert_to_numpy(output), 15.0) def test_shape(self): x = ops.ones((2, 3, 7, 1))