diff --git a/keras_hub/src/models/efficientnet/efficientnet_backbone.py b/keras_hub/src/models/efficientnet/efficientnet_backbone.py index 4016bb01e..95f434149 100644 --- a/keras_hub/src/models/efficientnet/efficientnet_backbone.py +++ b/keras_hub/src/models/efficientnet/efficientnet_backbone.py @@ -54,6 +54,13 @@ class EfficientNetBackbone(FeaturePyramidBackbone): MBConvBlock, but instead of using a depthwise convolution and a 1x1 output convolution blocks fused blocks use a single 3x3 convolution block. + stackwise_force_input_filters: list of ints, overrides + stackwise_input_filters if > 0. Primarily used to parameterize stem + filters (usually stackwise_input_filters[0]) differrently than stack + input filters. + stackwise_nores_option: list of bools, toggles if residiual connection + is not used. If False (default), the stack will use residual + connections, otherwise not. min_depth: integer, minimum number of filters. Can be None and ignored if use_depth_divisor_as_min_depth is set to True. include_initial_padding: bool, whether to include initial zero padding @@ -66,6 +73,8 @@ class EfficientNetBackbone(FeaturePyramidBackbone): stem_conv_padding: str, can be 'same' or 'valid'. Padding for the stem. batch_norm_momentum: float, momentum for the moving average calcualtion in the batch normalization layers. + batch_norm_epsilon: float, epsilon for batch norm calcualtions. Used + in denominator for calculations to prevent divide by 0 errors. Example: ```python @@ -100,6 +109,8 @@ def __init__( stackwise_squeeze_and_excite_ratios, stackwise_strides, stackwise_block_types, + stackwise_force_input_filters=[0] * 7, + stackwise_nores_option=[False] * 7, dropout=0.2, depth_divisor=8, min_depth=8, @@ -163,6 +174,8 @@ def __init__( num_repeats = stackwise_num_repeats[i] input_filters = stackwise_input_filters[i] output_filters = stackwise_output_filters[i] + force_input_filters = stackwise_force_input_filters[i] + nores = stackwise_nores_option[i] # Update block input and output filters based on depth multiplier. input_filters = round_filters( @@ -200,6 +213,16 @@ def __init__( self._pyramid_outputs[f"P{curr_pyramid_level}"] = x curr_pyramid_level += 1 + if force_input_filters > 0: + input_filters = round_filters( + filters=force_input_filters, + width_coefficient=width_coefficient, + min_depth=min_depth, + depth_divisor=depth_divisor, + use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, + cap_round_filter_decrease=cap_round_filter_decrease, + ) + # 97 is the start of the lowercase alphabet. letter_identifier = chr(j + 97) stackwise_block_type = stackwise_block_types[i] @@ -232,6 +255,8 @@ def __init__( activation=activation, dropout=dropout * block_id / blocks, batch_norm_momentum=batch_norm_momentum, + batch_norm_epsilon=batch_norm_epsilon, + nores=nores, name=block_name, ) x = block(x) @@ -291,6 +316,7 @@ def __init__( self.stackwise_strides = stackwise_strides self.stackwise_block_types = stackwise_block_types + self.stackwise_force_input_filters = stackwise_force_input_filters self.include_stem_padding = include_stem_padding self.use_depth_divisor_as_min_depth = use_depth_divisor_as_min_depth self.cap_round_filter_decrease = cap_round_filter_decrease @@ -318,6 +344,7 @@ def get_config(self): "stackwise_squeeze_and_excite_ratios": self.stackwise_squeeze_and_excite_ratios, "stackwise_strides": self.stackwise_strides, "stackwise_block_types": self.stackwise_block_types, + "stackwise_force_input_filters": self.stackwise_force_input_filters, "include_stem_padding": self.include_stem_padding, "use_depth_divisor_as_min_depth": self.use_depth_divisor_as_min_depth, "cap_round_filter_decrease": self.cap_round_filter_decrease, diff --git a/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py b/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py index f31004b5d..c11e63654 100644 --- a/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py +++ b/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py @@ -26,6 +26,8 @@ def setUp(self): ], "stackwise_strides": [1, 2, 2, 2, 1, 2], "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3, + "stackwise_force_input_filters": [0] * 6, + "stackwise_nores_option": [False] * 6, "width_coefficient": 1.0, "depth_coefficient": 1.0, } @@ -60,15 +62,9 @@ def test_valid_call_original_v1(self): "stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320], "stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6], "stackwise_strides": [1, 2, 2, 2, 1, 2, 1], - "stackwise_squeeze_and_excite_ratios": [ - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - ], + "stackwise_squeeze_and_excite_ratios": [0.25] * 7, + "stackwise_force_input_filters": [0] * 7, + "stackwise_nores_option": [False] * 7, "width_coefficient": 1.0, "depth_coefficient": 1.0, "stackwise_block_types": ["v1"] * 7, diff --git a/keras_hub/src/models/efficientnet/efficientnet_presets.py b/keras_hub/src/models/efficientnet/efficientnet_presets.py index 39c951481..860a48a68 100644 --- a/keras_hub/src/models/efficientnet/efficientnet_presets.py +++ b/keras_hub/src/models/efficientnet/efficientnet_presets.py @@ -12,18 +12,57 @@ "path": "efficientnet", "model_card": "https://arxiv.org/abs/1905.11946", }, - "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra_imagenet", + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra_imagenet/1", }, "efficientnet_b1_ft_imagenet": { "metadata": { "description": ( - "EfficientNet B1 model fine-trained on the ImageNet 1k dataset." + "EfficientNet B1 model fine-tuned on the ImageNet 1k dataset." ), "params": 7794184, "official_name": "EfficientNet", "path": "efficientnet", "model_card": "https://arxiv.org/abs/1905.11946", }, - "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet", + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/1", + }, + "efficientnet_el_ra_imagenet": { + "metadata": { + "description": ( + "EfficientNet-EdgeTPU Large model trained on the ImageNet 1k " + "dataset with RandAugment recipe." + ), + "params": 10589712, + "official_name": "EfficientNet", + "path": "efficientnet", + "model_card": "https://arxiv.org/abs/1905.11946", + }, + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_el_ra_imagenet/1", + }, + "efficientnet_em_ra2_imagenet": { + "metadata": { + "description": ( + "EfficientNet-EdgeTPU Medium model trained on the ImageNet 1k " + "dataset with RandAugment2 recipe." + ), + "params": 6899496, + "official_name": "EfficientNet", + "path": "efficientnet", + "model_card": "https://arxiv.org/abs/1905.11946", + }, + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_em_ra2_imagenet/1", + }, + "efficientnet_es_ra_imagenet": { + "metadata": { + "description": ( + "EfficientNet-EdgeTPU Small model trained on the ImageNet 1k " + "dataset with RandAugment recipe." + ), + "params": 5438392, + "official_name": "EfficientNet", + "path": "efficientnet", + "model_card": "https://arxiv.org/abs/1905.11946", + }, + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_es_ra_imagenet/1", }, } diff --git a/keras_hub/src/models/efficientnet/fusedmbconv.py b/keras_hub/src/models/efficientnet/fusedmbconv.py index 96f55a22b..51a7f95fe 100644 --- a/keras_hub/src/models/efficientnet/fusedmbconv.py +++ b/keras_hub/src/models/efficientnet/fusedmbconv.py @@ -47,6 +47,9 @@ class FusedMBConvBlock(keras.layers.Layer): se_ratio: default 0.0, The filters used in the Squeeze-Excitation phase, and are chosen as the maximum between 1 and input_filters*se_ratio batch_norm_momentum: default 0.9, the BatchNormalization momentum + batch_norm_epsilon: default 1e-3, float, epsilon for batch norm + calcualtions. Used in denominator for calculations to prevent divide + by 0 errors. activation: default "swish", the activation function used between convolution operations dropout: float, the optional dropout rate to apply before the output @@ -70,8 +73,10 @@ def __init__( data_format="channels_last", se_ratio=0.0, batch_norm_momentum=0.9, + batch_norm_epsilon=1e-3, activation="swish", dropout=0.2, + nores=False, **kwargs ): super().__init__(**kwargs) @@ -83,8 +88,10 @@ def __init__( self.data_format = data_format self.se_ratio = se_ratio self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon self.activation = activation self.dropout = dropout + self.nores = nores self.filters = self.input_filters * self.expand_ratio self.filters_se = max(1, int(input_filters * se_ratio)) @@ -101,18 +108,13 @@ def __init__( self.bn1 = keras.layers.BatchNormalization( axis=BN_AXIS, momentum=self.batch_norm_momentum, + epsilon=self.batch_norm_epsilon, name=self.name + "expand_bn", ) self.act = keras.layers.Activation( self.activation, name=self.name + "expand_activation" ) - self.bn2 = keras.layers.BatchNormalization( - axis=BN_AXIS, - momentum=self.batch_norm_momentum, - name=self.name + "bn", - ) - self.se_conv1 = keras.layers.Conv2D( self.filters_se, 1, @@ -144,9 +146,10 @@ def __init__( name=self.name + "project_conv", ) - self.bn3 = keras.layers.BatchNormalization( + self.bn2 = keras.layers.BatchNormalization( axis=BN_AXIS, momentum=self.batch_norm_momentum, + epsilon=self.batch_norm_epsilon, name=self.name + "project_bn", ) @@ -192,12 +195,16 @@ def call(self, inputs): # Output phase: x = self.output_conv(x) - x = self.bn3(x) + x = self.bn2(x) if self.expand_ratio == 1: x = self.act(x) # Residual: - if self.strides == 1 and self.input_filters == self.output_filters: + if ( + self.strides == 1 + and self.input_filters == self.output_filters + and not self.nores + ): if self.dropout: x = self.dropout_layer(x) x = keras.layers.Add(name=self.name + "add")([x, inputs]) @@ -213,8 +220,10 @@ def get_config(self): "data_format": self.data_format, "se_ratio": self.se_ratio, "batch_norm_momentum": self.batch_norm_momentum, + "batch_norm_epsilon": self.batch_norm_epsilon, "activation": self.activation, "dropout": self.dropout, + "nores": self.nores, } base_config = super().get_config() diff --git a/keras_hub/src/models/efficientnet/mbconv.py b/keras_hub/src/models/efficientnet/mbconv.py index 392e62c04..b4dc05f7c 100644 --- a/keras_hub/src/models/efficientnet/mbconv.py +++ b/keras_hub/src/models/efficientnet/mbconv.py @@ -23,8 +23,10 @@ def __init__( data_format="channels_last", se_ratio=0.0, batch_norm_momentum=0.9, + batch_norm_epsilon=1e-3, activation="swish", dropout=0.2, + nores=False, **kwargs ): """Implementation of the MBConv block @@ -60,6 +62,9 @@ def __init__( is above 0. The filters used in this phase are chosen as the maximum between 1 and input_filters*se_ratio batch_norm_momentum: default 0.9, the BatchNormalization momentum + batch_norm_epsilon: default 1e-3, float, epsilon for batch norm + calcualtions. Used in denominator for calculations to prevent + divide by 0 errors. activation: default "swish", the activation function used between convolution operations dropout: float, the optional dropout rate to apply before the output @@ -83,8 +88,10 @@ def __init__( self.data_format = data_format self.se_ratio = se_ratio self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon self.activation = activation self.dropout = dropout + self.nores = nores self.filters = self.input_filters * self.expand_ratio self.filters_se = max(1, int(input_filters * se_ratio)) @@ -101,6 +108,7 @@ def __init__( self.bn1 = keras.layers.BatchNormalization( axis=BN_AXIS, momentum=self.batch_norm_momentum, + epsilon=self.batch_norm_epsilon, name=self.name + "expand_bn", ) self.act = keras.layers.Activation( @@ -119,6 +127,7 @@ def __init__( self.bn2 = keras.layers.BatchNormalization( axis=BN_AXIS, momentum=self.batch_norm_momentum, + epsilon=self.batch_norm_epsilon, name=self.name + "bn", ) @@ -156,6 +165,7 @@ def __init__( self.bn3 = keras.layers.BatchNormalization( axis=BN_AXIS, momentum=self.batch_norm_momentum, + epsilon=self.batch_norm_epsilon, name=self.name + "project_bn", ) @@ -207,7 +217,11 @@ def call(self, inputs): x = self.output_conv(x) x = self.bn3(x) - if self.strides == 1 and self.input_filters == self.output_filters: + if ( + self.strides == 1 + and self.input_filters == self.output_filters + and not self.nores + ): if self.dropout: x = self.dropout_layer(x) x = keras.layers.Add(name=self.name + "add")([x, inputs]) @@ -223,8 +237,10 @@ def get_config(self): "data_format": self.data_format, "se_ratio": self.se_ratio, "batch_norm_momentum": self.batch_norm_momentum, + "batch_norm_epsilon": self.batch_norm_epsilon, "activation": self.activation, "dropout": self.dropout, + "nores": self.nores, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_hub/src/utils/timm/convert_efficientnet.py b/keras_hub/src/utils/timm/convert_efficientnet.py index 609c26d35..5c58c7c04 100644 --- a/keras_hub/src/utils/timm/convert_efficientnet.py +++ b/keras_hub/src/utils/timm/convert_efficientnet.py @@ -13,10 +13,57 @@ "b0": { "width_coefficient": 1.0, "depth_coefficient": 1.0, + "stackwise_squeeze_and_excite_ratios": [0.25] * 7, }, "b1": { "width_coefficient": 1.0, "depth_coefficient": 1.1, + "stackwise_squeeze_and_excite_ratios": [0.25] * 7, + }, + "el": { + "width_coefficient": 1.2, + "depth_coefficient": 1.4, + "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5], + "stackwise_num_repeats": [1, 2, 4, 5, 4, 2], + "stackwise_input_filters": [32, 24, 32, 48, 96, 144], + "stackwise_output_filters": [24, 32, 48, 96, 144, 192], + "stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8], + "stackwise_strides": [1, 2, 2, 2, 1, 2], + "stackwise_squeeze_and_excite_ratios": [0] * 6, + "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3, + "stackwise_force_input_filters": [24, 0, 0, 0, 0, 0], + "stackwise_nores_option": [True] + [False] * 5, + "activation": "relu", + }, + "em": { + "width_coefficient": 1.0, + "depth_coefficient": 1.1, + "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5], + "stackwise_num_repeats": [1, 2, 4, 5, 4, 2], + "stackwise_input_filters": [32, 24, 32, 48, 96, 144], + "stackwise_output_filters": [24, 32, 48, 96, 144, 192], + "stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8], + "stackwise_strides": [1, 2, 2, 2, 1, 2], + "stackwise_squeeze_and_excite_ratios": [0] * 6, + "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3, + "stackwise_force_input_filters": [24, 0, 0, 0, 0, 0], + "stackwise_nores_option": [True] + [False] * 5, + "activation": "relu", + }, + "es": { + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5], + "stackwise_num_repeats": [1, 2, 4, 5, 4, 2], + "stackwise_input_filters": [32, 24, 32, 48, 96, 144], + "stackwise_output_filters": [24, 32, 48, 96, 144, 192], + "stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8], + "stackwise_strides": [1, 2, 2, 2, 1, 2], + "stackwise_squeeze_and_excite_ratios": [0] * 6, + "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3, + "stackwise_force_input_filters": [24, 0, 0, 0, 0, 0], + "stackwise_nores_option": [True] + [False] * 5, + "activation": "relu", }, } @@ -31,15 +78,6 @@ def convert_backbone_config(timm_config): "stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320], "stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6], "stackwise_strides": [1, 2, 2, 2, 1, 2, 1], - "stackwise_squeeze_and_excite_ratios": [ - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - ], "stackwise_block_types": ["v1"] * 7, "min_depth": None, "include_stem_padding": True, @@ -68,21 +106,21 @@ def convert_weights(backbone, loader, timm_config): timm_architecture = timm_config["architecture"] variant = "_".join(timm_architecture.split("_")[1:]) - def port_conv2d(keras_layer_name, hf_weight_prefix, port_bias=True): + def port_conv2d(keras_layer, hf_weight_prefix, port_bias=True): loader.port_weight( - backbone.get_layer(keras_layer_name).kernel, + keras_layer.kernel, hf_weight_key=f"{hf_weight_prefix}.weight", hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), ) if port_bias: loader.port_weight( - backbone.get_layer(keras_layer_name).bias, + keras_layer.bias, hf_weight_key=f"{hf_weight_prefix}.bias", ) def port_depthwise_conv2d( - keras_layer_name, + keras_layer, hf_weight_prefix, port_bias=True, depth_multiplier=1, @@ -99,39 +137,39 @@ def convert_pt_conv2d_kernel(pt_kernel): ) loader.port_weight( - backbone.get_layer(keras_layer_name).kernel, + keras_layer.kernel, hf_weight_key=f"{hf_weight_prefix}.weight", hook_fn=lambda x, _: convert_pt_conv2d_kernel(x), ) if port_bias: loader.port_weight( - backbone.get_layer(keras_layer_name).bias, + keras_layer.bias, hf_weight_key=f"{hf_weight_prefix}.bias", ) - def port_batch_normalization(keras_layer_name, hf_weight_prefix): + def port_batch_normalization(keras_layer, hf_weight_prefix): loader.port_weight( - backbone.get_layer(keras_layer_name).gamma, + keras_layer.gamma, hf_weight_key=f"{hf_weight_prefix}.weight", ) loader.port_weight( - backbone.get_layer(keras_layer_name).beta, + keras_layer.beta, hf_weight_key=f"{hf_weight_prefix}.bias", ) loader.port_weight( - backbone.get_layer(keras_layer_name).moving_mean, + keras_layer.moving_mean, hf_weight_key=f"{hf_weight_prefix}.running_mean", ) loader.port_weight( - backbone.get_layer(keras_layer_name).moving_variance, + keras_layer.moving_variance, hf_weight_key=f"{hf_weight_prefix}.running_var", ) # do we need num batches tracked? # Stem - port_conv2d("stem_conv", "conv_stem", port_bias=False) - port_batch_normalization("stem_bn", "bn1") + port_conv2d(backbone.get_layer("stem_conv"), "conv_stem", port_bias=False) + port_batch_normalization(backbone.get_layer("stem_bn"), "bn1") # Stages num_stacks = len(backbone.stackwise_kernel_sizes) @@ -144,72 +182,168 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): repeats = int( math.ceil(VARIANT_MAP[variant]["depth_coefficient"] * repeats) ) + se_ratio = VARIANT_MAP[variant]["stackwise_squeeze_and_excite_ratios"][ + stack_index + ] for block_idx in range(repeats): conv_pw_count = 0 bn_count = 1 - conv_pw_name_map = ["conv_pw", "conv_pwl"] # 97 is the start of the lowercase alphabet. letter_identifier = chr(block_idx + 97) - if block_type == "v1": - keras_block_prefix = f"block{stack_index+1}{letter_identifier}_" - hf_block_prefix = f"blocks.{stack_index}.{block_idx}." + keras_block_prefix = f"block{stack_index+1}{letter_identifier}_" + hf_block_prefix = f"blocks.{stack_index}.{block_idx}." + if block_type == "v1": + conv_pw_name_map = ["conv_pw", "conv_pwl"] # Initial Expansion Conv if expansion_ratio != 1: port_conv2d( - keras_block_prefix + "expand_conv", + backbone.get_layer(keras_block_prefix + "expand_conv"), hf_block_prefix + conv_pw_name_map[conv_pw_count], port_bias=False, ) conv_pw_count += 1 port_batch_normalization( - keras_block_prefix + "expand_bn", + backbone.get_layer(keras_block_prefix + "expand_bn"), hf_block_prefix + f"bn{bn_count}", ) bn_count += 1 # Depthwise Conv port_depthwise_conv2d( - keras_block_prefix + "dwconv", + backbone.get_layer(keras_block_prefix + "dwconv"), hf_block_prefix + "conv_dw", port_bias=False, ) port_batch_normalization( - keras_block_prefix + "dwconv_bn", + backbone.get_layer(keras_block_prefix + "dwconv_bn"), hf_block_prefix + f"bn{bn_count}", ) bn_count += 1 - # Squeeze and Excite + if 0 < se_ratio <= 1: + # Squeeze and Excite + port_conv2d( + backbone.get_layer(keras_block_prefix + "se_reduce"), + hf_block_prefix + "se.conv_reduce", + ) + port_conv2d( + backbone.get_layer(keras_block_prefix + "se_expand"), + hf_block_prefix + "se.conv_expand", + ) + + # Output/Projection port_conv2d( - keras_block_prefix + "se_reduce", - hf_block_prefix + "se.conv_reduce", + backbone.get_layer(keras_block_prefix + "project"), + hf_block_prefix + conv_pw_name_map[conv_pw_count], + port_bias=False, ) + conv_pw_count += 1 + port_batch_normalization( + backbone.get_layer(keras_block_prefix + "project_bn"), + hf_block_prefix + f"bn{bn_count}", + ) + bn_count += 1 + elif block_type == "fused": + fused_block_layer = backbone.get_layer(keras_block_prefix) + + # Initial Expansion Conv + if expansion_ratio != 1: + port_conv2d( + fused_block_layer.conv1, + hf_block_prefix + "conv_exp", + port_bias=False, + ) + conv_pw_count += 1 + port_batch_normalization( + fused_block_layer.bn1, + hf_block_prefix + f"bn{bn_count}", + ) + bn_count += 1 + + if 0 < se_ratio <= 1: + # Squeeze and Excite + port_conv2d( + fused_block_layer.se_conv1, + hf_block_prefix + "se.conv_reduce", + ) + port_conv2d( + fused_block_layer.se_conv2, + hf_block_prefix + "se.conv_expand", + ) + + # Output/Projection port_conv2d( - keras_block_prefix + "se_expand", - hf_block_prefix + "se.conv_expand", + fused_block_layer.output_conv, + hf_block_prefix + "conv_pwl", + port_bias=False, + ) + conv_pw_count += 1 + port_batch_normalization( + fused_block_layer.bn2, + hf_block_prefix + f"bn{bn_count}", + ) + bn_count += 1 + + elif block_type == "unfused": + unfused_block_layer = backbone.get_layer(keras_block_prefix) + # Initial Expansion Conv + if expansion_ratio != 1: + port_conv2d( + unfused_block_layer.conv1, + hf_block_prefix + "conv_pw", + port_bias=False, + ) + conv_pw_count += 1 + port_batch_normalization( + unfused_block_layer.bn1, + hf_block_prefix + f"bn{bn_count}", + ) + bn_count += 1 + + # Depthwise Conv + port_depthwise_conv2d( + unfused_block_layer.depthwise, + hf_block_prefix + "conv_dw", + port_bias=False, + ) + port_batch_normalization( + unfused_block_layer.bn2, + hf_block_prefix + f"bn{bn_count}", ) + bn_count += 1 + + if 0 < se_ratio <= 1: + # Squeeze and Excite + port_conv2d( + unfused_block_layer.se_conv1, + hf_block_prefix + "se.conv_reduce", + ) + port_conv2d( + unfused_block_layer.se_conv2, + hf_block_prefix + "se.conv_expand", + ) # Output/Projection port_conv2d( - keras_block_prefix + "project", - hf_block_prefix + conv_pw_name_map[conv_pw_count], + unfused_block_layer.output_conv, + hf_block_prefix + "conv_pwl", port_bias=False, ) conv_pw_count += 1 port_batch_normalization( - keras_block_prefix + "project_bn", + unfused_block_layer.bn3, hf_block_prefix + f"bn{bn_count}", ) bn_count += 1 # Head/Top - port_conv2d("top_conv", "conv_head", port_bias=False) - port_batch_normalization("top_bn", "bn2") + port_conv2d(backbone.get_layer("top_conv"), "conv_head", port_bias=False) + port_batch_normalization(backbone.get_layer("top_bn"), "bn2") def convert_head(task, loader, timm_config): diff --git a/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py b/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py index 5790d6130..75810a19a 100644 --- a/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py +++ b/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py @@ -2,9 +2,15 @@ Convert efficientnet checkpoints. python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ - --preset efficientnet_b0_ra_imagenet --upload_uri kaggle://kerashub/efficientnet/keras/efficientnet_b0_ra_imagenet + --preset efficientnet_b0_ra_imagenet --upload_uri kaggle://keras/efficientnet/keras/efficientnet_b0_ra_imagenet python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ - --preset efficientnet_b1_ft_imagenet --upload_uri kaggle://kerashub/efficientnet/keras/efficientnet_b1_ft_imagenet + --preset efficientnet_b1_ft_imagenet --upload_uri kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet +python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ + --preset efficientnet_el_ra_imagenet --upload_uri kaggle://keras/efficientnet/keras/efficientnet_el_ra_imagenet +python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ + --preset efficientnet_em_ra2_imagenet --upload_uri kaggle://keras/efficientnet/keras/efficientnet_em_ra2_imagenet +python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ + --preset efficientnet_es_ra_imagenet --upload_uri kaggle://keras/efficientnet/keras/efficientnet_es_ra_imagenet """ import os @@ -23,6 +29,9 @@ PRESET_MAP = { "efficientnet_b0_ra_imagenet": "timm/efficientnet_b0.ra_in1k", "efficientnet_b1_ft_imagenet": "timm/efficientnet_b1.ft_in1k", + "efficientnet_el_ra_imagenet": "timm/efficientnet_el.ra_in1k", + "efficientnet_em_ra2_imagenet": "timm/efficientnet_em.ra2_in1k", + "efficientnet_es_ra_imagenet": "timm/efficientnet_es.ra_in1k", } FLAGS = flags.FLAGS