Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Enable semantic segmentation backbone and head #412

Merged
merged 11 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Split `backbone` argument to `SemanticSegmentation` into `backbone` and `head` arguments ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412))

### Deprecated


### Fixed

- Fixed a bug where the `DefaultDataKeys.METADATA` couldn't be a dict ([#393](https://github.com/PyTorchLightning/lightning-flash/pull/393))
- Fixed a bug where the `SemanticSegmentation` task would not work as expected with finetuning callbacks ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412))

## [0.3.2] - 2021-06-08

Expand Down
4 changes: 2 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ def available_backbones(cls) -> List[str]:
return registry.available_keys()

@classmethod
def available_models(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "models", None)
def available_heads(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "heads", None)
if registry is None:
return []
return registry.available_keys()
Expand Down
98 changes: 20 additions & 78 deletions flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,106 +11,48 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from functools import partial

import torch.nn as nn
from deprecate import deprecated
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.image.backbones import catch_url_error

if _TORCHVISION_AVAILABLE:
from torchvision.models import segmentation

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.vision import UNet
else:
from pl_bolts.models.vision import UNet
from torchvision.models import mobilenetv3, resnet

FCN_MODELS = ["fcn_resnet50", "fcn_resnet101"]
DEEPLABV3_MODELS = ["deeplabv3_resnet50", "deeplabv3_resnet101", "deeplabv3_mobilenet_v3_large"]
LRASPP_MODELS = ["lraspp_mobilenet_v3_large"]
RESNET_MODELS = ["resnet50", "resnet101"]
MOBILENET_MODELS = ["mobilenet_v3_large"]

SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:

def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)
in_channels = model.classifier[-1].in_channels
model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1)
return model

for model_name in FCN_MODELS + DEEPLABV3_MODELS:
_type = model_name.split("_")[0]
def _load_resnet(model_name: str, pretrained: bool = True):
backbone = resnet.__dict__[model_name](
pretrained=pretrained,
replace_stride_with_dilation=[False, True, True],
)
return backbone

for model_name in RESNET_MODELS:
SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_fcn_deeplabv3, model_name)),
fn=catch_url_error(partial(_load_resnet, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
type=_type
)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=deprecated(
target=None,
stream=partial(warnings.warn, category=UserWarning),
deprecated_in="0.3.1",
remove_in="0.5.0",
template_mgs="The 'torchvision/fcn_resnet50' backbone has been deprecated since v%(deprecated_in)s in "
"favor of 'fcn_resnet50'. It will be removed in v%(remove_in)s.",
)(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet50")),
name="torchvision/fcn_resnet50",
)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=deprecated(
target=None,
stream=partial(warnings.warn, category=UserWarning),
deprecated_in="0.3.1",
remove_in="0.5.0",
template_mgs="The 'torchvision/fcn_resnet101' backbone has been deprecated since v%(deprecated_in)s in "
"favor of 'fcn_resnet101'. It will be removed in v%(remove_in)s.",
)(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet101")),
name="torchvision/fcn_resnet101",
)

def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)

low_channels = model.classifier.low_classifier.in_channels
high_channels = model.classifier.high_classifier.in_channels

model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1)
return model
def _load_mobilenetv3(model_name: str, pretrained: bool = True):
backbone = mobilenetv3.__dict__[model_name](
pretrained=pretrained,
_dilated=True,
)
return backbone

for model_name in LRASPP_MODELS:
for model_name in MOBILENET_MODELS:
SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_lraspp, model_name)),
fn=catch_url_error(partial(_load_mobilenetv3, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
type="lraspp"
)

if _BOLTS_AVAILABLE:

def load_bolts_unet(num_classes: int, pretrained: bool = False, **kwargs) -> nn.Module:
if pretrained:
rank_zero_warn(
"No pretrained weights are available for the pl_bolts.models.vision.UNet model. This backbone will be "
"initialized with random weights!", UserWarning
)
return UNet(num_classes, **kwargs)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet"
)
113 changes: 113 additions & 0 deletions flash/image/segmentation/heads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from functools import partial

import torch.nn as nn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
from torchvision.models import MobileNetV3, ResNet
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3
from torchvision.models.segmentation.fcn import FCN, FCNHead
from torchvision.models.segmentation.lraspp import LRASPP

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.vision import UNet
else:
from pl_bolts.models.vision import UNet

RESNET_MODELS = ["resnet50", "resnet101"]
MOBILENET_MODELS = ["mobilenet_v3_large"]

SEMANTIC_SEGMENTATION_HEADS = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:

def _get_backbone_meta(backbone):
if isinstance(backbone, ResNet):
out_layer = 'layer4'
out_inplanes = 2048
aux_layer = 'layer3'
aux_inplanes = 1024
elif isinstance(backbone, MobileNetV3):
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
] + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_layer = str(out_pos)
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_layer = str(aux_pos)
aux_inplanes = backbone[aux_pos].out_channels
else:
raise MisconfigurationException(
f"{type(backbone)} backbone is not currently supported for semantic segmentation."
)
return backbone, out_layer, out_inplanes, aux_layer, aux_inplanes

def _load_fcn_deeplabv3(model_name, backbone, num_classes):
backbone, out_layer, out_inplanes, aux_layer, aux_inplanes = _get_backbone_meta(backbone)

return_layers = {out_layer: 'out'}
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

model_map = {
"deeplabv3": (DeepLabHead, DeepLabV3),
"fcn": (FCNHead, FCN),
}
classifier = model_map[model_name][0](out_inplanes, num_classes)
base_model = model_map[model_name][1]

return base_model(backbone, classifier, None)

for model_name in ["fcn", "deeplabv3"]:
SEMANTIC_SEGMENTATION_HEADS(
fn=partial(_load_fcn_deeplabv3, model_name),
name=model_name,
namespace="image/segmentation",
package="torchvision",
)

def _load_lraspp(backbone, num_classes):
backbone, high_pos, high_channels, low_pos, low_channels = _get_backbone_meta(backbone)
backbone = IntermediateLayerGetter(backbone, return_layers={low_pos: 'low', high_pos: 'high'})
return LRASPP(backbone, low_channels, high_channels, num_classes)

SEMANTIC_SEGMENTATION_HEADS(
fn=_load_lraspp,
name="lraspp",
namespace="image/segmentation",
package="torchvision",
)

if _BOLTS_AVAILABLE:

def _load_bolts_unet(_, num_classes: int, **kwargs) -> nn.Module:
rank_zero_warn("The UNet model does not require a backbone, so the backbone will be ignored.", UserWarning)
return UNet(num_classes, **kwargs)

SEMANTIC_SEGMENTATION_HEADS(
fn=_load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet"
)
38 changes: 25 additions & 13 deletions flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS
from flash.image.segmentation.serialization import SegmentationLabels

if _KORNIA_AVAILABLE:
Expand Down Expand Up @@ -54,14 +55,15 @@ class SemanticSegmentation(ClassificationTask):

Args:
num_classes: Number of classes to classify.
backbone: A string or (model, num_features) tuple to use to compute image features,
defaults to ``"torchvision/fcn_resnet50"``.
backbone: A string or model to use to compute image features.
backbone_kwargs: Additional arguments for the backbone configuration.
pretrained: Use a pretrained backbone, defaults to ``False``.
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.AdamW`.
metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.IoU`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
head: A string or (model, num_features) tuple to use to compute image features.
head_kwargs: Additional arguments for the head configuration.
pretrained: Use a pretrained backbone.
loss_fn: Loss function for training.
optimizer: Optimizer to use for training.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
"""
Expand All @@ -70,11 +72,15 @@ class SemanticSegmentation(ClassificationTask):

backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES

heads: FlashRegistry = SEMANTIC_SEGMENTATION_HEADS

def __init__(
self,
num_classes: int,
backbone: Union[str, Tuple[nn.Module, int]] = "fcn_resnet50",
backbone: Union[str, nn.Module] = "resnet50",
backbone_kwargs: Optional[Dict] = None,
head: str = "fcn",
head_kwargs: Optional[Dict] = None,
pretrained: bool = True,
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW,
Expand Down Expand Up @@ -113,8 +119,15 @@ def __init__(
if not backbone_kwargs:
backbone_kwargs = {}

# TODO: pretrained to True causes some issues
self.backbone = self.backbones.get(backbone)(num_classes, pretrained=pretrained, **backbone_kwargs)
if not head_kwargs:
head_kwargs = {}

if isinstance(backbone, nn.Module):
self.backbone = backbone
else:
self.backbone = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)

self.head = self.heads.get(head)(self.backbone, num_classes, **head_kwargs)

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
Expand All @@ -134,8 +147,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
return batch

def forward(self, x) -> torch.Tensor:
# infer the image to the model
res = self.backbone(x)
res = self.head(x)

# some frameworks like torchvision return a dict.
# In particular, torchvision segmentation models return the output logits
Expand All @@ -145,7 +157,7 @@ def forward(self, x) -> torch.Tensor:
elif torch.is_tensor(res):
out = res
else:
raise NotImplementedError(f"Unsupported output type: {type(out)}")
raise NotImplementedError(f"Unsupported output type: {type(res)}")

return out

Expand Down
20 changes: 11 additions & 9 deletions flash_examples/finetuning/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,33 @@
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
batch_size=4,
val_split=0.3,
image_size=(200, 200), # (600, 800)
val_split=0.1,
image_size=(200, 200),
num_classes=21,
)

# 2.2 Visualise the samples
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3.a List available backbones
print(SemanticSegmentation.available_backbones())
# 3.a List available backbones and heads
print(f"Backbones: {SemanticSegmentation.available_backbones()}")
print(f"Heads: {SemanticSegmentation.available_heads()}")

# 3.b Build the model
model = SemanticSegmentation(
backbone="fcn_resnet50", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True)
backbone="mobilenet_v3_large",
head="fcn",
num_classes=datamodule.num_classes,
serializer=SegmentationLabels(visualize=True),
)

# 4. Create the trainer.
trainer = flash.Trainer(
max_epochs=1,
fast_dev_run=1,
)
trainer = flash.Trainer(fast_dev_run=True)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Segment a few images!
predictions = model.predict([
"data/CameraRGB/F61-1.png",
"data/CameraRGB/F62-1.png",
Expand Down
Loading