Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Automatic Download of Pretrained Weights in DETR #17712

Merged
merged 11 commits into from
Jun 21, 2022
4 changes: 4 additions & 0 deletions src/transformers/models/detr/configuration_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class DetrConfig(PretrainedConfig):
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a
list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone ('bool', *optional*, defaults to `True`):
AnugunjNaman marked this conversation as resolved.
Show resolved Hide resolved
Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5).
class_cost (`float`, *optional*, defaults to 1):
Expand Down Expand Up @@ -147,6 +149,7 @@ def __init__(
auxiliary_loss=False,
position_embedding_type="sine",
backbone="resnet50",
use_pretrained_backbone=True,
dilation=False,
class_cost=1,
bbox_cost=5,
Expand Down Expand Up @@ -180,6 +183,7 @@ def __init__(
self.auxiliary_loss = auxiliary_loss
self.position_embedding_type = position_embedding_type
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.dilation = dilation
# Hungarian matcher
self.class_cost = class_cost
Expand Down
17 changes: 14 additions & 3 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ class DetrTimmConvEncoder(nn.Module):

"""

def __init__(self, name: str, dilation: bool):
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool):
super().__init__()

kwargs = {}
Expand All @@ -335,7 +335,9 @@ def __init__(self, name: str, dilation: bool):

requires_backends(self, ["timm"])

backbone = create_model(name, pretrained=True, features_only=True, out_indices=(1, 2, 3, 4), **kwargs)
backbone = create_model(
name, pretrained=use_pretrained_backbone, features_only=True, out_indices=(1, 2, 3, 4), **kwargs
)
# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
Expand Down Expand Up @@ -1177,7 +1179,7 @@ def __init__(self, config: DetrConfig):
super().__init__(config)

# Create backbone + positional encoding
backbone = DetrTimmConvEncoder(config.backbone, config.dilation)
backbone = DetrTimmConvEncoder(config.backbone, config.dilation, config.use_pretrained_backbone)
position_embeddings = build_position_encoding(config)
self.backbone = DetrConvModel(backbone, position_embeddings)

Expand Down Expand Up @@ -1234,6 +1236,9 @@ def forward(
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
>>> # model use pretrained backbone weights by default
>>> # to prevent this set use_pretrained_backbone = False
>>> # model = DetrModel.from_pretrained("facebook/detr-resnet-50", use_pretrained_backbone = False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this makes sense, cause it will load pre-trained weights anyway.

A better code example (in my opinion) would be:

# randomly initialize a DETR model with pre-trained ResNet weights
config = DetrConfig()
model = DetrModel(config)

# randomly initialize a DETR model (with randomly initialized ResNet)
config = DetrConfig(use_pretrained_backbone=False)
model = DetrModel(config)

Copy link
Contributor Author

@AnugunjNaman AnugunjNaman Jun 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NielsRogge. Done. Works like you asked.

>>> model = DetrModel.from_pretrained("facebook/detr-resnet-50")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
Expand Down Expand Up @@ -1392,6 +1397,9 @@ def forward(
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
>>> # model use pretrained backbone weights by default
>>> # to prevent this set use_pretrained_backbone = False
>>> # model = DetrModel.from_pretrained("facebook/detr-resnet-50", use_pretrained_backbone = False)
>>> model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

>>> inputs = feature_extractor(images=image, return_tensors="pt")
Expand Down Expand Up @@ -1548,6 +1556,9 @@ def forward(
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50-panoptic")
>>> # model use pretrained backbone weights by default
>>> # to prevent this set use_pretrained_backbone = False
>>> # model = DetrModel.from_pretrained("facebook/detr-resnet-50", use_pretrained_backbone = False)
>>> model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")

>>> inputs = feature_extractor(images=image, return_tensors="pt")
Expand Down