Skip to content

Commit

Permalink
Merge fd8954b into 41c1a49
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochaorui authored Mar 4, 2021
2 parents 41c1a49 + fd8954b commit b83f4c7
Show file tree
Hide file tree
Showing 13 changed files with 971 additions and 40 deletions.
6 changes: 4 additions & 2 deletions mmedit/datasets/pipelines/matting_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,14 +576,15 @@ def __repr__(self):

@PIPELINES.register_module()
class TransformTrimap:
"""Generate two-channel trimap and encode it into six-channel.
"""Transform trimap into two-channel and six-channel.
This calss will generate a two-channel trimap composed of definite
foreground and backgroud masks and encode it into a six-channel trimap
using Gaussian blurs of the generated two-channel trimap at three
different scales. The transformed trimap has 6 channels.
Required key is "trimap", added key is "transformed_trimap".
Required key is "trimap", added key is "transformed_trimap" and
"two_channel_trimap".
Adopted from the following repository:
https://github.com/MarcoForte/FBA_Matting/blob/master/networks/transforms.py.
Expand Down Expand Up @@ -619,6 +620,7 @@ def __call__(self, results):
dt_mask / (2 * ((factor * L)**2)))

results['transformed_trimap'] = trimap_trans
results['two_channel_trimap'] = trimap2
return results

def __repr__(self):
Expand Down
9 changes: 8 additions & 1 deletion mmedit/datasets/pipelines/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ class Normalize:
to_rgb (bool): Whether to convert channels from BGR to RGB.
"""

def __init__(self, keys, mean, std, to_rgb=False):
def __init__(self, keys, mean, std, to_rgb=False, save_original=False):
self.keys = keys
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
self.save_original = save_original

def __call__(self, results):
"""Call function.
Expand All @@ -37,11 +38,17 @@ def __call__(self, results):
"""
for key in self.keys:
if isinstance(results[key], list):
if self.save_original:
results[key + '_unnormalised'] = [
v.copy() for v in results[key]
]
results[key] = [
mmcv.imnormalize(v, self.mean, self.std, self.to_rgb)
for v in results[key]
]
else:
if self.save_original:
results[key + '_unnormalised'] = results[key].copy()
results[key] = mmcv.imnormalize(results[key], self.mean,
self.std, self.to_rgb)

Expand Down
5 changes: 3 additions & 2 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# yapf: disable
from .encoder_decoders import (VGG16, ContextualAttentionNeck, DeepFillDecoder,
DeepFillEncoder, DeepFillEncoderDecoder,
DepthwiseIndexBlock, GLDecoder, GLDilationNeck,
DepthwiseIndexBlock, FBADecoder,
FBAResnetDilated, GLDecoder, GLDilationNeck,
GLEncoder, GLEncoderDecoder, HolisticIndexBlock,
IndexedUpsample, IndexNetDecoder,
IndexNetEncoder, PConvDecoder, PConvEncoder,
Expand All @@ -22,5 +23,5 @@
'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR',
'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator'
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder'
]
12 changes: 7 additions & 5 deletions mmedit/models/backbones/encoder_decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .decoders import (DeepFillDecoder, GLDecoder, IndexedUpsample,
from .decoders import (DeepFillDecoder, FBADecoder, GLDecoder, IndexedUpsample,
IndexNetDecoder, PConvDecoder, PlainDecoder,
ResGCADecoder, ResNetDec, ResShortcutDec)
from .encoders import (VGG16, DeepFillEncoder, DepthwiseIndexBlock, GLEncoder,
HolisticIndexBlock, IndexNetEncoder, PConvEncoder,
ResGCAEncoder, ResNetEnc, ResShortcutEnc)
from .encoders import (VGG16, DeepFillEncoder, DepthwiseIndexBlock,
FBAResnetDilated, GLEncoder, HolisticIndexBlock,
IndexNetEncoder, PConvEncoder, ResGCAEncoder, ResNetEnc,
ResShortcutEnc)
from .gl_encoder_decoder import GLEncoderDecoder
from .necks import ContextualAttentionNeck, GLDilationNeck
from .pconv_encoder_decoder import PConvEncoderDecoder
Expand All @@ -17,5 +18,6 @@
'ResShortcutDec', 'HolisticIndexBlock', 'DepthwiseIndexBlock',
'DeepFillEncoder', 'DeepFillEncoderDecoder', 'DeepFillDecoder',
'ContextualAttentionNeck', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'ResGCAEncoder', 'ResGCADecoder'
'IndexNetDecoder', 'ResGCAEncoder', 'ResGCADecoder', 'FBAResnetDilated',
'FBADecoder'
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .deepfill_decoder import DeepFillDecoder
from .fba_decoder import FBADecoder
from .gl_decoder import GLDecoder
from .indexnet_decoder import IndexedUpsample, IndexNetDecoder
from .pconv_decoder import PConvDecoder
Expand All @@ -7,5 +8,6 @@

__all__ = [
'GLDecoder', 'PlainDecoder', 'PConvDecoder', 'ResNetDec', 'ResShortcutDec',
'DeepFillDecoder', 'IndexedUpsample', 'IndexNetDecoder', 'ResGCADecoder'
'DeepFillDecoder', 'IndexedUpsample', 'IndexNetDecoder', 'ResGCADecoder',
'FBADecoder'
]
212 changes: 212 additions & 0 deletions mmedit/models/backbones/encoder_decoders/decoders/fba_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmedit.models.registry import COMPONENTS
from mmedit.utils import get_root_logger


@COMPONENTS.register_module()
class FBADecoder(nn.Module):
"""Decoder for FBA matting.
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""

def __init__(self,
pool_scales,
in_channels,
channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
super().__init__()

assert isinstance(pool_scales, (list, tuple))
# Pyramid Pooling Module
self.pool_scales = pool_scales
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
self.batch_norm = False

self.ppm = []
for scale in self.pool_scales:
self.ppm.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(scale),
*(ConvModule(
self.in_channels,
self.channels,
kernel_size=1,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg).children())))
self.ppm = nn.ModuleList(self.ppm)

# Follwed the author's implementation that
# concatenate conv layers described in the supplementary
# material between up operations
self.conv_up1 = nn.Sequential(*(list(
ConvModule(
self.in_channels + len(pool_scales) * 256,
self.channels,
padding=1,
kernel_size=3,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg).children()) + list(
ConvModule(
self.channels,
self.channels,
padding=1,
bias=True,
kernel_size=3,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg).children())))

self.conv_up2 = nn.Sequential(*(list(
ConvModule(
self.channels * 2,
self.channels,
padding=1,
kernel_size=3,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg).children())))

if (self.norm_cfg['type'] == 'BN'):
d_up3 = 128
else:
d_up3 = 64

self.conv_up3 = nn.Sequential(*(list(
ConvModule(
self.channels + d_up3,
64,
padding=1,
kernel_size=3,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg).children())))

self.unpool = nn.MaxUnpool2d(2, stride=2)

self.conv_up4 = nn.Sequential(*(list(
ConvModule(
64 + 3 + 3 + 2,
32,
padding=1,
kernel_size=3,
bias=True,
conv_cfg=self.conv_cfg,
act_cfg=self.act_cfg).children()) + list(
ConvModule(
32,
16,
padding=1,
kernel_size=3,
bias=True,
conv_cfg=self.conv_cfg,
act_cfg=self.act_cfg).children()) + list(
ConvModule(
16,
7,
padding=0,
kernel_size=1,
bias=True,
conv_cfg=self.conv_cfg,
act_cfg=None).children())))

def init_weights(self, pretrained=None):
"""Init weights for the model.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)

else:
raise TypeError('pretrained must be a str or None')

def forward(self, inputs):
"""Forward function.
Args:
inputs (dict): Output dict of FbaEncoder.
Returns:
Tensor: Predicted alpha, fg and bg of the current batch.
"""

conv_out = inputs['conv_out']
img = inputs['merged']
two_channel_trimap = inputs['two_channel_trimap']
conv5 = conv_out[-1]
input_size = conv5.size()
ppm_out = [conv5]
for pool_scale in self.ppm:
ppm_out.append(
nn.functional.interpolate(
pool_scale(conv5), (input_size[2], input_size[3]),
mode='bilinear',
align_corners=self.align_corners))
ppm_out = torch.cat(ppm_out, 1)
x = self.conv_up1(ppm_out)

x = torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)

x = torch.cat((x, conv_out[-4]), 1)

x = self.conv_up2(x)
x = torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)

x = torch.cat((x, conv_out[-5]), 1)
x = self.conv_up3(x)

x = torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)

x = torch.cat((x, conv_out[-6][:, :3], img, two_channel_trimap), 1)
output = self.conv_up4(x)
alpha = torch.clamp(output[:, 0:1], 0, 1)
F = torch.sigmoid(output[:, 1:4])
B = torch.sigmoid(output[:, 4:7])

return alpha, F, B
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .deepfill_encoder import DeepFillEncoder
from .fba_encoder import FBAResnetDilated
from .gl_encoder import GLEncoder
from .indexnet_encoder import (DepthwiseIndexBlock, HolisticIndexBlock,
IndexNetEncoder)
Expand All @@ -9,5 +10,5 @@
__all__ = [
'GLEncoder', 'VGG16', 'ResNetEnc', 'HolisticIndexBlock',
'DepthwiseIndexBlock', 'ResShortcutEnc', 'PConvEncoder', 'DeepFillEncoder',
'IndexNetEncoder', 'ResGCAEncoder'
'IndexNetEncoder', 'ResGCAEncoder', 'FBAResnetDilated'
]
44 changes: 44 additions & 0 deletions mmedit/models/backbones/encoder_decoders/encoders/fba_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from mmedit.models.registry import COMPONENTS
from .resnet import ResNet


@COMPONENTS.register_module()
class FBAResnetDilated(ResNet):
"""ResNet-based encoder for FBA image matting."""

def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (N, C, H, W).
Returns:
Tensor: Output tensor.
"""
# x: (merged_t, trimap_t, two_channel_trimap,merged)
# t refers to tranformed.
two_channel_trimap = x[:, 9:11]
merged = x[:, 11:14]
x = x[:, 0:11, ...]
conv_out = [x]
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.activate(x)
conv_out.append(x)
x = self.maxpool(x)
x = self.layer1(x)
conv_out.append(x)
x = self.layer2(x)
conv_out.append(x)
x = self.layer3(x)
conv_out.append(x)
x = self.layer4(x)
conv_out.append(x)
return {
'conv_out': conv_out,
'merged': merged,
'two_channel_trimap': two_channel_trimap
}
Loading

0 comments on commit b83f4c7

Please sign in to comment.