Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backbone of FBA. #215

Merged
merged 34 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mmedit/datasets/pipelines/matting_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,14 +576,15 @@ def __repr__(self):

@PIPELINES.register_module()
class TransformTrimap:
"""Generate two-channel trimap and encode it into six-channel.
"""Transform trimap into two-channel and six-channel.

This calss will generate a two-channel trimap composed of definite
foreground and backgroud masks and encode it into a six-channel trimap
using Gaussian blurs of the generated two-channel trimap at three
different scales. The transformed trimap has 6 channels.

Required key is "trimap", added key is "transformed_trimap".
Required key is "trimap", added key is "transformed_trimap" and
"two_channel_trimap".

Adopted from the following repository:
https://github.com/MarcoForte/FBA_Matting/blob/master/networks/transforms.py.
Expand Down Expand Up @@ -619,6 +620,7 @@ def __call__(self, results):
dt_mask / (2 * ((factor * L)**2)))

results['transformed_trimap'] = trimap_trans
results['two_channel_trimap'] = trimap2
return results

def __repr__(self):
Expand Down
9 changes: 8 additions & 1 deletion mmedit/datasets/pipelines/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ class Normalize:
to_rgb (bool): Whether to convert channels from BGR to RGB.
"""

def __init__(self, keys, mean, std, to_rgb=False):
def __init__(self, keys, mean, std, to_rgb=False, save_original=False):
self.keys = keys
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
self.save_original = save_original

def __call__(self, results):
"""Call function.
Expand All @@ -37,11 +38,17 @@ def __call__(self, results):
"""
for key in self.keys:
if isinstance(results[key], list):
if self.save_original:
results[key + '_unnormalised'] = [
v.copy() for v in results[key]
]
results[key] = [
mmcv.imnormalize(v, self.mean, self.std, self.to_rgb)
for v in results[key]
]
else:
if self.save_original:
results[key + '_unnormalised'] = results[key].copy()
results[key] = mmcv.imnormalize(results[key], self.mean,
self.std, self.to_rgb)

Expand Down
5 changes: 3 additions & 2 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# yapf: disable
from .encoder_decoders import (VGG16, ContextualAttentionNeck, DeepFillDecoder,
DeepFillEncoder, DeepFillEncoderDecoder,
DepthwiseIndexBlock, GLDecoder, GLDilationNeck,
DepthwiseIndexBlock, FBADecoder,
FBAResnetDilated, GLDecoder, GLDilationNeck,
GLEncoder, GLEncoderDecoder, HolisticIndexBlock,
IndexedUpsample, IndexNetDecoder,
IndexNetEncoder, PConvDecoder, PConvEncoder,
Expand All @@ -22,5 +23,5 @@
'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR',
'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator'
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder'
]
12 changes: 7 additions & 5 deletions mmedit/models/backbones/encoder_decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .decoders import (DeepFillDecoder, GLDecoder, IndexedUpsample,
from .decoders import (DeepFillDecoder, FBADecoder, GLDecoder, IndexedUpsample,
IndexNetDecoder, PConvDecoder, PlainDecoder,
ResGCADecoder, ResNetDec, ResShortcutDec)
from .encoders import (VGG16, DeepFillEncoder, DepthwiseIndexBlock, GLEncoder,
HolisticIndexBlock, IndexNetEncoder, PConvEncoder,
ResGCAEncoder, ResNetEnc, ResShortcutEnc)
from .encoders import (VGG16, DeepFillEncoder, DepthwiseIndexBlock,
FBAResnetDilated, GLEncoder, HolisticIndexBlock,
IndexNetEncoder, PConvEncoder, ResGCAEncoder, ResNetEnc,
ResShortcutEnc)
from .gl_encoder_decoder import GLEncoderDecoder
from .necks import ContextualAttentionNeck, GLDilationNeck
from .pconv_encoder_decoder import PConvEncoderDecoder
Expand All @@ -17,5 +18,6 @@
'ResShortcutDec', 'HolisticIndexBlock', 'DepthwiseIndexBlock',
'DeepFillEncoder', 'DeepFillEncoderDecoder', 'DeepFillDecoder',
'ContextualAttentionNeck', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'ResGCAEncoder', 'ResGCADecoder'
'IndexNetDecoder', 'ResGCAEncoder', 'ResGCADecoder', 'FBAResnetDilated',
'FBADecoder'
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .deepfill_decoder import DeepFillDecoder
from .fba_decoder import FBADecoder
from .gl_decoder import GLDecoder
from .indexnet_decoder import IndexedUpsample, IndexNetDecoder
from .pconv_decoder import PConvDecoder
Expand All @@ -7,5 +8,6 @@

__all__ = [
'GLDecoder', 'PlainDecoder', 'PConvDecoder', 'ResNetDec', 'ResShortcutDec',
'DeepFillDecoder', 'IndexedUpsample', 'IndexNetDecoder', 'ResGCADecoder'
'DeepFillDecoder', 'IndexedUpsample', 'IndexNetDecoder', 'ResGCADecoder',
'FBADecoder'
]
212 changes: 212 additions & 0 deletions mmedit/models/backbones/encoder_decoders/decoders/fba_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give credit if the code is modified from the original code

import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmedit.models.registry import COMPONENTS
from mmedit.utils import get_root_logger


@COMPONENTS.register_module()
class FBADecoder(nn.Module):
"""Decoder for FBA matting.

pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""

def __init__(self,
pool_scales,
in_channels,
channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
super().__init__()

assert isinstance(pool_scales, (list, tuple))
# Pyramid Pooling Module
self.pool_scales = pool_scales
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
self.batch_norm = False

self.ppm = []
for scale in self.pool_scales:
self.ppm.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(scale),
*(ConvModule(
self.in_channels,
self.channels,
kernel_size=1,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg).children())))
self.ppm = nn.ModuleList(self.ppm)

# Follwed the author's implementation that
# concatenate conv layers described in the supplementary
# material between up operations
self.conv_up1 = nn.Sequential(*(list(
ConvModule(
self.in_channels + len(pool_scales) * 256,
self.channels,
padding=1,
kernel_size=3,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg).children()) + list(
ConvModule(
self.channels,
self.channels,
padding=1,
bias=True,
kernel_size=3,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg).children())))

self.conv_up2 = nn.Sequential(*(list(
ConvModule(
self.channels * 2,
self.channels,
padding=1,
kernel_size=3,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg).children())))

if (self.norm_cfg['type'] == 'BN'):
d_up3 = 128
else:
d_up3 = 64

self.conv_up3 = nn.Sequential(*(list(
ConvModule(
self.channels + d_up3,
64,
padding=1,
kernel_size=3,
bias=True,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg).children())))

self.unpool = nn.MaxUnpool2d(2, stride=2)

self.conv_up4 = nn.Sequential(*(list(
ConvModule(
64 + 3 + 3 + 2,
32,
padding=1,
kernel_size=3,
bias=True,
conv_cfg=self.conv_cfg,
act_cfg=self.act_cfg).children()) + list(
ConvModule(
32,
16,
padding=1,
kernel_size=3,
bias=True,
conv_cfg=self.conv_cfg,
act_cfg=self.act_cfg).children()) + list(
ConvModule(
16,
7,
padding=0,
kernel_size=1,
bias=True,
conv_cfg=self.conv_cfg,
act_cfg=None).children())))

def init_weights(self, pretrained=None):
"""Init weights for the model.

Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)

else:
raise TypeError('pretrained must be a str or None')

def forward(self, inputs):
"""Forward function.
Args:
inputs (dict): Output dict of FbaEncoder.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enocder or decoder?

Returns:
Tensor: Predicted alpha, fg and bg of the current batch.
"""

conv_out = inputs['conv_out']
img = inputs['merged']
two_channel_trimap = inputs['two_channel_trimap']
conv5 = conv_out[-1]
input_size = conv5.size()
ppm_out = [conv5]
for pool_scale in self.ppm:
ppm_out.append(
nn.functional.interpolate(
pool_scale(conv5), (input_size[2], input_size[3]),
mode='bilinear',
align_corners=self.align_corners))
ppm_out = torch.cat(ppm_out, 1)
x = self.conv_up1(ppm_out)

x = torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)

x = torch.cat((x, conv_out[-4]), 1)

x = self.conv_up2(x)
x = torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)

x = torch.cat((x, conv_out[-5]), 1)
x = self.conv_up3(x)

x = torch.nn.functional.interpolate(
x,
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners)

x = torch.cat((x, conv_out[-6][:, :3], img, two_channel_trimap), 1)
output = self.conv_up4(x)
alpha = torch.clamp(output[:, 0:1], 0, 1)
F = torch.sigmoid(output[:, 1:4])
B = torch.sigmoid(output[:, 4:7])

return alpha, F, B
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .deepfill_encoder import DeepFillEncoder
from .fba_encoder import FBAResnetDilated
from .gl_encoder import GLEncoder
from .indexnet_encoder import (DepthwiseIndexBlock, HolisticIndexBlock,
IndexNetEncoder)
Expand All @@ -9,5 +10,5 @@
__all__ = [
'GLEncoder', 'VGG16', 'ResNetEnc', 'HolisticIndexBlock',
'DepthwiseIndexBlock', 'ResShortcutEnc', 'PConvEncoder', 'DeepFillEncoder',
'IndexNetEncoder', 'ResGCAEncoder'
'IndexNetEncoder', 'ResGCAEncoder', 'FBAResnetDilated'
]
44 changes: 44 additions & 0 deletions mmedit/models/backbones/encoder_decoders/encoders/fba_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from mmedit.models.registry import COMPONENTS
from .resnet import ResNet


@COMPONENTS.register_module()
class FBAResnetDilated(ResNet):
"""ResNet-based encoder for FBA image matting."""

def forward(self, x):
"""Forward function.

Args:
x (Tensor): Input tensor with shape (N, C, H, W).

Returns:
Tensor: Output tensor.
"""
# x: (merged_t, trimap_t, two_channel_trimap,merged)
# t refers to tranformed.
two_channel_trimap = x[:, 9:11]
merged = x[:, 11:14]
x = x[:, 0:11, ...]
conv_out = [x]
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.activate(x)
conv_out.append(x)
x = self.maxpool(x)
x = self.layer1(x)
conv_out.append(x)
x = self.layer2(x)
conv_out.append(x)
x = self.layer3(x)
conv_out.append(x)
x = self.layer4(x)
conv_out.append(x)
return {
'conv_out': conv_out,
'merged': merged,
'two_channel_trimap': two_channel_trimap
}
Loading