Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Conv2dNormActivation and Conv3dNormActivation Blocks #5445

Merged
merged 17 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,6 @@ Operators
StochasticDepth
FrozenBatchNorm2d
ConvNormActivation
Conv2dNormActivation
Conv3dNormActivation
SqueezeExcitation
3 changes: 2 additions & 1 deletion torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss
from .misc import FrozenBatchNorm2d, ConvNormActivation, SqueezeExcitation
from .misc import FrozenBatchNorm2d, ConvNormActivation, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation
from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool
Expand Down Expand Up @@ -52,6 +52,7 @@
"StochasticDepth",
"FrozenBatchNorm2d",
"ConvNormActivation",
"Conv2dNormActivation" "Conv3dNormActivation",
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
"SqueezeExcitation",
"generalized_box_iou_loss",
]
186 changes: 165 additions & 21 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Callable, List, Optional

import torch
Expand Down Expand Up @@ -65,29 +66,12 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"


class ConvNormActivation(torch.nn.Sequential):
"""
Configurable block used for Convolution-Normalzation-Activation blocks.

Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.

"""

class _ConvNormActivation(torch.nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
layer: Callable[..., torch.nn.Module],
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
Expand All @@ -98,12 +82,14 @@ def __init__(
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:

datumbox marked this conversation as resolved.
Show resolved Hide resolved
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
if padding is None:
padding = (kernel_size - 1) // 2 * dilation
if bias is None:
bias = norm_layer is None

layers = [
torch.nn.Conv2d(
layer(
in_channels,
out_channels,
kernel_size,
Expand All @@ -114,16 +100,174 @@ def __init__(
bias=bias,
)
]

if norm_layer is not None:
layers.append(norm_layer(out_channels))

if activation_layer is not None:
params = {} if inplace is None else {"inplace": inplace}
layers.append(activation_layer(**params))
super().__init__(*layers)
_log_api_usage_once(self)
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved

self.out_channels = out_channels

datumbox marked this conversation as resolved.
Show resolved Hide resolved

class ConvNormActivation(_ConvNormActivation):
"""
Configurable block used for Convolution-Normalzation-Activation blocks.
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved

Args:
in_channels (int): Number of channels in the input image
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.

oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:
warnings.warn(
"The ConvNormActivation class are deprecated since 0.13 and will be removed in 0.15. "
"Use torchvision.ops.Conv2dNormActivation instead.",
FutureWarning,
)
layer = torch.nn.Conv2d
super().__init__(
in_channels,
datumbox marked this conversation as resolved.
Show resolved Hide resolved
out_channels,
layer,
kernel_size,
stride,
padding,
groups,
norm_layer,
activation_layer,
dilation,
inplace,
bias,
)
datumbox marked this conversation as resolved.
Show resolved Hide resolved


class Conv2dNormActivation(_ConvNormActivation):
"""
Configurable block used for Convolution-Normalzation-Activation blocks.
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved

Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.

"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:

layer = torch.nn.Conv2d
super().__init__(
in_channels,
out_channels,
layer,
kernel_size,
stride,
padding,
groups,
norm_layer,
activation_layer,
dilation,
inplace,
bias,
)


class Conv3dNormActivation(_ConvNormActivation):
"""
Configurable block used for Convolution-Normalzation-Activation blocks.

Args:
in_channels (int): Number of channels in the input video.
out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
stride (int, optional): Stride of the convolution. Default: 1
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
dilation (int): Spacing between kernel elements. Default: 1
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
groups: int = 1,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
dilation: int = 1,
inplace: Optional[bool] = True,
bias: Optional[bool] = None,
) -> None:

layer = torch.nn.Conv3d
super().__init__(
in_channels,
out_channels,
layer,
kernel_size,
stride,
padding,
groups,
norm_layer,
activation_layer,
dilation,
inplace,
bias,
)


class SqueezeExcitation(torch.nn.Module):
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
"""
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
Expand Down