diff --git a/mmedit/models/components/__init__.py b/mmedit/models/components/__init__.py index ce83d04cc2..db155cc0bd 100644 --- a/mmedit/models/components/__init__.py +++ b/mmedit/models/components/__init__.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. from .discriminators import (DeepFillv1Discriminators, GLDiscs, ModifiedVGG, - MultiLayerDiscriminator, PatchDiscriminator) + MultiLayerDiscriminator, PatchDiscriminator, + UNetDiscriminatorWithSpectralNorm) from .refiners import DeepFillRefiner, PlainRefiner from .stylegan2 import StyleGAN2Discriminator, StyleGANv2Generator __all__ = [ 'PlainRefiner', 'GLDiscs', 'ModifiedVGG', 'MultiLayerDiscriminator', 'DeepFillv1Discriminators', 'DeepFillRefiner', 'PatchDiscriminator', - 'StyleGAN2Discriminator', 'StyleGANv2Generator' + 'StyleGAN2Discriminator', 'StyleGANv2Generator', + 'UNetDiscriminatorWithSpectralNorm' ] diff --git a/mmedit/models/components/discriminators/__init__.py b/mmedit/models/components/discriminators/__init__.py index d533651fe6..efcfb95034 100644 --- a/mmedit/models/components/discriminators/__init__.py +++ b/mmedit/models/components/discriminators/__init__.py @@ -6,8 +6,10 @@ from .multi_layer_disc import MultiLayerDiscriminator from .patch_disc import PatchDiscriminator from .ttsr_disc import TTSRDiscriminator +from .unet_disc import UNetDiscriminatorWithSpectralNorm __all__ = [ 'GLDiscs', 'ModifiedVGG', 'MultiLayerDiscriminator', 'TTSRDiscriminator', - 'DeepFillv1Discriminators', 'PatchDiscriminator', 'LightCNN' + 'DeepFillv1Discriminators', 'PatchDiscriminator', 'LightCNN', + 'UNetDiscriminatorWithSpectralNorm' ] diff --git a/mmedit/models/components/discriminators/unet_disc.py b/mmedit/models/components/discriminators/unet_disc.py new file mode 100644 index 0000000000..29465823ca --- /dev/null +++ b/mmedit/models/components/discriminators/unet_disc.py @@ -0,0 +1,111 @@ +import torch.nn as nn +from mmcv.runner import load_checkpoint +from torch.nn.utils import spectral_norm + +from mmedit.models.registry import COMPONENTS +from mmedit.utils import get_root_logger + + +@COMPONENTS.register_module() +class UNetDiscriminatorWithSpectralNorm(nn.Module): + """A U-Net discriminator with spectral normalization. + + Args: + in_channels (int): Channel number of the input. + mid_channels (int, optional): Channel number of the intermediate + features. Default: 64. + skip_connection (bool, optional): Whether to use skip connection. + Default: True. + """ + + def __init__(self, in_channels, mid_channels=64, skip_connection=True): + + super().__init__() + + self.skip_connection = skip_connection + + self.conv_0 = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1) + + # downsample + self.conv_1 = spectral_norm( + nn.Conv2d(mid_channels, mid_channels * 2, 4, 2, 1, bias=False)) + self.conv_2 = spectral_norm( + nn.Conv2d(mid_channels * 2, mid_channels * 4, 4, 2, 1, bias=False)) + self.conv_3 = spectral_norm( + nn.Conv2d(mid_channels * 4, mid_channels * 8, 4, 2, 1, bias=False)) + + # upsample + self.conv_4 = spectral_norm( + nn.Conv2d(mid_channels * 8, mid_channels * 4, 3, 1, 1, bias=False)) + self.conv_5 = spectral_norm( + nn.Conv2d(mid_channels * 4, mid_channels * 2, 3, 1, 1, bias=False)) + self.conv_6 = spectral_norm( + nn.Conv2d(mid_channels * 2, mid_channels, 3, 1, 1, bias=False)) + + # final layers + self.conv_7 = spectral_norm( + nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=False)) + self.conv_8 = spectral_norm( + nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=False)) + self.conv_9 = nn.Conv2d(mid_channels, 1, 3, 1, 1) + + self.upsample = nn.Upsample( + scale_factor=2, mode='bilinear', align_corners=False) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, img): + """Forward function. + + Args: + img (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + + feat_0 = self.lrelu(self.conv_0(img)) + + # downsample + feat_1 = self.lrelu(self.conv_1(feat_0)) + feat_2 = self.lrelu(self.conv_2(feat_1)) + feat_3 = self.lrelu(self.conv_3(feat_2)) + + # upsample + feat_3 = self.upsample(feat_3) + feat_4 = self.lrelu(self.conv_4(feat_3)) + if self.skip_connection: + feat_4 = feat_4 + feat_2 + + feat_4 = self.upsample(feat_4) + feat_5 = self.lrelu(self.conv_5(feat_4)) + if self.skip_connection: + feat_5 = feat_5 + feat_1 + + feat_5 = self.upsample(feat_5) + feat_6 = self.lrelu(self.conv_6(feat_5)) + if self.skip_connection: + feat_6 = feat_6 + feat_0 + + # final layers + out = self.lrelu(self.conv_7(feat_6)) + out = self.lrelu(self.conv_8(out)) + + return self.conv_9(out) + + def init_weights(self, pretrained=None, strict=True): + """Init weights for models. + + Args: + pretrained (str, optional): Path for pretrained weights. If given + None, pretrained weights will not be loaded. Defaults to None. + strict (boo, optional): Whether strictly load the pretrained model. + Defaults to True. + """ + + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=strict, logger=logger) + elif pretrained is not None: # Use PyTorch default initialization. + raise TypeError(f'"pretrained" must be a str or None. ' + f'But received {type(pretrained)}.') diff --git a/tests/test_models/test_components/test_discriminators/test_unet_disc.py b/tests/test_models/test_components/test_discriminators/test_unet_disc.py new file mode 100644 index 0000000000..78c14c7ced --- /dev/null +++ b/tests/test_models/test_components/test_discriminators/test_unet_disc.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmedit.models.components import UNetDiscriminatorWithSpectralNorm + + +def test_unet_disc_with_spectral_norm(): + # cpu + disc = UNetDiscriminatorWithSpectralNorm(in_channels=3) + img = torch.randn(1, 3, 16, 16) + disc(img) + + with pytest.raises(TypeError): + # pretrained must be a string path + disc.init_weights(pretrained=233) + + # cuda + if torch.cuda.is_available(): + disc = disc.cuda() + img = img.cuda() + disc(img) + + with pytest.raises(TypeError): + # pretrained must be a string path + disc.init_weights(pretrained=233)