Skip to content

Commit

Permalink
[Feature] Add TTSR-GAN (#383)
Browse files Browse the repository at this point in the history
* Update

* re-update all

Co-authored-by: liyinshuo <[email protected]>
  • Loading branch information
Yshuo-Li and liyinshuo authored Jun 29, 2021
1 parent 2efae2b commit 50da141
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 110 deletions.
33 changes: 18 additions & 15 deletions mmedit/models/backbones/sr_backbones/ttsr_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,33 +329,30 @@ def __init__(self,
# end, merge features
self.merge_features = MergeFeatures(mid_channels, out_channels)

def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None):
def forward(self, x, soft_attention, textures):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
s (Tensor): Soft-Attention tensor with shape (n, 1, h, w).
t_level3 (Tensor): Transferred HR texture T in level3.
(n, 4c, h, w)
t_level2 (Tensor): Transferred HR texture T in level2.
(n, 2c, 2h, 2w)
t_level1 (Tensor): Transferred HR texture T in level1.
(n, c, 4h, 4w)
soft_attention (Tensor): Soft-Attention tensor with shape
(n, 1, h, w).
textures (Tuple[Tensor]): Transferred HR texture tensors.
[(N, C, H, W), (N, C/2, 2H, 2W), ...]
Returns:
Tensor: Forward results.
"""

assert t_level1.shape[1] == self.texture_channels
assert textures[-1].shape[1] == self.texture_channels

x1 = self.sfe(x)

# stage 1
x1_res = torch.cat((x1, t_level3), dim=1)
x1_res = torch.cat((x1, textures[0]), dim=1)
x1_res = self.conv_first1(x1_res)

# soft-attention
x1 = x1 + x1_res * s
x1 = x1 + x1_res * soft_attention

x1_res = self.res_block1(x1)
x1_res = self.conv_last1(x1_res)
Expand All @@ -367,12 +364,15 @@ def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None):
x22 = self.up1(x1)
x22 = F.relu(x22)

x22_res = torch.cat((x22, t_level2), dim=1)
x22_res = torch.cat((x22, textures[1]), dim=1)
x22_res = self.conv_first2(x22_res)

# soft-attention
x22_res = x22_res * F.interpolate(
s, scale_factor=2, mode='bicubic', align_corners=False)
soft_attention,
scale_factor=2,
mode='bicubic',
align_corners=False)
x22 = x22 + x22_res

x21_res, x22_res = self.csfi2(x21, x22)
Expand All @@ -392,12 +392,15 @@ def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None):
x33 = self.up2(x22)
x33 = F.relu(x33)

x33_res = torch.cat((x33, t_level1), dim=1)
x33_res = torch.cat((x33, textures[2]), dim=1)
x33_res = self.conv_first3(x33_res)

# soft-attention
x33_res = x33_res * F.interpolate(
s, scale_factor=4, mode='bicubic', align_corners=False)
soft_attention,
scale_factor=4,
mode='bicubic',
align_corners=False)
x33 = x33 + x33_res

x31_res, x32_res, x33_res = self.csfi3(x31, x32, x33)
Expand Down
10 changes: 5 additions & 5 deletions mmedit/models/extractors/lte.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def forward(self, x):
x (Tensor): Input tensor with shape (n, 3, h, w).
Returns:
Forward results in 3 levels.
x_level1 (Tensor): Forward results in level 1 (n, 64, h, w).
x_level2 (Tensor): Forward results in level 2 (n, 128, h/2, w/2).
x_level3 (Tensor): Forward results in level 3 (n, 256, h/4, w/4).
Tuple[Tensor]: Forward results in 3 levels.
x_level3: Forward results in level 3 (n, 256, h/4, w/4).
x_level2: Forward results in level 2 (n, 128, h/2, w/2).
x_level1: Forward results in level 1 (n, 64, h, w).
"""

x = self.img_normalize(x)
Expand All @@ -75,7 +75,7 @@ def forward(self, x):
x_level2 = x = self.slice2(x)
x_level3 = x = self.slice3(x)

return x_level1, x_level2, x_level3
return [x_level3, x_level2, x_level1]

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Expand Down
115 changes: 94 additions & 21 deletions mmedit/models/restorers/ttsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from mmedit.core import tensor2img
from ..builder import build_backbone, build_component, build_loss
from ..common import set_requires_grad
from ..registry import MODELS
from .basic_restorer import BasicRestorer

Expand All @@ -21,6 +22,11 @@ class TTSR(BasicRestorer):
extractor (dict): Config for the extractor.
transformer (dict): Config for the transformer.
pixel_loss (dict): Config for the pixel loss.
discriminator (dict): Config for the discriminator. Default: None.
perceptual_loss (dict): Config for the perceptual loss. Default: None.
transferal_perceptual_loss (dict): Config for the transferal perceptual
loss. Default: None.
gan_loss (dict): Config for the GAN loss. Default: None
train_cfg (dict): Config for train. Default: None.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path for pretrained model. Default: None.
Expand All @@ -31,6 +37,10 @@ def __init__(self,
extractor,
transformer,
pixel_loss,
discriminator=None,
perceptual_loss=None,
transferal_perceptual_loss=None,
gan_loss=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
Expand All @@ -43,13 +53,31 @@ def __init__(self,
self.generator = build_backbone(generator)
self.transformer = build_component(transformer)
self.extractor = build_component(extractor)
# discriminator
if discriminator and gan_loss:
self.discriminator = build_component(discriminator)
self.gan_loss = build_loss(gan_loss)
else:
self.discriminator = None
self.gan_loss = None

# loss
self.pixel_loss = build_loss(pixel_loss)

self.perceptual_loss = build_loss(
perceptual_loss) if perceptual_loss else None
if transferal_perceptual_loss:
self.transferal_perceptual_loss = build_loss(
transferal_perceptual_loss)
else:
self.transferal_perceptual_loss = None
# pretrained
self.init_weights(pretrained)

# fix pre-trained networks
self.register_buffer('step_counter', torch.zeros(1))
self.fix_iter = train_cfg.get('fix_iter', 0) if train_cfg else 0
self.disc_steps = train_cfg.get('disc_steps', 1) if train_cfg else 1

def forward_dummy(self, lq, lq_up, ref, ref_downup, only_pred=True):
"""Forward of networks.
Expand All @@ -64,28 +92,23 @@ def forward_dummy(self, lq, lq_up, ref, ref_downup, only_pred=True):
Returns:
pred (Tensor): Predicted super-resolution results (n, 3, 4h, 4w).
s (Tensor): Soft-Attention tensor with shape (n, 1, h, w).
t_level3 (Tensor): Transformed HR texture T in level3.
(n, 4c, h, w)
t_level2 (Tensor): Transformed HR texture T in level2.
(n, 2c, 2h, 2w)
t_level1 (Tensor): Transformed HR texture T in level1.
(n, c, 4h, 4w)
soft_attention (Tensor): Soft-Attention tensor with shape
(n, 1, h, w).
textures (Tuple[Tensor]): Transferred GT textures.
[(N, C, H, W), (N, C/2, 2H, 2W), ...]
"""

_, _, lq_up_level3 = self.extractor(lq_up)
_, _, ref_downup_level3 = self.extractor(ref_downup)
ref_level1, ref_level2, ref_level3 = self.extractor(ref)
lq_up, _, _ = self.extractor(lq_up)
ref_downup, _, _ = self.extractor(ref_downup)
refs = self.extractor(ref)

s, t_level3, t_level2, t_level1 = self.transformer(
lq_up_level3, ref_downup_level3, ref_level1, ref_level2,
ref_level3)
soft_attention, textures = self.transformer(lq_up, ref_downup, refs)

pred = self.generator(lq, s, t_level3, t_level2, t_level1)
pred = self.generator(lq, soft_attention, textures)

if only_pred:
return pred
return pred, s, t_level3, t_level2, t_level1
return pred, soft_attention, textures

def forward(self, lq, gt=None, test_mode=False, **kwargs):
"""Forward function.
Expand Down Expand Up @@ -123,20 +146,68 @@ def train_step(self, data_batch, optimizer):
ref_downup = data_batch['ref_downup']

# generate
pred = self.forward_dummy(lq, lq_up, ref, ref_downup)
pred, soft_attention, textures = self(
lq, lq_up=lq_up, ref=ref, ref_downup=ref_downup, only_pred=False)

# loss
losses = dict()
log_vars = dict()

# no updates to discriminator parameters.
set_requires_grad(self.discriminator, False)

losses['loss_pix'] = self.pixel_loss(pred, gt)
if self.step_counter >= self.fix_iter:
# perceptual loss
if self.perceptual_loss:
loss_percep, loss_style = self.perceptual_loss(pred, gt)
if loss_percep is not None:
losses['loss_perceptual'] = loss_percep
if loss_style is not None:
losses['loss_style'] = loss_style
if self.transferal_perceptual_loss:
set_requires_grad(self.extractor, False)
sr_textures = self.extractor((pred + 1.) / 2.)
losses['loss_transferal'] = self.transferal_perceptual_loss(
sr_textures, soft_attention, textures)
# gan loss for generator
if self.gan_loss:
fake_g_pred = self.discriminator(pred)
losses['loss_gan'] = self.gan_loss(
fake_g_pred, target_is_real=True, is_disc=False)

# parse loss
loss, log_vars = self.parse_losses(losses)
loss_g, log_vars_g = self.parse_losses(losses)
log_vars.update(log_vars_g)

# optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer['generator'].zero_grad()
loss_g.backward()
optimizer['generator'].step()

if self.discriminator and self.step_counter >= self.fix_iter:
# discriminator
set_requires_grad(self.discriminator, True)
for _ in range(self.disc_steps):
# real
real_d_pred = self.discriminator(gt)
loss_d_real = self.gan_loss(
real_d_pred, target_is_real=True, is_disc=True)
loss_d, log_vars_d = self.parse_losses(
dict(loss_d_real=loss_d_real))
optimizer['discriminator'].zero_grad()
loss_d.backward()
log_vars.update(log_vars_d)
# fake
fake_d_pred = self.discriminator(pred.detach())
loss_d_fake = self.gan_loss(
fake_d_pred, target_is_real=False, is_disc=True)
loss_d, log_vars_d = self.parse_losses(
dict(loss_d_fake=loss_d_fake))
loss_d.backward()
log_vars.update(log_vars_d)

optimizer['discriminator'].step()

log_vars.pop('loss') # remove the unnecessary 'loss'
outputs = dict(
Expand All @@ -145,6 +216,8 @@ def train_step(self, data_batch, optimizer):
results=dict(
lq=lq.cpu(), gt=gt.cpu(), ref=ref.cpu(), output=pred.cpu()))

self.step_counter += 1

return outputs

def forward_test(self,
Expand Down
Loading

0 comments on commit 50da141

Please sign in to comment.