Skip to content

Commit

Permalink
[Feature] Add UNetDiscriminatorWithSpectralNorm (open-mmlab#605)
Browse files Browse the repository at this point in the history
* Add UNetDiscriminatorWithSpectralNorm

* Add unittest
  • Loading branch information
ckkelvinchan authored Nov 30, 2021
1 parent b15f2f4 commit 9b17348
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 3 deletions.
6 changes: 4 additions & 2 deletions mmedit/models/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .discriminators import (DeepFillv1Discriminators, GLDiscs, ModifiedVGG,
MultiLayerDiscriminator, PatchDiscriminator)
MultiLayerDiscriminator, PatchDiscriminator,
UNetDiscriminatorWithSpectralNorm)
from .refiners import DeepFillRefiner, PlainRefiner
from .stylegan2 import StyleGAN2Discriminator, StyleGANv2Generator

__all__ = [
'PlainRefiner', 'GLDiscs', 'ModifiedVGG', 'MultiLayerDiscriminator',
'DeepFillv1Discriminators', 'DeepFillRefiner', 'PatchDiscriminator',
'StyleGAN2Discriminator', 'StyleGANv2Generator'
'StyleGAN2Discriminator', 'StyleGANv2Generator',
'UNetDiscriminatorWithSpectralNorm'
]
4 changes: 3 additions & 1 deletion mmedit/models/components/discriminators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from .multi_layer_disc import MultiLayerDiscriminator
from .patch_disc import PatchDiscriminator
from .ttsr_disc import TTSRDiscriminator
from .unet_disc import UNetDiscriminatorWithSpectralNorm

__all__ = [
'GLDiscs', 'ModifiedVGG', 'MultiLayerDiscriminator', 'TTSRDiscriminator',
'DeepFillv1Discriminators', 'PatchDiscriminator', 'LightCNN'
'DeepFillv1Discriminators', 'PatchDiscriminator', 'LightCNN',
'UNetDiscriminatorWithSpectralNorm'
]
111 changes: 111 additions & 0 deletions mmedit/models/components/discriminators/unet_disc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch.nn as nn
from mmcv.runner import load_checkpoint
from torch.nn.utils import spectral_norm

from mmedit.models.registry import COMPONENTS
from mmedit.utils import get_root_logger


@COMPONENTS.register_module()
class UNetDiscriminatorWithSpectralNorm(nn.Module):
"""A U-Net discriminator with spectral normalization.
Args:
in_channels (int): Channel number of the input.
mid_channels (int, optional): Channel number of the intermediate
features. Default: 64.
skip_connection (bool, optional): Whether to use skip connection.
Default: True.
"""

def __init__(self, in_channels, mid_channels=64, skip_connection=True):

super().__init__()

self.skip_connection = skip_connection

self.conv_0 = nn.Conv2d(
in_channels, mid_channels, kernel_size=3, stride=1, padding=1)

# downsample
self.conv_1 = spectral_norm(
nn.Conv2d(mid_channels, mid_channels * 2, 4, 2, 1, bias=False))
self.conv_2 = spectral_norm(
nn.Conv2d(mid_channels * 2, mid_channels * 4, 4, 2, 1, bias=False))
self.conv_3 = spectral_norm(
nn.Conv2d(mid_channels * 4, mid_channels * 8, 4, 2, 1, bias=False))

# upsample
self.conv_4 = spectral_norm(
nn.Conv2d(mid_channels * 8, mid_channels * 4, 3, 1, 1, bias=False))
self.conv_5 = spectral_norm(
nn.Conv2d(mid_channels * 4, mid_channels * 2, 3, 1, 1, bias=False))
self.conv_6 = spectral_norm(
nn.Conv2d(mid_channels * 2, mid_channels, 3, 1, 1, bias=False))

# final layers
self.conv_7 = spectral_norm(
nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=False))
self.conv_8 = spectral_norm(
nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=False))
self.conv_9 = nn.Conv2d(mid_channels, 1, 3, 1, 1)

self.upsample = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

def forward(self, img):
"""Forward function.
Args:
img (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""

feat_0 = self.lrelu(self.conv_0(img))

# downsample
feat_1 = self.lrelu(self.conv_1(feat_0))
feat_2 = self.lrelu(self.conv_2(feat_1))
feat_3 = self.lrelu(self.conv_3(feat_2))

# upsample
feat_3 = self.upsample(feat_3)
feat_4 = self.lrelu(self.conv_4(feat_3))
if self.skip_connection:
feat_4 = feat_4 + feat_2

feat_4 = self.upsample(feat_4)
feat_5 = self.lrelu(self.conv_5(feat_4))
if self.skip_connection:
feat_5 = feat_5 + feat_1

feat_5 = self.upsample(feat_5)
feat_6 = self.lrelu(self.conv_6(feat_5))
if self.skip_connection:
feat_6 = feat_6 + feat_0

# final layers
out = self.lrelu(self.conv_7(feat_6))
out = self.lrelu(self.conv_8(out))

return self.conv_9(out)

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""

if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is not None: # Use PyTorch default initialization.
raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmedit.models.components import UNetDiscriminatorWithSpectralNorm


def test_unet_disc_with_spectral_norm():
# cpu
disc = UNetDiscriminatorWithSpectralNorm(in_channels=3)
img = torch.randn(1, 3, 16, 16)
disc(img)

with pytest.raises(TypeError):
# pretrained must be a string path
disc.init_weights(pretrained=233)

# cuda
if torch.cuda.is_available():
disc = disc.cuda()
img = img.cuda()
disc(img)

with pytest.raises(TypeError):
# pretrained must be a string path
disc.init_weights(pretrained=233)

0 comments on commit 9b17348

Please sign in to comment.