Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GLEAN #296

Merged
merged 40 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
c0dc3d3
clone from MMEditing
ckkelvinchan Mar 26, 2021
5e107e0
add GenerateFrameIndicesForRecurrent
ckkelvinchan Mar 26, 2021
ae02181
add unit test
ckkelvinchan Mar 26, 2021
8cc6c8b
update format
ckkelvinchan Mar 26, 2021
c9781e8
Merge pull request #2 from ckkelvinchan/augmentations
nbei Mar 26, 2021
7d3c668
add stylegan2 with pretrained models
nbei Mar 27, 2021
d6094e4
Add StyleGAN2 components
ckkelvinchan Apr 21, 2021
0693e55
add GLEAN architecture (WIP)
ckkelvinchan Apr 22, 2021
d2bf8d3
GLEAN architecture
ckkelvinchan Apr 23, 2021
b5b2efd
Use build_component to build StyleGANv2
ckkelvinchan Apr 24, 2021
ca39f10
add GenerateFrameIndicesForRecurrent
ckkelvinchan Mar 26, 2021
a7cb67b
add stylegan2 with pretrained models
nbei Mar 27, 2021
34b1d8f
Merge branch 'glean' of https://github.com/ckkelvinchan/mmediting int…
ckkelvinchan Apr 25, 2021
cb735af
Merge branch 'master' of https://github.com/open-mmlab/mmediting into…
ckkelvinchan Apr 25, 2021
feb7f7c
minor revision of architecture
ckkelvinchan Apr 25, 2021
f9bf419
Minor fix
ckkelvinchan Apr 30, 2021
112e140
Add MSELoss for perceptual loss
ckkelvinchan May 5, 2021
1442b39
merge master
ckkelvinchan May 5, 2021
a1fd615
merge master
ckkelvinchan May 5, 2021
8815e7f
Merge branch 'master' of https://github.com/open-mmlab/mmediting into…
ckkelvinchan May 5, 2021
f257964
sort
ckkelvinchan May 5, 2021
ffa651b
Add GLEAN model
ckkelvinchan May 5, 2021
326fd3b
Add unittest
ckkelvinchan May 9, 2021
37d0fe9
Remove pretrained in test
ckkelvinchan May 9, 2021
a35bbf9
Replace by test_srgan for verification
ckkelvinchan May 9, 2021
96ffe92
Change disc to ModifiedVGG
ckkelvinchan May 9, 2021
06fb8ce
Remove init_weights
ckkelvinchan May 9, 2021
e53d097
Remove _load_pretrained_model
ckkelvinchan May 9, 2021
cb4c4af
revert to test_srgan
ckkelvinchan May 9, 2021
dbb0212
Change Discriminator
ckkelvinchan May 9, 2021
0de7d9a
Use original StyleGAN2 discriminator
ckkelvinchan May 9, 2021
830ec65
Use GLEAN as generator
ckkelvinchan May 9, 2021
2ac5e59
Change to GLEAN model
ckkelvinchan May 9, 2021
e279361
minor change
ckkelvinchan May 9, 2021
604c102
remove redundancy in test_glean.py
ckkelvinchan May 9, 2021
6c09470
install
ckkelvinchan May 22, 2021
854e14a
Add unittests
ckkelvinchan May 22, 2021
b36c2bd
Revert to original StyleGAN2 discriminator
ckkelvinchan May 27, 2021
054cf83
Change math.log2 to np.log2
ckkelvinchan May 27, 2021
d1a05e1
rebase master
ckkelvinchan May 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
SimpleEncoderDecoder)
# yapf: enable
from .generation_backbones import ResnetGenerator, UnetGenerator
from .sr_backbones import (EDSR, RDN, SRCNN, BasicVSRNet, EDVRNet, IconVSR,
MSRResNet, RRDBNet, TOFlow)
from .sr_backbones import (EDSR, RDN, SRCNN, BasicVSRNet, EDVRNet,
GLEANStyleGANv2, IconVSR, MSRResNet, RRDBNet,
TOFlow)

__all__ = [
'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder',
Expand All @@ -25,5 +26,5 @@
'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder',
'BasicVSRNet', 'IconVSR'
'BasicVSRNet', 'IconVSR', 'GLEANStyleGANv2'
]
3 changes: 2 additions & 1 deletion mmedit/models/backbones/sr_backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .basicvsr_net import BasicVSRNet
from .edsr import EDSR
from .edvr_net import EDVRNet
from .glean_styleganv2 import GLEANStyleGANv2
from .iconvsr import IconVSR
from .rdn import RDN
from .rrdb_net import RRDBNet
Expand All @@ -10,5 +11,5 @@

__all__ = [
'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN',
'BasicVSRNet', 'IconVSR', 'RDN'
'BasicVSRNet', 'IconVSR', 'RDN', 'GLEANStyleGANv2'
]
315 changes: 315 additions & 0 deletions mmedit/models/backbones/sr_backbones/glean_styleganv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
import math

import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint

from mmedit.models.backbones.sr_backbones.rrdb_net import RRDB
from mmedit.models.builder import build_component
from mmedit.models.common import PixelShufflePack, make_layer
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


@BACKBONES.register_module()
class GLEANStyleGANv2(nn.Module):
r"""GLEAN (using StyleGANv2) architecture for super-resolution.

Paper:
GLEAN: Generative Latent Bank for Large-Factor Image Super-Resolution,
CVPR, 2021

This method makes use of StyleGAN2 and hence the arguments mostly follow
that in 'StyleGAN2v2Generator'.

In StyleGAN2, we use a static architecture composing of a style mapping
module and number of covolutional style blocks. More details can be found
in: Analyzing and Improving the Image Quality of StyleGAN CVPR2020.

You can load pretrained model through passing information into
``pretrained`` argument. We have already offered offical weights as
follows:

- styelgan2-ffhq-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-ffhq-config-f-official_20210327_171224-bce9310c.pth # noqa
- stylegan2-horse-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-horse-config-f-official_20210327_173203-ef3e69ca.pth # noqa
- stylegan2-car-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-car-config-f-official_20210327_172340-8cfe053c.pth # noqa
- styelgan2-cat-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-cat-config-f-official_20210327_172444-15bc485b.pth # noqa
- stylegan2-church-config-f: http://download.openmmlab.com/mmgen/stylegan2/official_weights/stylegan2-church-config-f-official_20210327_172657-1d42b7d1.pth # noqa

If you want to load the ema model, you can just use following codes:

.. code-block:: python

# ckpt_http is one of the valid path from http source
generator = StyleGANv2Generator(1024, 512,
pretrained=dict(
ckpt_path=ckpt_http,
prefix='generator_ema'))

Of course, you can also download the checkpoint in advance and set
``ckpt_path`` with local path. If you just want to load the original
generator (not the ema model), please set the prefix with 'generator'.

Note that our implementation allows to generate BGR image, while the
original StyleGAN2 outputs RGB images by default. Thus, we provide
``bgr2rgb`` argument to convert the image space.

Args:
in_size (int): The size of the input image.
out_size (int): The output size of the StyleGAN2 generator.
style_channels (int): The number of channels for style code.
num_mlps (int, optional): The number of MLP layers. Defaults to 8.
channel_multiplier (int, optional): The mulitiplier factor for the
channel number. Defaults to 2.
blur_kernel (list, optional): The blurry kernel. Defaults
to [1, 3, 3, 1].
lr_mlp (float, optional): The learning rate for the style mapping
layer. Defaults to 0.01.
default_style_mode (str, optional): The default mode of style mixing.
In training, we defaultly adopt mixing style mode. However, in the
evaluation, we use 'single' style mode. `['mix', 'single']` are
currently supported. Defaults to 'mix'.
eval_style_mode (str, optional): The evaluation mode of style mixing.
Defaults to 'single'.
mix_prob (float, optional): Mixing probabilty. The value should be
in range of [0, 1]. Defaults to 0.9.
pretrained (dict | None, optional): Information for pretained models.
The necessary key is 'ckpt_path'. Besides, you can also provide
'prefix' to load the generator part from the whole state dict.
Defaults to None.
bgr2rgb (bool, optional): Whether to flip the image channel dimension.
Defaults to False.
"""

def __init__(self,
in_size,
out_size,
style_channels,
num_mlps=8,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
default_style_mode='mix',
eval_style_mode='single',
mix_prob=0.9,
pretrained=None,
bgr2rgb=False):

super().__init__()

# latent bank (StyleGANv2), with weights being fixed
self.generator = build_component(
dict(
type='StyleGANv2Generator',
out_size=out_size,
style_channels=style_channels,
num_mlps=num_mlps,
channel_multiplier=channel_multiplier,
blur_kernel=blur_kernel,
lr_mlp=lr_mlp,
default_style_mode=default_style_mode,
eval_style_mode=eval_style_mode,
mix_prob=mix_prob,
pretrained=pretrained,
bgr2rgb=bgr2rgb))
self.generator.requires_grad_(False)

self.in_size = in_size
self.style_channels = style_channels
channels = self.generator.channels

# encoder
num_styles = int(math.log2(out_size)) * 2 - 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, we use np for math ops

math.log2 -> np.log2

encoder_res = [2**i for i in range(int(math.log2(in_size)), 1, -1)]
self.encoder = nn.ModuleList()
self.encoder.append(
nn.Sequential(
RRDBFeatureExtractor(3, 64, num_blocks=23),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe adding mid_channels(defaults 64) , in_channels and out_channels to the class (GLEANStyleGANv2) will be better.

nn.Conv2d(64, channels[in_size], 3, 1, 1, bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True)))
for res in encoder_res:
in_channels = channels[res]
if res > 4:
out_channels = channels[res // 2]
block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 2, 1, bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True))
else:
block = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, 1, 1, bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Flatten(),
nn.Linear(16 * in_channels, num_styles * style_channels))
self.encoder.append(block)

# additional modules for StyleGANv2
self.fusion_out = nn.ModuleList()
self.fusion_skip = nn.ModuleList()
for res in encoder_res[::-1]:
num_channels = channels[res]
self.fusion_out.append(
nn.Conv2d(num_channels * 2, num_channels, 3, 1, 1, bias=True))
self.fusion_skip.append(
nn.Conv2d(num_channels + 3, 3, 3, 1, 1, bias=True))

# decoder
decoder_res = [
2**i for i in range(
int(math.log2(in_size)), int(math.log2(out_size) + 1))
]
self.decoder = nn.ModuleList()
for res in decoder_res:
if res == in_size:
in_channels = channels[res]
else:
in_channels = 2 * channels[res]

if res < out_size:
out_channels = channels[res * 2]
self.decoder.append(
PixelShufflePack(
in_channels, out_channels, 2, upsample_kernel=3))
else:
self.decoder.append(
nn.Sequential(
nn.Conv2d(in_channels, 64, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(64, 3, 3, 1, 1)))

def forward(self, lr):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Named LQ in other models of MMEditing

"""Forward function.

Args:
lr (Tensor): Input LR image with shape (n, c, h, w).

Returns:
Tensor: Output HR image.
"""

n, c, h, w = lr.size()
if h != self.in_size or w != self.in_size:
raise AssertionError(
f'Spatial resolution must equal in_size ({self.in_size}).'
f' Got ({h}, {w}).')

# encoder
feat = lr
encoder_features = []
for block in self.encoder:
feat = block(feat)
encoder_features.append(feat)
encoder_features = encoder_features[::-1]

latent = encoder_features[0].view(lr.size(0), -1, self.style_channels)
encoder_features = encoder_features[1:]

# generator
injected_noise = [
getattr(self.generator, f'injected_noise_{i}')
for i in range(self.generator.num_injected_noises)
]
# 4x4 stage
out = self.generator.constant_input(latent)
out = self.generator.conv1(out, latent[:, 0], noise=injected_noise[0])
skip = self.generator.to_rgb1(out, latent[:, 1])

_index = 1

# 8x8 ---> higher res
generator_features = []
for up_conv, conv, noise1, noise2, to_rgb in zip(
self.generator.convs[::2], self.generator.convs[1::2],
injected_noise[1::2], injected_noise[2::2],
self.generator.to_rgbs):

# feature fusion by channel-wise concatenation
if out.size(2) <= self.in_size:
fusion_index = (_index - 1) // 2
feat = encoder_features[fusion_index]

out = torch.cat([out, feat], dim=1)
out = self.fusion_out[fusion_index](out)

skip = torch.cat([skip, feat], dim=1)
skip = self.fusion_skip[fusion_index](skip)

# original StyleGAN operations
out = up_conv(out, latent[:, _index], noise=noise1)
out = conv(out, latent[:, _index + 1], noise=noise2)
skip = to_rgb(out, latent[:, _index + 2], skip)

# store features for decoder
if out.size(2) > self.in_size:
generator_features.append(out)

_index += 2

# decoder
hr = encoder_features[-1]
for i, block in enumerate(self.decoder):
if i > 0:
hr = torch.cat([hr, generator_features[i - 1]], dim=1)
hr = block(hr)

return hr

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.

Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is not None:
raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')


class RRDBFeatureExtractor(nn.Module):
"""Feature extractor composed of Residual-in-Residual Dense Blocks (RRDBs).

It is equivalent to ESRGAN with the upsampling module removed.

Args:
in_channels (int): Channel number of inputs.
mid_channels (int): Channel number of intermediate features.
Default: 64
num_blocks (int): Block number in the trunk network. Default: 23
growth_channels (int): Channels for each growth. Default: 32.
"""

def __init__(self,
in_channels=3,
mid_channels=64,
num_blocks=23,
growth_channels=32):

super().__init__()

self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
self.body = make_layer(
RRDB,
num_blocks,
mid_channels=mid_channels,
growth_channels=growth_channels)
self.conv_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)

def forward(self, x):
"""Forward function.

Args:
x (Tensor): Input tensor with shape (n, c, h, w).

Returns:
Tensor: Forward results.
"""

feat = self.conv_first(x)
return feat + self.conv_body(self.body(feat))
4 changes: 3 additions & 1 deletion mmedit/models/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .discriminators import (DeepFillv1Discriminators, GLDiscs, ModifiedVGG,
MultiLayerDiscriminator, PatchDiscriminator)
from .refiners import DeepFillRefiner, PlainRefiner
from .stylegan2 import StyleGAN2Discriminator, StyleGANv2Generator

__all__ = [
'PlainRefiner', 'GLDiscs', 'ModifiedVGG', 'MultiLayerDiscriminator',
'DeepFillv1Discriminators', 'DeepFillRefiner', 'PatchDiscriminator'
'DeepFillv1Discriminators', 'DeepFillRefiner', 'PatchDiscriminator',
'StyleGAN2Discriminator', 'StyleGANv2Generator'
]
4 changes: 4 additions & 0 deletions mmedit/models/components/stylegan2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .generator_discriminator import (StyleGAN2Discriminator,
StyleGANv2Generator)

__all__ = ['StyleGANv2Generator', 'StyleGAN2Discriminator']
Loading