Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
segmentation_models.pytorch integration (#562)
Browse files Browse the repository at this point in the history
* init smp integration 🚀

* fix backbone & head

* backbone/heads backward compatibility

* update

* update

* move ENCODERS to bottom

* update self.encoder

* remove lrsapp

* update tests ✅

* fix model tests

* update

* Fixes

* Update CHANGELOG.md

Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
3 people authored Jul 13, 2021
1 parent c318e4a commit 78867ad
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 182 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))

- Added support for Semantic Segmentation backbones and heads from `segmentation-models.pytorch` ([#562](https://github.com/PyTorchLightning/lightning-flash/pull/562))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/general/jit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ This table gives a breakdown of the supported features.
- Yes
- Yes
* - :class:`~flash.image.segmentation.model.SemanticSegmentation`
- Yes
- No
- Yes
- Yes
* - :class:`~flash.image.style_transfer.model.StyleTransfer`
Expand Down
2 changes: 2 additions & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _compare_version(package: str, op, version) -> bool:
_CYTOOLZ_AVAILABLE = _module_available("cytoolz")
_UVICORN_AVAILABLE = _module_available("uvicorn")
_PIL_AVAILABLE = _module_available("PIL")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")

if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
Expand All @@ -100,6 +101,7 @@ def _compare_version(package: str, op, version) -> bool:
_COCO_AVAILABLE,
_FIFTYONE_AVAILABLE,
_PYSTICHE_AVAILABLE,
_SEGMENTATION_MODELS_AVAILABLE,
])
_SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE

Expand Down
43 changes: 11 additions & 32 deletions flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,24 @@
from functools import partial

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.image.backbones import catch_url_error
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE

if _TORCHVISION_AVAILABLE:
from torchvision.models import mobilenetv3, resnet

MOBILENET_MODELS = ["mobilenet_v3_large"]
RESNET_MODELS = ["resnet50", "resnet101"]
if _SEGMENTATION_MODELS_AVAILABLE:
import segmentation_models_pytorch as smp

SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:
if _SEGMENTATION_MODELS_AVAILABLE:

def _load_resnet(model_name: str, pretrained: bool = True):
backbone = resnet.__dict__[model_name](
pretrained=pretrained,
replace_stride_with_dilation=[False, True, True],
)
return backbone
ENCODERS = smp.encoders.get_encoder_names()

for model_name in RESNET_MODELS:
SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_load_resnet, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
)

def _load_mobilenetv3(model_name: str, pretrained: bool = True):
backbone = mobilenetv3.__dict__[model_name](
pretrained=pretrained,
_dilated=True,
)
def _load_smp_backbone(backbone: str, **_) -> str:
return backbone

for model_name in MOBILENET_MODELS:
for encoder_name in ENCODERS:
short_name = encoder_name
if short_name.startswith("timm-"):
short_name = encoder_name[5:]
SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_load_mobilenetv3, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
partial(_load_smp_backbone, backbone=encoder_name), name=short_name, namespace="image/segmentation"
)
2 changes: 1 addition & 1 deletion flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: Tuple[int, int] = (196, 196),
image_size: Tuple[int, int] = (128, 128),
deserializer: Optional['Deserializer'] = None,
num_classes: int = None,
labels_map: Dict[int, Tuple[int, int, int]] = None,
Expand Down
125 changes: 38 additions & 87 deletions flash/image/segmentation/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,103 +11,54 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from functools import partial

import torch.nn as nn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from typing import Callable

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE

if _TORCHVISION_AVAILABLE:
from torchvision.models import MobileNetV3, ResNet
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3
from torchvision.models.segmentation.fcn import FCN, FCNHead
from torchvision.models.segmentation.lraspp import LRASPP
if _SEGMENTATION_MODELS_AVAILABLE:
import segmentation_models_pytorch as smp

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.vision import UNet
else:
from pl_bolts.models.vision import UNet
SMP_MODEL_CLASS = [
smp.Unet, smp.UnetPlusPlus, smp.MAnet, smp.Linknet, smp.FPN, smp.PSPNet, smp.DeepLabV3, smp.DeepLabV3Plus,
smp.PAN
]
SMP_MODELS = {a.__name__.lower(): a for a in SMP_MODEL_CLASS}

SEMANTIC_SEGMENTATION_HEADS = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:

def _get_backbone_meta(backbone):
"""Adapted from torchvision.models.segmentation.segmentation._segm_model:
https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/segmentation.py#L25
"""
if isinstance(backbone, ResNet):
out_layer = 'layer4'
out_inplanes = 2048
aux_layer = 'layer3'
aux_inplanes = 1024
elif isinstance(backbone, MobileNetV3):
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)]
stage_indices = [0] + stage_indices + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_layer = str(out_pos)
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_layer = str(aux_pos)
aux_inplanes = backbone[aux_pos].out_channels
else:
raise MisconfigurationException(
f"{type(backbone)} backbone is not currently supported for semantic segmentation."
)
return backbone, out_layer, out_inplanes, aux_layer, aux_inplanes

def _load_fcn_deeplabv3(model_name, backbone, num_classes):
backbone, out_layer, out_inplanes, aux_layer, aux_inplanes = _get_backbone_meta(backbone)

return_layers = {out_layer: 'out'}
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

model_map = {
"deeplabv3": (DeepLabHead, DeepLabV3),
"fcn": (FCNHead, FCN),
}
classifier = model_map[model_name][0](out_inplanes, num_classes)
base_model = model_map[model_name][1]

return base_model(backbone, classifier, None)
if _SEGMENTATION_MODELS_AVAILABLE:

def _load_smp_head(
head: str,
backbone: str,
pretrained: bool = True,
num_classes: int = 1,
in_channels: int = 3,
**kwargs,
) -> Callable:

if head not in SMP_MODELS:
raise NotImplementedError(f"{head} is not implemented! Supported heads -> {SMP_MODELS.keys()}")

encoder_weights = None
if pretrained:
encoder_weights = "imagenet"

return smp.create_model(
arch=head,
encoder_name=backbone,
encoder_weights=encoder_weights,
classes=num_classes,
in_channels=in_channels,
**kwargs,
)

for model_name in ["fcn", "deeplabv3"]:
for model_name in SMP_MODELS:
SEMANTIC_SEGMENTATION_HEADS(
fn=partial(_load_fcn_deeplabv3, model_name),
partial(_load_smp_head, head=model_name),
name=model_name,
namespace="image/segmentation",
package="torchvision",
package="segmentation_models.pytorch"
)

def _load_lraspp(backbone, num_classes):
backbone, high_pos, high_channels, low_pos, low_channels = _get_backbone_meta(backbone)
backbone = IntermediateLayerGetter(backbone, return_layers={low_pos: 'low', high_pos: 'high'})
return LRASPP(backbone, low_channels, high_channels, num_classes)

SEMANTIC_SEGMENTATION_HEADS(
fn=_load_lraspp,
name="lraspp",
namespace="image/segmentation",
package="torchvision",
)

if _BOLTS_AVAILABLE:

def _load_bolts_unet(_, num_classes: int, **kwargs) -> nn.Module:
rank_zero_warn("The UNet model does not require a backbone, so the backbone will be ignored.", UserWarning)
return UNet(num_classes, **kwargs)

SEMANTIC_SEGMENTATION_HEADS(
fn=_load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet"
)
9 changes: 6 additions & 3 deletions flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
num_classes: int,
backbone: Union[str, nn.Module] = "resnet50",
backbone_kwargs: Optional[Dict] = None,
head: str = "fcn",
head: str = "fpn",
head_kwargs: Optional[Dict] = None,
pretrained: bool = True,
loss_fn: Optional[Callable] = None,
Expand Down Expand Up @@ -117,9 +117,12 @@ def __init__(
if isinstance(backbone, nn.Module):
self.backbone = backbone
else:
self.backbone = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)
self.backbone = self.backbones.get(backbone)(**backbone_kwargs)

self.head = self.heads.get(head)(self.backbone, num_classes, **head_kwargs)
self.head: nn.Module = self.heads.get(head)(
backbone=self.backbone, num_classes=num_classes, pretrained=pretrained, **head_kwargs
)
self.backbone = self.head.encoder

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
Expand Down
6 changes: 3 additions & 3 deletions flash_examples/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
val_split=0.1,
image_size=(200, 200),
image_size=(256, 256),
num_classes=21,
)

# 2. Build the task
model = SemanticSegmentation(
backbone="mobilenet_v3_large",
head="fcn",
backbone="mobilenetv3_large_100",
head="fpn",
num_classes=datamodule.num_classes,
)

Expand Down
1 change: 1 addition & 0 deletions requirements/datatype_image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ matplotlib
pycocotools>=2.0.2 ; python_version >= "3.7"
fiftyone
pystiche>=0.7.2
segmentation-models-pytorch
13 changes: 5 additions & 8 deletions tests/image/segmentation/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE

from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES


@pytest.mark.parametrize(["backbone"], [
pytest.param("resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
pytest.param("mobilenet_v3_large", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
pytest.param("resnet50", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")),
pytest.param("dpn131", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")),
])
def test_semantic_segmentation_backbones_registry(backbone):
img = torch.rand(1, 3, 32, 32)
backbone = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)(pretrained=False)
backbone = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)()
assert backbone
backbone.eval()
assert backbone(img) is not None
assert isinstance(backbone, str)
Loading

0 comments on commit 78867ad

Please sign in to comment.