From 4cc60645dd7464b12eaf98f49db1058becb746ec Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Mon, 29 Apr 2024 12:14:18 -0700 Subject: [PATCH] Return a tuple from `ops.shape` with the Torch backend. With Torch, `x.shape` returns a `torch.Size`, which is a subclass of `tuple` but can cause different behaviors. In particular `convert_to_tensor` does not work on `torch.Size`. This fixes https://github.com/keras-team/keras/issues/18900 --- keras/src/backend/torch/core.py | 3 ++- keras/src/ops/core_test.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 8cc6c5b5b56..68453255b1f 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -234,7 +234,8 @@ def is_tensor(x): def shape(x): - return x.shape + # Convert from `torch.Size` to plain tuple. + return tuple(x.shape) def cast(x, dtype): diff --git a/keras/src/ops/core_test.py b/keras/src/ops/core_test.py index 31d553853b6..29526c20ad4 100644 --- a/keras/src/ops/core_test.py +++ b/keras/src/ops/core_test.py @@ -338,10 +338,12 @@ def test_stop_gradient_return(self): self.assertAllClose(x, y) def test_shape(self): - x = np.ones((2, 3, 7, 1)) + x = ops.ones((2, 3, 7, 1)) + self.assertEqual(core.shape(x).__class__, tuple) self.assertAllEqual(core.shape(x), (2, 3, 7, 1)) x = KerasTensor((None, 3, None, 1)) + self.assertEqual(core.shape(x).__class__, tuple) self.assertAllEqual(core.shape(x), (None, 3, None, 1)) @pytest.mark.skipif(