From 4450e1193ea10b5df9ec1873d72410dbf0b80332 Mon Sep 17 00:00:00 2001 From: Yosuke Shinya <42844407+shinya7y@users.noreply.github.com> Date: Sun, 16 Jan 2022 23:07:37 +0000 Subject: [PATCH] add TIMMBackbone based on https://github.com/open-mmlab/mmclassification/pull/427 https://github.com/open-mmlab/mmsegmentation/pull/998 --- .circleci/config.yml | 2 + .github/workflows/build.yml | 3 + mmdet/models/backbones/__init__.py | 4 +- mmdet/models/backbones/timm_backbone.py | 63 +++++++++ .../test_backbones/test_timm_backbone.py | 133 ++++++++++++++++++ 5 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 mmdet/models/backbones/timm_backbone.py create mode 100644 tests/test_models/test_backbones/test_timm_backbone.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 804bfd761f7..975e7c92060 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -76,6 +76,7 @@ jobs: - run: name: Run unittests command: | + pip install timm coverage run --branch --source mmdet -m pytest tests/ coverage xml coverage report -m @@ -120,6 +121,7 @@ jobs: - run: name: Run unittests command: | + pip install timm pytest tests/ workflows: diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 31d063a16f3..7ba8fafe02a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -71,6 +71,7 @@ jobs: run: rm -rf .eggs && pip install -e . - name: Run unittests and generate coverage report run: | + pip install timm coverage run --branch --source mmdet -m pytest tests/ coverage xml coverage report -m @@ -143,6 +144,7 @@ jobs: TORCH_CUDA_ARCH_LIST=7.0 pip install . - name: Run unittests and generate coverage report run: | + python -m pip install timm coverage run --branch --source mmdet -m pytest tests/ coverage xml coverage report -m @@ -216,6 +218,7 @@ jobs: TORCH_CUDA_ARCH_LIST=7.0 pip install . - name: Run unittests and generate coverage report run: | + python -m pip install timm coverage run --branch --source mmdet -m pytest tests/ coverage xml coverage report -m diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py index 8a4e9cd1b2d..11f1469e4c8 100644 --- a/mmdet/models/backbones/__init__.py +++ b/mmdet/models/backbones/__init__.py @@ -14,11 +14,13 @@ from .resnext import ResNeXt from .ssd_vgg import SSDVGG from .swin import SwinTransformer +from .timm_backbone import TIMMBackbone from .trident_resnet import TridentResNet __all__ = [ 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'MobileNetV2', 'Res2Net', 'HourglassNet', 'DetectoRS_ResNet', 'DetectoRS_ResNeXt', 'Darknet', 'ResNeSt', 'TridentResNet', 'CSPDarknet', - 'SwinTransformer', 'PyramidVisionTransformer', 'PyramidVisionTransformerV2' + 'SwinTransformer', 'PyramidVisionTransformer', + 'PyramidVisionTransformerV2', 'TIMMBackbone' ] diff --git a/mmdet/models/backbones/timm_backbone.py b/mmdet/models/backbones/timm_backbone.py new file mode 100644 index 00000000000..01b29fc5ed9 --- /dev/null +++ b/mmdet/models/backbones/timm_backbone.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + import timm +except ImportError: + timm = None + +from mmcv.cnn.bricks.registry import NORM_LAYERS +from mmcv.runner import BaseModule + +from ..builder import BACKBONES + + +@BACKBONES.register_module() +class TIMMBackbone(BaseModule): + """Wrapper to use backbones from timm library. More details can be found in + `timm `_ . + + Args: + model_name (str): Name of timm model to instantiate. + pretrained (bool): Load pretrained weights if True. + checkpoint_path (str): Path of checkpoint to load after + model is initialized. + in_channels (int): Number of input image channels. Default: 3. + init_cfg (dict, optional): Initialization config dict + **kwargs: Other timm & model specific arguments. + """ + + def __init__( + self, + model_name, + features_only=True, + pretrained=True, + checkpoint_path='', + in_channels=3, + init_cfg=None, + **kwargs, + ): + if timm is None: + raise RuntimeError('timm is not installed') + super(TIMMBackbone, self).__init__(init_cfg) + if 'norm_layer' in kwargs: + kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer']) + self.timm_model = timm.create_model( + model_name=model_name, + features_only=features_only, + pretrained=pretrained, + in_chans=in_channels, + checkpoint_path=checkpoint_path, + **kwargs, + ) + + # Make unused parameters None + self.timm_model.global_pool = None + self.timm_model.fc = None + self.timm_model.classifier = None + + # Hack to use pretrained weights from timm + if pretrained or checkpoint_path: + self._is_init = True + + def forward(self, x): + features = self.timm_model(x) + return features diff --git a/tests/test_models/test_backbones/test_timm_backbone.py b/tests/test_models/test_backbones/test_timm_backbone.py new file mode 100644 index 00000000000..e712f4e2ebe --- /dev/null +++ b/tests/test_models/test_backbones/test_timm_backbone.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmdet.models.backbones import TIMMBackbone +from .utils import check_norm_state + + +def test_timm_backbone(): + with pytest.raises(TypeError): + # pretrained must be a string path + model = TIMMBackbone() + model.init_weights(pretrained=0) + + # Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN' + # Test resnet18 from timm, norm_layer='BN2d' + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=32, + norm_layer='BN2d') + + # Test resnet18 from timm, norm_layer='SyncBN' + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=32, + norm_layer='SyncBN') + + # Test resnet18 from timm, features_only=True, output_stride=32 + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=32) + model.init_weights() + model.train() + assert check_norm_state(model.modules(), True) + + imgs = torch.randn(1, 3, 224, 224) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 64, 112, 112)) + assert feats[1] == torch.Size((1, 64, 56, 56)) + assert feats[2] == torch.Size((1, 128, 28, 28)) + assert feats[3] == torch.Size((1, 256, 14, 14)) + assert feats[4] == torch.Size((1, 512, 7, 7)) + + # Test resnet18 from timm, features_only=True, output_stride=16 + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=16) + imgs = torch.randn(1, 3, 224, 224) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 64, 112, 112)) + assert feats[1] == torch.Size((1, 64, 56, 56)) + assert feats[2] == torch.Size((1, 128, 28, 28)) + assert feats[3] == torch.Size((1, 256, 14, 14)) + assert feats[4] == torch.Size((1, 512, 14, 14)) + + # Test resnet18 from timm, features_only=True, output_stride=8 + model = TIMMBackbone( + model_name='resnet18', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 224, 224) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 64, 112, 112)) + assert feats[1] == torch.Size((1, 64, 56, 56)) + assert feats[2] == torch.Size((1, 128, 28, 28)) + assert feats[3] == torch.Size((1, 256, 28, 28)) + assert feats[4] == torch.Size((1, 512, 28, 28)) + + # Test efficientnet_b1 with pretrained weights + model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True) + + # Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8 + model = TIMMBackbone( + model_name='resnetv2_50x1_bitm', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 8, 8) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 64, 4, 4)) + assert feats[1] == torch.Size((1, 256, 2, 2)) + assert feats[2] == torch.Size((1, 512, 1, 1)) + assert feats[3] == torch.Size((1, 1024, 1, 1)) + assert feats[4] == torch.Size((1, 2048, 1, 1)) + + # Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8 + model = TIMMBackbone( + model_name='resnetv2_50x3_bitm', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 8, 8) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 192, 4, 4)) + assert feats[1] == torch.Size((1, 768, 2, 2)) + assert feats[2] == torch.Size((1, 1536, 1, 1)) + assert feats[3] == torch.Size((1, 3072, 1, 1)) + assert feats[4] == torch.Size((1, 6144, 1, 1)) + + # Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8 + model = TIMMBackbone( + model_name='resnetv2_101x1_bitm', + features_only=True, + pretrained=False, + output_stride=8) + imgs = torch.randn(1, 3, 8, 8) + feats = model(imgs) + feats = [feat.shape for feat in feats] + assert len(feats) == 5 + assert feats[0] == torch.Size((1, 64, 4, 4)) + assert feats[1] == torch.Size((1, 256, 2, 2)) + assert feats[2] == torch.Size((1, 512, 1, 1)) + assert feats[3] == torch.Size((1, 1024, 1, 1)) + assert feats[4] == torch.Size((1, 2048, 1, 1))