Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

add swav and simclr models to imageclassifier + backbone reorg #68

Merged
merged 12 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,71 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import suppress
carmocca marked this conversation as resolved.
Show resolved Hide resolved
from typing import Tuple

import torch.nn as nn
import torchvision
from pytorch_lightning.utilities import _BOLTS_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn as nn

if _BOLTS_AVAILABLE:
with suppress(TypeError):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
from pl_bolts.models.self_supervised import SimCLR, SwAV
carmocca marked this conversation as resolved.
Show resolved Hide resolved

ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com"

MOBILENET_MODELS = ["mobilenet_v2"]
VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"]
RESNET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"]
DENSENET_MODELS = ["densenet121", "densenet169", "densenet161", "densenet161"]
TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS

BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"]


def backbone_and_num_features(model_name: str, *args, **kwargs) -> Tuple[nn.Module, int]:
if model_name in BOLTS_MODELS:
return bolts_backbone_and_num_features(model_name)

if model_name in TORCHVISION_MODELS:
return torchvision_backbone_and_num_features(model_name, *args, **kwargs)

raise ValueError(f"{model_name} is not supported yet.")


def bolts_backbone_and_num_features(model_name: str) -> Tuple[nn.Module, int]:
"""
>>> bolts_backbone_and_num_features('simclr-imagenet') # doctest: +ELLIPSIS
(Sequential(...), 2048)
>>> bolts_backbone_and_num_features('swav-imagenet') # doctest: +ELLIPSIS
(Sequential(...), 3000)
"""

# TODO: maybe we should plain pytorch weights so we don't need to rely on bolts to load these
# also mabye just use torchhub for the ssl lib
def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"):
simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(simclr.encoder.children())[:-2])
return backbone, 2048

def load_swav_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"):
swav = SwAV.load_from_checkpoint(path_or_url, strict=True)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(swav.model.children())[:-2])
return backbone, 3000

models = {
'simclr-imagenet': load_simclr_imagenet,
'swav-imagenet': load_swav_imagenet,
}
if not _BOLTS_AVAILABLE:
raise MisconfigurationException("Bolts isn't installed. Please, use ``pip install lightning-bolts``.")
if model_name in models:
return models[model_name]()

raise ValueError(f"{model_name} is not supported yet.")


def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
Expand All @@ -31,22 +91,20 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr
if model is None:
raise MisconfigurationException(f"{model_name} is not supported by torchvision")

if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]:
if model_name in MOBILENET_MODELS + VGG_MODELS:
model = model(pretrained=pretrained)
backbone = model.features
num_features = model.classifier[-1].in_features
return backbone, num_features

elif model_name in [
"resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"
]:
elif model_name in RESNET_MODELS:
model = model(pretrained=pretrained)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features

elif model_name in ["densenet121", "densenet169", "densenet161", "densenet161"]:
elif model_name in DENSENET_MODELS:
model = model(pretrained=pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
Expand Down
4 changes: 2 additions & 2 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn import functional as F

from flash.core.classification import ClassificationTask
from flash.vision.backbones import torchvision_backbone_and_num_features
from flash.vision.backbones import backbone_and_num_features
from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline


Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(

self.save_hyperparameters()

self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained)
self.backbone, num_features = backbone_and_num_features(backbone, pretrained)

self.head = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Expand Down
10 changes: 2 additions & 8 deletions flash/vision/embedding/image_embedder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
from flash.core import Task
from flash.core.data import TaskDataPipeline
from flash.core.data.utils import _contains_any_tensor
from flash.vision.backbones import torchvision_backbone_and_num_features
from flash.vision.backbones import backbone_and_num_features
from flash.vision.classification.data import _default_valid_transforms, _pil_loader
from flash.vision.embedding.model_map import _load_bolts_model, _models


class ImageEmbedderDataPipeline(TaskDataPipeline):
Expand Down Expand Up @@ -115,12 +114,7 @@ def __init__(
assert pooling_fn in [torch.mean, torch.max]
self.pooling_fn = pooling_fn

if backbone in _models:
config = _load_bolts_model(backbone)
self.backbone = config['model']
num_features = config['num_features']
else:
self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained)
self.backbone, num_features = backbone_and_num_features(backbone, pretrained)

if embedding_dim is None:
self.head = nn.Identity()
Expand Down
49 changes: 0 additions & 49 deletions flash/vision/embedding/model_map.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/vision/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_init_train(tmpdir, backbone):


def test_non_existent_backbone():
with pytest.raises(MisconfigurationException):
with pytest.raises(ValueError):
ImageClassifier(2, "i am never going to implement this lol")


Expand Down
4 changes: 2 additions & 2 deletions tests/vision/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.
import pytest

from flash.vision.embedding.model_map import _load_bolts_model
from flash.vision.backbones import bolts_backbone_and_num_features


@pytest.mark.parametrize("name", ['simclr-imagenet', 'swav-imagenet'])
def test_load_bolts(name):
_load_bolts_model(name)
bolts_backbone_and_num_features(name)