diff --git a/mmedit/models/backbones/sr_backbones/ttsr_net.py b/mmedit/models/backbones/sr_backbones/ttsr_net.py index cb181b8a60..b2733ef2df 100644 --- a/mmedit/models/backbones/sr_backbones/ttsr_net.py +++ b/mmedit/models/backbones/sr_backbones/ttsr_net.py @@ -329,33 +329,30 @@ def __init__(self, # end, merge features self.merge_features = MergeFeatures(mid_channels, out_channels) - def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None): + def forward(self, x, soft_attention, textures): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). - s (Tensor): Soft-Attention tensor with shape (n, 1, h, w). - t_level3 (Tensor): Transferred HR texture T in level3. - (n, 4c, h, w) - t_level2 (Tensor): Transferred HR texture T in level2. - (n, 2c, 2h, 2w) - t_level1 (Tensor): Transferred HR texture T in level1. - (n, c, 4h, 4w) + soft_attention (Tensor): Soft-Attention tensor with shape + (n, 1, h, w). + textures (Tuple[Tensor]): Transferred HR texture tensors. + [(N, C, H, W), (N, C/2, 2H, 2W), ...] Returns: Tensor: Forward results. """ - assert t_level1.shape[1] == self.texture_channels + assert textures[-1].shape[1] == self.texture_channels x1 = self.sfe(x) # stage 1 - x1_res = torch.cat((x1, t_level3), dim=1) + x1_res = torch.cat((x1, textures[0]), dim=1) x1_res = self.conv_first1(x1_res) # soft-attention - x1 = x1 + x1_res * s + x1 = x1 + x1_res * soft_attention x1_res = self.res_block1(x1) x1_res = self.conv_last1(x1_res) @@ -367,12 +364,15 @@ def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None): x22 = self.up1(x1) x22 = F.relu(x22) - x22_res = torch.cat((x22, t_level2), dim=1) + x22_res = torch.cat((x22, textures[1]), dim=1) x22_res = self.conv_first2(x22_res) # soft-attention x22_res = x22_res * F.interpolate( - s, scale_factor=2, mode='bicubic', align_corners=False) + soft_attention, + scale_factor=2, + mode='bicubic', + align_corners=False) x22 = x22 + x22_res x21_res, x22_res = self.csfi2(x21, x22) @@ -392,12 +392,15 @@ def forward(self, x, s=None, t_level3=None, t_level2=None, t_level1=None): x33 = self.up2(x22) x33 = F.relu(x33) - x33_res = torch.cat((x33, t_level1), dim=1) + x33_res = torch.cat((x33, textures[2]), dim=1) x33_res = self.conv_first3(x33_res) # soft-attention x33_res = x33_res * F.interpolate( - s, scale_factor=4, mode='bicubic', align_corners=False) + soft_attention, + scale_factor=4, + mode='bicubic', + align_corners=False) x33 = x33 + x33_res x31_res, x32_res, x33_res = self.csfi3(x31, x32, x33) diff --git a/mmedit/models/extractors/lte.py b/mmedit/models/extractors/lte.py index 5eb0c225e4..0c43c480af 100644 --- a/mmedit/models/extractors/lte.py +++ b/mmedit/models/extractors/lte.py @@ -63,10 +63,10 @@ def forward(self, x): x (Tensor): Input tensor with shape (n, 3, h, w). Returns: - Forward results in 3 levels. - x_level1 (Tensor): Forward results in level 1 (n, 64, h, w). - x_level2 (Tensor): Forward results in level 2 (n, 128, h/2, w/2). - x_level3 (Tensor): Forward results in level 3 (n, 256, h/4, w/4). + Tuple[Tensor]: Forward results in 3 levels. + x_level3: Forward results in level 3 (n, 256, h/4, w/4). + x_level2: Forward results in level 2 (n, 128, h/2, w/2). + x_level1: Forward results in level 1 (n, 64, h, w). """ x = self.img_normalize(x) @@ -75,7 +75,7 @@ def forward(self, x): x_level2 = x = self.slice2(x) x_level3 = x = self.slice3(x) - return x_level1, x_level2, x_level3 + return [x_level3, x_level2, x_level1] def init_weights(self, pretrained=None, strict=True): """Init weights for models. diff --git a/mmedit/models/restorers/ttsr.py b/mmedit/models/restorers/ttsr.py index fca5b81e3e..617e509823 100644 --- a/mmedit/models/restorers/ttsr.py +++ b/mmedit/models/restorers/ttsr.py @@ -6,6 +6,7 @@ from mmedit.core import tensor2img from ..builder import build_backbone, build_component, build_loss +from ..common import set_requires_grad from ..registry import MODELS from .basic_restorer import BasicRestorer @@ -21,6 +22,11 @@ class TTSR(BasicRestorer): extractor (dict): Config for the extractor. transformer (dict): Config for the transformer. pixel_loss (dict): Config for the pixel loss. + discriminator (dict): Config for the discriminator. Default: None. + perceptual_loss (dict): Config for the perceptual loss. Default: None. + transferal_perceptual_loss (dict): Config for the transferal perceptual + loss. Default: None. + gan_loss (dict): Config for the GAN loss. Default: None train_cfg (dict): Config for train. Default: None. test_cfg (dict): Config for testing. Default: None. pretrained (str): Path for pretrained model. Default: None. @@ -31,6 +37,10 @@ def __init__(self, extractor, transformer, pixel_loss, + discriminator=None, + perceptual_loss=None, + transferal_perceptual_loss=None, + gan_loss=None, train_cfg=None, test_cfg=None, pretrained=None): @@ -43,13 +53,31 @@ def __init__(self, self.generator = build_backbone(generator) self.transformer = build_component(transformer) self.extractor = build_component(extractor) + # discriminator + if discriminator and gan_loss: + self.discriminator = build_component(discriminator) + self.gan_loss = build_loss(gan_loss) + else: + self.discriminator = None + self.gan_loss = None # loss self.pixel_loss = build_loss(pixel_loss) - + self.perceptual_loss = build_loss( + perceptual_loss) if perceptual_loss else None + if transferal_perceptual_loss: + self.transferal_perceptual_loss = build_loss( + transferal_perceptual_loss) + else: + self.transferal_perceptual_loss = None # pretrained self.init_weights(pretrained) + # fix pre-trained networks + self.register_buffer('step_counter', torch.zeros(1)) + self.fix_iter = train_cfg.get('fix_iter', 0) if train_cfg else 0 + self.disc_steps = train_cfg.get('disc_steps', 1) if train_cfg else 1 + def forward_dummy(self, lq, lq_up, ref, ref_downup, only_pred=True): """Forward of networks. @@ -64,28 +92,23 @@ def forward_dummy(self, lq, lq_up, ref, ref_downup, only_pred=True): Returns: pred (Tensor): Predicted super-resolution results (n, 3, 4h, 4w). - s (Tensor): Soft-Attention tensor with shape (n, 1, h, w). - t_level3 (Tensor): Transformed HR texture T in level3. - (n, 4c, h, w) - t_level2 (Tensor): Transformed HR texture T in level2. - (n, 2c, 2h, 2w) - t_level1 (Tensor): Transformed HR texture T in level1. - (n, c, 4h, 4w) + soft_attention (Tensor): Soft-Attention tensor with shape + (n, 1, h, w). + textures (Tuple[Tensor]): Transferred GT textures. + [(N, C, H, W), (N, C/2, 2H, 2W), ...] """ - _, _, lq_up_level3 = self.extractor(lq_up) - _, _, ref_downup_level3 = self.extractor(ref_downup) - ref_level1, ref_level2, ref_level3 = self.extractor(ref) + lq_up, _, _ = self.extractor(lq_up) + ref_downup, _, _ = self.extractor(ref_downup) + refs = self.extractor(ref) - s, t_level3, t_level2, t_level1 = self.transformer( - lq_up_level3, ref_downup_level3, ref_level1, ref_level2, - ref_level3) + soft_attention, textures = self.transformer(lq_up, ref_downup, refs) - pred = self.generator(lq, s, t_level3, t_level2, t_level1) + pred = self.generator(lq, soft_attention, textures) if only_pred: return pred - return pred, s, t_level3, t_level2, t_level1 + return pred, soft_attention, textures def forward(self, lq, gt=None, test_mode=False, **kwargs): """Forward function. @@ -123,20 +146,68 @@ def train_step(self, data_batch, optimizer): ref_downup = data_batch['ref_downup'] # generate - pred = self.forward_dummy(lq, lq_up, ref, ref_downup) + pred, soft_attention, textures = self( + lq, lq_up=lq_up, ref=ref, ref_downup=ref_downup, only_pred=False) # loss losses = dict() + log_vars = dict() + + # no updates to discriminator parameters. + set_requires_grad(self.discriminator, False) losses['loss_pix'] = self.pixel_loss(pred, gt) + if self.step_counter >= self.fix_iter: + # perceptual loss + if self.perceptual_loss: + loss_percep, loss_style = self.perceptual_loss(pred, gt) + if loss_percep is not None: + losses['loss_perceptual'] = loss_percep + if loss_style is not None: + losses['loss_style'] = loss_style + if self.transferal_perceptual_loss: + set_requires_grad(self.extractor, False) + sr_textures = self.extractor((pred + 1.) / 2.) + losses['loss_transferal'] = self.transferal_perceptual_loss( + sr_textures, soft_attention, textures) + # gan loss for generator + if self.gan_loss: + fake_g_pred = self.discriminator(pred) + losses['loss_gan'] = self.gan_loss( + fake_g_pred, target_is_real=True, is_disc=False) # parse loss - loss, log_vars = self.parse_losses(losses) + loss_g, log_vars_g = self.parse_losses(losses) + log_vars.update(log_vars_g) # optimize - optimizer.zero_grad() - loss.backward() - optimizer.step() + optimizer['generator'].zero_grad() + loss_g.backward() + optimizer['generator'].step() + + if self.discriminator and self.step_counter >= self.fix_iter: + # discriminator + set_requires_grad(self.discriminator, True) + for _ in range(self.disc_steps): + # real + real_d_pred = self.discriminator(gt) + loss_d_real = self.gan_loss( + real_d_pred, target_is_real=True, is_disc=True) + loss_d, log_vars_d = self.parse_losses( + dict(loss_d_real=loss_d_real)) + optimizer['discriminator'].zero_grad() + loss_d.backward() + log_vars.update(log_vars_d) + # fake + fake_d_pred = self.discriminator(pred.detach()) + loss_d_fake = self.gan_loss( + fake_d_pred, target_is_real=False, is_disc=True) + loss_d, log_vars_d = self.parse_losses( + dict(loss_d_fake=loss_d_fake)) + loss_d.backward() + log_vars.update(log_vars_d) + + optimizer['discriminator'].step() log_vars.pop('loss') # remove the unnecessary 'loss' outputs = dict( @@ -145,6 +216,8 @@ def train_step(self, data_batch, optimizer): results=dict( lq=lq.cpu(), gt=gt.cpu(), ref=ref.cpu(), output=pred.cpu())) + self.step_counter += 1 + return outputs def forward_test(self, diff --git a/mmedit/models/transformers/search_transformer.py b/mmedit/models/transformers/search_transformer.py index 5df10f5cf2..db7f2be26a 100644 --- a/mmedit/models/transformers/search_transformer.py +++ b/mmedit/models/transformers/search_transformer.py @@ -35,13 +35,12 @@ def gather(self, inputs, dim, index): return outputs - def forward(self, lq_up_level3, ref_downup_level3, ref_level1, ref_level2, - ref_level3): + def forward(self, lq_up, ref_downup, refs): """Texture transformer Q = LTE(lq_up) K = LTE(ref_downup) - V = LTE(ref), from V_level1 to V_level3 + V = LTE(ref), from V_level_n to V_level_1 Relevance embedding aims to embed the relevance between the LQ and Ref image by estimating the similarity between Q and K. @@ -51,41 +50,40 @@ def forward(self, lq_up_level3, ref_downup_level3, ref_level1, ref_level2, features T and the LQ features F from the backbone. Args: - All args are features come from extractor (sucn as LTE). + All args are features come from extractor (such as LTE). These features contain 3 levels. When upscale_factor=4, the size ratio of these features is - level1:level2:level3 = 4:2:1. - lq_up_level3 (Tensor): level3 feature of 4x bicubic-upsampled lq - image. (N, 4C, H, W) - ref_downup_level3 (Tensor): level3 feature of ref_downup. - ref_downup is obtained by applying bicubic down-sampling and - up-sampling with factor 4x on ref. (N, 4C, H, W) - ref_level1 (Tensor): level1 feature of ref image. (N, C, 4H, 4W) - ref_level2 (Tensor): level2 feature of ref image. (N, 2C, 2H, 2W) - ref_level3 (Tensor): level3 feature of ref image. (N, 4C, H, W) + level3:level2:level1 = 1:2:4. + lq_up (Tensor): Tensor of 4x bicubic-upsampled lq image. + (N, C, H, W) + ref_downup (Tensor): Tensor of ref_downup. ref_downup is obtained + by applying bicubic down-sampling and up-sampling with factor + 4x on ref. (N, C, H, W) + refs (Tuple[Tensor]): Tuple of ref tensors. + [(N, C, H, W), (N, C/2, 2H, 2W), ...] Returns: - s (Tensor): Soft-Attention tensor. (N, 1, H, W) - t_level3 (Tensor): Transferred GT texture T in level3. - (N, 4C, H, W) - t_level2 (Tensor): Transferred GT texture T in level2. - (N, 2C, 2H, 2W) - t_level1 (Tensor): Transferred GT texture T in level1. - (N, C, 4H, 4W) + soft_attention (Tensor): Soft-Attention tensor. (N, 1, H, W) + textures (Tuple[Tensor]): Transferred GT textures. + [(N, C, H, W), (N, C/2, 2H, 2W), ...] """ + + levels = len(refs) # query - query = F.unfold(lq_up_level3, kernel_size=(3, 3), padding=1) + query = F.unfold(lq_up, kernel_size=(3, 3), padding=1) # key - key = F.unfold(ref_downup_level3, kernel_size=(3, 3), padding=1) + key = F.unfold(ref_downup, kernel_size=(3, 3), padding=1) key_t = key.permute(0, 2, 1) # values - value_level3 = F.unfold(ref_level3, kernel_size=(3, 3), padding=1) - value_level2 = F.unfold( - ref_level2, kernel_size=(6, 6), padding=2, stride=2) - value_level1 = F.unfold( - ref_level1, kernel_size=(12, 12), padding=4, stride=4) + values = [ + F.unfold( + refs[i], + kernel_size=3 * pow(2, i), + padding=pow(2, i), + stride=pow(2, i)) for i in range(levels) + ] key_t = F.normalize(key_t, dim=2) # [N, H*W, C*k*k] query = F.normalize(query, dim=1) # [N, C*k*k, H*W] @@ -95,30 +93,19 @@ def forward(self, lq_up_level3, ref_downup_level3, ref_level1, ref_level2, max_val, max_index = torch.max(rel_embedding, dim=1) # [N, H*W] # hard-attention - t_level3_unfold = self.gather(value_level3, 2, max_index) - t_level2_unfold = self.gather(value_level2, 2, max_index) - t_level1_unfold = self.gather(value_level1, 2, max_index) + textures = [self.gather(value, 2, max_index) for value in values] # to tensor - t_level3 = F.fold( - t_level3_unfold, - output_size=lq_up_level3.size()[-2:], - kernel_size=(3, 3), - padding=1) / (3. * 3.) - t_level2 = F.fold( - t_level2_unfold, - output_size=(lq_up_level3.size(2) * 2, lq_up_level3.size(3) * 2), - kernel_size=(6, 6), - padding=2, - stride=2) / (3. * 3.) - t_level1 = F.fold( - t_level1_unfold, - output_size=(lq_up_level3.size(2) * 4, lq_up_level3.size(3) * 4), - kernel_size=(12, 12), - padding=4, - stride=4) / (3. * 3.) - - s = max_val.view( - max_val.size(0), 1, lq_up_level3.size(2), lq_up_level3.size(3)) - - return s, t_level3, t_level2, t_level1 + h, w = lq_up.size()[-2:] + textures = [ + F.fold( + textures[i], + output_size=(h * pow(2, i), w * pow(2, i)), + kernel_size=3 * pow(2, i), + padding=pow(2, i), + stride=pow(2, i)) / 9. for i in range(levels) + ] + + soft_attention = max_val.view(max_val.size(0), 1, h, w) + + return soft_attention, textures diff --git a/tests/test_models/test_common/test_img_normalize.py b/tests/test_models/test_common/test_img_normalize.py index fd425af917..d8a16f17e2 100644 --- a/tests/test_models/test_common/test_img_normalize.py +++ b/tests/test_models/test_common/test_img_normalize.py @@ -19,7 +19,3 @@ def test_normalize_layer(): std_y = y.std(dim=1) assert sum(torch.div(std_x, std_y) - rgb_std) < 1e-5 assert sum(torch.div(mean_x - rgb_mean, rgb_std) - mean_y) < 1e-5 - - -if __name__ == '__main__': - test_normalize_layer() diff --git a/tests/test_models/test_extractors/test_lte.py b/tests/test_models/test_extractors/test_lte.py index 23ddc5f0d6..803de236a1 100644 --- a/tests/test_models/test_extractors/test_lte.py +++ b/tests/test_models/test_extractors/test_lte.py @@ -13,7 +13,7 @@ def test_lte(): x = torch.rand(2, 3, 64, 64) - x_level1, x_level2, x_level3 = lte(x) + x_level3, x_level2, x_level1 = lte(x) assert x_level1.shape == (2, 64, 64, 64) assert x_level2.shape == (2, 128, 32, 32) assert x_level3.shape == (2, 256, 16, 16) @@ -22,7 +22,7 @@ def test_lte(): with pytest.raises(IOError): model_cfg['pretrained'] = '' lte = build_component(model_cfg) - x_level1, x_level2, x_level3 = lte(x) + x_level3, x_level2, x_level1 = lte(x) lte.init_weights('') with pytest.raises(TypeError): lte.init_weights(1) diff --git a/tests/test_models/test_restorers/test_ttsr.py b/tests/test_models/test_restorers/test_ttsr.py index 28e19a3f26..f0758b2df9 100644 --- a/tests/test_models/test_restorers/test_ttsr.py +++ b/tests/test_models/test_restorers/test_ttsr.py @@ -44,7 +44,7 @@ def test_merge_features(): def test_ttsr_net(): inputs = torch.rand(2, 3, 24, 24) - s = torch.rand(2, 1, 24, 24) + soft_attention = torch.rand(2, 1, 24, 24) t_level3 = torch.rand(2, 64, 24, 24) t_level2 = torch.rand(2, 32, 48, 48) t_level1 = torch.rand(2, 16, 96, 96) @@ -56,13 +56,12 @@ def test_ttsr_net(): mid_channels=16, texture_channels=16) ttsr = build_backbone(ttsr_cfg) - outputs = ttsr(inputs, s, t_level3, t_level2, t_level1) + outputs = ttsr(inputs, soft_attention, (t_level3, t_level2, t_level1)) assert outputs.shape == (2, 3, 96, 96) def test_ttsr(): - model_cfg = dict( type='TTSR', generator=dict( @@ -82,6 +81,44 @@ def test_ttsr(): # build restorer restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) + model_cfg = dict( + type='TTSR', + generator=dict( + type='TTSRNet', + in_channels=3, + out_channels=3, + mid_channels=64, + num_blocks=(16, 16, 8, 4)), + extractor=dict(type='LTE'), + transformer=dict(type='SearchTransformer'), + discriminator=dict(type='TTSRDiscriminator', in_size=64), + pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'), + perceptual_loss=dict( + type='PerceptualLoss', + layer_weights={'29': 1.0}, + vgg_type='vgg19', + perceptual_weight=1e-2, + style_weight=0.001, + criterion='mse'), + transferal_perceptual_loss=dict( + type='TransferalPerceptualLoss', + loss_weight=1e-2, + use_attention=False, + criterion='mse'), + gan_loss=dict( + type='GANLoss', + gan_type='vanilla', + loss_weight=1e-3, + real_label_val=1.0, + fake_label_val=0)) + + scale = 4 + train_cfg = None + test_cfg = Config(dict(metrics=['PSNR', 'SSIM'], crop_border=scale)) + + # build restorer + restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) + # test attributes assert restorer.__class__.__name__ == 'TTSR' @@ -98,9 +135,13 @@ def test_ttsr(): } # prepare optimizer - optim_cfg = dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)) - optimizer = obj_from_dict(optim_cfg, torch.optim, - dict(params=restorer.parameters())) + optim_cfg_g = dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)) + optim_cfg_d = dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)) + optimizer = dict( + generator=obj_from_dict(optim_cfg_g, torch.optim, + dict(params=restorer.parameters())), + discriminator=obj_from_dict(optim_cfg_d, torch.optim, + dict(params=restorer.parameters()))) # test train_step and forward_test (cpu) outputs = restorer.train_step(data_batch, optimizer) @@ -125,8 +166,11 @@ def test_ttsr(): } # train_step - optimizer = obj_from_dict(optim_cfg, torch.optim, - dict(params=restorer.parameters())) + optimizer = dict( + generator=obj_from_dict(optim_cfg_g, torch.optim, + dict(params=restorer.parameters())), + discriminator=obj_from_dict(optim_cfg_d, torch.optim, + dict(params=restorer.parameters()))) outputs = restorer.train_step(data_batch, optimizer) assert isinstance(outputs, dict) assert isinstance(outputs['log_vars'], dict) diff --git a/tests/test_models/test_transformer/test_search_transformer.py b/tests/test_models/test_transformer/test_search_transformer.py index ff132233c2..29d802acde 100644 --- a/tests/test_models/test_transformer/test_search_transformer.py +++ b/tests/test_models/test_transformer/test_search_transformer.py @@ -1,6 +1,6 @@ import torch -from mmedit.models import build_component +from mmedit.models.builder import build_component def test_search_transformer(): @@ -13,8 +13,9 @@ def test_search_transformer(): ref_level2 = torch.randn((2, 16, 64, 64)) ref_level1 = torch.randn((2, 8, 128, 128)) - s, t_level3, t_level2, t_level1 = model(lr_pad_level3, ref_pad_level3, - ref_level1, ref_level2, ref_level3) + s, textures = model(lr_pad_level3, ref_pad_level3, + (ref_level3, ref_level2, ref_level1)) + t_level3, t_level2, t_level1 = textures assert s.shape == (2, 1, 32, 32) assert t_level3.shape == (2, 32, 32, 32)