diff --git a/mmedit/datasets/pipelines/matting_aug.py b/mmedit/datasets/pipelines/matting_aug.py index de5b4fbb58..28fb04f233 100644 --- a/mmedit/datasets/pipelines/matting_aug.py +++ b/mmedit/datasets/pipelines/matting_aug.py @@ -576,14 +576,15 @@ def __repr__(self): @PIPELINES.register_module() class TransformTrimap: - """Generate two-channel trimap and encode it into six-channel. + """Transform trimap into two-channel and six-channel. This calss will generate a two-channel trimap composed of definite foreground and backgroud masks and encode it into a six-channel trimap using Gaussian blurs of the generated two-channel trimap at three different scales. The transformed trimap has 6 channels. - Required key is "trimap", added key is "transformed_trimap". + Required key is "trimap", added key is "transformed_trimap" and + "two_channel_trimap". Adopted from the following repository: https://github.com/MarcoForte/FBA_Matting/blob/master/networks/transforms.py. @@ -619,6 +620,7 @@ def __call__(self, results): dt_mask / (2 * ((factor * L)**2))) results['transformed_trimap'] = trimap_trans + results['two_channel_trimap'] = trimap2 return results def __repr__(self): diff --git a/mmedit/datasets/pipelines/normalization.py b/mmedit/datasets/pipelines/normalization.py index fcb462b1e7..1b41ef9e23 100644 --- a/mmedit/datasets/pipelines/normalization.py +++ b/mmedit/datasets/pipelines/normalization.py @@ -19,11 +19,12 @@ class Normalize: to_rgb (bool): Whether to convert channels from BGR to RGB. """ - def __init__(self, keys, mean, std, to_rgb=False): + def __init__(self, keys, mean, std, to_rgb=False, save_original=False): self.keys = keys self.mean = np.array(mean, dtype=np.float32) self.std = np.array(std, dtype=np.float32) self.to_rgb = to_rgb + self.save_original = save_original def __call__(self, results): """Call function. @@ -37,11 +38,17 @@ def __call__(self, results): """ for key in self.keys: if isinstance(results[key], list): + if self.save_original: + results[key + '_unnormalised'] = [ + v.copy() for v in results[key] + ] results[key] = [ mmcv.imnormalize(v, self.mean, self.std, self.to_rgb) for v in results[key] ] else: + if self.save_original: + results[key + '_unnormalised'] = results[key].copy() results[key] = mmcv.imnormalize(results[key], self.mean, self.std, self.to_rgb) diff --git a/mmedit/models/backbones/__init__.py b/mmedit/models/backbones/__init__.py index ae54e77ce8..93393a1863 100644 --- a/mmedit/models/backbones/__init__.py +++ b/mmedit/models/backbones/__init__.py @@ -1,7 +1,8 @@ # yapf: disable from .encoder_decoders import (VGG16, ContextualAttentionNeck, DeepFillDecoder, DeepFillEncoder, DeepFillEncoderDecoder, - DepthwiseIndexBlock, GLDecoder, GLDilationNeck, + DepthwiseIndexBlock, FBADecoder, + FBAResnetDilated, GLDecoder, GLDilationNeck, GLEncoder, GLEncoderDecoder, HolisticIndexBlock, IndexedUpsample, IndexNetDecoder, IndexNetEncoder, PConvDecoder, PConvEncoder, @@ -22,5 +23,5 @@ 'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR', 'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder', 'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN', - 'UnetGenerator', 'ResnetGenerator' + 'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder' ] diff --git a/mmedit/models/backbones/encoder_decoders/__init__.py b/mmedit/models/backbones/encoder_decoders/__init__.py index 6e2c6a11f2..6155b171f6 100644 --- a/mmedit/models/backbones/encoder_decoders/__init__.py +++ b/mmedit/models/backbones/encoder_decoders/__init__.py @@ -1,9 +1,10 @@ -from .decoders import (DeepFillDecoder, GLDecoder, IndexedUpsample, +from .decoders import (DeepFillDecoder, FBADecoder, GLDecoder, IndexedUpsample, IndexNetDecoder, PConvDecoder, PlainDecoder, ResGCADecoder, ResNetDec, ResShortcutDec) -from .encoders import (VGG16, DeepFillEncoder, DepthwiseIndexBlock, GLEncoder, - HolisticIndexBlock, IndexNetEncoder, PConvEncoder, - ResGCAEncoder, ResNetEnc, ResShortcutEnc) +from .encoders import (VGG16, DeepFillEncoder, DepthwiseIndexBlock, + FBAResnetDilated, GLEncoder, HolisticIndexBlock, + IndexNetEncoder, PConvEncoder, ResGCAEncoder, ResNetEnc, + ResShortcutEnc) from .gl_encoder_decoder import GLEncoderDecoder from .necks import ContextualAttentionNeck, GLDilationNeck from .pconv_encoder_decoder import PConvEncoderDecoder @@ -17,5 +18,6 @@ 'ResShortcutDec', 'HolisticIndexBlock', 'DepthwiseIndexBlock', 'DeepFillEncoder', 'DeepFillEncoderDecoder', 'DeepFillDecoder', 'ContextualAttentionNeck', 'IndexedUpsample', 'IndexNetEncoder', - 'IndexNetDecoder', 'ResGCAEncoder', 'ResGCADecoder' + 'IndexNetDecoder', 'ResGCAEncoder', 'ResGCADecoder', 'FBAResnetDilated', + 'FBADecoder' ] diff --git a/mmedit/models/backbones/encoder_decoders/decoders/__init__.py b/mmedit/models/backbones/encoder_decoders/decoders/__init__.py index 023e2801d1..054bbc489d 100644 --- a/mmedit/models/backbones/encoder_decoders/decoders/__init__.py +++ b/mmedit/models/backbones/encoder_decoders/decoders/__init__.py @@ -1,4 +1,5 @@ from .deepfill_decoder import DeepFillDecoder +from .fba_decoder import FBADecoder from .gl_decoder import GLDecoder from .indexnet_decoder import IndexedUpsample, IndexNetDecoder from .pconv_decoder import PConvDecoder @@ -7,5 +8,6 @@ __all__ = [ 'GLDecoder', 'PlainDecoder', 'PConvDecoder', 'ResNetDec', 'ResShortcutDec', - 'DeepFillDecoder', 'IndexedUpsample', 'IndexNetDecoder', 'ResGCADecoder' + 'DeepFillDecoder', 'IndexedUpsample', 'IndexNetDecoder', 'ResGCADecoder', + 'FBADecoder' ] diff --git a/mmedit/models/backbones/encoder_decoders/decoders/fba_decoder.py b/mmedit/models/backbones/encoder_decoders/decoders/fba_decoder.py new file mode 100644 index 0000000000..fe01e35c4a --- /dev/null +++ b/mmedit/models/backbones/encoder_decoders/decoders/fba_decoder.py @@ -0,0 +1,212 @@ +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, constant_init, kaiming_init +from mmcv.runner import load_checkpoint +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmedit.models.registry import COMPONENTS +from mmedit.utils import get_root_logger + + +@COMPONENTS.register_module() +class FBADecoder(nn.Module): + """Decoder for FBA matting. + + pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid + Module. + in_channels (int): Input channels. + channels (int): Channels after modules, before conv_seg. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict): Config of activation layers. + align_corners (bool): align_corners argument of F.interpolate. + """ + + def __init__(self, + pool_scales, + in_channels, + channels, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + align_corners=False): + super().__init__() + + assert isinstance(pool_scales, (list, tuple)) + # Pyramid Pooling Module + self.pool_scales = pool_scales + self.in_channels = in_channels + self.channels = channels + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.align_corners = align_corners + self.batch_norm = False + + self.ppm = [] + for scale in self.pool_scales: + self.ppm.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + *(ConvModule( + self.in_channels, + self.channels, + kernel_size=1, + bias=True, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg).children()))) + self.ppm = nn.ModuleList(self.ppm) + + # Follwed the author's implementation that + # concatenate conv layers described in the supplementary + # material between up operations + self.conv_up1 = nn.Sequential(*(list( + ConvModule( + self.in_channels + len(pool_scales) * 256, + self.channels, + padding=1, + kernel_size=3, + bias=True, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg).children()) + list( + ConvModule( + self.channels, + self.channels, + padding=1, + bias=True, + kernel_size=3, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg).children()))) + + self.conv_up2 = nn.Sequential(*(list( + ConvModule( + self.channels * 2, + self.channels, + padding=1, + kernel_size=3, + bias=True, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg).children()))) + + if (self.norm_cfg['type'] == 'BN'): + d_up3 = 128 + else: + d_up3 = 64 + + self.conv_up3 = nn.Sequential(*(list( + ConvModule( + self.channels + d_up3, + 64, + padding=1, + kernel_size=3, + bias=True, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg).children()))) + + self.unpool = nn.MaxUnpool2d(2, stride=2) + + self.conv_up4 = nn.Sequential(*(list( + ConvModule( + 64 + 3 + 3 + 2, + 32, + padding=1, + kernel_size=3, + bias=True, + conv_cfg=self.conv_cfg, + act_cfg=self.act_cfg).children()) + list( + ConvModule( + 32, + 16, + padding=1, + kernel_size=3, + bias=True, + conv_cfg=self.conv_cfg, + act_cfg=self.act_cfg).children()) + list( + ConvModule( + 16, + 7, + padding=0, + kernel_size=1, + bias=True, + conv_cfg=self.conv_cfg, + act_cfg=None).children()))) + + def init_weights(self, pretrained=None): + """Init weights for the model. + + Args: + pretrained (str, optional): Path for pretrained weights. If given + None, pretrained weights will not be loaded. Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, inputs): + """Forward function. + Args: + inputs (dict): Output dict of FbaEncoder. + Returns: + Tensor: Predicted alpha, fg and bg of the current batch. + """ + + conv_out = inputs['conv_out'] + img = inputs['merged'] + two_channel_trimap = inputs['two_channel_trimap'] + conv5 = conv_out[-1] + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append( + nn.functional.interpolate( + pool_scale(conv5), (input_size[2], input_size[3]), + mode='bilinear', + align_corners=self.align_corners)) + ppm_out = torch.cat(ppm_out, 1) + x = self.conv_up1(ppm_out) + + x = torch.nn.functional.interpolate( + x, + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners) + + x = torch.cat((x, conv_out[-4]), 1) + + x = self.conv_up2(x) + x = torch.nn.functional.interpolate( + x, + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners) + + x = torch.cat((x, conv_out[-5]), 1) + x = self.conv_up3(x) + + x = torch.nn.functional.interpolate( + x, + scale_factor=2, + mode='bilinear', + align_corners=self.align_corners) + + x = torch.cat((x, conv_out[-6][:, :3], img, two_channel_trimap), 1) + output = self.conv_up4(x) + alpha = torch.clamp(output[:, 0:1], 0, 1) + F = torch.sigmoid(output[:, 1:4]) + B = torch.sigmoid(output[:, 4:7]) + + return alpha, F, B diff --git a/mmedit/models/backbones/encoder_decoders/encoders/__init__.py b/mmedit/models/backbones/encoder_decoders/encoders/__init__.py index dd61b6b3ce..2c1456e4c0 100644 --- a/mmedit/models/backbones/encoder_decoders/encoders/__init__.py +++ b/mmedit/models/backbones/encoder_decoders/encoders/__init__.py @@ -1,4 +1,5 @@ from .deepfill_encoder import DeepFillEncoder +from .fba_encoder import FBAResnetDilated from .gl_encoder import GLEncoder from .indexnet_encoder import (DepthwiseIndexBlock, HolisticIndexBlock, IndexNetEncoder) @@ -9,5 +10,5 @@ __all__ = [ 'GLEncoder', 'VGG16', 'ResNetEnc', 'HolisticIndexBlock', 'DepthwiseIndexBlock', 'ResShortcutEnc', 'PConvEncoder', 'DeepFillEncoder', - 'IndexNetEncoder', 'ResGCAEncoder' + 'IndexNetEncoder', 'ResGCAEncoder', 'FBAResnetDilated' ] diff --git a/mmedit/models/backbones/encoder_decoders/encoders/fba_encoder.py b/mmedit/models/backbones/encoder_decoders/encoders/fba_encoder.py new file mode 100644 index 0000000000..d8f404d926 --- /dev/null +++ b/mmedit/models/backbones/encoder_decoders/encoders/fba_encoder.py @@ -0,0 +1,44 @@ +from mmedit.models.registry import COMPONENTS +from .resnet import ResNet + + +@COMPONENTS.register_module() +class FBAResnetDilated(ResNet): + """ResNet-based encoder for FBA image matting.""" + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (N, C, H, W). + + Returns: + Tensor: Output tensor. + """ + # x: (merged_t, trimap_t, two_channel_trimap,merged) + # t refers to tranformed. + two_channel_trimap = x[:, 9:11] + merged = x[:, 11:14] + x = x[:, 0:11, ...] + conv_out = [x] + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.activate(x) + conv_out.append(x) + x = self.maxpool(x) + x = self.layer1(x) + conv_out.append(x) + x = self.layer2(x) + conv_out.append(x) + x = self.layer3(x) + conv_out.append(x) + x = self.layer4(x) + conv_out.append(x) + return { + 'conv_out': conv_out, + 'merged': merged, + 'two_channel_trimap': two_channel_trimap + } diff --git a/mmedit/models/backbones/encoder_decoders/encoders/resnet.py b/mmedit/models/backbones/encoder_decoders/encoders/resnet.py new file mode 100644 index 0000000000..79cb5fbbb0 --- /dev/null +++ b/mmedit/models/backbones/encoder_decoders/encoders/resnet.py @@ -0,0 +1,476 @@ +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (build_activation_layer, build_conv_layer, + build_norm_layer, constant_init, kaiming_init) +from mmcv.runner import load_checkpoint +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmedit.utils import get_root_logger + + +class BasicBlock(nn.Module): + """Basic block for ResNet.""" + + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + act_cfg=dict(type='ReLU'), + conv_cfg=None, + norm_cfg=dict(type='BN'), + with_cp=False): + super(BasicBlock, self).__init__() + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.activate = build_activation_layer(act_cfg) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.activate(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.activate(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck block for ResNet.""" + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + act_cfg=dict(type='ReLU'), + conv_cfg=None, + norm_cfg=dict(type='BN'), + with_cp=False): + super(Bottleneck, self).__init__() + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.act_cfg = act_cfg + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.conv1_stride = 1 + self.conv2_stride = stride + self.with_cp = with_cp + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.activate = build_activation_layer(act_cfg) + self.downsample = downsample + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.activate(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.activate(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.activate(out) + + return out + + +class ResNet(nn.Module): + """General ResNet. + + This class is adopted from + https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/backbones/resnet.py. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Default" 3. + stem_channels (int): Number of stem channels. Default: 64. + base_channels (int): Number of base channels of res layer. Default: 64. + num_stages (int): Resnet stages, normally 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + act_cfg (dict): Dictionary to construct and config activation layer. + conv_cfg (dict): Dictionary to construct and config convolution layer. + norm_cfg (dict): Dictionary to construct and config norm layer. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + multi_grid (Sequence[int]|None): Multi grid dilation rates of last + stage. Default: None + contract_dilation (bool): Whether contract first dilation of each layer + Default: False + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels, + stem_channels, + base_channels, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 2, 4), + deep_stem=False, + avg_down=False, + frozen_stages=-1, + act_cfg=dict(type='ReLU'), + conv_cfg=None, + norm_cfg=dict(type='BN'), + with_cp=False, + multi_grid=None, + contract_dilation=False, + zero_init_residual=True): + super(ResNet, self).__init__() + from functools import partial + + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + self.block, stage_blocks = self.arch_settings[depth] + self.depth = depth + self.inplanes = stem_channels + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + + self.conv_cfg = conv_cfg + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + + self.with_cp = with_cp + self.multi_grid = multi_grid + self.contract_dilation = contract_dilation + self.zero_init_residual = zero_init_residual + + self._make_stem_layer(in_channels, stem_channels) + + self.layer1 = self._make_layer( + self.block, 64, stage_blocks[0], stride=strides[0]) + self.layer2 = self._make_layer( + self.block, 128, stage_blocks[1], stride=strides[1]) + self.layer3 = self._make_layer( + self.block, 256, stage_blocks[2], stride=strides[2]) + self.layer4 = self._make_layer( + self.block, 512, stage_blocks[3], stride=strides[3]) + + self.layer1.apply(partial(self._nostride_dilate, dilate=dilations[0])) + self.layer2.apply(partial(self._nostride_dilate, dilate=dilations[1])) + self.layer3.apply(partial(self._nostride_dilate, dilate=dilations[2])) + self.layer4.apply(partial(self._nostride_dilate, dilate=dilations[3])) + + self._freeze_stages() + + def _make_stem_layer(self, in_channels, stem_channels): + """Make stem layer for ResNet.""" + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + build_activation_layer(self.act_cfg), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + build_activation_layer(self.act_cfg), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + build_activation_layer(self.act_cfg)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.activate = build_activation_layer(self.act_cfg) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + @property + def norm1(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm1_name) + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + build_conv_layer( + self.conv_cfg, + self.inplanes, + planes * block.expansion, + stride=stride, + kernel_size=1, + dilation=dilation, + bias=False), + build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample=downsample, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + conv_cfg=self.conv_cfg)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + conv_cfg=self.conv_cfg)) + + return nn.Sequential(*layers) + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1 and dilate > 1: + # the convolution with stride + + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate // 2, dilate // 2) + m.padding = (dilate // 2, dilate // 2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def init_weights(self, pretrained=None): + """Init weights for the model. + + Args: + pretrained (str, optional): Path for pretrained weights. If given + None, pretrained weights will not be loaded. Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + if self.zero_init_residual: + + for m in self.modules(): + if isinstance(m, Bottleneck): + constant_init(m.norm3, 0) + elif isinstance(m, BasicBlock): + constant_init(m.norm2, 0) + else: + raise TypeError('pretrained must be a str or None') + + def _freeze_stages(self): + """Freeze stages param and norm stats.""" + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (N, C, H, W). + + Returns: + Tensor: Output tensor. + """ + conv_out = [x] + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.activate(x) + conv_out.append(x) + x = self.maxpool(x) + x = self.layer1(x) + conv_out.append(x) + x = self.layer2(x) + conv_out.append(x) + x = self.layer3(x) + conv_out.append(x) + x = self.layer4(x) + conv_out.append(x) + return conv_out diff --git a/mmedit/models/backbones/encoder_decoders/simple_encoder_decoder.py b/mmedit/models/backbones/encoder_decoders/simple_encoder_decoder.py index 08c5a411b2..4c6270ee27 100644 --- a/mmedit/models/backbones/encoder_decoders/simple_encoder_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/simple_encoder_decoder.py @@ -17,7 +17,8 @@ def __init__(self, encoder, decoder): super().__init__() self.encoder = build_component(encoder) - decoder['in_channels'] = self.encoder.out_channels + if hasattr(self.encoder, 'out_channels'): + decoder['in_channels'] = self.encoder.out_channels self.decoder = build_component(decoder) def init_weights(self, pretrained=None): diff --git a/mmedit/models/mattors/utils.py b/mmedit/models/mattors/utils.py index 580de0ecc7..4edbc5b587 100644 --- a/mmedit/models/mattors/utils.py +++ b/mmedit/models/mattors/utils.py @@ -1,3 +1,6 @@ +import torch + + def get_unknown_tensor(trimap, meta): """Get 1-channel unknown area tensor from the 3 or 1-channel trimap tensor. @@ -20,3 +23,33 @@ def get_unknown_tensor(trimap, meta): # 0 for bg, 128/255 for unknown, 1 for fg weight = trimap.eq(128 / 255).float() return weight + + +def fba_fusion(alpha, img, F, B): + """Postprocess the predicted. + + This class is adopted from + https://github.com/MarcoForte/FBA_Matting. + + Args: + alpha (Tensor): Tensor with shape (N, 1, H, W). + img (Tensor): Tensor with shape (N, 3, H, W). + F (Tensor): Tensor with shape (N, 3, H, W). + B (Tensor): Tensor with shape (N, 3, H, W). + + Returns: + alpha (Tensor): Tensor with shape (N, 1, H, W). + F (Tensor): Tensor with shape (N, 3, H, W). + B (Tensor): Tensor with shape (N, 3, H, W). + """ + F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B)) + B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * + (1 - alpha) * F) + + F = torch.clamp(F, 0, 1) + B = torch.clamp(B, 0, 1) + la = 0.1 + alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / ( + torch.sum((F - B) * (F - B), 1, keepdim=True) + la) + alpha = torch.clamp(alpha, 0, 1) + return alpha, F, B diff --git a/tests/test_decoders.py b/tests/test_decoders.py index 07c714ed6f..7f34f1ae35 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -2,10 +2,11 @@ import pytest import torch -from mmedit.models.backbones import (VGG16, IndexedUpsample, IndexNetDecoder, - IndexNetEncoder, PlainDecoder, - ResGCADecoder, ResGCAEncoder, ResNetDec, - ResNetEnc, ResShortcutDec, ResShortcutEnc) +from mmedit.models.backbones import (VGG16, FBADecoder, IndexedUpsample, + IndexNetDecoder, IndexNetEncoder, + PlainDecoder, ResGCADecoder, + ResGCAEncoder, ResNetDec, ResNetEnc, + ResShortcutDec, ResShortcutEnc) def assert_tensor_with_shape(tensor, shape): @@ -14,6 +15,20 @@ def assert_tensor_with_shape(tensor, shape): assert tensor.shape == shape +def _demo_inputs(input_shape=(1, 4, 64, 64)): + """ + Create a superset of inputs needed to run encoder. + + Args: + input_shape (tuple): input batch dimensions. + Default: (1, 4, 64, 64). + """ + img = np.random.random(input_shape).astype(np.float32) + img = torch.from_numpy(img) + + return img + + def test_plain_decoder(): """Test PlainDecoder.""" @@ -199,15 +214,30 @@ def test_indexnet_decoder(): assert out.shape == (2, 1, 32, 32) -def _demo_inputs(input_shape=(1, 4, 64, 64)): - """ - Create a superset of inputs needed to run encoder. +def test_fba_decoder(): - Args: - input_shape (tuple): input batch dimensions. - Default: (1, 4, 64, 64). - """ - img = np.random.random(input_shape).astype(np.float32) - img = torch.from_numpy(img) - - return img + with pytest.raises(AssertionError): + # pool_scales must be list|tuple + FBADecoder(pool_scales=1, in_channels=32, channels=16) + inputs = dict() + conv_out_1 = _demo_inputs((1, 11, 320, 320)) + conv_out_2 = _demo_inputs((1, 64, 160, 160)) + conv_out_3 = _demo_inputs((1, 256, 80, 80)) + conv_out_4 = _demo_inputs((1, 512, 40, 40)) + conv_out_5 = _demo_inputs((1, 1024, 40, 40)) + conv_out_6 = _demo_inputs((1, 2048, 40, 40)) + inputs['conv_out'] = [ + conv_out_1, conv_out_2, conv_out_3, conv_out_4, conv_out_5, conv_out_6 + ] + inputs['merged'] = _demo_inputs((1, 3, 320, 320)) + inputs['two_channel_trimap'] = _demo_inputs((1, 2, 320, 320)) + model = FBADecoder( + pool_scales=(1, 2, 3, 6), + in_channels=2048, + channels=256, + norm_cfg=dict(type='GN', num_groups=32)) + + alpha, F, B = model(inputs) + assert_tensor_with_shape(alpha, torch.Size([1, 1, 320, 320])) + assert_tensor_with_shape(F, torch.Size([1, 3, 320, 320])) + assert_tensor_with_shape(B, torch.Size([1, 3, 320, 320])) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index e9b62bc492..9c82a429e5 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -6,8 +6,11 @@ from mmcv.utils.parrots_wrapper import _BatchNorm from mmedit.models.backbones import (VGG16, DepthwiseIndexBlock, - HolisticIndexBlock, IndexNetEncoder, - ResGCAEncoder, ResNetEnc, ResShortcutEnc) + FBAResnetDilated, HolisticIndexBlock, + IndexNetEncoder, ResGCAEncoder, ResNetEnc, + ResShortcutEnc) +from mmedit.models.backbones.encoder_decoders.encoders.resnet import ( + BasicBlock, Bottleneck) def check_norm_state(modules, train_state): @@ -19,6 +22,13 @@ def check_norm_state(modules, train_state): return True +def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (BasicBlock, Bottleneck)): + return True + return False + + def assert_tensor_with_shape(tensor, shape): """"Check if the shape of the tensor is equal to the target shape.""" assert isinstance(tensor, torch.Tensor) @@ -31,6 +41,20 @@ def assert_mid_feat_shape(mid_feat, target_shape): assert_tensor_with_shape(mid_feat[i], torch.Size(target_shape[i])) +def _demo_inputs(input_shape=(2, 4, 64, 64)): + """ + Create a superset of inputs needed to run encoder. + + Args: + input_shape (tuple): input batch dimensions. + Default: (1, 4, 64, 64). + """ + img = np.random.random(input_shape).astype(np.float32) + img = torch.from_numpy(img) + + return img + + def test_vgg16_encoder(): """Test VGG16 encoder.""" target_shape = [(2, 64, 32, 32), (2, 128, 16, 16), (2, 256, 8, 8), @@ -492,15 +516,111 @@ def test_indexnet_encoder(): assert dec_idx_feat.shape == target_shape -def _demo_inputs(input_shape=(2, 4, 64, 64)): - """ - Create a superset of inputs needed to run encoder. +def test_fba_encoder(): + """Test FBA encoder.""" + + with pytest.raises(KeyError): + # ResNet depth should be in [18, 34, 50, 101, 152] + FBAResnetDilated( + 20, + in_channels=11, + stem_channels=64, + base_channels=64, + ) + + with pytest.raises(AssertionError): + # In ResNet: 1 <= num_stages <= 4 + FBAResnetDilated( + 50, + in_channels=11, + stem_channels=64, + base_channels=64, + num_stages=0) + + with pytest.raises(AssertionError): + # In ResNet: 1 <= num_stages <= 4 + FBAResnetDilated( + 50, + in_channels=11, + stem_channels=64, + base_channels=64, + num_stages=5) + + with pytest.raises(AssertionError): + # len(strides) == len(dilations) == num_stages + FBAResnetDilated( + 50, + in_channels=11, + stem_channels=64, + base_channels=64, + strides=(1, ), + dilations=(1, 1), + num_stages=3) - Args: - input_shape (tuple): input batch dimensions. - Default: (1, 4, 64, 64). - """ - img = np.random.random(input_shape).astype(np.float32) - img = torch.from_numpy(img) + with pytest.raises(TypeError): + # pretrained must be a string path + model = FBAResnetDilated( + 50, + in_channels=11, + stem_channels=64, + base_channels=64, + ) + model.init_weights(pretrained=233) + + model = FBAResnetDilated( + depth=50, + in_channels=11, + stem_channels=64, + base_channels=64, + conv_cfg=dict(type='ConvWS'), + norm_cfg=dict(type='GN', num_groups=32)) - return img + model.init_weights() + model.train() + + input = _demo_inputs((1, 14, 320, 320)) + + output = model(input) + + assert 'conv_out' in output.keys() + assert 'merged' in output.keys() + assert 'two_channel_trimap' in output.keys() + + assert isinstance(output['conv_out'], list) + assert len(output['conv_out']) == 6 + + assert isinstance(output['merged'], torch.Tensor) + assert_tensor_with_shape(output['merged'], torch.Size([1, 3, 320, 320])) + + assert isinstance(output['two_channel_trimap'], torch.Tensor) + assert_tensor_with_shape(output['two_channel_trimap'], + torch.Size([1, 2, 320, 320])) + if torch.cuda.is_available(): + model = FBAResnetDilated( + depth=50, + in_channels=11, + stem_channels=64, + base_channels=64, + conv_cfg=dict(type='ConvWS'), + norm_cfg=dict(type='GN', num_groups=32)) + model.init_weights() + model.train() + model.cuda() + + input = _demo_inputs((1, 14, 320, 320)).cuda() + output = model(input) + + assert 'conv_out' in output.keys() + assert 'merged' in output.keys() + assert 'two_channel_trimap' in output.keys() + + assert isinstance(output['conv_out'], list) + assert len(output['conv_out']) == 6 + + assert isinstance(output['merged'], torch.Tensor) + assert_tensor_with_shape(output['merged'], torch.Size([1, 3, 320, + 320])) + + assert isinstance(output['two_channel_trimap'], torch.Tensor) + assert_tensor_with_shape(output['two_channel_trimap'], + torch.Size([1, 2, 320, 320]))