forked from open-mmlab/mmagic
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add UNetDiscriminatorWithSpectralNorm (open-mmlab#605)
* Add UNetDiscriminatorWithSpectralNorm * Add unittest
- Loading branch information
1 parent
b15f2f4
commit 9b17348
Showing
4 changed files
with
144 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .discriminators import (DeepFillv1Discriminators, GLDiscs, ModifiedVGG, | ||
MultiLayerDiscriminator, PatchDiscriminator) | ||
MultiLayerDiscriminator, PatchDiscriminator, | ||
UNetDiscriminatorWithSpectralNorm) | ||
from .refiners import DeepFillRefiner, PlainRefiner | ||
from .stylegan2 import StyleGAN2Discriminator, StyleGANv2Generator | ||
|
||
__all__ = [ | ||
'PlainRefiner', 'GLDiscs', 'ModifiedVGG', 'MultiLayerDiscriminator', | ||
'DeepFillv1Discriminators', 'DeepFillRefiner', 'PatchDiscriminator', | ||
'StyleGAN2Discriminator', 'StyleGANv2Generator' | ||
'StyleGAN2Discriminator', 'StyleGANv2Generator', | ||
'UNetDiscriminatorWithSpectralNorm' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import torch.nn as nn | ||
from mmcv.runner import load_checkpoint | ||
from torch.nn.utils import spectral_norm | ||
|
||
from mmedit.models.registry import COMPONENTS | ||
from mmedit.utils import get_root_logger | ||
|
||
|
||
@COMPONENTS.register_module() | ||
class UNetDiscriminatorWithSpectralNorm(nn.Module): | ||
"""A U-Net discriminator with spectral normalization. | ||
Args: | ||
in_channels (int): Channel number of the input. | ||
mid_channels (int, optional): Channel number of the intermediate | ||
features. Default: 64. | ||
skip_connection (bool, optional): Whether to use skip connection. | ||
Default: True. | ||
""" | ||
|
||
def __init__(self, in_channels, mid_channels=64, skip_connection=True): | ||
|
||
super().__init__() | ||
|
||
self.skip_connection = skip_connection | ||
|
||
self.conv_0 = nn.Conv2d( | ||
in_channels, mid_channels, kernel_size=3, stride=1, padding=1) | ||
|
||
# downsample | ||
self.conv_1 = spectral_norm( | ||
nn.Conv2d(mid_channels, mid_channels * 2, 4, 2, 1, bias=False)) | ||
self.conv_2 = spectral_norm( | ||
nn.Conv2d(mid_channels * 2, mid_channels * 4, 4, 2, 1, bias=False)) | ||
self.conv_3 = spectral_norm( | ||
nn.Conv2d(mid_channels * 4, mid_channels * 8, 4, 2, 1, bias=False)) | ||
|
||
# upsample | ||
self.conv_4 = spectral_norm( | ||
nn.Conv2d(mid_channels * 8, mid_channels * 4, 3, 1, 1, bias=False)) | ||
self.conv_5 = spectral_norm( | ||
nn.Conv2d(mid_channels * 4, mid_channels * 2, 3, 1, 1, bias=False)) | ||
self.conv_6 = spectral_norm( | ||
nn.Conv2d(mid_channels * 2, mid_channels, 3, 1, 1, bias=False)) | ||
|
||
# final layers | ||
self.conv_7 = spectral_norm( | ||
nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=False)) | ||
self.conv_8 = spectral_norm( | ||
nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=False)) | ||
self.conv_9 = nn.Conv2d(mid_channels, 1, 3, 1, 1) | ||
|
||
self.upsample = nn.Upsample( | ||
scale_factor=2, mode='bilinear', align_corners=False) | ||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) | ||
|
||
def forward(self, img): | ||
"""Forward function. | ||
Args: | ||
img (Tensor): Input tensor with shape (n, c, h, w). | ||
Returns: | ||
Tensor: Forward results. | ||
""" | ||
|
||
feat_0 = self.lrelu(self.conv_0(img)) | ||
|
||
# downsample | ||
feat_1 = self.lrelu(self.conv_1(feat_0)) | ||
feat_2 = self.lrelu(self.conv_2(feat_1)) | ||
feat_3 = self.lrelu(self.conv_3(feat_2)) | ||
|
||
# upsample | ||
feat_3 = self.upsample(feat_3) | ||
feat_4 = self.lrelu(self.conv_4(feat_3)) | ||
if self.skip_connection: | ||
feat_4 = feat_4 + feat_2 | ||
|
||
feat_4 = self.upsample(feat_4) | ||
feat_5 = self.lrelu(self.conv_5(feat_4)) | ||
if self.skip_connection: | ||
feat_5 = feat_5 + feat_1 | ||
|
||
feat_5 = self.upsample(feat_5) | ||
feat_6 = self.lrelu(self.conv_6(feat_5)) | ||
if self.skip_connection: | ||
feat_6 = feat_6 + feat_0 | ||
|
||
# final layers | ||
out = self.lrelu(self.conv_7(feat_6)) | ||
out = self.lrelu(self.conv_8(out)) | ||
|
||
return self.conv_9(out) | ||
|
||
def init_weights(self, pretrained=None, strict=True): | ||
"""Init weights for models. | ||
Args: | ||
pretrained (str, optional): Path for pretrained weights. If given | ||
None, pretrained weights will not be loaded. Defaults to None. | ||
strict (boo, optional): Whether strictly load the pretrained model. | ||
Defaults to True. | ||
""" | ||
|
||
if isinstance(pretrained, str): | ||
logger = get_root_logger() | ||
load_checkpoint(self, pretrained, strict=strict, logger=logger) | ||
elif pretrained is not None: # Use PyTorch default initialization. | ||
raise TypeError(f'"pretrained" must be a str or None. ' | ||
f'But received {type(pretrained)}.') |
26 changes: 26 additions & 0 deletions
26
tests/test_models/test_components/test_discriminators/test_unet_disc.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import pytest | ||
import torch | ||
|
||
from mmedit.models.components import UNetDiscriminatorWithSpectralNorm | ||
|
||
|
||
def test_unet_disc_with_spectral_norm(): | ||
# cpu | ||
disc = UNetDiscriminatorWithSpectralNorm(in_channels=3) | ||
img = torch.randn(1, 3, 16, 16) | ||
disc(img) | ||
|
||
with pytest.raises(TypeError): | ||
# pretrained must be a string path | ||
disc.init_weights(pretrained=233) | ||
|
||
# cuda | ||
if torch.cuda.is_available(): | ||
disc = disc.cuda() | ||
img = img.cuda() | ||
disc(img) | ||
|
||
with pytest.raises(TypeError): | ||
# pretrained must be a string path | ||
disc.init_weights(pretrained=233) |