Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DeepLab models #959

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions segmentation_models_pytorch/decoders/deeplabv3/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,33 @@
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

from collections.abc import Iterable, Sequence
from typing import Literal

import torch
from torch import nn
from torch.nn import functional as F

__all__ = ["DeepLabV3Decoder"]
__all__ = ["DeepLabV3Decoder", "DeepLabV3PlusDecoder"]


class DeepLabV3Decoder(nn.Sequential):
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)):
def __init__(
self,
in_channels: int,
out_channels: int,
atrous_rates: Iterable[int],
aspp_separable: bool,
aspp_dropout: float,
):
super().__init__(
ASPP(in_channels, out_channels, atrous_rates),
ASPP(
in_channels,
out_channels,
atrous_rates,
separable=aspp_separable,
dropout=aspp_dropout,
),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
Expand All @@ -54,10 +70,12 @@ def forward(self, *features):
class DeepLabV3PlusDecoder(nn.Module):
def __init__(
self,
encoder_channels,
out_channels=256,
atrous_rates=(12, 24, 36),
output_stride=16,
encoder_channels: Sequence[int, ...],
out_channels: int,
atrous_rates: Iterable[int],
output_stride: Literal[8, 16],
aspp_separable: bool,
aspp_dropout: float,
):
super().__init__()
if output_stride not in {8, 16}:
Expand All @@ -69,7 +87,13 @@ def __init__(
self.output_stride = output_stride

self.aspp = nn.Sequential(
ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
ASPP(
encoder_channels[-1],
out_channels,
atrous_rates,
separable=aspp_separable,
dropout=aspp_dropout,
),
SeparableConv2d(
out_channels, out_channels, kernel_size=3, padding=1, bias=False
),
Expand Down Expand Up @@ -111,7 +135,7 @@ def forward(self, *features):


class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
def __init__(self, in_channels: int, out_channels: int, dilation: int):
super().__init__(
nn.Conv2d(
in_channels,
Expand All @@ -127,7 +151,7 @@ def __init__(self, in_channels, out_channels, dilation):


class ASPPSeparableConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
def __init__(self, in_channels: int, out_channels: int, dilation: int):
super().__init__(
SeparableConv2d(
in_channels,
Expand All @@ -143,7 +167,7 @@ def __init__(self, in_channels, out_channels, dilation):


class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels: int, out_channels: int):
super().__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
Expand All @@ -159,16 +183,22 @@ def forward(self, x):


class ASPP(nn.Module):
def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
def __init__(
self,
in_channels: int,
out_channels: int,
atrous_rates: Iterable[int],
separable: bool,
dropout: float,
):
super(ASPP, self).__init__()
modules = []
modules.append(
modules = [
nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
)
]

rate1, rate2, rate3 = tuple(atrous_rates)
ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
Expand All @@ -184,7 +214,7 @@ def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5),
nn.Dropout(dropout),
)

def forward(self, x):
Expand Down
49 changes: 31 additions & 18 deletions segmentation_models_pytorch/decoders/deeplabv3/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Optional
from collections.abc import Iterable
from typing import Literal, Optional

from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder


Expand All @@ -22,13 +24,17 @@ class DeepLabV3(SegmentationModel):
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_channels: A number of convolution filters in ASPP module. Default is 256
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values)
decoder_aspp_separable: Use separable convolutions in ASPP module. Default is False
decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
Expand All @@ -49,11 +55,15 @@ def __init__(
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
encoder_output_stride: Literal[8, 16] = 8,
decoder_channels: int = 256,
decoder_atrous_rates: Iterable[int] = (12, 24, 36),
decoder_aspp_separable: bool = False,
decoder_aspp_dropout: float = 0.5,
in_channels: int = 3,
classes: int = 1,
activation: Optional[str] = None,
upsampling: int = 8,
upsampling: Optional[int] = None,
aux_params: Optional[dict] = None,
):
super().__init__()
Expand All @@ -63,19 +73,23 @@ def __init__(
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
output_stride=8,
output_stride=encoder_output_stride,
)

self.decoder = DeepLabV3Decoder(
in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels
in_channels=self.encoder.out_channels[-1],
out_channels=decoder_channels,
atrous_rates=decoder_atrous_rates,
aspp_separable=decoder_aspp_separable,
aspp_dropout=decoder_aspp_dropout,
)

self.segmentation_head = SegmentationHead(
in_channels=self.decoder.out_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
upsampling=upsampling,
upsampling=encoder_output_stride if upsampling is None else upsampling,
)

if aux_params is not None:
Expand All @@ -100,7 +114,9 @@ class DeepLabV3Plus(SegmentationModel):
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
decoder_atrous_rates: Dilation rates for ASPP module (should be an iterable of 3 integer values)
decoder_aspp_separable: Use separable convolutions in ASPP module. Default is True
decoder_aspp_dropout: Use dropout in ASPP module projection layer. Default is 0.5
decoder_channels: A number of convolution filters in ASPP module. Default is 256
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
Expand Down Expand Up @@ -129,9 +145,11 @@ def __init__(
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
encoder_output_stride: int = 16,
encoder_output_stride: Literal[8, 16] = 16,
decoder_channels: int = 256,
decoder_atrous_rates: tuple = (12, 24, 36),
decoder_atrous_rates: Iterable[int] = (12, 24, 36),
decoder_aspp_separable: bool = True,
decoder_aspp_dropout: float = 0.5,
in_channels: int = 3,
classes: int = 1,
activation: Optional[str] = None,
Expand All @@ -140,13 +158,6 @@ def __init__(
):
super().__init__()

if encoder_output_stride not in [8, 16]:
raise ValueError(
"Encoder output stride should be 8 or 16, got {}".format(
encoder_output_stride
)
)

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
Expand All @@ -160,6 +171,8 @@ def __init__(
out_channels=decoder_channels,
atrous_rates=decoder_atrous_rates,
output_stride=encoder_output_stride,
aspp_separable=decoder_aspp_separable,
aspp_dropout=decoder_aspp_dropout,
)

self.segmentation_head = SegmentationHead(
Expand Down