From 40274d373e12b0a8ff73d2f4670d9558ce17dff9 Mon Sep 17 00:00:00 2001 From: Alan Blanchet Date: Mon, 12 Aug 2024 16:44:02 +0200 Subject: [PATCH 1/2] fix: Parameterized norm freezing For the R18 model, the authors don't freeze norms in the backbone. --- src/transformers/models/rt_detr/configuration_rt_detr.py | 4 ++++ src/transformers/models/rt_detr/modeling_rt_detr.py | 7 ++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/rt_detr/configuration_rt_detr.py b/src/transformers/models/rt_detr/configuration_rt_detr.py index 0e34d0376f9fa6..4b1c9b337dfc39 100644 --- a/src/transformers/models/rt_detr/configuration_rt_detr.py +++ b/src/transformers/models/rt_detr/configuration_rt_detr.py @@ -55,6 +55,8 @@ class RTDetrConfig(PretrainedConfig): use_timm_backbone (`bool`, *optional*, defaults to `False`): Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers library. + freeze_backbone_batch_norm (`bool`, *optional*, defaults to `True`): + Whether to freeze the batch normalization layers in the backbone. backbone_kwargs (`dict`, *optional*): Keyword arguments to be passed to AutoBackbone when loading from a checkpoint e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. @@ -190,6 +192,7 @@ def __init__( backbone=None, use_pretrained_backbone=False, use_timm_backbone=False, + freeze_backbone_batch_norms=True, backbone_kwargs=None, # encoder HybridEncoder encoder_hidden_dim=256, @@ -280,6 +283,7 @@ def __init__( self.backbone = backbone self.use_pretrained_backbone = use_pretrained_backbone self.use_timm_backbone = use_timm_backbone + self.freeze_backbone_batch_norms = freeze_backbone_batch_norms self.backbone_kwargs = backbone_kwargs # encoder self.encoder_hidden_dim = encoder_hidden_dim diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 3f476725941e3c..ab83a81f50674d 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -559,9 +559,10 @@ def __init__(self, config): backbone = load_backbone(config) - # replace batch norm by frozen batch norm - with torch.no_grad(): - replace_batch_norm(backbone) + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + with torch.no_grad(): + replace_batch_norm(backbone) self.model = backbone self.intermediate_channel_sizes = self.model.channels From 4cc6e710ab7afbbba5552cbdce944ab3c3932f35 Mon Sep 17 00:00:00 2001 From: Alan-Blanchet Date: Tue, 13 Aug 2024 09:57:27 +0200 Subject: [PATCH 2/2] Update src/transformers/models/rt_detr/configuration_rt_detr.py Co-authored-by: Pavel Iakubovskii --- src/transformers/models/rt_detr/configuration_rt_detr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/rt_detr/configuration_rt_detr.py b/src/transformers/models/rt_detr/configuration_rt_detr.py index 4b1c9b337dfc39..ca20cc584dfd0b 100644 --- a/src/transformers/models/rt_detr/configuration_rt_detr.py +++ b/src/transformers/models/rt_detr/configuration_rt_detr.py @@ -55,7 +55,7 @@ class RTDetrConfig(PretrainedConfig): use_timm_backbone (`bool`, *optional*, defaults to `False`): Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers library. - freeze_backbone_batch_norm (`bool`, *optional*, defaults to `True`): + freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`): Whether to freeze the batch normalization layers in the backbone. backbone_kwargs (`dict`, *optional*): Keyword arguments to be passed to AutoBackbone when loading from a checkpoint