From 6f849e5500715fdfb99c178602d6b345797b06f8 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 6 Apr 2021 16:29:07 +0200 Subject: [PATCH] timm integration (#196) Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Kaushik Bokka (cherry picked from commit 6e548a443687d7e6de206f646762305ce17c448e) --- flash/utils/imports.py | 6 ++++++ flash/vision/backbones.py | 20 +++++++++++++++++ requirements/extras.txt | 1 + tests/vision/test_backbones.py | 39 +++++++++++++++++++++++++++++++++- 4 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 flash/utils/imports.py create mode 100644 requirements/extras.txt diff --git a/flash/utils/imports.py b/flash/utils/imports.py new file mode 100644 index 0000000000..dbcedc5e3a --- /dev/null +++ b/flash/utils/imports.py @@ -0,0 +1,6 @@ +from pytorch_lightning.utilities.imports import _module_available + +_TABNET_AVAILABLE = _module_available("pytorch_tabnet") +_KORNIA_AVAILABLE = _module_available("kornia") +_COCO_AVAILABLE = _module_available("pycocotools") +_TIMM_AVAILABLE = _module_available("timm") diff --git a/flash/vision/backbones.py b/flash/vision/backbones.py index dcae77822c..65e1b0b557 100644 --- a/flash/vision/backbones.py +++ b/flash/vision/backbones.py @@ -19,6 +19,11 @@ from torch import nn as nn from torchvision.models.detection.backbone_utils import resnet_fpn_backbone +from flash.utils.imports import _TIMM_AVAILABLE + +if _TIMM_AVAILABLE: + import timm + if _BOLTS_AVAILABLE: from pl_bolts.models.self_supervised import SimCLR, SwAV @@ -70,6 +75,9 @@ def backbone_and_num_features( if model_name in TORCHVISION_MODELS: return torchvision_backbone_and_num_features(model_name, pretrained) + if _TIMM_AVAILABLE and model_name in timm.list_models(): + return timm_backbone_and_num_features(model_name, pretrained) + raise ValueError(f"{model_name} is not supported yet.") @@ -140,3 +148,15 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr return backbone, num_features raise ValueError(f"{model_name} is not supported yet.") + + +def timm_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: + + if model_name in timm.list_models(): + backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='') + num_features = backbone.num_features + return backbone, num_features + + raise ValueError( + f"{model_name} is not supported in timm yet. https://rwightman.github.io/pytorch-image-models/models/" + ) diff --git a/requirements/extras.txt b/requirements/extras.txt new file mode 100644 index 0000000000..78882bf8f3 --- /dev/null +++ b/requirements/extras.txt @@ -0,0 +1 @@ +timm>=0.4.5 \ No newline at end of file diff --git a/tests/vision/test_backbones.py b/tests/vision/test_backbones.py index b04a96b8b0..aec4064ddd 100644 --- a/tests/vision/test_backbones.py +++ b/tests/vision/test_backbones.py @@ -1,6 +1,13 @@ import pytest +from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE -from flash.vision.backbones import backbone_and_num_features +from flash.utils.imports import _TIMM_AVAILABLE +from flash.vision.backbones import ( + backbone_and_num_features, + bolts_backbone_and_num_features, + timm_backbone_and_num_features, + torchvision_backbone_and_num_features, +) @pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenet_v2", 1280), @@ -11,3 +18,33 @@ def test_backbone_and_num_features(backbone, expected_num_features): assert backbone_model assert num_features == expected_num_features + + +@pytest.mark.skipif(not _TIMM_AVAILABLE, reason="test requires timm") +@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenetv2_100", 1280)]) +def test_timm_backbone_and_num_features(backbone, expected_num_features): + + backbone_model, num_features = timm_backbone_and_num_features(model_name=backbone, pretrained=False) + + assert backbone_model + assert num_features == expected_num_features + + +@pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="test requires bolts") +@pytest.mark.parametrize(["backbone", "expected_num_features"], [("simclr-imagenet", 2048), ("swav-imagenet", 2048)]) +def test_bolts_backbone_and_num_features(backbone, expected_num_features): + + backbone_model, num_features = bolts_backbone_and_num_features(model_name=backbone) + + assert backbone_model + assert num_features == expected_num_features + + +@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires torchvision") +@pytest.mark.parametrize(["backbone", "expected_num_features"], [("resnet34", 512), ("mobilenet_v2", 1280)]) +def test_torchvision_backbone_and_num_features(backbone, expected_num_features): + + backbone_model, num_features = torchvision_backbone_and_num_features(model_name=backbone, pretrained=False) + + assert backbone_model + assert num_features == expected_num_features