Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support TIMMBackbone #998

Merged
merged 17 commits into from
Nov 2, 2021
17 changes: 17 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,17 @@ jobs:
run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report
run: |
pip install timm
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m
if: ${{matrix.torch >= '1.5.0'}}
- name: Skip timm unittests and generate coverage report
run: |
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could CI work with TIMM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/open-mmlab/mmsegmentation/runs/4016108790?check_suite_focus=true

Can work with PyTorch >= 1.6.0, but failed with PyTorch 1.5.1
Please take a look at these two commits:
b75c201
bc7e80e

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could undo the ignore

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIMM does not support pt1.3

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if: ${{matrix.torch < '1.5.0'}} will skip the timm unittests.

coverage xml
coverage report -m
if: ${{matrix.torch < '1.5.0'}}

build_cuda101:
runs-on: ubuntu-18.04
Expand Down Expand Up @@ -142,9 +150,17 @@ jobs:
TORCH_CUDA_ARCH_LIST=7.0 pip install .
- name: Run unittests and generate coverage report
run: |
python -m pip install timm
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m
if: ${{matrix.torch >= '1.5.0'}}
- name: Skip timm unittests and generate coverage report
run: |
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
coverage xml
coverage report -m
if: ${{matrix.torch < '1.5.0'}}
- name: Upload coverage to Codecov
uses: codecov/[email protected]
with:
Expand Down Expand Up @@ -198,6 +214,7 @@ jobs:
TORCH_CUDA_ARCH_LIST=7.0 pip install .
- name: Run unittests and generate coverage report
run: |
python -m pip install timm
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m
Expand Down
3 changes: 2 additions & 1 deletion mmseg/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt
from .swin import SwinTransformer
from .timm_backbone import TIMMBackbone
from .unet import UNet
from .vit import VisionTransformer

__all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet'
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone'
]
63 changes: 63 additions & 0 deletions mmseg/models/backbones/timm_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) OpenMMLab. All rights reserved.
try:
import timm
except ImportError:
timm = None

from mmcv.cnn.bricks.registry import NORM_LAYERS
from mmcv.runner import BaseModule

from ..builder import BACKBONES


@BACKBONES.register_module()
class TIMMBackbone(BaseModule):
"""Wrapper to use backbones from timm library. More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_ .

Args:
model_name (str): Name of timm model to instantiate.
pretrained (bool): Load pretrained weights if True.
checkpoint_path (str): Path of checkpoint to load after
model is initialized.
in_channels (int): Number of input image channels. Default: 3.
init_cfg (dict, optional): Initialization config dict
**kwargs: Other timm & model specific arguments.
"""

def __init__(
self,
model_name,
features_only=True,
pretrained=True,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs,
):
if timm is None:
raise RuntimeError('timm is not installed')
super(TIMMBackbone, self).__init__(init_cfg)
if 'norm_layer' in kwargs:
kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer'])
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs,
)

# Make unused parameters None
self.timm_model.global_pool = None
self.timm_model.fc = None
self.timm_model.classifier = None

# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True

def forward(self, x):
features = self.timm_model(x)
return features
133 changes: 133 additions & 0 deletions tests/test_models/test_backbones/test_timm_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmseg.models.backbones import TIMMBackbone
from .utils import check_norm_state


def test_timm_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = TIMMBackbone()
model.init_weights(pretrained=0)

# Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN'
# Test resnet18 from timm, norm_layer='BN2d'
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
norm_layer='BN2d')

# Test resnet18 from timm, norm_layer='SyncBN'
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
norm_layer='SyncBN')

# Test resnet18 from timm, features_only=True, output_stride=32
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)

imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 14, 14))
assert feats[4] == torch.Size((1, 512, 7, 7))

# Test resnet18 from timm, features_only=True, output_stride=16
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=16)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 14, 14))
assert feats[4] == torch.Size((1, 512, 14, 14))

# Test resnet18 from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 28, 28))
assert feats[4] == torch.Size((1, 512, 28, 28))

# Test efficientnet_b1 with pretrained weights
model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True)

# Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_50x1_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 4, 4))
assert feats[1] == torch.Size((1, 256, 2, 2))
assert feats[2] == torch.Size((1, 512, 1, 1))
assert feats[3] == torch.Size((1, 1024, 1, 1))
assert feats[4] == torch.Size((1, 2048, 1, 1))

# Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_50x3_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 192, 4, 4))
assert feats[1] == torch.Size((1, 768, 2, 2))
assert feats[2] == torch.Size((1, 1536, 1, 1))
assert feats[3] == torch.Size((1, 3072, 1, 1))
assert feats[4] == torch.Size((1, 6144, 1, 1))

# Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_101x1_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 4, 4))
assert feats[1] == torch.Size((1, 256, 2, 2))
assert feats[2] == torch.Size((1, 512, 1, 1))
assert feats[3] == torch.Size((1, 1024, 1, 1))
assert feats[4] == torch.Size((1, 2048, 1, 1))