Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Automatic Download of Pretrained Weights in DETR #17712

Merged
merged 11 commits into from
Jun 21, 2022

Conversation

AnugunjNaman
Copy link
Contributor

What does this PR do?

Fixes #15764

@AnugunjNaman AnugunjNaman changed the title added use_backbone_pretrained Fix Automatic Download of Pretrained Weights in DETR Jun 15, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 15, 2022

The documentation is not available anymore as the PR was closed or merged.

@AnugunjNaman
Copy link
Contributor Author

@NielsRogge

Comment on lines 1239 to 1241
>>> # model use pretrained backbone weights by default
>>> # to prevent this set use_pretrained_backbone = False
>>> # model = DetrModel.from_pretrained("facebook/detr-resnet-50", use_pretrained_backbone = False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this makes sense, cause it will load pre-trained weights anyway.

A better code example (in my opinion) would be:

# randomly initialize a DETR model with pre-trained ResNet weights
config = DetrConfig()
model = DetrModel(config)

# randomly initialize a DETR model (with randomly initialized ResNet)
config = DetrConfig(use_pretrained_backbone=False)
model = DetrModel(config)

Copy link
Contributor Author

@AnugunjNaman AnugunjNaman Jun 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NielsRogge. Done. Works like you asked.

Copy link
Contributor

@NielsRogge NielsRogge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding! We could perhaps add a small code example snippet in detr.mdx.

@AnugunjNaman
Copy link
Contributor Author

@NielsRogge Any Update here?

Comment on lines 117 to 128
from transformers import DetrConfig

Option 1: instantiate DETR with pre-trained weights for entire model
model = DetrForObjectDetection.from_pretrained("facebook/resnet-50")

Option 2: instantiate DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone
config = DetrConfig()
model = DetrForObjectDetection(config)

Option 3: instantiate DETR with randomly initialized weights for backbone + Transformer
config = DetrConfig(use_pretrained_backbone=False)
model = DetrForObjectDetection(config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from transformers import DetrConfig
Option 1: instantiate DETR with pre-trained weights for entire model
model = DetrForObjectDetection.from_pretrained("facebook/resnet-50")
Option 2: instantiate DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone
config = DetrConfig()
model = DetrForObjectDetection(config)
Option 3: instantiate DETR with randomly initialized weights for backbone + Transformer
config = DetrConfig(use_pretrained_backbone=False)
model = DetrForObjectDetection(config)
```
from transformers import DetrConfig
Option 1: instantiate DETR with pre-trained weights for entire model
model = DetrForObjectDetection.from_pretrained("facebook/resnet-50")
Option 2: instantiate DETR with randomly initialized weights for Transformer, but pre-trained weights for backbone
config = DetrConfig()
model = DetrForObjectDetection(config)
Option 3: instantiate DETR with randomly initialized weights for backbone + Transformer
config = DetrConfig(use_pretrained_backbone=False)
model = DetrForObjectDetection(config)
```

This won't properly render if you don't add the code statements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The black command failed when added python after `. Unclear how to proceed here in that case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think you need to do ```python for the first statement

Copy link
Contributor Author

@AnugunjNaman AnugunjNaman Jun 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did exactly that 66eae21 The circle ci failed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

src/transformers/models/detr/configuration_detr.py Outdated Show resolved Hide resolved
@NielsRogge NielsRogge merged commit 27e9073 into huggingface:main Jun 21, 2022
younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 25, 2022
* added use_backbone_pretrained

* style fixes

* update

* Update detr.mdx

* Update detr.mdx

* Update detr.mdx

* update using doc py

* Update detr.mdx

* Update src/transformers/models/detr/configuration_detr.py

Co-authored-by: Sylvain Gugger <[email protected]>

Co-authored-by: Sylvain Gugger <[email protected]>
younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 29, 2022
* added use_backbone_pretrained

* style fixes

* update

* Update detr.mdx

* Update detr.mdx

* Update detr.mdx

* update using doc py

* Update detr.mdx

* Update src/transformers/models/detr/configuration_detr.py

Co-authored-by: Sylvain Gugger <[email protected]>

Co-authored-by: Sylvain Gugger <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Resnet weights should not be downloaded when building DETR by from_pretrained
4 participants