Skip to content

Commit

Permalink
Add Efficientnet Edge presets (#1976)
Browse files Browse the repository at this point in the history
* WIP initially adding edge presets

* WIP el variant working

* added all hf timm edge presets

* removing irrelevant note

* format pass

* remove irrelevant old commented code

* fix unit test regression

* add presets to preset file

* added arg docstrings and version handles

* updated block specific docstring for batch_norm_epsilon
  • Loading branch information
pkgoogle authored Nov 13, 2024
1 parent e3938f1 commit 4a287cd
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 64 deletions.
27 changes: 27 additions & 0 deletions keras_hub/src/models/efficientnet/efficientnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
MBConvBlock, but instead of using a depthwise convolution and a 1x1
output convolution blocks fused blocks use a single 3x3 convolution
block.
stackwise_force_input_filters: list of ints, overrides
stackwise_input_filters if > 0. Primarily used to parameterize stem
filters (usually stackwise_input_filters[0]) differrently than stack
input filters.
stackwise_nores_option: list of bools, toggles if residiual connection
is not used. If False (default), the stack will use residual
connections, otherwise not.
min_depth: integer, minimum number of filters. Can be None and ignored
if use_depth_divisor_as_min_depth is set to True.
include_initial_padding: bool, whether to include initial zero padding
Expand All @@ -66,6 +73,8 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
stem_conv_padding: str, can be 'same' or 'valid'. Padding for the stem.
batch_norm_momentum: float, momentum for the moving average calcualtion
in the batch normalization layers.
batch_norm_epsilon: float, epsilon for batch norm calcualtions. Used
in denominator for calculations to prevent divide by 0 errors.
Example:
```python
Expand Down Expand Up @@ -100,6 +109,8 @@ def __init__(
stackwise_squeeze_and_excite_ratios,
stackwise_strides,
stackwise_block_types,
stackwise_force_input_filters=[0] * 7,
stackwise_nores_option=[False] * 7,
dropout=0.2,
depth_divisor=8,
min_depth=8,
Expand Down Expand Up @@ -163,6 +174,8 @@ def __init__(
num_repeats = stackwise_num_repeats[i]
input_filters = stackwise_input_filters[i]
output_filters = stackwise_output_filters[i]
force_input_filters = stackwise_force_input_filters[i]
nores = stackwise_nores_option[i]

# Update block input and output filters based on depth multiplier.
input_filters = round_filters(
Expand Down Expand Up @@ -200,6 +213,16 @@ def __init__(
self._pyramid_outputs[f"P{curr_pyramid_level}"] = x
curr_pyramid_level += 1

if force_input_filters > 0:
input_filters = round_filters(
filters=force_input_filters,
width_coefficient=width_coefficient,
min_depth=min_depth,
depth_divisor=depth_divisor,
use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
cap_round_filter_decrease=cap_round_filter_decrease,
)

# 97 is the start of the lowercase alphabet.
letter_identifier = chr(j + 97)
stackwise_block_type = stackwise_block_types[i]
Expand Down Expand Up @@ -232,6 +255,8 @@ def __init__(
activation=activation,
dropout=dropout * block_id / blocks,
batch_norm_momentum=batch_norm_momentum,
batch_norm_epsilon=batch_norm_epsilon,
nores=nores,
name=block_name,
)
x = block(x)
Expand Down Expand Up @@ -291,6 +316,7 @@ def __init__(
self.stackwise_strides = stackwise_strides
self.stackwise_block_types = stackwise_block_types

self.stackwise_force_input_filters = stackwise_force_input_filters
self.include_stem_padding = include_stem_padding
self.use_depth_divisor_as_min_depth = use_depth_divisor_as_min_depth
self.cap_round_filter_decrease = cap_round_filter_decrease
Expand Down Expand Up @@ -318,6 +344,7 @@ def get_config(self):
"stackwise_squeeze_and_excite_ratios": self.stackwise_squeeze_and_excite_ratios,
"stackwise_strides": self.stackwise_strides,
"stackwise_block_types": self.stackwise_block_types,
"stackwise_force_input_filters": self.stackwise_force_input_filters,
"include_stem_padding": self.include_stem_padding,
"use_depth_divisor_as_min_depth": self.use_depth_divisor_as_min_depth,
"cap_round_filter_decrease": self.cap_round_filter_decrease,
Expand Down
14 changes: 5 additions & 9 deletions keras_hub/src/models/efficientnet/efficientnet_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def setUp(self):
],
"stackwise_strides": [1, 2, 2, 2, 1, 2],
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
"stackwise_force_input_filters": [0] * 6,
"stackwise_nores_option": [False] * 6,
"width_coefficient": 1.0,
"depth_coefficient": 1.0,
}
Expand Down Expand Up @@ -60,15 +62,9 @@ def test_valid_call_original_v1(self):
"stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320],
"stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6],
"stackwise_strides": [1, 2, 2, 2, 1, 2, 1],
"stackwise_squeeze_and_excite_ratios": [
0.25,
0.25,
0.25,
0.25,
0.25,
0.25,
0.25,
],
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
"stackwise_force_input_filters": [0] * 7,
"stackwise_nores_option": [False] * 7,
"width_coefficient": 1.0,
"depth_coefficient": 1.0,
"stackwise_block_types": ["v1"] * 7,
Expand Down
45 changes: 42 additions & 3 deletions keras_hub/src/models/efficientnet/efficientnet_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,57 @@
"path": "efficientnet",
"model_card": "https://arxiv.org/abs/1905.11946",
},
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra_imagenet",
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra_imagenet/1",
},
"efficientnet_b1_ft_imagenet": {
"metadata": {
"description": (
"EfficientNet B1 model fine-trained on the ImageNet 1k dataset."
"EfficientNet B1 model fine-tuned on the ImageNet 1k dataset."
),
"params": 7794184,
"official_name": "EfficientNet",
"path": "efficientnet",
"model_card": "https://arxiv.org/abs/1905.11946",
},
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet",
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/1",
},
"efficientnet_el_ra_imagenet": {
"metadata": {
"description": (
"EfficientNet-EdgeTPU Large model trained on the ImageNet 1k "
"dataset with RandAugment recipe."
),
"params": 10589712,
"official_name": "EfficientNet",
"path": "efficientnet",
"model_card": "https://arxiv.org/abs/1905.11946",
},
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_el_ra_imagenet/1",
},
"efficientnet_em_ra2_imagenet": {
"metadata": {
"description": (
"EfficientNet-EdgeTPU Medium model trained on the ImageNet 1k "
"dataset with RandAugment2 recipe."
),
"params": 6899496,
"official_name": "EfficientNet",
"path": "efficientnet",
"model_card": "https://arxiv.org/abs/1905.11946",
},
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_em_ra2_imagenet/1",
},
"efficientnet_es_ra_imagenet": {
"metadata": {
"description": (
"EfficientNet-EdgeTPU Small model trained on the ImageNet 1k "
"dataset with RandAugment recipe."
),
"params": 5438392,
"official_name": "EfficientNet",
"path": "efficientnet",
"model_card": "https://arxiv.org/abs/1905.11946",
},
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_es_ra_imagenet/1",
},
}
27 changes: 18 additions & 9 deletions keras_hub/src/models/efficientnet/fusedmbconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class FusedMBConvBlock(keras.layers.Layer):
se_ratio: default 0.0, The filters used in the Squeeze-Excitation phase,
and are chosen as the maximum between 1 and input_filters*se_ratio
batch_norm_momentum: default 0.9, the BatchNormalization momentum
batch_norm_epsilon: default 1e-3, float, epsilon for batch norm
calcualtions. Used in denominator for calculations to prevent divide
by 0 errors.
activation: default "swish", the activation function used between
convolution operations
dropout: float, the optional dropout rate to apply before the output
Expand All @@ -70,8 +73,10 @@ def __init__(
data_format="channels_last",
se_ratio=0.0,
batch_norm_momentum=0.9,
batch_norm_epsilon=1e-3,
activation="swish",
dropout=0.2,
nores=False,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -83,8 +88,10 @@ def __init__(
self.data_format = data_format
self.se_ratio = se_ratio
self.batch_norm_momentum = batch_norm_momentum
self.batch_norm_epsilon = batch_norm_epsilon
self.activation = activation
self.dropout = dropout
self.nores = nores
self.filters = self.input_filters * self.expand_ratio
self.filters_se = max(1, int(input_filters * se_ratio))

Expand All @@ -101,18 +108,13 @@ def __init__(
self.bn1 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "expand_bn",
)
self.act = keras.layers.Activation(
self.activation, name=self.name + "expand_activation"
)

self.bn2 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
name=self.name + "bn",
)

self.se_conv1 = keras.layers.Conv2D(
self.filters_se,
1,
Expand Down Expand Up @@ -144,9 +146,10 @@ def __init__(
name=self.name + "project_conv",
)

self.bn3 = keras.layers.BatchNormalization(
self.bn2 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "project_bn",
)

Expand Down Expand Up @@ -192,12 +195,16 @@ def call(self, inputs):

# Output phase:
x = self.output_conv(x)
x = self.bn3(x)
x = self.bn2(x)
if self.expand_ratio == 1:
x = self.act(x)

# Residual:
if self.strides == 1 and self.input_filters == self.output_filters:
if (
self.strides == 1
and self.input_filters == self.output_filters
and not self.nores
):
if self.dropout:
x = self.dropout_layer(x)
x = keras.layers.Add(name=self.name + "add")([x, inputs])
Expand All @@ -213,8 +220,10 @@ def get_config(self):
"data_format": self.data_format,
"se_ratio": self.se_ratio,
"batch_norm_momentum": self.batch_norm_momentum,
"batch_norm_epsilon": self.batch_norm_epsilon,
"activation": self.activation,
"dropout": self.dropout,
"nores": self.nores,
}

base_config = super().get_config()
Expand Down
18 changes: 17 additions & 1 deletion keras_hub/src/models/efficientnet/mbconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ def __init__(
data_format="channels_last",
se_ratio=0.0,
batch_norm_momentum=0.9,
batch_norm_epsilon=1e-3,
activation="swish",
dropout=0.2,
nores=False,
**kwargs
):
"""Implementation of the MBConv block
Expand Down Expand Up @@ -60,6 +62,9 @@ def __init__(
is above 0. The filters used in this phase are chosen as the
maximum between 1 and input_filters*se_ratio
batch_norm_momentum: default 0.9, the BatchNormalization momentum
batch_norm_epsilon: default 1e-3, float, epsilon for batch norm
calcualtions. Used in denominator for calculations to prevent
divide by 0 errors.
activation: default "swish", the activation function used between
convolution operations
dropout: float, the optional dropout rate to apply before the output
Expand All @@ -83,8 +88,10 @@ def __init__(
self.data_format = data_format
self.se_ratio = se_ratio
self.batch_norm_momentum = batch_norm_momentum
self.batch_norm_epsilon = batch_norm_epsilon
self.activation = activation
self.dropout = dropout
self.nores = nores
self.filters = self.input_filters * self.expand_ratio
self.filters_se = max(1, int(input_filters * se_ratio))

Expand All @@ -101,6 +108,7 @@ def __init__(
self.bn1 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "expand_bn",
)
self.act = keras.layers.Activation(
Expand All @@ -119,6 +127,7 @@ def __init__(
self.bn2 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "bn",
)

Expand Down Expand Up @@ -156,6 +165,7 @@ def __init__(
self.bn3 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "project_bn",
)

Expand Down Expand Up @@ -207,7 +217,11 @@ def call(self, inputs):
x = self.output_conv(x)
x = self.bn3(x)

if self.strides == 1 and self.input_filters == self.output_filters:
if (
self.strides == 1
and self.input_filters == self.output_filters
and not self.nores
):
if self.dropout:
x = self.dropout_layer(x)
x = keras.layers.Add(name=self.name + "add")([x, inputs])
Expand All @@ -223,8 +237,10 @@ def get_config(self):
"data_format": self.data_format,
"se_ratio": self.se_ratio,
"batch_norm_momentum": self.batch_norm_momentum,
"batch_norm_epsilon": self.batch_norm_epsilon,
"activation": self.activation,
"dropout": self.dropout,
"nores": self.nores,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
Loading

0 comments on commit 4a287cd

Please sign in to comment.