Skip to content

Commit

Permalink
add TIMMBackbone
Browse files Browse the repository at this point in the history
  • Loading branch information
shinya7y committed Jan 16, 2022
1 parent ff9bc39 commit 4450e11
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ jobs:
- run:
name: Run unittests
command: |
pip install timm
coverage run --branch --source mmdet -m pytest tests/
coverage xml
coverage report -m
Expand Down Expand Up @@ -120,6 +121,7 @@ jobs:
- run:
name: Run unittests
command: |
pip install timm
pytest tests/
workflows:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ jobs:
run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report
run: |
pip install timm
coverage run --branch --source mmdet -m pytest tests/
coverage xml
coverage report -m
Expand Down Expand Up @@ -143,6 +144,7 @@ jobs:
TORCH_CUDA_ARCH_LIST=7.0 pip install .
- name: Run unittests and generate coverage report
run: |
python -m pip install timm
coverage run --branch --source mmdet -m pytest tests/
coverage xml
coverage report -m
Expand Down Expand Up @@ -216,6 +218,7 @@ jobs:
TORCH_CUDA_ARCH_LIST=7.0 pip install .
- name: Run unittests and generate coverage report
run: |
python -m pip install timm
coverage run --branch --source mmdet -m pytest tests/
coverage xml
coverage report -m
Expand Down
4 changes: 3 additions & 1 deletion mmdet/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
from .resnext import ResNeXt
from .ssd_vgg import SSDVGG
from .swin import SwinTransformer
from .timm_backbone import TIMMBackbone
from .trident_resnet import TridentResNet

__all__ = [
'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet',
'MobileNetV2', 'Res2Net', 'HourglassNet', 'DetectoRS_ResNet',
'DetectoRS_ResNeXt', 'Darknet', 'ResNeSt', 'TridentResNet', 'CSPDarknet',
'SwinTransformer', 'PyramidVisionTransformer', 'PyramidVisionTransformerV2'
'SwinTransformer', 'PyramidVisionTransformer',
'PyramidVisionTransformerV2', 'TIMMBackbone'
]
63 changes: 63 additions & 0 deletions mmdet/models/backbones/timm_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) OpenMMLab. All rights reserved.
try:
import timm
except ImportError:
timm = None

from mmcv.cnn.bricks.registry import NORM_LAYERS
from mmcv.runner import BaseModule

from ..builder import BACKBONES


@BACKBONES.register_module()
class TIMMBackbone(BaseModule):
"""Wrapper to use backbones from timm library. More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
Args:
model_name (str): Name of timm model to instantiate.
pretrained (bool): Load pretrained weights if True.
checkpoint_path (str): Path of checkpoint to load after
model is initialized.
in_channels (int): Number of input image channels. Default: 3.
init_cfg (dict, optional): Initialization config dict
**kwargs: Other timm & model specific arguments.
"""

def __init__(
self,
model_name,
features_only=True,
pretrained=True,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs,
):
if timm is None:
raise RuntimeError('timm is not installed')
super(TIMMBackbone, self).__init__(init_cfg)
if 'norm_layer' in kwargs:
kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer'])
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs,
)

# Make unused parameters None
self.timm_model.global_pool = None
self.timm_model.fc = None
self.timm_model.classifier = None

# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True

def forward(self, x):
features = self.timm_model(x)
return features
133 changes: 133 additions & 0 deletions tests/test_models/test_backbones/test_timm_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmdet.models.backbones import TIMMBackbone
from .utils import check_norm_state


def test_timm_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = TIMMBackbone()
model.init_weights(pretrained=0)

# Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN'
# Test resnet18 from timm, norm_layer='BN2d'
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
norm_layer='BN2d')

# Test resnet18 from timm, norm_layer='SyncBN'
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
norm_layer='SyncBN')

# Test resnet18 from timm, features_only=True, output_stride=32
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)

imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 14, 14))
assert feats[4] == torch.Size((1, 512, 7, 7))

# Test resnet18 from timm, features_only=True, output_stride=16
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=16)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 14, 14))
assert feats[4] == torch.Size((1, 512, 14, 14))

# Test resnet18 from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 28, 28))
assert feats[4] == torch.Size((1, 512, 28, 28))

# Test efficientnet_b1 with pretrained weights
model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True)

# Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_50x1_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 4, 4))
assert feats[1] == torch.Size((1, 256, 2, 2))
assert feats[2] == torch.Size((1, 512, 1, 1))
assert feats[3] == torch.Size((1, 1024, 1, 1))
assert feats[4] == torch.Size((1, 2048, 1, 1))

# Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_50x3_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 192, 4, 4))
assert feats[1] == torch.Size((1, 768, 2, 2))
assert feats[2] == torch.Size((1, 1536, 1, 1))
assert feats[3] == torch.Size((1, 3072, 1, 1))
assert feats[4] == torch.Size((1, 6144, 1, 1))

# Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_101x1_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 4, 4))
assert feats[1] == torch.Size((1, 256, 2, 2))
assert feats[2] == torch.Size((1, 512, 1, 1))
assert feats[3] == torch.Size((1, 1024, 1, 1))
assert feats[4] == torch.Size((1, 2048, 1, 1))

0 comments on commit 4450e11

Please sign in to comment.