forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
204 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
try: | ||
import timm | ||
except ImportError: | ||
timm = None | ||
|
||
from mmcv.cnn.bricks.registry import NORM_LAYERS | ||
from mmcv.runner import BaseModule | ||
|
||
from ..builder import BACKBONES | ||
|
||
|
||
@BACKBONES.register_module() | ||
class TIMMBackbone(BaseModule): | ||
"""Wrapper to use backbones from timm library. More details can be found in | ||
`timm <https://github.com/rwightman/pytorch-image-models>`_ . | ||
Args: | ||
model_name (str): Name of timm model to instantiate. | ||
pretrained (bool): Load pretrained weights if True. | ||
checkpoint_path (str): Path of checkpoint to load after | ||
model is initialized. | ||
in_channels (int): Number of input image channels. Default: 3. | ||
init_cfg (dict, optional): Initialization config dict | ||
**kwargs: Other timm & model specific arguments. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name, | ||
features_only=True, | ||
pretrained=True, | ||
checkpoint_path='', | ||
in_channels=3, | ||
init_cfg=None, | ||
**kwargs, | ||
): | ||
if timm is None: | ||
raise RuntimeError('timm is not installed') | ||
super(TIMMBackbone, self).__init__(init_cfg) | ||
if 'norm_layer' in kwargs: | ||
kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer']) | ||
self.timm_model = timm.create_model( | ||
model_name=model_name, | ||
features_only=features_only, | ||
pretrained=pretrained, | ||
in_chans=in_channels, | ||
checkpoint_path=checkpoint_path, | ||
**kwargs, | ||
) | ||
|
||
# Make unused parameters None | ||
self.timm_model.global_pool = None | ||
self.timm_model.fc = None | ||
self.timm_model.classifier = None | ||
|
||
# Hack to use pretrained weights from timm | ||
if pretrained or checkpoint_path: | ||
self._is_init = True | ||
|
||
def forward(self, x): | ||
features = self.timm_model(x) | ||
return features |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import pytest | ||
import torch | ||
|
||
from mmdet.models.backbones import TIMMBackbone | ||
from .utils import check_norm_state | ||
|
||
|
||
def test_timm_backbone(): | ||
with pytest.raises(TypeError): | ||
# pretrained must be a string path | ||
model = TIMMBackbone() | ||
model.init_weights(pretrained=0) | ||
|
||
# Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN' | ||
# Test resnet18 from timm, norm_layer='BN2d' | ||
model = TIMMBackbone( | ||
model_name='resnet18', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=32, | ||
norm_layer='BN2d') | ||
|
||
# Test resnet18 from timm, norm_layer='SyncBN' | ||
model = TIMMBackbone( | ||
model_name='resnet18', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=32, | ||
norm_layer='SyncBN') | ||
|
||
# Test resnet18 from timm, features_only=True, output_stride=32 | ||
model = TIMMBackbone( | ||
model_name='resnet18', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=32) | ||
model.init_weights() | ||
model.train() | ||
assert check_norm_state(model.modules(), True) | ||
|
||
imgs = torch.randn(1, 3, 224, 224) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 64, 112, 112)) | ||
assert feats[1] == torch.Size((1, 64, 56, 56)) | ||
assert feats[2] == torch.Size((1, 128, 28, 28)) | ||
assert feats[3] == torch.Size((1, 256, 14, 14)) | ||
assert feats[4] == torch.Size((1, 512, 7, 7)) | ||
|
||
# Test resnet18 from timm, features_only=True, output_stride=16 | ||
model = TIMMBackbone( | ||
model_name='resnet18', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=16) | ||
imgs = torch.randn(1, 3, 224, 224) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 64, 112, 112)) | ||
assert feats[1] == torch.Size((1, 64, 56, 56)) | ||
assert feats[2] == torch.Size((1, 128, 28, 28)) | ||
assert feats[3] == torch.Size((1, 256, 14, 14)) | ||
assert feats[4] == torch.Size((1, 512, 14, 14)) | ||
|
||
# Test resnet18 from timm, features_only=True, output_stride=8 | ||
model = TIMMBackbone( | ||
model_name='resnet18', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=8) | ||
imgs = torch.randn(1, 3, 224, 224) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 64, 112, 112)) | ||
assert feats[1] == torch.Size((1, 64, 56, 56)) | ||
assert feats[2] == torch.Size((1, 128, 28, 28)) | ||
assert feats[3] == torch.Size((1, 256, 28, 28)) | ||
assert feats[4] == torch.Size((1, 512, 28, 28)) | ||
|
||
# Test efficientnet_b1 with pretrained weights | ||
model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True) | ||
|
||
# Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8 | ||
model = TIMMBackbone( | ||
model_name='resnetv2_50x1_bitm', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=8) | ||
imgs = torch.randn(1, 3, 8, 8) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 64, 4, 4)) | ||
assert feats[1] == torch.Size((1, 256, 2, 2)) | ||
assert feats[2] == torch.Size((1, 512, 1, 1)) | ||
assert feats[3] == torch.Size((1, 1024, 1, 1)) | ||
assert feats[4] == torch.Size((1, 2048, 1, 1)) | ||
|
||
# Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8 | ||
model = TIMMBackbone( | ||
model_name='resnetv2_50x3_bitm', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=8) | ||
imgs = torch.randn(1, 3, 8, 8) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 192, 4, 4)) | ||
assert feats[1] == torch.Size((1, 768, 2, 2)) | ||
assert feats[2] == torch.Size((1, 1536, 1, 1)) | ||
assert feats[3] == torch.Size((1, 3072, 1, 1)) | ||
assert feats[4] == torch.Size((1, 6144, 1, 1)) | ||
|
||
# Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8 | ||
model = TIMMBackbone( | ||
model_name='resnetv2_101x1_bitm', | ||
features_only=True, | ||
pretrained=False, | ||
output_stride=8) | ||
imgs = torch.randn(1, 3, 8, 8) | ||
feats = model(imgs) | ||
feats = [feat.shape for feat in feats] | ||
assert len(feats) == 5 | ||
assert feats[0] == torch.Size((1, 64, 4, 4)) | ||
assert feats[1] == torch.Size((1, 256, 2, 2)) | ||
assert feats[2] == torch.Size((1, 512, 1, 1)) | ||
assert feats[3] == torch.Size((1, 1024, 1, 1)) | ||
assert feats[4] == torch.Size((1, 2048, 1, 1)) |