diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index 23f59b5d13..6d8329021b 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -5,11 +5,12 @@ from .res_layer import ResLayer from .se_layer import SELayer from .self_attention_block import SelfAttentionBlock -from .shape_convert import nchw_to_nlc, nlc_to_nchw +from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, + nlc_to_nchw) from .up_conv_block import UpConvBlock __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', - 'nchw_to_nlc', 'nlc_to_nchw' + 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc' ] diff --git a/mmseg/models/utils/shape_convert.py b/mmseg/models/utils/shape_convert.py index 0677348c80..cce1e220b6 100644 --- a/mmseg/models/utils/shape_convert.py +++ b/mmseg/models/utils/shape_convert.py @@ -27,3 +27,81 @@ def nchw_to_nlc(x): """ assert len(x.shape) == 4 return x.flatten(2).transpose(1, 2).contiguous() + + +def nchw2nlc2nchw(module, x, contiguous=False, **kwargs): + """Flatten [N, C, H, W] shape tensor `x` to [N, L, C] shape tensor. Use the + reshaped tensor as the input of `module`, and the convert the output of + `module`, whose shape is. + + [N, L, C], to [N, C, H, W]. + + Args: + module (Callable): A callable object the takes a tensor + with shape [N, L, C] as input. + x (Tensor): The input tensor of shape [N, C, H, W]. + contiguous: + contiguous (Bool): Whether to make the tensor contiguous + after each shape transform. + + Returns: + Tensor: The output tensor of shape [N, C, H, W]. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> norm = nn.LayerNorm(4) + >>> feature_map = torch.rand(4, 4, 5, 5) + >>> output = nchw2nlc2nchw(norm, feature_map) + """ + B, C, H, W = x.shape + if not contiguous: + x = x.flatten(2).transpose(1, 2) + x = module(x, **kwargs) + x = x.transpose(1, 2).reshape(B, C, H, W) + else: + x = x.flatten(2).transpose(1, 2).contiguous() + x = module(x, **kwargs) + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + return x + + +def nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs): + """Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the + reshaped tensor as the input of `module`, and convert the output of + `module`, whose shape is. + + [N, C, H, W], to [N, L, C]. + + Args: + module (Callable): A callable object the takes a tensor + with shape [N, C, H, W] as input. + x (Tensor): The input tensor of shape [N, L, C]. + hw_shape: (Sequence[int]): The height and width of the + feature map with shape [N, C, H, W]. + contiguous (Bool): Whether to make the tensor contiguous + after each shape transform. + + Returns: + Tensor: The output tensor of shape [N, L, C]. + + Example: + >>> import torch + >>> import torch.nn as nn + >>> conv = nn.Conv2d(16, 16, 3, 1, 1) + >>> feature_map = torch.rand(4, 25, 16) + >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5)) + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + if not contiguous: + x = x.transpose(1, 2).reshape(B, C, H, W) + x = module(x, **kwargs) + x = x.flatten(2).transpose(1, 2) + else: + x = x.transpose(1, 2).reshape(B, C, H, W).contiguous() + x = module(x, **kwargs) + x = x.flatten(2).transpose(1, 2).contiguous() + return x diff --git a/tests/test_models/test_utils/test_shape_convert.py b/tests/test_models/test_utils/test_shape_convert.py new file mode 100644 index 0000000000..60e87f38ed --- /dev/null +++ b/tests/test_models/test_utils/test_shape_convert.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmseg.models.utils import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, + nlc_to_nchw) + + +def test_nchw2nlc2nchw(): + # Test nchw2nlc2nchw function + shape_nchw = (4, 2, 5, 5) + shape_nlc = (4, 25, 2) + + def test_func(x): + assert x.shape == torch.Size(shape_nlc) + return x + + x = torch.rand(*shape_nchw) + output = nchw2nlc2nchw(test_func, x) + assert output.shape == torch.Size(shape_nchw) + + def test_func2(x, arg): + assert x.shape == torch.Size(shape_nlc) + assert arg == 100 + return x + + x = torch.rand(*shape_nchw) + output = nchw2nlc2nchw(test_func2, x, arg=100) + assert output.shape == torch.Size(shape_nchw) + + def test_func3(x): + assert x.is_contiguous() + assert x.shape == torch.Size(shape_nlc) + return x + + x = torch.rand(*shape_nchw) + output = nchw2nlc2nchw(test_func3, x, contiguous=True) + assert output.shape == torch.Size(shape_nchw) + assert output.is_contiguous() + + +def test_nlc2nchw2nlc(): + # Test nlc2nchw2nlc function + shape_nchw = (4, 2, 5, 5) + shape_nlc = (4, 25, 2) + + def test_func(x): + assert x.shape == torch.Size(shape_nchw) + return x + + x = torch.rand(*shape_nlc) + output = nlc2nchw2nlc(test_func, x, shape_nchw[2:]) + assert output.shape == torch.Size(shape_nlc) + + def test_func2(x, arg): + assert x.shape == torch.Size(shape_nchw) + assert arg == 100 + return x + + x = torch.rand(*shape_nlc) + output = nlc2nchw2nlc(test_func2, x, shape_nchw[2:], arg=100) + assert output.shape == torch.Size(shape_nlc) + + def test_func3(x): + assert x.is_contiguous() + assert x.shape == torch.Size(shape_nchw) + return x + + x = torch.rand(*shape_nlc) + output = nlc2nchw2nlc(test_func3, x, shape_nchw[2:], contiguous=True) + assert output.shape == torch.Size(shape_nlc) + assert output.is_contiguous() + + +def test_nchw_to_nlc(): + # Test nchw_to_nlc function + shape_nchw = (4, 2, 5, 5) + shape_nlc = (4, 25, 2) + x = torch.rand(*shape_nchw) + y = nchw_to_nlc(x) + assert y.shape == torch.Size(shape_nlc) + + +def test_nlc_to_nchw(): + # Test nlc_to_nchw function + shape_nchw = (4, 2, 5, 5) + shape_nlc = (4, 25, 2) + x = torch.rand(*shape_nlc) + y = nlc_to_nchw(x, (5, 5)) + assert y.shape == torch.Size(shape_nchw)