diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 4e05cf79c32..a50e202e00b 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -45,4 +45,6 @@ Operators FeaturePyramidNetwork StochasticDepth FrozenBatchNorm2d + Conv2dNormActivation + Conv3dNormActivation SqueezeExcitation diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 9067b6876fd..3a0dcdb31cd 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -6,7 +6,7 @@ from torch.nn import functional as F from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation +from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth from ..utils import _log_api_usage_once @@ -127,7 +127,7 @@ def __init__( # Stem firstconv_output_channels = block_setting[0].input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=4, diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 47ecf59f1e2..1ee59e069ea 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops.misc import ConvNormActivation +from ...ops.misc import Conv2dNormActivation from ...utils import _log_api_usage_once from .. import mobilenet from . import _utils as det_utils @@ -29,7 +29,7 @@ def _prediction_block( ) -> nn.Sequential: return nn.Sequential( # 3x3 depthwise with stride 1 and padding 1 - ConvNormActivation( + Conv2dNormActivation( in_channels, in_channels, kernel_size=kernel_size, @@ -47,11 +47,11 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., intermediate_channels = out_channels // 2 return nn.Sequential( # 1x1 projection to half output channels - ConvNormActivation( + Conv2dNormActivation( in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation ), # 3x3 depthwise with stride 2 and padding 1 - ConvNormActivation( + Conv2dNormActivation( intermediate_channels, intermediate_channels, kernel_size=3, @@ -61,7 +61,7 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., activation_layer=activation, ), # 1x1 projetion to output channels - ConvNormActivation( + Conv2dNormActivation( intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation ), ) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index f7eba46cb39..f245d00cffe 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -8,7 +8,7 @@ from torchvision.ops import StochasticDepth from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation, SqueezeExcitation +from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -104,7 +104,7 @@ def __init__( expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) if expanded_channels != cnf.input_channels: layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.input_channels, expanded_channels, kernel_size=1, @@ -115,7 +115,7 @@ def __init__( # depthwise layers.append( - ConvNormActivation( + Conv2dNormActivation( expanded_channels, expanded_channels, kernel_size=cnf.kernel, @@ -132,7 +132,7 @@ def __init__( # project layers.append( - ConvNormActivation( + Conv2dNormActivation( expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None ) ) @@ -193,7 +193,7 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU ) ) @@ -224,7 +224,7 @@ def __init__( lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 4 * lastconv_input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( lastconv_input_channels, lastconv_output_channels, kernel_size=1, diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index e24c5962d7e..930f68d13e9 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -6,7 +6,7 @@ from torch import nn from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation +from ..ops.misc import Conv2dNormActivation from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -20,11 +20,11 @@ # necessary for backwards compatibility -class _DeprecatedConvBNAct(ConvNormActivation): +class _DeprecatedConvBNAct(Conv2dNormActivation): def __init__(self, *args, **kwargs): warnings.warn( "The ConvBNReLU/ConvBNActivation classes are deprecated since 0.12 and will be removed in 0.14. " - "Use torchvision.ops.misc.ConvNormActivation instead.", + "Use torchvision.ops.misc.Conv2dNormActivation instead.", FutureWarning, ) if kwargs.get("norm_layer", None) is None: @@ -56,12 +56,12 @@ def __init__( if expand_ratio != 1: # pw layers.append( - ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) + Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) ) layers.extend( [ # dw - ConvNormActivation( + Conv2dNormActivation( hidden_dim, hidden_dim, stride=stride, @@ -144,7 +144,7 @@ def __init__( input_channel = _make_divisible(input_channel * width_mult, round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) features: List[nn.Module] = [ - ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) + Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) ] # building inverted residual blocks for t, c, n, s in inverted_residual_setting: @@ -155,7 +155,7 @@ def __init__( input_channel = output_channel # building last several layers features.append( - ConvNormActivation( + Conv2dNormActivation( input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6 ) ) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 711888b7c8b..530467d6d53 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -6,7 +6,7 @@ from torch import nn, Tensor from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer +from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -83,7 +83,7 @@ def __init__( # expand if cnf.expanded_channels != cnf.input_channels: layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.input_channels, cnf.expanded_channels, kernel_size=1, @@ -95,7 +95,7 @@ def __init__( # depthwise stride = 1 if cnf.dilation > 1 else cnf.stride layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, @@ -112,7 +112,7 @@ def __init__( # project layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None ) ) @@ -172,7 +172,7 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=3, @@ -190,7 +190,7 @@ def __init__( lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( lastconv_input_channels, lastconv_output_channels, kernel_size=1, diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 83645cf38df..4dfd232d499 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -6,7 +6,7 @@ from torch import Tensor from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn.modules.instancenorm import InstanceNorm2d -from torchvision.ops.misc import ConvNormActivation +from torchvision.ops import Conv2dNormActivation from ..._internally_replaced_utils import load_state_dict_from_url from ...utils import _log_api_usage_once @@ -38,17 +38,17 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): # and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful # for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm # because these aren't frozen, but we don't bother (also, we woudn't be able to load the original weights). - self.convnormrelu1 = ConvNormActivation( + self.convnormrelu1 = Conv2dNormActivation( in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True ) - self.convnormrelu2 = ConvNormActivation( + self.convnormrelu2 = Conv2dNormActivation( out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True ) if stride == 1: self.downsample = nn.Identity() else: - self.downsample = ConvNormActivation( + self.downsample = Conv2dNormActivation( in_channels, out_channels, norm_layer=norm_layer, @@ -77,13 +77,13 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): super().__init__() # See note in ResidualBlock for the reason behind bias=True - self.convnormrelu1 = ConvNormActivation( + self.convnormrelu1 = Conv2dNormActivation( in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True ) - self.convnormrelu2 = ConvNormActivation( + self.convnormrelu2 = Conv2dNormActivation( out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True ) - self.convnormrelu3 = ConvNormActivation( + self.convnormrelu3 = Conv2dNormActivation( out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True ) self.relu = nn.ReLU(inplace=True) @@ -91,7 +91,7 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): if stride == 1: self.downsample = nn.Identity() else: - self.downsample = ConvNormActivation( + self.downsample = Conv2dNormActivation( in_channels, out_channels, norm_layer=norm_layer, @@ -124,7 +124,9 @@ def __init__(self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), norm_l assert len(layers) == 5 # See note in ResidualBlock for the reason behind bias=True - self.convnormrelu = ConvNormActivation(3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True) + self.convnormrelu = Conv2dNormActivation( + 3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True + ) self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=1) self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=2) @@ -170,17 +172,17 @@ def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128 assert len(flow_layers) == 2 assert len(corr_layers) in (1, 2) - self.convcorr1 = ConvNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) + self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) if len(corr_layers) == 2: - self.convcorr2 = ConvNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3) + self.convcorr2 = Conv2dNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3) else: self.convcorr2 = nn.Identity() - self.convflow1 = ConvNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) - self.convflow2 = ConvNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) + self.convflow1 = Conv2dNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) + self.convflow2 = Conv2dNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) # out_channels - 2 because we cat the flow (2 channels) at the end - self.conv = ConvNormActivation( + self.conv = Conv2dNormActivation( corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3 ) @@ -301,7 +303,7 @@ class MaskPredictor(nn.Module): def __init__(self, *, in_channels, hidden_size, multiplier=0.25): super().__init__() - self.convrelu = ConvNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) + self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) # 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder # and we interpolate with all 9 surrounding neighbors. See paper and appendix B. self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0) diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index b1e38f2cbbb..8cd9f16d13e 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -6,7 +6,7 @@ from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops.misc import ConvNormActivation +from ...ops.misc import Conv2dNormActivation from .utils import _fuse_modules, _replace_relu, quantize_model @@ -54,7 +54,7 @@ def forward(self, x: Tensor) -> Tensor: def fuse_model(self, is_qat: Optional[bool] = None) -> None: for m in self.modules(): - if type(m) is ConvNormActivation: + if type(m) is Conv2dNormActivation: _fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True) if type(m) is QuantizableInvertedResidual: m.fuse_model(is_qat) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 2f58cd96ace..4d7e2f7baad 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -5,7 +5,7 @@ from torch.ao.quantization import QuantStub, DeQuantStub from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops.misc import ConvNormActivation, SqueezeExcitation +from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf from .utils import _fuse_modules, _replace_relu @@ -103,7 +103,7 @@ def forward(self, x: Tensor) -> Tensor: def fuse_model(self, is_qat: Optional[bool] = None) -> None: for m in self.modules(): - if type(m) is ConvNormActivation: + if type(m) is Conv2dNormActivation: modules_to_fuse = ["0", "1"] if len(m) == 3 and type(m[2]) is nn.ReLU: modules_to_fuse.append("2") diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 3f393c8e82d..74abd20b237 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -12,7 +12,7 @@ from torch import nn, Tensor from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation, SqueezeExcitation +from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -55,7 +55,7 @@ } -class SimpleStemIN(ConvNormActivation): +class SimpleStemIN(Conv2dNormActivation): """Simple stem for ImageNet: 3x3, BN, ReLU.""" def __init__( @@ -88,10 +88,10 @@ def __init__( w_b = int(round(width_out * bottleneck_multiplier)) g = w_b // group_width - layers["a"] = ConvNormActivation( + layers["a"] = Conv2dNormActivation( width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer ) - layers["b"] = ConvNormActivation( + layers["b"] = Conv2dNormActivation( w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer ) @@ -105,7 +105,7 @@ def __init__( activation=activation_layer, ) - layers["c"] = ConvNormActivation( + layers["c"] = Conv2dNormActivation( w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None ) super().__init__(layers) @@ -131,7 +131,7 @@ def __init__( self.proj = None should_proj = (width_in != width_out) or (stride != 1) if should_proj: - self.proj = ConvNormActivation( + self.proj = Conv2dNormActivation( width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None ) self.f = BottleneckTransform( diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index b36658e34d8..29f756ccbe5 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -7,7 +7,7 @@ import torch.nn as nn from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation +from ..ops.misc import Conv2dNormActivation from ..utils import _log_api_usage_once __all__ = [ @@ -163,7 +163,7 @@ def __init__( for i, conv_stem_layer_config in enumerate(conv_stem_configs): seq_proj.add_module( f"conv_bn_relu_{i}", - ConvNormActivation( + Conv2dNormActivation( in_channels=prev_channels, out_channels=conv_stem_layer_config.out_channels, kernel_size=conv_stem_layer_config.kernel_size, diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index c0e3f7f238e..9da336764d3 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -14,7 +14,7 @@ from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss -from .misc import FrozenBatchNorm2d, SqueezeExcitation +from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_pool import ps_roi_pool, PSRoIPool @@ -51,6 +51,8 @@ "stochastic_depth", "StochasticDepth", "FrozenBatchNorm2d", + "Conv2dNormActivation", + "Conv3dNormActivation", "SqueezeExcitation", "generalized_box_iou_loss", ] diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 268962e204c..c7c52a86ff1 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, List, Optional import torch @@ -66,24 +67,6 @@ def __repr__(self) -> str: class ConvNormActivation(torch.nn.Sequential): - """ - Configurable block used for Convolution-Normalzation-Activation blocks. - - Args: - in_channels (int): Number of channels in the input image - out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block - kernel_size: (int, optional): Size of the convolving kernel. Default: 3 - stride (int, optional): Stride of the convolution. Default: 1 - padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` - groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 - norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolutiuon layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` - activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` - dilation (int): Spacing between kernel elements. Default: 1 - inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` - bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. - - """ - def __init__( self, in_channels: int, @@ -97,13 +80,16 @@ def __init__( dilation: int = 1, inplace: Optional[bool] = True, bias: Optional[bool] = None, + conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d, ) -> None: + if padding is None: padding = (kernel_size - 1) // 2 * dilation if bias is None: bias = norm_layer is None + layers = [ - torch.nn.Conv2d( + conv_layer( in_channels, out_channels, kernel_size, @@ -114,8 +100,10 @@ def __init__( bias=bias, ) ] + if norm_layer is not None: layers.append(norm_layer(out_channels)) + if activation_layer is not None: params = {} if inplace is None else {"inplace": inplace} layers.append(activation_layer(**params)) @@ -123,6 +111,110 @@ def __init__( _log_api_usage_once(self) self.out_channels = out_channels + if self.__class__ == ConvNormActivation: + warnings.warn( + "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead." + ) + + +class Conv2dNormActivation(ConvNormActivation): + """ + Configurable block used for Convolution2d-Normalzation-Activation blocks. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block + kernel_size: (int, optional): Size of the convolving kernel. Default: 3 + stride (int, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d`` + activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + dilation (int): Spacing between kernel elements. Default: 1 + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + ) -> None: + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + norm_layer, + activation_layer, + dilation, + inplace, + bias, + torch.nn.Conv2d, + ) + + +class Conv3dNormActivation(ConvNormActivation): + """ + Configurable block used for Convolution3d-Normalzation-Activation blocks. + + Args: + in_channels (int): Number of channels in the input video. + out_channels (int): Number of channels produced by the Convolution-Normalzation-Activation block + kernel_size: (int, optional): Size of the convolving kernel. Default: 3 + stride (int, optional): Stride of the convolution. Default: 1 + padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in wich case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation`` + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 + norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d`` + activation_layer (Callable[..., torch.nn.Module], optinal): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU`` + dilation (int): Spacing between kernel elements. Default: 1 + inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` + bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + ) -> None: + + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + norm_layer, + activation_layer, + dilation, + inplace, + bias, + torch.nn.Conv3d, + ) + class SqueezeExcitation(torch.nn.Module): """