Skip to content

Commit

Permalink
Add Conv2dNormActivation and Conv3dNormActivation Blocks (#5445)
Browse files Browse the repository at this point in the history
* Add ops.conv3d

* Refactor for conv2d and 3d

* Refactor

* Fix bug

* Addres review

* Fix bug

* nit fix

* Fix flake

* Final fix

* remove documentation

* fix linter

* Update all the implementations to use new Conv

* Small doc fix

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Joao Gomes <[email protected]>
  • Loading branch information
3 people authored Feb 25, 2022
1 parent cdcb8b6 commit f15ba56
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 73 deletions.
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ Operators
FeaturePyramidNetwork
StochasticDepth
FrozenBatchNorm2d
Conv2dNormActivation
Conv3dNormActivation
SqueezeExcitation
4 changes: 2 additions & 2 deletions torchvision/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import functional as F

from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ..ops.misc import Conv2dNormActivation
from ..ops.stochastic_depth import StochasticDepth
from ..utils import _log_api_usage_once

Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
# Stem
firstconv_output_channels = block_setting[0].input_channels
layers.append(
ConvNormActivation(
Conv2dNormActivation(
3,
firstconv_output_channels,
kernel_size=4,
Expand Down
10 changes: 5 additions & 5 deletions torchvision/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn, Tensor

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
from ...ops.misc import Conv2dNormActivation
from ...utils import _log_api_usage_once
from .. import mobilenet
from . import _utils as det_utils
Expand All @@ -29,7 +29,7 @@ def _prediction_block(
) -> nn.Sequential:
return nn.Sequential(
# 3x3 depthwise with stride 1 and padding 1
ConvNormActivation(
Conv2dNormActivation(
in_channels,
in_channels,
kernel_size=kernel_size,
Expand All @@ -47,11 +47,11 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[...,
intermediate_channels = out_channels // 2
return nn.Sequential(
# 1x1 projection to half output channels
ConvNormActivation(
Conv2dNormActivation(
in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
),
# 3x3 depthwise with stride 2 and padding 1
ConvNormActivation(
Conv2dNormActivation(
intermediate_channels,
intermediate_channels,
kernel_size=3,
Expand All @@ -61,7 +61,7 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[...,
activation_layer=activation,
),
# 1x1 projetion to output channels
ConvNormActivation(
Conv2dNormActivation(
intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
),
)
Expand Down
12 changes: 6 additions & 6 deletions torchvision/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchvision.ops import StochasticDepth

from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..utils import _log_api_usage_once
from ._utils import _make_divisible

Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels:
layers.append(
ConvNormActivation(
Conv2dNormActivation(
cnf.input_channels,
expanded_channels,
kernel_size=1,
Expand All @@ -115,7 +115,7 @@ def __init__(

# depthwise
layers.append(
ConvNormActivation(
Conv2dNormActivation(
expanded_channels,
expanded_channels,
kernel_size=cnf.kernel,
Expand All @@ -132,7 +132,7 @@ def __init__(

# project
layers.append(
ConvNormActivation(
Conv2dNormActivation(
expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
Expand Down Expand Up @@ -193,7 +193,7 @@ def __init__(
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(
ConvNormActivation(
Conv2dNormActivation(
3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
)
)
Expand Down Expand Up @@ -224,7 +224,7 @@ def __init__(
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 4 * lastconv_input_channels
layers.append(
ConvNormActivation(
Conv2dNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
Expand Down
14 changes: 7 additions & 7 deletions torchvision/models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import nn

from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ..ops.misc import Conv2dNormActivation
from ..utils import _log_api_usage_once
from ._utils import _make_divisible

Expand All @@ -20,11 +20,11 @@


# necessary for backwards compatibility
class _DeprecatedConvBNAct(ConvNormActivation):
class _DeprecatedConvBNAct(Conv2dNormActivation):
def __init__(self, *args, **kwargs):
warnings.warn(
"The ConvBNReLU/ConvBNActivation classes are deprecated since 0.12 and will be removed in 0.14. "
"Use torchvision.ops.misc.ConvNormActivation instead.",
"Use torchvision.ops.misc.Conv2dNormActivation instead.",
FutureWarning,
)
if kwargs.get("norm_layer", None) is None:
Expand Down Expand Up @@ -56,12 +56,12 @@ def __init__(
if expand_ratio != 1:
# pw
layers.append(
ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
)
layers.extend(
[
# dw
ConvNormActivation(
Conv2dNormActivation(
hidden_dim,
hidden_dim,
stride=stride,
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features: List[nn.Module] = [
ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
Expand All @@ -155,7 +155,7 @@ def __init__(
input_channel = output_channel
# building last several layers
features.append(
ConvNormActivation(
Conv2dNormActivation(
input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
)
)
Expand Down
12 changes: 6 additions & 6 deletions torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import nn, Tensor

from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer
from ..utils import _log_api_usage_once
from ._utils import _make_divisible

Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(
ConvNormActivation(
Conv2dNormActivation(
cnf.input_channels,
cnf.expanded_channels,
kernel_size=1,
Expand All @@ -95,7 +95,7 @@ def __init__(
# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(
ConvNormActivation(
Conv2dNormActivation(
cnf.expanded_channels,
cnf.expanded_channels,
kernel_size=cnf.kernel,
Expand All @@ -112,7 +112,7 @@ def __init__(

# project
layers.append(
ConvNormActivation(
Conv2dNormActivation(
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(
ConvNormActivation(
Conv2dNormActivation(
3,
firstconv_output_channels,
kernel_size=3,
Expand All @@ -190,7 +190,7 @@ def __init__(
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
layers.append(
ConvNormActivation(
Conv2dNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
Expand Down
32 changes: 17 additions & 15 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.ops.misc import ConvNormActivation
from torchvision.ops import Conv2dNormActivation

from ..._internally_replaced_utils import load_state_dict_from_url
from ...utils import _log_api_usage_once
Expand Down Expand Up @@ -38,17 +38,17 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1):
# and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful
# for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm
# because these aren't frozen, but we don't bother (also, we woudn't be able to load the original weights).
self.convnormrelu1 = ConvNormActivation(
self.convnormrelu1 = Conv2dNormActivation(
in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True
)
self.convnormrelu2 = ConvNormActivation(
self.convnormrelu2 = Conv2dNormActivation(
out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True
)

if stride == 1:
self.downsample = nn.Identity()
else:
self.downsample = ConvNormActivation(
self.downsample = Conv2dNormActivation(
in_channels,
out_channels,
norm_layer=norm_layer,
Expand Down Expand Up @@ -77,21 +77,21 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1):
super().__init__()

# See note in ResidualBlock for the reason behind bias=True
self.convnormrelu1 = ConvNormActivation(
self.convnormrelu1 = Conv2dNormActivation(
in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True
)
self.convnormrelu2 = ConvNormActivation(
self.convnormrelu2 = Conv2dNormActivation(
out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True
)
self.convnormrelu3 = ConvNormActivation(
self.convnormrelu3 = Conv2dNormActivation(
out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True
)
self.relu = nn.ReLU(inplace=True)

if stride == 1:
self.downsample = nn.Identity()
else:
self.downsample = ConvNormActivation(
self.downsample = Conv2dNormActivation(
in_channels,
out_channels,
norm_layer=norm_layer,
Expand Down Expand Up @@ -124,7 +124,9 @@ def __init__(self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), norm_l
assert len(layers) == 5

# See note in ResidualBlock for the reason behind bias=True
self.convnormrelu = ConvNormActivation(3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True)
self.convnormrelu = Conv2dNormActivation(
3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True
)

self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=1)
self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=2)
Expand Down Expand Up @@ -170,17 +172,17 @@ def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128
assert len(flow_layers) == 2
assert len(corr_layers) in (1, 2)

self.convcorr1 = ConvNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1)
self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1)
if len(corr_layers) == 2:
self.convcorr2 = ConvNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3)
self.convcorr2 = Conv2dNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3)
else:
self.convcorr2 = nn.Identity()

self.convflow1 = ConvNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7)
self.convflow2 = ConvNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3)
self.convflow1 = Conv2dNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7)
self.convflow2 = Conv2dNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3)

# out_channels - 2 because we cat the flow (2 channels) at the end
self.conv = ConvNormActivation(
self.conv = Conv2dNormActivation(
corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3
)

Expand Down Expand Up @@ -301,7 +303,7 @@ class MaskPredictor(nn.Module):

def __init__(self, *, in_channels, hidden_size, multiplier=0.25):
super().__init__()
self.convrelu = ConvNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3)
self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3)
# 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder
# and we interpolate with all 9 surrounding neighbors. See paper and appendix B.
self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
from ...ops.misc import Conv2dNormActivation
from .utils import _fuse_modules, _replace_relu, quantize_model


Expand Down Expand Up @@ -54,7 +54,7 @@ def forward(self, x: Tensor) -> Tensor:

def fuse_model(self, is_qat: Optional[bool] = None) -> None:
for m in self.modules():
if type(m) is ConvNormActivation:
if type(m) is Conv2dNormActivation:
_fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True)
if type(m) is QuantizableInvertedResidual:
m.fuse_model(is_qat)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.ao.quantization import QuantStub, DeQuantStub

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation, SqueezeExcitation
from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf
from .utils import _fuse_modules, _replace_relu

Expand Down Expand Up @@ -103,7 +103,7 @@ def forward(self, x: Tensor) -> Tensor:

def fuse_model(self, is_qat: Optional[bool] = None) -> None:
for m in self.modules():
if type(m) is ConvNormActivation:
if type(m) is Conv2dNormActivation:
modules_to_fuse = ["0", "1"]
if len(m) == 3 and type(m[2]) is nn.ReLU:
modules_to_fuse.append("2")
Expand Down
Loading

0 comments on commit f15ba56

Please sign in to comment.