Skip to content

Commit

Permalink
Update RandomCrop KPL to use the new BaseImageAugmentationLayer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 444336893
  • Loading branch information
tensorflower-gardener committed Apr 25, 2022
1 parent 9613004 commit 4087259
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 16 deletions.
21 changes: 21 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.layers.-random-crop.pbtxt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.RandomCrop"
tf_class {
is_instance: "<class \'keras.layers.preprocessing.image_preprocessing.RandomCrop\'>"
is_instance: "<class \'keras.layers.preprocessing.image_preprocessing.BaseImageAugmentationLayer\'>"
is_instance: "<class \'keras.engine.base_layer.BaseRandomLayer\'>"
is_instance: "<class \'keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
Expand All @@ -12,6 +13,10 @@ tf_class {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "auto_vectorize"
mtype: "<type \'property\'>"
}
member {
name: "compute_dtype"
mtype: "<type \'property\'>"
Expand Down Expand Up @@ -152,6 +157,18 @@ tf_class {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregationV2.NONE\'], "
}
member_method {
name: "augment_bounding_box"
argspec: "args=[\'self\', \'bounding_box\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "augment_image"
argspec: "args=[\'self\', \'image\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "augment_label"
argspec: "args=[\'self\', \'label\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
Expand Down Expand Up @@ -212,6 +229,10 @@ tf_class {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_random_transformation"
argspec: "args=[\'self\', \'image\', \'label\', \'bounding_box\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.experimental.preprocessing.RandomCrop"
tf_class {
is_instance: "<class \'keras.layers.preprocessing.image_preprocessing.RandomCrop\'>"
is_instance: "<class \'keras.layers.preprocessing.image_preprocessing.BaseImageAugmentationLayer\'>"
is_instance: "<class \'keras.engine.base_layer.BaseRandomLayer\'>"
is_instance: "<class \'keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
Expand All @@ -12,6 +13,10 @@ tf_class {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "auto_vectorize"
mtype: "<type \'property\'>"
}
member {
name: "compute_dtype"
mtype: "<type \'property\'>"
Expand Down Expand Up @@ -152,6 +157,18 @@ tf_class {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregationV2.NONE\'], "
}
member_method {
name: "augment_bounding_box"
argspec: "args=[\'self\', \'bounding_box\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "augment_image"
argspec: "args=[\'self\', \'image\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "augment_label"
argspec: "args=[\'self\', \'label\', \'transformation\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
Expand Down Expand Up @@ -212,6 +229,10 @@ tf_class {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_random_transformation"
argspec: "args=[\'self\', \'image\', \'label\', \'bounding_box\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
Expand Down
57 changes: 41 additions & 16 deletions keras/layers/preprocessing/image_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def _ensure_inputs_are_compute_dtype(self, inputs):
@keras_export('keras.layers.RandomCrop',
'keras.layers.experimental.preprocessing.RandomCrop',
v1=[])
class RandomCrop(base_layer.BaseRandomLayer):
class RandomCrop(BaseImageAugmentationLayer):
"""A preprocessing layer which randomly crops images during training.
During training, this layer will randomly choose a location to crop images
Expand Down Expand Up @@ -486,31 +486,56 @@ def __init__(self, height, width, seed=None, **kwargs):
self.seed = seed

def call(self, inputs, training=True):
inputs = utils.ensure_tensor(inputs, dtype=self.compute_dtype)
inputs = self._ensure_inputs_are_compute_dtype(inputs)
inputs, is_dict = self._format_inputs(inputs)
if training:
input_shape = tf.shape(inputs)
h_diff = input_shape[H_AXIS] - self.height
w_diff = input_shape[W_AXIS] - self.width
return tf.cond(
tf.reduce_all((h_diff >= 0, w_diff >= 0)),
lambda: self._random_crop(inputs),
lambda: self._resize(inputs))
images = inputs['images']
if images.shape.rank == 3:
return self._format_output(self._augment(inputs), is_dict)
elif images.shape.rank == 4:
return self._format_output(self._batch_augment(inputs), is_dict)
else:
raise ValueError('Image augmentation layers are expecting inputs to be '
'rank 3 (HWC) or 4D (NHWC) tensors. Got shape: '
f'{images.shape}')
else:
return self._resize(inputs)
output = inputs
output['images'] = self._resize(inputs['images'])
# self._resize() returns valid results for both batched and unbatched
# input
return self._format_output(output, is_dict)

def _random_crop(self, inputs):
input_shape = tf.shape(inputs)
def get_random_transformation(self,
image=None,
label=None,
bounding_box=None):
input_shape = tf.shape(image)
h_diff = input_shape[H_AXIS] - self.height
w_diff = input_shape[W_AXIS] - self.width
dtype = input_shape.dtype
rands = self._random_generator.random_uniform([2], 0, dtype.max, dtype)
h_start = rands[0] % (h_diff + 1)
w_start = rands[1] % (w_diff + 1)
return tf.image.crop_to_bounding_box(inputs, h_start, w_start,
self.height, self.width)
return {'top': h_start, 'left': w_start}

def augment_image(self, image, transformation=None):
if transformation is None:
transformation = self.get_random_transformation(image)
input_shape = tf.shape(image)
h_diff = input_shape[H_AXIS] - self.height
w_diff = input_shape[W_AXIS] - self.width
return tf.cond(
tf.reduce_all((h_diff >= 0, w_diff >= 0)),
lambda: self._crop(image, transformation), lambda: self._resize(image))

def _crop(self, image, transformation):
top = transformation['top']
left = transformation['left']
return tf.image.crop_to_bounding_box(image, top, left, self.height,
self.width)

def _resize(self, inputs):
outputs = image_utils.smart_resize(inputs, [self.height, self.width])
def _resize(self, image):
outputs = image_utils.smart_resize(image, [self.height, self.width])
# smart_resize will always output float32, so we need to re-cast.
return tf.cast(outputs, self.compute_dtype)

Expand Down
40 changes: 40 additions & 0 deletions keras/layers/preprocessing/image_preprocessing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,46 @@ def test_unbatched_image(self):
actual_output = layer(inp, training=True)
self.assertAllClose(inp[2:10, 2:10, :], actual_output)

def test_batched_input(self):
np.random.seed(1337)
inp = np.random.random((20, 16, 16, 3))
mock_offset = [2, 2]
with test_utils.use_gpu():
layer = image_preprocessing.RandomCrop(8, 8)
with tf.compat.v1.test.mock.patch.object(
layer._random_generator, 'random_uniform', return_value=mock_offset):
actual_output = layer(inp, training=True)
self.assertAllClose(inp[:, 2:10, 2:10, :], actual_output)

def test_augment_image(self):
np.random.seed(1337)
inp = np.random.random((16, 16, 3))
mock_offset = [2, 2]
with test_utils.use_gpu():
layer = image_preprocessing.RandomCrop(8, 8)
with tf.compat.v1.test.mock.patch.object(
layer._random_generator, 'random_uniform', return_value=mock_offset):
actual_output = layer.augment_image(inp)
self.assertAllClose(inp[2:10, 2:10, :], actual_output)

def test_training_false(self):
np.random.seed(1337)
height, width = 4, 6
inp = np.random.random((12, 8, 16, 3))
inp_dict = {'images': inp}
with test_utils.use_gpu():
layer = image_preprocessing.RandomCrop(height, width)
# test wih tensor input
actual_output = layer(inp, training=False)
resized_inp = tf.image.resize(inp, size=[4, 8])
expected_output = resized_inp[:, :, 1:7, :]
self.assertAllClose(expected_output, actual_output)
# test with dictionary input
actual_output = layer(inp_dict, training=False)
resized_inp = tf.image.resize(inp, size=[4, 8])
expected_output = resized_inp[:, :, 1:7, :]
self.assertAllClose(expected_output, actual_output['images'])

@test_utils.run_v2_only
def test_uint8_input(self):
inputs = keras.Input((128, 128, 3), batch_size=2, dtype=tf.uint8)
Expand Down

0 comments on commit 4087259

Please sign in to comment.