-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backbone of FBA. #215
Merged
Merged
Backbone of FBA. #215
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
8123792
Backbone of FBA.
7db053d
Init.
3da2f0d
Doc string for forward.
33e5100
Doc string of Init.
c05f367
Modified API.
eea8b2f
FBAencoder.
ebcb92a
Tiny.
61d8c66
Doc string.
3837b65
Decoder.
d68dfb0
Init.
beff29c
General Res.
595ee42
Added test.
fb8bc23
Added tests.
fc0fcaf
Added two_channel_trimap key.
7b79b10
Tiny.
b763b3f
Restore shape.
c412f9b
Tiny.
a845d72
Added utis. May change.
32dca1f
Added FBA mattor.
1fb35c0
Improved.
1e82ab9
Tiny.
35501a3
Modified.
17f8b0e
Tiny.
15b99c1
Tiny.
49f728f
Tiny.
e282240
Delete plugins.
bcc5491
Modified.
c0ce9b8
Tiny.
5bd3944
Mattor.
76b4d23
Tiny.
8edc76f
Tiny.
230709a
Postponed to next branch.
4b6bbeb
Tiny.
fd8954b
Update base_mattor.py
innerlee File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
212 changes: 212 additions & 0 deletions
212
mmedit/models/backbones/encoder_decoders/decoders/fba_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import torch | ||
import torch.nn as nn | ||
from mmcv.cnn import ConvModule, constant_init, kaiming_init | ||
from mmcv.runner import load_checkpoint | ||
from mmcv.utils.parrots_wrapper import _BatchNorm | ||
|
||
from mmedit.models.registry import COMPONENTS | ||
from mmedit.utils import get_root_logger | ||
|
||
|
||
@COMPONENTS.register_module() | ||
class FBADecoder(nn.Module): | ||
"""Decoder for FBA matting. | ||
|
||
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | ||
Module. | ||
in_channels (int): Input channels. | ||
channels (int): Channels after modules, before conv_seg. | ||
conv_cfg (dict|None): Config of conv layers. | ||
norm_cfg (dict|None): Config of norm layers. | ||
act_cfg (dict): Config of activation layers. | ||
align_corners (bool): align_corners argument of F.interpolate. | ||
""" | ||
|
||
def __init__(self, | ||
pool_scales, | ||
in_channels, | ||
channels, | ||
conv_cfg=None, | ||
norm_cfg=dict(type='BN'), | ||
act_cfg=dict(type='ReLU'), | ||
align_corners=False): | ||
super().__init__() | ||
|
||
assert isinstance(pool_scales, (list, tuple)) | ||
# Pyramid Pooling Module | ||
self.pool_scales = pool_scales | ||
self.in_channels = in_channels | ||
self.channels = channels | ||
self.conv_cfg = conv_cfg | ||
self.norm_cfg = norm_cfg | ||
self.act_cfg = act_cfg | ||
self.align_corners = align_corners | ||
self.batch_norm = False | ||
|
||
self.ppm = [] | ||
for scale in self.pool_scales: | ||
self.ppm.append( | ||
nn.Sequential( | ||
nn.AdaptiveAvgPool2d(scale), | ||
*(ConvModule( | ||
self.in_channels, | ||
self.channels, | ||
kernel_size=1, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg).children()))) | ||
self.ppm = nn.ModuleList(self.ppm) | ||
|
||
# Follwed the author's implementation that | ||
# concatenate conv layers described in the supplementary | ||
# material between up operations | ||
self.conv_up1 = nn.Sequential(*(list( | ||
ConvModule( | ||
self.in_channels + len(pool_scales) * 256, | ||
self.channels, | ||
padding=1, | ||
kernel_size=3, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg).children()) + list( | ||
ConvModule( | ||
self.channels, | ||
self.channels, | ||
padding=1, | ||
bias=True, | ||
kernel_size=3, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg).children()))) | ||
|
||
self.conv_up2 = nn.Sequential(*(list( | ||
ConvModule( | ||
self.channels * 2, | ||
self.channels, | ||
padding=1, | ||
kernel_size=3, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg).children()))) | ||
|
||
if (self.norm_cfg['type'] == 'BN'): | ||
d_up3 = 128 | ||
else: | ||
d_up3 = 64 | ||
|
||
self.conv_up3 = nn.Sequential(*(list( | ||
ConvModule( | ||
self.channels + d_up3, | ||
64, | ||
padding=1, | ||
kernel_size=3, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg).children()))) | ||
|
||
self.unpool = nn.MaxUnpool2d(2, stride=2) | ||
|
||
self.conv_up4 = nn.Sequential(*(list( | ||
ConvModule( | ||
64 + 3 + 3 + 2, | ||
32, | ||
padding=1, | ||
kernel_size=3, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
act_cfg=self.act_cfg).children()) + list( | ||
ConvModule( | ||
32, | ||
16, | ||
padding=1, | ||
kernel_size=3, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
act_cfg=self.act_cfg).children()) + list( | ||
ConvModule( | ||
16, | ||
7, | ||
padding=0, | ||
kernel_size=1, | ||
bias=True, | ||
conv_cfg=self.conv_cfg, | ||
act_cfg=None).children()))) | ||
|
||
def init_weights(self, pretrained=None): | ||
"""Init weights for the model. | ||
|
||
Args: | ||
pretrained (str, optional): Path for pretrained weights. If given | ||
None, pretrained weights will not be loaded. Defaults to None. | ||
""" | ||
if isinstance(pretrained, str): | ||
logger = get_root_logger() | ||
load_checkpoint(self, pretrained, strict=False, logger=logger) | ||
elif pretrained is None: | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
kaiming_init(m) | ||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): | ||
constant_init(m, 1) | ||
|
||
else: | ||
raise TypeError('pretrained must be a str or None') | ||
|
||
def forward(self, inputs): | ||
"""Forward function. | ||
Args: | ||
inputs (dict): Output dict of FbaEncoder. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. enocder or decoder? |
||
Returns: | ||
Tensor: Predicted alpha, fg and bg of the current batch. | ||
""" | ||
|
||
conv_out = inputs['conv_out'] | ||
img = inputs['merged'] | ||
two_channel_trimap = inputs['two_channel_trimap'] | ||
conv5 = conv_out[-1] | ||
input_size = conv5.size() | ||
ppm_out = [conv5] | ||
for pool_scale in self.ppm: | ||
ppm_out.append( | ||
nn.functional.interpolate( | ||
pool_scale(conv5), (input_size[2], input_size[3]), | ||
mode='bilinear', | ||
align_corners=self.align_corners)) | ||
ppm_out = torch.cat(ppm_out, 1) | ||
x = self.conv_up1(ppm_out) | ||
|
||
x = torch.nn.functional.interpolate( | ||
x, | ||
scale_factor=2, | ||
mode='bilinear', | ||
align_corners=self.align_corners) | ||
|
||
x = torch.cat((x, conv_out[-4]), 1) | ||
|
||
x = self.conv_up2(x) | ||
x = torch.nn.functional.interpolate( | ||
x, | ||
scale_factor=2, | ||
mode='bilinear', | ||
align_corners=self.align_corners) | ||
|
||
x = torch.cat((x, conv_out[-5]), 1) | ||
x = self.conv_up3(x) | ||
|
||
x = torch.nn.functional.interpolate( | ||
x, | ||
scale_factor=2, | ||
mode='bilinear', | ||
align_corners=self.align_corners) | ||
|
||
x = torch.cat((x, conv_out[-6][:, :3], img, two_channel_trimap), 1) | ||
output = self.conv_up4(x) | ||
alpha = torch.clamp(output[:, 0:1], 0, 1) | ||
F = torch.sigmoid(output[:, 1:4]) | ||
B = torch.sigmoid(output[:, 4:7]) | ||
|
||
return alpha, F, B |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
mmedit/models/backbones/encoder_decoders/encoders/fba_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from mmedit.models.registry import COMPONENTS | ||
from .resnet import ResNet | ||
|
||
|
||
@COMPONENTS.register_module() | ||
class FBAResnetDilated(ResNet): | ||
"""ResNet-based encoder for FBA image matting.""" | ||
|
||
def forward(self, x): | ||
"""Forward function. | ||
|
||
Args: | ||
x (Tensor): Input tensor with shape (N, C, H, W). | ||
|
||
Returns: | ||
Tensor: Output tensor. | ||
""" | ||
# x: (merged_t, trimap_t, two_channel_trimap,merged) | ||
# t refers to tranformed. | ||
two_channel_trimap = x[:, 9:11] | ||
merged = x[:, 11:14] | ||
x = x[:, 0:11, ...] | ||
conv_out = [x] | ||
if self.deep_stem: | ||
x = self.stem(x) | ||
else: | ||
x = self.conv1(x) | ||
x = self.norm1(x) | ||
x = self.activate(x) | ||
conv_out.append(x) | ||
x = self.maxpool(x) | ||
x = self.layer1(x) | ||
conv_out.append(x) | ||
x = self.layer2(x) | ||
conv_out.append(x) | ||
x = self.layer3(x) | ||
conv_out.append(x) | ||
x = self.layer4(x) | ||
conv_out.append(x) | ||
return { | ||
'conv_out': conv_out, | ||
'merged': merged, | ||
'two_channel_trimap': two_channel_trimap | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give credit if the code is modified from the original code