-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GLEAN #296
Add GLEAN #296
Changes from 35 commits
c0dc3d3
5e107e0
ae02181
8cc6c8b
c9781e8
7d3c668
d6094e4
0693e55
d2bf8d3
b5b2efd
ca39f10
a7cb67b
34b1d8f
cb735af
feb7f7c
f9bf419
112e140
1442b39
a1fd615
8815e7f
f257964
ffa651b
326fd3b
37d0fe9
a35bbf9
96ffe92
06fb8ce
e53d097
cb4c4af
dbb0212
0de7d9a
830ec65
2ac5e59
e279361
604c102
6c09470
854e14a
b36c2bd
054cf83
d1a05e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,315 @@ | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
from mmcv.runner import load_checkpoint | ||
|
||
from mmedit.models.backbones.sr_backbones.rrdb_net import RRDB | ||
from mmedit.models.builder import build_component | ||
from mmedit.models.common import PixelShufflePack, make_layer | ||
from mmedit.models.registry import BACKBONES | ||
from mmedit.utils import get_root_logger | ||
|
||
|
||
@BACKBONES.register_module() | ||
class GLEANStyleGANv2(nn.Module): | ||
r"""GLEAN (using StyleGANv2) architecture for super-resolution. | ||
|
||
Paper: | ||
GLEAN: Generative Latent Bank for Large-Factor Image Super-Resolution, | ||
CVPR, 2021 | ||
|
||
This method makes use of StyleGAN2 and hence the arguments mostly follow | ||
that in 'StyleGAN2v2Generator'. | ||
|
||
In StyleGAN2, we use a static architecture composing of a style mapping | ||
module and number of covolutional style blocks. More details can be found | ||
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020. | ||
|
||
You can load pretrained model through passing information into | ||
``pretrained`` argument. We have already offered offical weights as | ||
follows: | ||
|
||
- styelgan2-ffhq-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa | ||
- stylegan2-horse-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa | ||
- stylegan2-car-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa | ||
- styelgan2-cat-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa | ||
- stylegan2-church-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa | ||
|
||
If you want to load the ema model, you can just use following codes: | ||
|
||
.. code-block:: python | ||
|
||
# ckpt_http is one of the valid path from http source | ||
generator = StyleGANv2Generator(1024, 512, | ||
pretrained=dict( | ||
ckpt_path=ckpt_http, | ||
prefix='generator_ema')) | ||
|
||
Of course, you can also download the checkpoint in advance and set | ||
``ckpt_path`` with local path. If you just want to load the original | ||
generator (not the ema model), please set the prefix with 'generator'. | ||
|
||
Note that our implementation allows to generate BGR image, while the | ||
original StyleGAN2 outputs RGB images by default. Thus, we provide | ||
``bgr2rgb`` argument to convert the image space. | ||
|
||
Args: | ||
in_size (int): The size of the input image. | ||
out_size (int): The output size of the StyleGAN2 generator. | ||
style_channels (int): The number of channels for style code. | ||
num_mlps (int, optional): The number of MLP layers. Defaults to 8. | ||
channel_multiplier (int, optional): The mulitiplier factor for the | ||
channel number. Defaults to 2. | ||
blur_kernel (list, optional): The blurry kernel. Defaults | ||
to [1, 3, 3, 1]. | ||
lr_mlp (float, optional): The learning rate for the style mapping | ||
layer. Defaults to 0.01. | ||
default_style_mode (str, optional): The default mode of style mixing. | ||
In training, we defaultly adopt mixing style mode. However, in the | ||
evaluation, we use 'single' style mode. `['mix', 'single']` are | ||
currently supported. Defaults to 'mix'. | ||
eval_style_mode (str, optional): The evaluation mode of style mixing. | ||
Defaults to 'single'. | ||
mix_prob (float, optional): Mixing probabilty. The value should be | ||
in range of [0, 1]. Defaults to 0.9. | ||
pretrained (dict | None, optional): Information for pretained models. | ||
The necessary key is 'ckpt_path'. Besides, you can also provide | ||
'prefix' to load the generator part from the whole state dict. | ||
Defaults to None. | ||
bgr2rgb (bool, optional): Whether to flip the image channel dimension. | ||
Defaults to False. | ||
""" | ||
|
||
def __init__(self, | ||
in_size, | ||
out_size, | ||
style_channels, | ||
num_mlps=8, | ||
channel_multiplier=2, | ||
blur_kernel=[1, 3, 3, 1], | ||
lr_mlp=0.01, | ||
default_style_mode='mix', | ||
eval_style_mode='single', | ||
mix_prob=0.9, | ||
pretrained=None, | ||
bgr2rgb=False): | ||
|
||
super().__init__() | ||
|
||
# latent bank (StyleGANv2), with weights being fixed | ||
self.generator = build_component( | ||
dict( | ||
type='StyleGANv2Generator', | ||
out_size=out_size, | ||
style_channels=style_channels, | ||
num_mlps=num_mlps, | ||
channel_multiplier=channel_multiplier, | ||
blur_kernel=blur_kernel, | ||
lr_mlp=lr_mlp, | ||
default_style_mode=default_style_mode, | ||
eval_style_mode=eval_style_mode, | ||
mix_prob=mix_prob, | ||
pretrained=pretrained, | ||
bgr2rgb=bgr2rgb)) | ||
self.generator.requires_grad_(False) | ||
|
||
self.in_size = in_size | ||
self.style_channels = style_channels | ||
channels = self.generator.channels | ||
|
||
# encoder | ||
num_styles = int(math.log2(out_size)) * 2 - 2 | ||
encoder_res = [2**i for i in range(int(math.log2(in_size)), 1, -1)] | ||
self.encoder = nn.ModuleList() | ||
self.encoder.append( | ||
nn.Sequential( | ||
RRDBFeatureExtractor(3, 64, num_blocks=23), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe adding |
||
nn.Conv2d(64, channels[in_size], 3, 1, 1, bias=True), | ||
nn.LeakyReLU(negative_slope=0.2, inplace=True))) | ||
for res in encoder_res: | ||
in_channels = channels[res] | ||
if res > 4: | ||
out_channels = channels[res // 2] | ||
block = nn.Sequential( | ||
nn.Conv2d(in_channels, out_channels, 3, 2, 1, bias=True), | ||
nn.LeakyReLU(negative_slope=0.2, inplace=True), | ||
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True), | ||
nn.LeakyReLU(negative_slope=0.2, inplace=True)) | ||
else: | ||
block = nn.Sequential( | ||
nn.Conv2d(in_channels, in_channels, 3, 1, 1, bias=True), | ||
nn.LeakyReLU(negative_slope=0.2, inplace=True), | ||
nn.Flatten(), | ||
nn.Linear(16 * in_channels, num_styles * style_channels)) | ||
self.encoder.append(block) | ||
|
||
# additional modules for StyleGANv2 | ||
self.fusion_out = nn.ModuleList() | ||
self.fusion_skip = nn.ModuleList() | ||
for res in encoder_res[::-1]: | ||
num_channels = channels[res] | ||
self.fusion_out.append( | ||
nn.Conv2d(num_channels * 2, num_channels, 3, 1, 1, bias=True)) | ||
self.fusion_skip.append( | ||
nn.Conv2d(num_channels + 3, 3, 3, 1, 1, bias=True)) | ||
|
||
# decoder | ||
decoder_res = [ | ||
2**i for i in range( | ||
int(math.log2(in_size)), int(math.log2(out_size) + 1)) | ||
] | ||
self.decoder = nn.ModuleList() | ||
for res in decoder_res: | ||
if res == in_size: | ||
in_channels = channels[res] | ||
else: | ||
in_channels = 2 * channels[res] | ||
|
||
if res < out_size: | ||
out_channels = channels[res * 2] | ||
self.decoder.append( | ||
PixelShufflePack( | ||
in_channels, out_channels, 2, upsample_kernel=3)) | ||
else: | ||
self.decoder.append( | ||
nn.Sequential( | ||
nn.Conv2d(in_channels, 64, 3, 1, 1), | ||
nn.LeakyReLU(negative_slope=0.2, inplace=True), | ||
nn.Conv2d(64, 3, 3, 1, 1))) | ||
|
||
def forward(self, lr): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Named |
||
"""Forward function. | ||
|
||
Args: | ||
lr (Tensor): Input LR image with shape (n, c, h, w). | ||
|
||
Returns: | ||
Tensor: Output HR image. | ||
""" | ||
|
||
n, c, h, w = lr.size() | ||
if h != self.in_size or w != self.in_size: | ||
raise AssertionError( | ||
f'Spatial resolution must equal in_size ({self.in_size}).' | ||
f' Got ({h}, {w}).') | ||
|
||
# encoder | ||
feat = lr | ||
encoder_features = [] | ||
for block in self.encoder: | ||
feat = block(feat) | ||
encoder_features.append(feat) | ||
encoder_features = encoder_features[::-1] | ||
|
||
latent = encoder_features[0].view(lr.size(0), -1, self.style_channels) | ||
encoder_features = encoder_features[1:] | ||
|
||
# generator | ||
injected_noise = [ | ||
getattr(self.generator, f'injected_noise_{i}') | ||
for i in range(self.generator.num_injected_noises) | ||
] | ||
# 4x4 stage | ||
out = self.generator.constant_input(latent) | ||
out = self.generator.conv1(out, latent[:, 0], noise=injected_noise[0]) | ||
skip = self.generator.to_rgb1(out, latent[:, 1]) | ||
|
||
_index = 1 | ||
|
||
# 8x8 ---> higher res | ||
generator_features = [] | ||
for up_conv, conv, noise1, noise2, to_rgb in zip( | ||
self.generator.convs[::2], self.generator.convs[1::2], | ||
injected_noise[1::2], injected_noise[2::2], | ||
self.generator.to_rgbs): | ||
|
||
# feature fusion by channel-wise concatenation | ||
if out.size(2) <= self.in_size: | ||
fusion_index = (_index - 1) // 2 | ||
feat = encoder_features[fusion_index] | ||
|
||
out = torch.cat([out, feat], dim=1) | ||
out = self.fusion_out[fusion_index](out) | ||
|
||
skip = torch.cat([skip, feat], dim=1) | ||
skip = self.fusion_skip[fusion_index](skip) | ||
|
||
# original StyleGAN operations | ||
out = up_conv(out, latent[:, _index], noise=noise1) | ||
out = conv(out, latent[:, _index + 1], noise=noise2) | ||
skip = to_rgb(out, latent[:, _index + 2], skip) | ||
|
||
# store features for decoder | ||
if out.size(2) > self.in_size: | ||
generator_features.append(out) | ||
|
||
_index += 2 | ||
|
||
# decoder | ||
hr = encoder_features[-1] | ||
for i, block in enumerate(self.decoder): | ||
if i > 0: | ||
hr = torch.cat([hr, generator_features[i - 1]], dim=1) | ||
hr = block(hr) | ||
|
||
return hr | ||
|
||
def init_weights(self, pretrained=None, strict=True): | ||
"""Init weights for models. | ||
|
||
Args: | ||
pretrained (str, optional): Path for pretrained weights. If given | ||
None, pretrained weights will not be loaded. Defaults to None. | ||
strict (boo, optional): Whether strictly load the pretrained model. | ||
Defaults to True. | ||
""" | ||
if isinstance(pretrained, str): | ||
logger = get_root_logger() | ||
load_checkpoint(self, pretrained, strict=strict, logger=logger) | ||
elif pretrained is not None: | ||
raise TypeError(f'"pretrained" must be a str or None. ' | ||
f'But received {type(pretrained)}.') | ||
|
||
|
||
class RRDBFeatureExtractor(nn.Module): | ||
"""Feature extractor composed of Residual-in-Residual Dense Blocks (RRDBs). | ||
|
||
It is equivalent to ESRGAN with the upsampling module removed. | ||
|
||
Args: | ||
in_channels (int): Channel number of inputs. | ||
mid_channels (int): Channel number of intermediate features. | ||
Default: 64 | ||
num_blocks (int): Block number in the trunk network. Default: 23 | ||
growth_channels (int): Channels for each growth. Default: 32. | ||
""" | ||
|
||
def __init__(self, | ||
in_channels=3, | ||
mid_channels=64, | ||
num_blocks=23, | ||
growth_channels=32): | ||
|
||
super().__init__() | ||
|
||
self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) | ||
self.body = make_layer( | ||
RRDB, | ||
num_blocks, | ||
mid_channels=mid_channels, | ||
growth_channels=growth_channels) | ||
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) | ||
|
||
def forward(self, x): | ||
"""Forward function. | ||
|
||
Args: | ||
x (Tensor): Input tensor with shape (n, c, h, w). | ||
|
||
Returns: | ||
Tensor: Forward results. | ||
""" | ||
|
||
feat = self.conv_first(x) | ||
return feat + self.conv_body(self.body(feat)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
from .discriminators import (DeepFillv1Discriminators, GLDiscs, ModifiedVGG, | ||
MultiLayerDiscriminator, PatchDiscriminator) | ||
from .refiners import DeepFillRefiner, PlainRefiner | ||
from .stylegan2 import StyleGAN2Discriminator, StyleGANv2Generator | ||
|
||
__all__ = [ | ||
'PlainRefiner', 'GLDiscs', 'ModifiedVGG', 'MultiLayerDiscriminator', | ||
'DeepFillv1Discriminators', 'DeepFillRefiner', 'PatchDiscriminator' | ||
'DeepFillv1Discriminators', 'DeepFillRefiner', 'PatchDiscriminator', | ||
'StyleGAN2Discriminator', 'StyleGANv2Generator' | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .generator_discriminator import (StyleGAN2Discriminator, | ||
StyleGANv2Generator) | ||
|
||
__all__ = ['StyleGANv2Generator', 'StyleGAN2Discriminator'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency, we use np for math ops
math.log2 -> np.log2