diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 6b215572900..f8485607168 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -43,3 +43,6 @@ Operators MultiScaleRoIAlign FeaturePyramidNetwork StochasticDepth + FrozenBatchNorm2d + ConvNormActivation + SqueezeExcitation diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 0acd5fe990f..d5cdf39d20f 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -13,6 +13,7 @@ from .deform_conv import deform_conv2d, DeformConv2d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss +from .misc import FrozenBatchNorm2d, ConvNormActivation, SqueezeExcitation from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_pool import ps_roi_pool, PSRoIPool @@ -48,4 +49,7 @@ "sigmoid_focal_loss", "stochastic_depth", "StochasticDepth", + "FrozenBatchNorm2d", + "ConvNormActivation", + "SqueezeExcitation", ] diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 3c344a8c5ce..fac9a3570d6 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,13 +1,3 @@ -""" -helper class that supports empty tensors on some nn functions. - -Ideally, add support directly in PyTorch to empty tensors in -those functions. - -This can be removed once https://github.com/pytorch/pytorch/issues/12013 -is implemented -""" - import warnings from typing import Callable, List, Optional @@ -53,8 +43,11 @@ def __init__(self, *args, **kwargs): # This is not in nn class FrozenBatchNorm2d(torch.nn.Module): """ - BatchNorm2d where the batch statistics and the affine parameters - are fixed + BatchNorm2d where the batch statistics and the affine parameters are fixed + + Args: + num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)`` + eps (float): a value added to the denominator for numerical stability. Default: 1e-5 """ def __init__( @@ -109,6 +102,23 @@ def __repr__(self) -> str: class ConvNormActivation(torch.nn.Sequential): + """ + Configurable block used for Convolution-Normalzation-Activation blocks. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block + kernel_size: (int, optional): Size of the convolving kernel. Default: 3 + stride (int, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` + activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + dilation (int): Spacing between kernel elements. Default: 1 + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + + """ + def __init__( self, in_channels: int, @@ -146,6 +156,17 @@ def __init__( class SqueezeExcitation(torch.nn.Module): + """ + This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). + Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in in eq. 3. + + Args: + input_channels (int): Number of channels in the input image + squeeze_channels (int): Number of squeeze channels + activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU`` + scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid`` + """ + def __init__( self, input_channels: int,