From 40872596b3a7e8c0f1c9556844ce9c1dcce1dfaa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 25 Apr 2022 12:12:48 -0700 Subject: [PATCH] Update RandomCrop KPL to use the new BaseImageAugmentationLayer PiperOrigin-RevId: 444336893 --- ...tensorflow.keras.layers.-random-crop.pbtxt | 21 +++++++ ...erimental.preprocessing.-random-crop.pbtxt | 21 +++++++ .../preprocessing/image_preprocessing.py | 57 +++++++++++++------ .../preprocessing/image_preprocessing_test.py | 40 +++++++++++++ 4 files changed, 123 insertions(+), 16 deletions(-) diff --git a/keras/api/golden/v2/tensorflow.keras.layers.-random-crop.pbtxt b/keras/api/golden/v2/tensorflow.keras.layers.-random-crop.pbtxt index fe62bb33508..12e412fc9c8 100644 --- a/keras/api/golden/v2/tensorflow.keras.layers.-random-crop.pbtxt +++ b/keras/api/golden/v2/tensorflow.keras.layers.-random-crop.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.RandomCrop" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -12,6 +13,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "auto_vectorize" + mtype: "" + } member { name: "compute_dtype" mtype: "" @@ -152,6 +157,18 @@ tf_class { name: "add_weight" argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregationV2.NONE\'], " } + member_method { + name: "augment_bounding_box" + argspec: "args=[\'self\', \'bounding_box\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "augment_image" + argspec: "args=[\'self\', \'image\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "augment_label" + argspec: "args=[\'self\', \'label\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "build" argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" @@ -212,6 +229,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_random_transformation" + argspec: "args=[\'self\', \'image\', \'label\', \'bounding_box\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } member_method { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/keras/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt b/keras/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt index b6e2b67f727..cc14f7fd9ac 100644 --- a/keras/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt +++ b/keras/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-random-crop.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.keras.layers.experimental.preprocessing.RandomCrop" tf_class { is_instance: "" + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -12,6 +13,10 @@ tf_class { name: "activity_regularizer" mtype: "" } + member { + name: "auto_vectorize" + mtype: "" + } member { name: "compute_dtype" mtype: "" @@ -152,6 +157,18 @@ tf_class { name: "add_weight" argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregationV2.NONE\'], " } + member_method { + name: "augment_bounding_box" + argspec: "args=[\'self\', \'bounding_box\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "augment_image" + argspec: "args=[\'self\', \'image\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "augment_label" + argspec: "args=[\'self\', \'label\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "build" argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" @@ -212,6 +229,10 @@ tf_class { name: "get_output_shape_at" argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_random_transformation" + argspec: "args=[\'self\', \'image\', \'label\', \'bounding_box\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } member_method { name: "get_weights" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" diff --git a/keras/layers/preprocessing/image_preprocessing.py b/keras/layers/preprocessing/image_preprocessing.py index e7da231dbd3..6e40cb2584e 100644 --- a/keras/layers/preprocessing/image_preprocessing.py +++ b/keras/layers/preprocessing/image_preprocessing.py @@ -444,7 +444,7 @@ def _ensure_inputs_are_compute_dtype(self, inputs): @keras_export('keras.layers.RandomCrop', 'keras.layers.experimental.preprocessing.RandomCrop', v1=[]) -class RandomCrop(base_layer.BaseRandomLayer): +class RandomCrop(BaseImageAugmentationLayer): """A preprocessing layer which randomly crops images during training. During training, this layer will randomly choose a location to crop images @@ -486,31 +486,56 @@ def __init__(self, height, width, seed=None, **kwargs): self.seed = seed def call(self, inputs, training=True): - inputs = utils.ensure_tensor(inputs, dtype=self.compute_dtype) + inputs = self._ensure_inputs_are_compute_dtype(inputs) + inputs, is_dict = self._format_inputs(inputs) if training: - input_shape = tf.shape(inputs) - h_diff = input_shape[H_AXIS] - self.height - w_diff = input_shape[W_AXIS] - self.width - return tf.cond( - tf.reduce_all((h_diff >= 0, w_diff >= 0)), - lambda: self._random_crop(inputs), - lambda: self._resize(inputs)) + images = inputs['images'] + if images.shape.rank == 3: + return self._format_output(self._augment(inputs), is_dict) + elif images.shape.rank == 4: + return self._format_output(self._batch_augment(inputs), is_dict) + else: + raise ValueError('Image augmentation layers are expecting inputs to be ' + 'rank 3 (HWC) or 4D (NHWC) tensors. Got shape: ' + f'{images.shape}') else: - return self._resize(inputs) + output = inputs + output['images'] = self._resize(inputs['images']) + # self._resize() returns valid results for both batched and unbatched + # input + return self._format_output(output, is_dict) - def _random_crop(self, inputs): - input_shape = tf.shape(inputs) + def get_random_transformation(self, + image=None, + label=None, + bounding_box=None): + input_shape = tf.shape(image) h_diff = input_shape[H_AXIS] - self.height w_diff = input_shape[W_AXIS] - self.width dtype = input_shape.dtype rands = self._random_generator.random_uniform([2], 0, dtype.max, dtype) h_start = rands[0] % (h_diff + 1) w_start = rands[1] % (w_diff + 1) - return tf.image.crop_to_bounding_box(inputs, h_start, w_start, - self.height, self.width) + return {'top': h_start, 'left': w_start} + + def augment_image(self, image, transformation=None): + if transformation is None: + transformation = self.get_random_transformation(image) + input_shape = tf.shape(image) + h_diff = input_shape[H_AXIS] - self.height + w_diff = input_shape[W_AXIS] - self.width + return tf.cond( + tf.reduce_all((h_diff >= 0, w_diff >= 0)), + lambda: self._crop(image, transformation), lambda: self._resize(image)) + + def _crop(self, image, transformation): + top = transformation['top'] + left = transformation['left'] + return tf.image.crop_to_bounding_box(image, top, left, self.height, + self.width) - def _resize(self, inputs): - outputs = image_utils.smart_resize(inputs, [self.height, self.width]) + def _resize(self, image): + outputs = image_utils.smart_resize(image, [self.height, self.width]) # smart_resize will always output float32, so we need to re-cast. return tf.cast(outputs, self.compute_dtype) diff --git a/keras/layers/preprocessing/image_preprocessing_test.py b/keras/layers/preprocessing/image_preprocessing_test.py index 1b2e506a09f..8fe8e0db451 100644 --- a/keras/layers/preprocessing/image_preprocessing_test.py +++ b/keras/layers/preprocessing/image_preprocessing_test.py @@ -423,6 +423,46 @@ def test_unbatched_image(self): actual_output = layer(inp, training=True) self.assertAllClose(inp[2:10, 2:10, :], actual_output) + def test_batched_input(self): + np.random.seed(1337) + inp = np.random.random((20, 16, 16, 3)) + mock_offset = [2, 2] + with test_utils.use_gpu(): + layer = image_preprocessing.RandomCrop(8, 8) + with tf.compat.v1.test.mock.patch.object( + layer._random_generator, 'random_uniform', return_value=mock_offset): + actual_output = layer(inp, training=True) + self.assertAllClose(inp[:, 2:10, 2:10, :], actual_output) + + def test_augment_image(self): + np.random.seed(1337) + inp = np.random.random((16, 16, 3)) + mock_offset = [2, 2] + with test_utils.use_gpu(): + layer = image_preprocessing.RandomCrop(8, 8) + with tf.compat.v1.test.mock.patch.object( + layer._random_generator, 'random_uniform', return_value=mock_offset): + actual_output = layer.augment_image(inp) + self.assertAllClose(inp[2:10, 2:10, :], actual_output) + + def test_training_false(self): + np.random.seed(1337) + height, width = 4, 6 + inp = np.random.random((12, 8, 16, 3)) + inp_dict = {'images': inp} + with test_utils.use_gpu(): + layer = image_preprocessing.RandomCrop(height, width) + # test wih tensor input + actual_output = layer(inp, training=False) + resized_inp = tf.image.resize(inp, size=[4, 8]) + expected_output = resized_inp[:, :, 1:7, :] + self.assertAllClose(expected_output, actual_output) + # test with dictionary input + actual_output = layer(inp_dict, training=False) + resized_inp = tf.image.resize(inp, size=[4, 8]) + expected_output = resized_inp[:, :, 1:7, :] + self.assertAllClose(expected_output, actual_output['images']) + @test_utils.run_v2_only def test_uint8_input(self): inputs = keras.Input((128, 128, 3), batch_size=2, dtype=tf.uint8)