From 5631fab687317df8fa67a7e0a8b4afa81cec643a Mon Sep 17 00:00:00 2001 From: Rist115 Date: Wed, 23 Nov 2022 20:29:58 +0900 Subject: [PATCH 01/14] add layer scale --- mmcv/cnn/bricks/__init__.py | 3 ++- mmcv/cnn/bricks/transformer.py | 41 ++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/mmcv/cnn/bricks/__init__.py b/mmcv/cnn/bricks/__init__.py index 7a22a0e46c5..bb977fc55b5 100644 --- a/mmcv/cnn/bricks/__init__.py +++ b/mmcv/cnn/bricks/__init__.py @@ -16,6 +16,7 @@ from .plugin import build_plugin_layer from .scale import Scale from .swish import Swish +from .transformer import LayerScale from .upsample import build_upsample_layer from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, Linear, MaxPool2d, MaxPool3d) @@ -28,5 +29,5 @@ 'Scale', 'ConvAWS2d', 'ConvWS2d', 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', - 'Conv3d', 'Dropout', 'DropPath' + 'Conv3d', 'Dropout', 'DropPath', 'LayerScale' ] diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index 84956ef508c..234e02010fb 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -551,6 +551,38 @@ def forward(self, return identity + self.dropout_layer(self.proj_drop(out)) +class LayerScale(nn.Module): + """LayerScale layer. + + Args: + dim (int): Dimension of input features. + inplace (bool): inplace: can optionally do the + operation in-place. Default: `False`. + data_format (str): The input data format, could be 'channels_last' + or 'channels_first', representing (B, C, H, W) and + (B, N, C) format data respectively. Default: 'channels_last'. + """ + + def __init__(self, + dim: int, + inplace: bool = False, + data_format: str = 'channels_last'): + super().__init__() + assert data_format in ('channels_last', 'channels_first'), \ + "'data_format' could only be channels_last or channels_first." + self.inplace = inplace + self.data_format = data_format + self.weight = nn.Parameter(torch.ones(dim) * 1e-5) + + def forward(self, x): + if self.data_format == 'channels_first': + if self.inplace: + return x.mul_(self.weight.view(-1, 1, 1)) + else: + return x * self.weight.view(-1, 1, 1) + return x.mul_(self.weight) if self.inplace else x * self.weight + + @MODELS.register_module() class FFN(BaseModule): """Implements feed-forward networks (FFNs) with identity connection. @@ -568,6 +600,8 @@ class FFN(BaseModule): zeroed in FFN. Default 0.0. add_identity (bool, optional): Whether to add the identity connection. Default: `True`. + use_layer_scale (bool): Whether to use layer_scale in FFN. + Default: `True`. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. @@ -588,6 +622,7 @@ def __init__(self, ffn_drop=0., dropout_layer=None, add_identity=True, + use_layer_scale=True, init_cfg=None, **kwargs): super().__init__(init_cfg) @@ -614,6 +649,11 @@ def __init__(self, dropout_layer) if dropout_layer else torch.nn.Identity() self.add_identity = add_identity + if use_layer_scale: + self.gamma2 = LayerScale(embed_dims) + else: + self.gamma2 = nn.Identity() + @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN') def forward(self, x, identity=None): """Forward function for `FFN`. @@ -621,6 +661,7 @@ def forward(self, x, identity=None): The function would add x to the output tensor if residue is None. """ out = self.layers(x) + out = self.gamma2(out) if not self.add_identity: return self.dropout_layer(out) if identity is None: From 12e1055d847f9d5286979ba2f2f51ce243d383c1 Mon Sep 17 00:00:00 2001 From: Rist115 Date: Wed, 23 Nov 2022 21:39:27 +0900 Subject: [PATCH 02/14] add layer scale --- mmcv/cnn/bricks/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mmcv/cnn/bricks/__init__.py b/mmcv/cnn/bricks/__init__.py index bb977fc55b5..6cf39300c9f 100644 --- a/mmcv/cnn/bricks/__init__.py +++ b/mmcv/cnn/bricks/__init__.py @@ -16,7 +16,6 @@ from .plugin import build_plugin_layer from .scale import Scale from .swish import Swish -from .transformer import LayerScale from .upsample import build_upsample_layer from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, Linear, MaxPool2d, MaxPool3d) From fbf9b6e07b2f30b84a1ea97e884581377226349e Mon Sep 17 00:00:00 2001 From: Rist115 Date: Thu, 24 Nov 2022 09:01:26 +0900 Subject: [PATCH 03/14] add layer scale --- tests/test_cnn/test_transformer.py | 49 ++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/tests/test_cnn/test_transformer.py b/tests/test_cnn/test_transformer.py index c342aa040db..3bf3f6092d4 100644 --- a/tests/test_cnn/test_transformer.py +++ b/tests/test_cnn/test_transformer.py @@ -7,7 +7,7 @@ from mmcv.cnn.bricks.drop import DropPath from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding, - BaseTransformerLayer, + BaseTransformerLayer, LayerScale, MultiheadAttention, PatchEmbed, PatchMerging, TransformerLayerSequence) @@ -538,7 +538,6 @@ def test_ffn(): with pytest.raises(AssertionError): # num_fcs should be no less than 2 FFN(num_fcs=1) - FFN(dropout=0, add_residual=True) ffn = FFN(dropout=0, add_identity=True) input_tensor = torch.rand(2, 20, 256) @@ -553,6 +552,52 @@ def test_ffn(): ffn(input_tensor, identity=residual).sum(), ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) + # test with layer_scale + ffn = FFN(dropout=0, add_identity=True, use_layer_scale=True) + + input_tensor = torch.rand(2, 20, 256) + input_tensor_nbc = input_tensor.transpose(0, 1) + assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum()) + + +def test_layer_scale(): + with pytest.raises(AssertionError): + cfg = dict( + dim=10, + data_format='BNC', + ) + LayerScale(**cfg) + + # test init + cfg = dict(dim=10) + ls = LayerScale(**cfg) + assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5) + + # test forward + # test channels_last + cfg = dict(dim=256, inplace=False, data_format='channels_last') + ls_channels_last = LayerScale(**cfg) + x = torch.randn((4, 49, 256)) + out = ls_channels_last(x) + assert tuple(out.size()) == (4, 49, 256) + assert torch.equal(x * 1e-5, out) + + # test channels_first + cfg = dict(dim=256, inplace=False, data_format='channels_first') + ls_channels_first = LayerScale(**cfg) + x = torch.randn((4, 256, 7, 7)) + out = ls_channels_first(x) + assert tuple(out.size()) == (4, 256, 7, 7) + assert torch.equal(x * 1e-5, out) + + # test inplace True + cfg = dict(dim=256, inplace=True, data_format='channels_first') + ls_channels_first = LayerScale(**cfg) + x = torch.randn((4, 256, 7, 7)) + out = ls_channels_first(x) + assert tuple(out.size()) == (4, 256, 7, 7) + assert x is out + @pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available') def test_basetransformerlayer_cuda(): From 7465c80b7c978f3b904a6ee2087cc86b70fc553e Mon Sep 17 00:00:00 2001 From: takuoko Date: Fri, 25 Nov 2022 15:25:31 +0900 Subject: [PATCH 04/14] Update mmcv/cnn/bricks/transformer.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmcv/cnn/bricks/transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index 234e02010fb..73eb1e7123a 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -556,11 +556,11 @@ class LayerScale(nn.Module): Args: dim (int): Dimension of input features. - inplace (bool): inplace: can optionally do the - operation in-place. Default: `False`. + inplace (bool): Whether performs operation in-place. + Default: `False`. data_format (str): The input data format, could be 'channels_last' - or 'channels_first', representing (B, C, H, W) and - (B, N, C) format data respectively. Default: 'channels_last'. + or 'channels_first', representing (B, C, H, W) and + (B, N, C) format data respectively. Default: 'channels_last'. """ def __init__(self, From 29804dc25b3afc36a793d52e2bf49809c58e2ec9 Mon Sep 17 00:00:00 2001 From: takuoko Date: Fri, 25 Nov 2022 16:12:34 +0900 Subject: [PATCH 05/14] Update mmcv/cnn/bricks/transformer.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmcv/cnn/bricks/transformer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index 73eb1e7123a..29fcd461706 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -571,16 +571,16 @@ def __init__(self, assert data_format in ('channels_last', 'channels_first'), \ "'data_format' could only be channels_last or channels_first." self.inplace = inplace - self.data_format = data_format - self.weight = nn.Parameter(torch.ones(dim) * 1e-5) + if data_format == 'channels_first': + self.weight = nn.Parameter(torch.ones(dim, 1, 1) * 1e-5) + else: + self.weight = nn.Parameter(torch.ones(dim) * 1e-5) def forward(self, x): - if self.data_format == 'channels_first': - if self.inplace: - return x.mul_(self.weight.view(-1, 1, 1)) - else: - return x * self.weight.view(-1, 1, 1) - return x.mul_(self.weight) if self.inplace else x * self.weight + if self.inplace: + return x.mul_(self.weight) + else: + return x * self.weight @MODELS.register_module() From 995f6a8df17c85390c5e76fb7322d6cb1e6d0642 Mon Sep 17 00:00:00 2001 From: Rist115 Date: Fri, 25 Nov 2022 16:13:46 +0900 Subject: [PATCH 06/14] add layer scale --- mmcv/cnn/bricks/transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index 29fcd461706..6b93e76512c 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -600,12 +600,12 @@ class FFN(BaseModule): zeroed in FFN. Default 0.0. add_identity (bool, optional): Whether to add the identity connection. Default: `True`. - use_layer_scale (bool): Whether to use layer_scale in FFN. - Default: `True`. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. + use_layer_scale (bool): Whether to use layer_scale in FFN. + Default: `True`. """ @deprecated_api_warning( @@ -622,8 +622,8 @@ def __init__(self, ffn_drop=0., dropout_layer=None, add_identity=True, - use_layer_scale=True, init_cfg=None, + use_layer_scale=True, **kwargs): super().__init__(init_cfg) assert num_fcs >= 2, 'num_fcs should be no less ' \ From 9cc7b9d3f1ee3cc2eb72728d916627f9c8032eb4 Mon Sep 17 00:00:00 2001 From: Rist115 Date: Fri, 25 Nov 2022 18:07:22 +0900 Subject: [PATCH 07/14] move LayerScale --- mmcv/cnn/bricks/__init__.py | 2 +- mmcv/cnn/bricks/scale.py | 32 +++++++++++++++++++++++ mmcv/cnn/bricks/transformer.py | 33 +---------------------- tests/test_cnn/test_scale.py | 42 +++++++++++++++++++++++++++++- tests/test_cnn/test_transformer.py | 41 +---------------------------- 5 files changed, 76 insertions(+), 74 deletions(-) diff --git a/mmcv/cnn/bricks/__init__.py b/mmcv/cnn/bricks/__init__.py index 6cf39300c9f..6c74986953b 100644 --- a/mmcv/cnn/bricks/__init__.py +++ b/mmcv/cnn/bricks/__init__.py @@ -14,7 +14,7 @@ from .norm import build_norm_layer, is_norm from .padding import build_padding_layer from .plugin import build_plugin_layer -from .scale import Scale +from .scale import LayerScale, Scale from .swish import Swish from .upsample import build_upsample_layer from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d, diff --git a/mmcv/cnn/bricks/scale.py b/mmcv/cnn/bricks/scale.py index dbd07c6a445..996c4d7231d 100644 --- a/mmcv/cnn/bricks/scale.py +++ b/mmcv/cnn/bricks/scale.py @@ -19,3 +19,35 @@ def __init__(self, scale: float = 1.0): def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.scale + + +class LayerScale(nn.Module): + """LayerScale layer. + + Args: + dim (int): Dimension of input features. + inplace (bool): Whether performs operation in-place. + Default: `False`. + data_format (str): The input data format, could be 'channels_last' + or 'channels_first', representing (B, C, H, W) and + (B, N, C) format data respectively. Default: 'channels_last'. + """ + + def __init__(self, + dim: int, + inplace: bool = False, + data_format: str = 'channels_last'): + super().__init__() + assert data_format in ('channels_last', 'channels_first'), \ + "'data_format' could only be channels_last or channels_first." + self.inplace = inplace + if data_format == 'channels_first': + self.weight = nn.Parameter(torch.ones(dim, 1, 1) * 1e-5) + else: + self.weight = nn.Parameter(torch.ones(dim) * 1e-5) + + def forward(self, x): + if self.inplace: + return x.mul_(self.weight) + else: + return x * self.weight diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index 6b93e76512c..fd28aee7bd0 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -15,6 +15,7 @@ from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer, build_norm_layer) from .drop import build_dropout +from .scale import LayerScale # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file try: @@ -551,38 +552,6 @@ def forward(self, return identity + self.dropout_layer(self.proj_drop(out)) -class LayerScale(nn.Module): - """LayerScale layer. - - Args: - dim (int): Dimension of input features. - inplace (bool): Whether performs operation in-place. - Default: `False`. - data_format (str): The input data format, could be 'channels_last' - or 'channels_first', representing (B, C, H, W) and - (B, N, C) format data respectively. Default: 'channels_last'. - """ - - def __init__(self, - dim: int, - inplace: bool = False, - data_format: str = 'channels_last'): - super().__init__() - assert data_format in ('channels_last', 'channels_first'), \ - "'data_format' could only be channels_last or channels_first." - self.inplace = inplace - if data_format == 'channels_first': - self.weight = nn.Parameter(torch.ones(dim, 1, 1) * 1e-5) - else: - self.weight = nn.Parameter(torch.ones(dim) * 1e-5) - - def forward(self, x): - if self.inplace: - return x.mul_(self.weight) - else: - return x * self.weight - - @MODELS.register_module() class FFN(BaseModule): """Implements feed-forward networks (FFNs) with identity connection. diff --git a/tests/test_cnn/test_scale.py b/tests/test_cnn/test_scale.py index bee78eb57f2..1a96e263d11 100644 --- a/tests/test_cnn/test_scale.py +++ b/tests/test_cnn/test_scale.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest import torch -from mmcv.cnn.bricks import Scale +from mmcv.cnn.bricks import LayerScale, Scale def test_scale(): @@ -20,3 +21,42 @@ def test_scale(): x = torch.rand(1, 3, 64, 64) output = scale(x) assert output.shape == (1, 3, 64, 64) + + +def test_layer_scale(): + with pytest.raises(AssertionError): + cfg = dict( + dim=10, + data_format='BNC', + ) + LayerScale(**cfg) + + # test init + cfg = dict(dim=10) + ls = LayerScale(**cfg) + assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5) + + # test forward + # test channels_last + cfg = dict(dim=256, inplace=False, data_format='channels_last') + ls_channels_last = LayerScale(**cfg) + x = torch.randn((4, 49, 256)) + out = ls_channels_last(x) + assert tuple(out.size()) == (4, 49, 256) + assert torch.equal(x * 1e-5, out) + + # test channels_first + cfg = dict(dim=256, inplace=False, data_format='channels_first') + ls_channels_first = LayerScale(**cfg) + x = torch.randn((4, 256, 7, 7)) + out = ls_channels_first(x) + assert tuple(out.size()) == (4, 256, 7, 7) + assert torch.equal(x * 1e-5, out) + + # test inplace True + cfg = dict(dim=256, inplace=True, data_format='channels_first') + ls_channels_first = LayerScale(**cfg) + x = torch.randn((4, 256, 7, 7)) + out = ls_channels_first(x) + assert tuple(out.size()) == (4, 256, 7, 7) + assert x is out diff --git a/tests/test_cnn/test_transformer.py b/tests/test_cnn/test_transformer.py index 3bf3f6092d4..019c9fe1ee2 100644 --- a/tests/test_cnn/test_transformer.py +++ b/tests/test_cnn/test_transformer.py @@ -7,7 +7,7 @@ from mmcv.cnn.bricks.drop import DropPath from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding, - BaseTransformerLayer, LayerScale, + BaseTransformerLayer, MultiheadAttention, PatchEmbed, PatchMerging, TransformerLayerSequence) @@ -560,45 +560,6 @@ def test_ffn(): assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum()) -def test_layer_scale(): - with pytest.raises(AssertionError): - cfg = dict( - dim=10, - data_format='BNC', - ) - LayerScale(**cfg) - - # test init - cfg = dict(dim=10) - ls = LayerScale(**cfg) - assert torch.equal(ls.weight, torch.ones(10, requires_grad=True) * 1e-5) - - # test forward - # test channels_last - cfg = dict(dim=256, inplace=False, data_format='channels_last') - ls_channels_last = LayerScale(**cfg) - x = torch.randn((4, 49, 256)) - out = ls_channels_last(x) - assert tuple(out.size()) == (4, 49, 256) - assert torch.equal(x * 1e-5, out) - - # test channels_first - cfg = dict(dim=256, inplace=False, data_format='channels_first') - ls_channels_first = LayerScale(**cfg) - x = torch.randn((4, 256, 7, 7)) - out = ls_channels_first(x) - assert tuple(out.size()) == (4, 256, 7, 7) - assert torch.equal(x * 1e-5, out) - - # test inplace True - cfg = dict(dim=256, inplace=True, data_format='channels_first') - ls_channels_first = LayerScale(**cfg) - x = torch.randn((4, 256, 7, 7)) - out = ls_channels_first(x) - assert tuple(out.size()) == (4, 256, 7, 7) - assert x is out - - @pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available') def test_basetransformerlayer_cuda(): # To test if the BaseTransformerLayer's behaviour remains From d20aa0fcdbf440fc9af76b57b946dd431f13ad80 Mon Sep 17 00:00:00 2001 From: Rist115 Date: Sun, 27 Nov 2022 09:51:56 +0900 Subject: [PATCH 08/14] add layer_scale_init_value --- mmcv/cnn/bricks/scale.py | 8 +++++--- mmcv/cnn/bricks/transformer.py | 10 +++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/mmcv/cnn/bricks/scale.py b/mmcv/cnn/bricks/scale.py index 996c4d7231d..49d27fc85bc 100644 --- a/mmcv/cnn/bricks/scale.py +++ b/mmcv/cnn/bricks/scale.py @@ -31,20 +31,22 @@ class LayerScale(nn.Module): data_format (str): The input data format, could be 'channels_last' or 'channels_first', representing (B, C, H, W) and (B, N, C) format data respectively. Default: 'channels_last'. + scale (float): Initial value of scale factor. Default: 1.0 """ def __init__(self, dim: int, inplace: bool = False, - data_format: str = 'channels_last'): + data_format: str = 'channels_last', + scale: float = 1e-5): super().__init__() assert data_format in ('channels_last', 'channels_first'), \ "'data_format' could only be channels_last or channels_first." self.inplace = inplace if data_format == 'channels_first': - self.weight = nn.Parameter(torch.ones(dim, 1, 1) * 1e-5) + self.weight = nn.Parameter(torch.ones(dim, 1, 1) * scale) else: - self.weight = nn.Parameter(torch.ones(dim) * 1e-5) + self.weight = nn.Parameter(torch.ones(dim) * scale) def forward(self, x): if self.inplace: diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index fd28aee7bd0..0c6a9640330 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -573,8 +573,8 @@ class FFN(BaseModule): when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. - use_layer_scale (bool): Whether to use layer_scale in FFN. - Default: `True`. + layer_scale_init_value (float): Initial value of scale factor in + LayerScale. Default: 1.0 """ @deprecated_api_warning( @@ -592,7 +592,7 @@ def __init__(self, dropout_layer=None, add_identity=True, init_cfg=None, - use_layer_scale=True, + layer_scale_init_value=0., **kwargs): super().__init__(init_cfg) assert num_fcs >= 2, 'num_fcs should be no less ' \ @@ -618,8 +618,8 @@ def __init__(self, dropout_layer) if dropout_layer else torch.nn.Identity() self.add_identity = add_identity - if use_layer_scale: - self.gamma2 = LayerScale(embed_dims) + if layer_scale_init_value > 0: + self.gamma2 = LayerScale(embed_dims, scale=layer_scale_init_value) else: self.gamma2 = nn.Identity() From 07985db82cfd83bdfda8bb176286f0718c0aa9c3 Mon Sep 17 00:00:00 2001 From: Rist115 Date: Sun, 27 Nov 2022 11:07:44 +0900 Subject: [PATCH 09/14] add typehint --- mmcv/cnn/bricks/scale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/cnn/bricks/scale.py b/mmcv/cnn/bricks/scale.py index 49d27fc85bc..2aa12b91357 100644 --- a/mmcv/cnn/bricks/scale.py +++ b/mmcv/cnn/bricks/scale.py @@ -48,7 +48,7 @@ def __init__(self, else: self.weight = nn.Parameter(torch.ones(dim) * scale) - def forward(self, x): + def forward(self, x) -> torch.Tensor: if self.inplace: return x.mul_(self.weight) else: From 6bd66aa831cc5915cf3ca16844c4e3548961c290 Mon Sep 17 00:00:00 2001 From: Rist115 Date: Mon, 28 Nov 2022 15:45:06 +0900 Subject: [PATCH 10/14] fix for tensor with any dim --- mmcv/cnn/bricks/scale.py | 16 ++++++++++------ tests/test_cnn/test_scale.py | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/mmcv/cnn/bricks/scale.py b/mmcv/cnn/bricks/scale.py index 2aa12b91357..01e2460bc5b 100644 --- a/mmcv/cnn/bricks/scale.py +++ b/mmcv/cnn/bricks/scale.py @@ -43,13 +43,17 @@ def __init__(self, assert data_format in ('channels_last', 'channels_first'), \ "'data_format' could only be channels_last or channels_first." self.inplace = inplace - if data_format == 'channels_first': - self.weight = nn.Parameter(torch.ones(dim, 1, 1) * scale) - else: - self.weight = nn.Parameter(torch.ones(dim) * scale) + self.data_format = data_format + self.weight = nn.Parameter(torch.ones(dim) * scale) def forward(self, x) -> torch.Tensor: + if self.data_format == 'channels_first': + shape = tuple((-1, *(1 for _ in range(x.dim() - 2)))) + else: + shape = tuple((*(1 for _ in range(x.dim() - 3)), -1)) if self.inplace: - return x.mul_(self.weight) + return x.mul_(self.weight.view(*shape)) else: - return x * self.weight + print(x.shape, self.weight.shape, + self.weight.view(*shape).shape, shape, self.data_format) + return x * self.weight.view(*shape) diff --git a/tests/test_cnn/test_scale.py b/tests/test_cnn/test_scale.py index 1a96e263d11..04d75ec16f5 100644 --- a/tests/test_cnn/test_scale.py +++ b/tests/test_cnn/test_scale.py @@ -45,6 +45,14 @@ def test_layer_scale(): assert tuple(out.size()) == (4, 49, 256) assert torch.equal(x * 1e-5, out) + # test channels_last 2d + cfg = dict(dim=256, inplace=False, data_format='channels_last') + ls_channels_last = LayerScale(**cfg) + x = torch.randn((4, 7, 49, 256)) + out = ls_channels_last(x) + assert tuple(out.size()) == (4, 7, 49, 256) + assert torch.equal(x * 1e-5, out) + # test channels_first cfg = dict(dim=256, inplace=False, data_format='channels_first') ls_channels_first = LayerScale(**cfg) @@ -53,6 +61,14 @@ def test_layer_scale(): assert tuple(out.size()) == (4, 256, 7, 7) assert torch.equal(x * 1e-5, out) + # test channels_first 3D + cfg = dict(dim=256, inplace=False, data_format='channels_first') + ls_channels_first = LayerScale(**cfg) + x = torch.randn((4, 256, 7, 7, 7)) + out = ls_channels_first(x) + assert tuple(out.size()) == (4, 256, 7, 7, 7) + assert torch.equal(x * 1e-5, out) + # test inplace True cfg = dict(dim=256, inplace=True, data_format='channels_first') ls_channels_first = LayerScale(**cfg) From 3e211d1776f16fd1d90ba15204de4babe5f39933 Mon Sep 17 00:00:00 2001 From: Rist115 Date: Fri, 2 Dec 2022 11:25:27 +0900 Subject: [PATCH 11/14] fix layer scale rule --- mmcv/cnn/bricks/scale.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mmcv/cnn/bricks/scale.py b/mmcv/cnn/bricks/scale.py index 01e2460bc5b..cf13997ee8e 100644 --- a/mmcv/cnn/bricks/scale.py +++ b/mmcv/cnn/bricks/scale.py @@ -50,10 +50,8 @@ def forward(self, x) -> torch.Tensor: if self.data_format == 'channels_first': shape = tuple((-1, *(1 for _ in range(x.dim() - 2)))) else: - shape = tuple((*(1 for _ in range(x.dim() - 3)), -1)) + shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) if self.inplace: return x.mul_(self.weight.view(*shape)) else: - print(x.shape, self.weight.shape, - self.weight.view(*shape).shape, shape, self.data_format) return x * self.weight.view(*shape) From f037cd17a2d5cc20d1aa01c7171f7fbf3837858d Mon Sep 17 00:00:00 2001 From: Rist115 Date: Mon, 5 Dec 2022 21:00:13 +0900 Subject: [PATCH 12/14] fix layer scale rule --- mmcv/cnn/bricks/scale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/cnn/bricks/scale.py b/mmcv/cnn/bricks/scale.py index cf13997ee8e..a47379898f7 100644 --- a/mmcv/cnn/bricks/scale.py +++ b/mmcv/cnn/bricks/scale.py @@ -48,7 +48,7 @@ def __init__(self, def forward(self, x) -> torch.Tensor: if self.data_format == 'channels_first': - shape = tuple((-1, *(1 for _ in range(x.dim() - 2)))) + shape = tuple((1, -1, *(1 for _ in range(x.dim() - 2)))) else: shape = tuple((*(1 for _ in range(x.dim() - 1)), -1)) if self.inplace: From 555e4ddbc1f6f80e86a5f13655adf14207c56508 Mon Sep 17 00:00:00 2001 From: Rist115 Date: Wed, 7 Dec 2022 13:53:01 +0900 Subject: [PATCH 13/14] fix test --- tests/test_cnn/test_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cnn/test_transformer.py b/tests/test_cnn/test_transformer.py index 019c9fe1ee2..b5a9562ee72 100644 --- a/tests/test_cnn/test_transformer.py +++ b/tests/test_cnn/test_transformer.py @@ -553,7 +553,7 @@ def test_ffn(): ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) # test with layer_scale - ffn = FFN(dropout=0, add_identity=True, use_layer_scale=True) + ffn = FFN(dropout=0, add_identity=True, layer_scale_init_value=0.1) input_tensor = torch.rand(2, 20, 256) input_tensor_nbc = input_tensor.transpose(0, 1) From ec1a7cb24c9e59b896eb885e812b772a667dc47a Mon Sep 17 00:00:00 2001 From: Rist115 Date: Fri, 9 Dec 2022 17:00:35 +0900 Subject: [PATCH 14/14] add docs --- docs/en/api/cnn.rst | 1 + docs/zh_cn/api/cnn.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/en/api/cnn.rst b/docs/en/api/cnn.rst index 8b4c9c13ba3..5cbcb191e9e 100644 --- a/docs/en/api/cnn.rst +++ b/docs/en/api/cnn.rst @@ -31,6 +31,7 @@ Module GeneralizedAttention HSigmoid HSwish + LayerScale Linear MaxPool2d MaxPool3d diff --git a/docs/zh_cn/api/cnn.rst b/docs/zh_cn/api/cnn.rst index 8b4c9c13ba3..5cbcb191e9e 100644 --- a/docs/zh_cn/api/cnn.rst +++ b/docs/zh_cn/api/cnn.rst @@ -31,6 +31,7 @@ Module GeneralizedAttention HSigmoid HSwish + LayerScale Linear MaxPool2d MaxPool3d