Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-653] Add new configs of Real BasicVSR #2030

Merged
merged 2 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.

# Please refer to https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta for more details. # noqa
# mmcv >= 2.0.1
# mmengine >= 0.8.0

from mmengine.config import read_base
from mmengine.optim.optimizer import OptimWrapper
from mmengine.runner.loops import IterBasedTrainLoop
from torch.optim.adam import Adam

from mmagic.engine import MultiOptimWrapperConstructor
from mmagic.models.data_preprocessors import DataPreprocessor
from mmagic.models.editors import (RealBasicVSR, RealBasicVSRNet,
UNetDiscriminatorWithSpectralNorm)
from mmagic.models.losses import GANLoss, L1Loss, PerceptualLoss

with read_base():
from .realbasicvsr_wogan_c64b20_2x30x8_8xb2_lr1e_4_300k_reds import *

experiment_name = 'realbasicvsr_c64b20-1x30x8_8xb1-lr5e-5-150k_reds'
work_dir = f'./work_dirs/{experiment_name}'
save_dir = './work_dirs/'

# load_from = 'https://download.openmmlab.com/mmediting/restorers/real_basicvsr/realbasicvsr_wogan_c64b20_2x30x8_lr1e-4_300k_reds_20211027-0e2ff207.pth' # noqa

scale = 4

# model settings
model.update(
dict(
type=RealBasicVSR,
generator=dict(
type=RealBasicVSRNet,
mid_channels=64,
num_propagation_blocks=20,
num_cleaning_blocks=20,
dynamic_refine_thres=255, # change to 5 for test
spynet_pretrained=
'https://download.openmmlab.com/mmediting/restorers/'
'basicvsr/spynet_20210409-c6c1bd09.pth',
is_fix_cleaning=False,
is_sequential_cleaning=False),
discriminator=dict(
type=UNetDiscriminatorWithSpectralNorm,
in_channels=3,
mid_channels=64,
skip_connection=True),
pixel_loss=dict(type=L1Loss, loss_weight=1.0, reduction='mean'),
cleaning_loss=dict(type=L1Loss, loss_weight=1.0, reduction='mean'),
perceptual_loss=dict(
type=PerceptualLoss,
layer_weights={
'2': 0.1,
'7': 0.1,
'16': 1.0,
'25': 1.0,
'34': 1.0,
},
vgg_type='vgg19',
perceptual_weight=1.0,
style_weight=0,
norm_img=False),
gan_loss=dict(
type=GANLoss,
gan_type='vanilla',
loss_weight=5e-2,
real_label_val=1.0,
fake_label_val=0),
is_use_sharpened_gt_in_pixel=True,
is_use_sharpened_gt_in_percep=True,
is_use_sharpened_gt_in_gan=False,
is_use_ema=True,
data_preprocessor=dict(
type=DataPreprocessor,
mean=[0., 0., 0.],
std=[255., 255., 255.],
)))

# optimizer
optim_wrapper.update(
dict(
_delete_=True,
constructor=MultiOptimWrapperConstructor,
generator=dict(
type=OptimWrapper,
optimizer=dict(type=Adam, lr=5e-5, betas=(0.9, 0.99))),
discriminator=dict(
type=OptimWrapper,
optimizer=dict(type=Adam, lr=1e-4, betas=(0.9, 0.99))),
))

train_cfg.update(
dict(type=IterBasedTrainLoop, max_iters=150_000, val_interval=5000))
Loading
Loading