Skip to content

Commit

Permalink
Merge ec1a7cb into fb39e1e
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku authored Dec 9, 2022
2 parents fb39e1e + ec1a7cb commit 51ffb60
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/en/api/cnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Module
GeneralizedAttention
HSigmoid
HSwish
LayerScale
Linear
MaxPool2d
MaxPool3d
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/api/cnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Module
GeneralizedAttention
HSigmoid
HSwish
LayerScale
Linear
MaxPool2d
MaxPool3d
Expand Down
4 changes: 2 additions & 2 deletions mmcv/cnn/bricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .norm import build_norm_layer, is_norm
from .padding import build_padding_layer
from .plugin import build_plugin_layer
from .scale import Scale
from .scale import LayerScale, Scale
from .swish import Swish
from .upsample import build_upsample_layer
from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Expand All @@ -28,5 +28,5 @@
'Scale', 'ConvAWS2d', 'ConvWS2d', 'conv_ws_2d',
'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'Conv2dAdaptivePadding',
'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d',
'Conv3d', 'Dropout', 'DropPath'
'Conv3d', 'Dropout', 'DropPath', 'LayerScale'
]
36 changes: 36 additions & 0 deletions mmcv/cnn/bricks/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,39 @@ def __init__(self, scale: float = 1.0):

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale


class LayerScale(nn.Module):
"""LayerScale layer.
Args:
dim (int): Dimension of input features.
inplace (bool): Whether performs operation in-place.
Default: `False`.
data_format (str): The input data format, could be 'channels_last'
or 'channels_first', representing (B, C, H, W) and
(B, N, C) format data respectively. Default: 'channels_last'.
scale (float): Initial value of scale factor. Default: 1.0
"""

def __init__(self,
dim: int,
inplace: bool = False,
data_format: str = 'channels_last',
scale: float = 1e-5):
super().__init__()
assert data_format in ('channels_last', 'channels_first'), \
"'data_format' could only be channels_last or channels_first."
self.inplace = inplace
self.data_format = data_format
self.weight = nn.Parameter(torch.ones(dim) * scale)

def forward(self, x) -> torch.Tensor:
if self.data_format == 'channels_first':
shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2))))
else:
shape = tuple((*(1 for _ in range(x.dim() - 1)), -1))
if self.inplace:
return x.mul_(self.weight.view(*shape))
else:
return x * self.weight.view(*shape)
12 changes: 11 additions & 1 deletion mmcv/cnn/bricks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
build_norm_layer)
from .drop import build_dropout
from .scale import LayerScale

# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try:
Expand Down Expand Up @@ -572,6 +573,8 @@ class FFN(BaseModule):
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
layer_scale_init_value (float): Initial value of scale factor in
LayerScale. Default: 1.0
"""

@deprecated_api_warning(
Expand All @@ -588,7 +591,8 @@ def __init__(self,
ffn_drop=0.,
dropout_layer=None,
add_identity=True,
init_cfg=None):
init_cfg=None,
layer_scale_init_value=0.):
super().__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
Expand All @@ -611,13 +615,19 @@ def __init__(self,
dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity

if layer_scale_init_value > 0:
self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value)
else:
self.gamma2 = nn.Identity()

@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
def forward(self, x, identity=None):
"""Forward function for `FFN`.
The function would add x to the output tensor if residue is None.
"""
out = self.layers(x)
out = self.gamma2(out)
if not self.add_identity:
return self.dropout_layer(out)
if identity is None:
Expand Down
58 changes: 57 additions & 1 deletion tests/test_cnn/test_scale.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmcv.cnn.bricks import Scale
from mmcv.cnn.bricks import LayerScale, Scale


def test_scale():
Expand All @@ -20,3 +21,58 @@ def test_scale():
x = torch.rand(1, 3, 64, 64)
output = scale(x)
assert output.shape == (1, 3, 64, 64)


def test_layer_scale():
with pytest.raises(AssertionError):
cfg = dict(
dim=10,
data_format='BNC',
)
LayerScale(**cfg)

# test init
cfg = dict(dim=10)
ls = LayerScale(**cfg)
assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5)

# test forward
# test channels_last
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 49, 256))
out = ls_channels_last(x)
assert tuple(out.size()) == (4, 49, 256)
assert torch.equal(x * 1e-5, out)

# test channels_last 2d
cfg = dict(dim=256, inplace=False, data_format='channels_last')
ls_channels_last = LayerScale(**cfg)
x = torch.randn((4, 7, 49, 256))
out = ls_channels_last(x)
assert tuple(out.size()) == (4, 7, 49, 256)
assert torch.equal(x * 1e-5, out)

# test channels_first
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert torch.equal(x * 1e-5, out)

# test channels_first 3D
cfg = dict(dim=256, inplace=False, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7, 7)
assert torch.equal(x * 1e-5, out)

# test inplace True
cfg = dict(dim=256, inplace=True, data_format='channels_first')
ls_channels_first = LayerScale(**cfg)
x = torch.randn((4, 256, 7, 7))
out = ls_channels_first(x)
assert tuple(out.size()) == (4, 256, 7, 7)
assert x is out
8 changes: 7 additions & 1 deletion tests/test_cnn/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,6 @@ def test_ffn():
with pytest.raises(AssertionError):
# num_fcs should be no less than 2
FFN(num_fcs=1)
FFN(dropout=0, add_residual=True)
ffn = FFN(dropout=0, add_identity=True)

input_tensor = torch.rand(2, 20, 256)
Expand All @@ -553,6 +552,13 @@ def test_ffn():
ffn(input_tensor, identity=residual).sum(),
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())

# test with layer_scale
ffn = FFN(dropout=0, add_identity=True, layer_scale_init_value=0.1)

input_tensor = torch.rand(2, 20, 256)
input_tensor_nbc = input_tensor.transpose(0, 1)
assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum())


@pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available')
def test_basetransformerlayer_cuda():
Expand Down

0 comments on commit 51ffb60

Please sign in to comment.