diff --git a/keras_core/layers/preprocessing/rescaling.py b/keras_core/layers/preprocessing/rescaling.py index 50c90c23d..8772ce741 100644 --- a/keras_core/layers/preprocessing/rescaling.py +++ b/keras_core/layers/preprocessing/rescaling.py @@ -1,3 +1,4 @@ +from keras_core import backend from keras_core.api_export import keras_core_export from keras_core.layers.preprocessing.tf_data_layer import TFDataLayer @@ -40,6 +41,14 @@ def call(self, inputs): dtype = self.compute_dtype scale = self.backend.cast(self.scale, dtype) offset = self.backend.cast(self.offset, dtype) + scale_shape = self.backend.core.shape(scale) + if ( + len(scale_shape) > 0 + and backend.image_data_format() == "channels_first" + ): + scale = self.backend.numpy.reshape( + scale, scale_shape + (1,) * (3 - len(scale_shape)) + ) return self.backend.cast(inputs, dtype) * scale + offset def compute_output_shape(self, input_shape): diff --git a/keras_core/layers/preprocessing/rescaling_test.py b/keras_core/layers/preprocessing/rescaling_test.py index 6bdf61cff..2852204c6 100644 --- a/keras_core/layers/preprocessing/rescaling_test.py +++ b/keras_core/layers/preprocessing/rescaling_test.py @@ -2,6 +2,7 @@ import pytest import tensorflow as tf +from keras_core import backend from keras_core import layers from keras_core import testing @@ -73,3 +74,13 @@ def test_tf_data_compatibility(self): ds = tf.data.Dataset.from_tensor_slices(x).batch(3).map(layer) for output in ds.take(1): output.numpy() + + def test_rescaling_with_channels_first_and_vector_scale(self): + config = backend.image_data_format() + backend.set_image_data_format("channels_first") + layer = layers.Rescaling( + scale=[1.0 / 255, 1.5 / 255, 2.0 / 255], offset=0.5 + ) + x = np.random.random((2, 3, 10, 10)) * 255 + layer(x) + backend.set_image_data_format(config)