-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
971 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
212 changes: 212 additions & 0 deletions
212
mmedit/models/backbones/encoder_decoders/decoders/fba_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import torch | ||
import torch.nn as nn | ||
from mmcv.cnn import ConvModule, constant_init, kaiming_init | ||
from mmcv.runner import load_checkpoint | ||
from mmcv.utils.parrots_wrapper import _BatchNorm | ||
|
||
from mmedit.models.registry import COMPONENTS | ||
from mmedit.utils import get_root_logger | ||
|
||
|
||
@COMPONENTS.register_module() | ||
class FBADecoder(nn.Module): | ||
"""Decoder for FBA matting. | ||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | ||
Module. | ||
in_channels (int): Input channels. | ||
channels (int): Channels after modules, before conv_seg. | ||
conv_cfg (dict|None): Config of conv layers. | ||
norm_cfg (dict|None): Config of norm layers. | ||
act_cfg (dict): Config of activation layers. | ||
align_corners (bool): align_corners argument of F.interpolate. | ||
""" | ||
|
||
def __init__(self, | ||
pool_scales, | ||
in_channels, | ||
channels, | ||
conv_cfg=None, | ||
norm_cfg=dict(type='BN'), | ||
act_cfg=dict(type='ReLU'), | ||
align_corners=False): | ||
super().__init__() | ||
|
||
assert isinstance(pool_scales, (list, tuple)) | ||
# Pyramid Pooling Module | ||
self.pool_scales = pool_scales | ||
self.in_channels = in_channels | ||
self.channels = channels | ||
self.conv_cfg = conv_cfg | ||
self.norm_cfg = norm_cfg | ||
self.act_cfg = act_cfg | ||
self.align_corners = align_corners | ||
self.batch_norm = False | ||
|
||
self.ppm = [] | ||
for scale in self.pool_scales: | ||
self.ppm.append( | ||
nn.Sequential( | ||
nn.AdaptiveAvgPool2d(scale), | ||
*(ConvModule( | ||
self.in_channels, | ||
self.channels, | ||
kernel_size=1, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg).children()))) | ||
self.ppm = nn.ModuleList(self.ppm) | ||
|
||
# Follwed the author's implementation that | ||
# concatenate conv layers described in the supplementary | ||
# material between up operations | ||
self.conv_up1 = nn.Sequential(*(list( | ||
ConvModule( | ||
self.in_channels + len(pool_scales) * 256, | ||
self.channels, | ||
padding=1, | ||
kernel_size=3, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg).children()) + list( | ||
ConvModule( | ||
self.channels, | ||
self.channels, | ||
padding=1, | ||
bias=True, | ||
kernel_size=3, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg).children()))) | ||
|
||
self.conv_up2 = nn.Sequential(*(list( | ||
ConvModule( | ||
self.channels * 2, | ||
self.channels, | ||
padding=1, | ||
kernel_size=3, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg).children()))) | ||
|
||
if (self.norm_cfg['type'] == 'BN'): | ||
d_up3 = 128 | ||
else: | ||
d_up3 = 64 | ||
|
||
self.conv_up3 = nn.Sequential(*(list( | ||
ConvModule( | ||
self.channels + d_up3, | ||
64, | ||
padding=1, | ||
kernel_size=3, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg).children()))) | ||
|
||
self.unpool = nn.MaxUnpool2d(2, stride=2) | ||
|
||
self.conv_up4 = nn.Sequential(*(list( | ||
ConvModule( | ||
64 + 3 + 3 + 2, | ||
32, | ||
padding=1, | ||
kernel_size=3, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
act_cfg=self.act_cfg).children()) + list( | ||
ConvModule( | ||
32, | ||
16, | ||
padding=1, | ||
kernel_size=3, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
act_cfg=self.act_cfg).children()) + list( | ||
ConvModule( | ||
16, | ||
7, | ||
padding=0, | ||
kernel_size=1, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
act_cfg=None).children()))) | ||
|
||
def init_weights(self, pretrained=None): | ||
"""Init weights for the model. | ||
Args: | ||
pretrained (str, optional): Path for pretrained weights. If given | ||
None, pretrained weights will not be loaded. Defaults to None. | ||
""" | ||
if isinstance(pretrained, str): | ||
logger = get_root_logger() | ||
load_checkpoint(self, pretrained, strict=False, logger=logger) | ||
elif pretrained is None: | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
kaiming_init(m) | ||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): | ||
constant_init(m, 1) | ||
|
||
else: | ||
raise TypeError('pretrained must be a str or None') | ||
|
||
def forward(self, inputs): | ||
"""Forward function. | ||
Args: | ||
inputs (dict): Output dict of FbaEncoder. | ||
Returns: | ||
Tensor: Predicted alpha, fg and bg of the current batch. | ||
""" | ||
|
||
conv_out = inputs['conv_out'] | ||
img = inputs['merged'] | ||
two_channel_trimap = inputs['two_channel_trimap'] | ||
conv5 = conv_out[-1] | ||
input_size = conv5.size() | ||
ppm_out = [conv5] | ||
for pool_scale in self.ppm: | ||
ppm_out.append( | ||
nn.functional.interpolate( | ||
pool_scale(conv5), (input_size[2], input_size[3]), | ||
mode='bilinear', | ||
align_corners=self.align_corners)) | ||
ppm_out = torch.cat(ppm_out, 1) | ||
x = self.conv_up1(ppm_out) | ||
|
||
x = torch.nn.functional.interpolate( | ||
x, | ||
scale_factor=2, | ||
mode='bilinear', | ||
align_corners=self.align_corners) | ||
|
||
x = torch.cat((x, conv_out[-4]), 1) | ||
|
||
x = self.conv_up2(x) | ||
x = torch.nn.functional.interpolate( | ||
x, | ||
scale_factor=2, | ||
mode='bilinear', | ||
align_corners=self.align_corners) | ||
|
||
x = torch.cat((x, conv_out[-5]), 1) | ||
x = self.conv_up3(x) | ||
|
||
x = torch.nn.functional.interpolate( | ||
x, | ||
scale_factor=2, | ||
mode='bilinear', | ||
align_corners=self.align_corners) | ||
|
||
x = torch.cat((x, conv_out[-6][:, :3], img, two_channel_trimap), 1) | ||
output = self.conv_up4(x) | ||
alpha = torch.clamp(output[:, 0:1], 0, 1) | ||
F = torch.sigmoid(output[:, 1:4]) | ||
B = torch.sigmoid(output[:, 4:7]) | ||
|
||
return alpha, F, B |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
mmedit/models/backbones/encoder_decoders/encoders/fba_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from mmedit.models.registry import COMPONENTS | ||
from .resnet import ResNet | ||
|
||
|
||
@COMPONENTS.register_module() | ||
class FBAResnetDilated(ResNet): | ||
"""ResNet-based encoder for FBA image matting.""" | ||
|
||
def forward(self, x): | ||
"""Forward function. | ||
Args: | ||
x (Tensor): Input tensor with shape (N, C, H, W). | ||
Returns: | ||
Tensor: Output tensor. | ||
""" | ||
# x: (merged_t, trimap_t, two_channel_trimap,merged) | ||
# t refers to tranformed. | ||
two_channel_trimap = x[:, 9:11] | ||
merged = x[:, 11:14] | ||
x = x[:, 0:11, ...] | ||
conv_out = [x] | ||
if self.deep_stem: | ||
x = self.stem(x) | ||
else: | ||
x = self.conv1(x) | ||
x = self.norm1(x) | ||
x = self.activate(x) | ||
conv_out.append(x) | ||
x = self.maxpool(x) | ||
x = self.layer1(x) | ||
conv_out.append(x) | ||
x = self.layer2(x) | ||
conv_out.append(x) | ||
x = self.layer3(x) | ||
conv_out.append(x) | ||
x = self.layer4(x) | ||
conv_out.append(x) | ||
return { | ||
'conv_out': conv_out, | ||
'merged': merged, | ||
'two_channel_trimap': two_channel_trimap | ||
} |
Oops, something went wrong.