Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
timm integration (#196)
Browse files Browse the repository at this point in the history
Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Kaushik Bokka <[email protected]>
  • Loading branch information
3 people authored Apr 6, 2021
1 parent 3b6a5de commit 6e548a4
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 1 deletion.
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
_TABNET_AVAILABLE = _module_available("pytorch_tabnet")
_KORNIA_AVAILABLE = _module_available("kornia")
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
20 changes: 20 additions & 0 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
from torch import nn as nn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

from flash.utils.imports import _TIMM_AVAILABLE

if _TIMM_AVAILABLE:
import timm

if _BOLTS_AVAILABLE:
from pl_bolts.models.self_supervised import SimCLR, SwAV

Expand Down Expand Up @@ -70,6 +75,9 @@ def backbone_and_num_features(
if model_name in TORCHVISION_MODELS:
return torchvision_backbone_and_num_features(model_name, pretrained)

if _TIMM_AVAILABLE and model_name in timm.list_models():
return timm_backbone_and_num_features(model_name, pretrained)

raise ValueError(f"{model_name} is not supported yet.")


Expand Down Expand Up @@ -140,3 +148,15 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr
return backbone, num_features

raise ValueError(f"{model_name} is not supported yet.")


def timm_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:

if model_name in timm.list_models():
backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='')
num_features = backbone.num_features
return backbone, num_features

raise ValueError(
f"{model_name} is not supported in timm yet. https://rwightman.github.io/pytorch-image-models/models/"
)
1 change: 1 addition & 0 deletions requirements/extras.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
timm>=0.4.5
39 changes: 38 additions & 1 deletion tests/vision/test_backbones.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import pytest
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE

from flash.vision.backbones import backbone_and_num_features
from flash.utils.imports import _TIMM_AVAILABLE
from flash.vision.backbones import (
backbone_and_num_features,
bolts_backbone_and_num_features,
timm_backbone_and_num_features,
torchvision_backbone_and_num_features,
)


@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenet_v2", 1280),
Expand All @@ -11,3 +18,33 @@ def test_backbone_and_num_features(backbone, expected_num_features):

assert backbone_model
assert num_features == expected_num_features


@pytest.mark.skipif(not _TIMM_AVAILABLE, reason="test requires timm")
@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenetv2_100", 1280)])
def test_timm_backbone_and_num_features(backbone, expected_num_features):

backbone_model, num_features = timm_backbone_and_num_features(model_name=backbone, pretrained=False)

assert backbone_model
assert num_features == expected_num_features


@pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="test requires bolts")
@pytest.mark.parametrize(["backbone", "expected_num_features"], [("simclr-imagenet", 2048), ("swav-imagenet", 2048)])
def test_bolts_backbone_and_num_features(backbone, expected_num_features):

backbone_model, num_features = bolts_backbone_and_num_features(model_name=backbone)

assert backbone_model
assert num_features == expected_num_features


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires torchvision")
@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenet_v2", 1280)])
def test_torchvision_backbone_and_num_features(backbone, expected_num_features):

backbone_model, num_features = torchvision_backbone_and_num_features(model_name=backbone, pretrained=False)

assert backbone_model
assert num_features == expected_num_features

0 comments on commit 6e548a4

Please sign in to comment.