Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add support for more backbones(mobilnet, vgg, densenet, resnext) & refactor #45

Merged
merged 12 commits into from
Feb 2, 2021
36 changes: 36 additions & 0 deletions flash/vision/classification/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Tuple

import torch.nn as nn
import torchvision
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:

model = getattr(torchvision.models, model_name, None)
if model is None:
raise MisconfigurationException(f"{model_name} is not supported by torchvision")

if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]:
model = model(pretrained=pretrained)
backbone = model.features
num_features = model.classifier[-1].in_features
return backbone, num_features

elif model_name in [
"resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"
]:
model = model(pretrained=pretrained)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features

elif model_name in ["densenet121", "densenet169", "densenet161", "densenet161"]:
model = model(pretrained=pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
return backbone, num_features

else:
raise ValueError(f"{model_name} is not supported yet.")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
20 changes: 2 additions & 18 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,9 @@
from torch.nn import functional as F

from flash.core.classification import ClassificationTask
from flash.vision.classification.backbones import torchvision_backbone_and_num_features
from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline

_resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731
_resnet_feats = lambda model: model.fc.in_features # noqa: E731

_backbones = {
"resnet18": (torchvision.models.resnet18, _resnet_backbone, _resnet_feats),
"resnet34": (torchvision.models.resnet34, _resnet_backbone, _resnet_feats),
"resnet50": (torchvision.models.resnet50, _resnet_backbone, _resnet_feats),
"resnet101": (torchvision.models.resnet101, _resnet_backbone, _resnet_feats),
"resnet152": (torchvision.models.resnet152, _resnet_backbone, _resnet_feats),
}


class ImageClassifier(ClassificationTask):
"""Task that classifies images.
Expand Down Expand Up @@ -67,13 +57,7 @@ def __init__(

self.save_hyperparameters()

if backbone not in _backbones:
raise NotImplementedError(f"Backbone {backbone} is not yet supported")

backbone_fn, split, num_feats = _backbones[backbone]
backbone = backbone_fn(pretrained=pretrained)
self.backbone = split(backbone)
num_features = num_feats(backbone)
self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained)

self.head = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Expand Down