Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

timm integration #196

Merged
merged 14 commits into from
Apr 6, 2021
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
_TABNET_AVAILABLE = _module_available("pytorch_tabnet")
_KORNIA_AVAILABLE = _module_available("kornia")
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
26 changes: 26 additions & 0 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
from torch import nn as nn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

from flash.utils.imports import _TIMM_AVAILABLE

if _TIMM_AVAILABLE:
import timm

if _BOLTS_AVAILABLE:
from pl_bolts.models.self_supervised import SimCLR, SwAV

Expand Down Expand Up @@ -70,6 +75,9 @@ def backbone_and_num_features(
if model_name in TORCHVISION_MODELS:
return torchvision_backbone_and_num_features(model_name, pretrained)

if _TIMM_AVAILABLE and model_name in timm.list_models():
return timm_backbone_and_num_features(model_name, pretrained)

raise ValueError(f"{model_name} is not supported yet.")


Expand Down Expand Up @@ -140,3 +148,21 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr
return backbone, num_features

raise ValueError(f"{model_name} is not supported yet.")


def timm_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
>>> timm_backbone_and_num_features('resnet18') # doctest: +ELLIPSIS
(ResNet(...), 512)
>>> timm_backbone_and_num_features('mobilenetv3_large_100') # doctest: +ELLIPSIS
(MobileNetV3(...), 1280)
"""

if model_name in timm.list_models():
backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='')
num_features = backbone.num_features
return backbone, num_features

raise ValueError(
f"{model_name} is not supported in timm yet. https://rwightman.github.io/pytorch-image-models/models/"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)
1 change: 1 addition & 0 deletions requirements/extras.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
timm~=0.4.5
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
13 changes: 12 additions & 1 deletion tests/vision/test_backbones.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from flash.vision.backbones import backbone_and_num_features
from flash.utils.imports import _TIMM_AVAILABLE
from flash.vision.backbones import backbone_and_num_features, timm_backbone_and_num_features


@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenet_v2", 1280),
Expand All @@ -11,3 +12,13 @@ def test_backbone_and_num_features(backbone, expected_num_features):

assert backbone_model
assert num_features == expected_num_features


@pytest.mark.skipif(not _TIMM_AVAILABLE, reason="test requires timm")
@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenetv2_100", 1280)])
def test_timm_backbone_and_num_features(backbone, expected_num_features):

backbone_model, num_features = timm_backbone_and_num_features(model_name=backbone, pretrained=False)

assert backbone_model
assert num_features == expected_num_features