Skip to content

Commit

Permalink
Fix drop block (#2250)
Browse files Browse the repository at this point in the history
* fix #2134

* make drop_block_2d keras 3 compatible

* update get_config
  • Loading branch information
divyashreepathihalli authored Dec 15, 2023
1 parent b980f68 commit ce88b46
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 78 deletions.
118 changes: 53 additions & 65 deletions keras_cv/layers/regularization/dropblock_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from keras_cv.backend import config

if config.keras_3():
base_layer = tf.keras.layers.Layer
else:
from tensorflow.keras.__internal__.layers import BaseRandomLayer

base_layer = BaseRandomLayer

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend import random
from keras_cv.utils import conv_utils


@keras_cv_export("keras_cv.layers.DropBlock2D")
class DropBlock2D(base_layer):
class DropBlock2D(keras.layers.Layer):
"""Applies DropBlock regularization to input features.
DropBlock is a form of structured dropout, where units in a contiguous
Expand Down Expand Up @@ -153,100 +145,96 @@ def __init__(
seed=None,
**kwargs,
):
# To-do: remove this once th elayer is ported to keras 3
# https://github.com/keras-team/keras-cv/issues/2136
if config.keras_3():
super().__init__(**kwargs)
if not 0.0 <= rate <= 1.0:
raise ValueError(
"This layer is not yet compatible with Keras 3."
"Please switch to Keras 2 to use this layer."
)
else:
super().__init__(seed=seed, **kwargs)
if not 0.0 <= rate <= 1.0:
raise ValueError(
f"rate must be a number between 0 and 1. "
f"Received: {rate}"
)

self._rate = rate
(
self._dropblock_height,
self._dropblock_width,
) = conv_utils.normalize_tuple(
value=block_size, n=2, name="block_size", allow_zero=False
f"rate must be a number between 0 and 1. " f"Received: {rate}"
)
self.seed = seed

self._rate = rate
(
self._dropblock_height,
self._dropblock_width,
) = conv_utils.normalize_tuple(
value=block_size, n=2, name="block_size", allow_zero=False
)
self.seed = seed
self._random_generator = random.SeedGenerator(self.seed)

def call(self, x, training=None):
if not training or self._rate == 0.0:
return x

_, height, width, _ = tf.split(tf.shape(x), 4)
_, height, width, _ = ops.split(ops.shape(x), 4)

# Unnest scalar values
height = tf.squeeze(height)
width = tf.squeeze(width)
height = ops.squeeze(height)
width = ops.squeeze(width)

dropblock_height = tf.math.minimum(self._dropblock_height, height)
dropblock_width = tf.math.minimum(self._dropblock_width, width)
dropblock_height = ops.minimum(self._dropblock_height, height)
dropblock_width = ops.minimum(self._dropblock_width, width)

gamma = (
self._rate
* tf.cast(width * height, dtype=tf.float32)
/ tf.cast(dropblock_height * dropblock_width, dtype=tf.float32)
/ tf.cast(
* ops.cast(width * height, dtype="float32")
/ ops.cast(dropblock_height * dropblock_width, dtype="float32")
/ ops.cast(
(width - self._dropblock_width + 1)
* (height - self._dropblock_height + 1),
tf.float32,
"float32",
)
)

# Forces the block to be inside the feature map.
w_i, h_i = tf.meshgrid(tf.range(width), tf.range(height))
valid_block = tf.logical_and(
tf.logical_and(
w_i, h_i = ops.meshgrid(ops.arange(width), ops.arange(height))
valid_block = ops.logical_and(
ops.logical_and(
w_i >= int(dropblock_width // 2),
w_i < width - (dropblock_width - 1) // 2,
),
tf.logical_and(
ops.logical_and(
h_i >= int(dropblock_height // 2),
h_i < width - (dropblock_height - 1) // 2,
),
)

valid_block = tf.reshape(valid_block, [1, height, width, 1])
valid_block = ops.reshape(valid_block, [1, height, width, 1])

random_noise = self._random_generator.random_uniform(
tf.shape(x), dtype=tf.float32
random_noise = random.uniform(
ops.shape(x), seed=self._random_generator, dtype="float32"
)
valid_block = tf.cast(valid_block, dtype=tf.float32)
seed_keep_rate = tf.cast(1 - gamma, dtype=tf.float32)
valid_block = ops.cast(valid_block, dtype="float32")
seed_keep_rate = ops.cast(1 - gamma, dtype="float32")
block_pattern = (1 - valid_block + seed_keep_rate + random_noise) >= 1
block_pattern = tf.cast(block_pattern, dtype=tf.float32)
block_pattern = ops.cast(block_pattern, dtype="float32")

window_size = [1, self._dropblock_height, self._dropblock_width, 1]

# Double negative and max_pool is essentially min_pooling
block_pattern = -tf.nn.max_pool(
block_pattern = -ops.max_pool(
-block_pattern,
ksize=window_size,
pool_size=window_size,
strides=[1, 1, 1, 1],
padding="SAME",
)

# Slightly scale the values, to account for magnitude change
percent_ones = tf.cast(
tf.reduce_sum(block_pattern), tf.float32
) / tf.cast(tf.size(block_pattern), tf.float32)
percent_ones = ops.cast(ops.sum(block_pattern), "float32") / ops.cast(
ops.size(block_pattern), "float32"
)
return (
x / tf.cast(percent_ones, x.dtype) * tf.cast(block_pattern, x.dtype)
x
/ ops.cast(percent_ones, x.dtype)
* ops.cast(block_pattern, x.dtype)
)

def get_config(self):
config = {
"rate": self._rate,
"block_size": (self._dropblock_height, self._dropblock_width),
"seed": self.seed,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
config = super().get_config()
config.update(
{
"rate": self._rate,
"block_size": (self._dropblock_height, self._dropblock_width),
"seed": self.seed,
}
)
return config
13 changes: 0 additions & 13 deletions keras_cv/layers/regularization/dropblock_2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import tensorflow as tf

from keras_cv.backend.config import keras_3
from keras_cv.layers.regularization.dropblock_2d import DropBlock2D
from keras_cv.tests.test_case import TestCase


@pytest.mark.skipif(keras_3(), reason="not implemented in keras 3")
class DropBlock2DTest(TestCase):
FEATURE_SHAPE = (1, 14, 14, 256) # Shape of ResNet block group 3
rng = tf.random.Generator.from_non_deterministic_state()
Expand Down Expand Up @@ -87,13 +84,3 @@ def test_input_gets_partially_zeroed_out_with_non_square_block_size(self):
@staticmethod
def _count_zeros(tensor: tf.Tensor) -> tf.Tensor:
return tf.size(tensor) - tf.math.count_nonzero(tensor, dtype=tf.int32)

def test_works_with_xla(self):
dummy_inputs = self.rng.uniform(shape=self.FEATURE_SHAPE)
layer = DropBlock2D(rate=0.1, block_size=7)

@tf.function(jit_compile=True)
def apply(x):
return layer(x, training=True)

apply(dummy_inputs)

0 comments on commit ce88b46

Please sign in to comment.