diff --git a/mmedit/models/backbones/__init__.py b/mmedit/models/backbones/__init__.py index 93393a1863..924ca754b8 100644 --- a/mmedit/models/backbones/__init__.py +++ b/mmedit/models/backbones/__init__.py @@ -12,7 +12,8 @@ SimpleEncoderDecoder) # yapf: enable from .generation_backbones import ResnetGenerator, UnetGenerator -from .sr_backbones import EDSR, SRCNN, EDVRNet, MSRResNet, RRDBNet, TOFlow +from .sr_backbones import (EDSR, SRCNN, BasicVSRNet, EDVRNet, MSRResNet, + RRDBNet, TOFlow) __all__ = [ 'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder', @@ -23,5 +24,6 @@ 'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR', 'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder', 'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN', - 'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder' + 'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder', + 'BasicVSRNet' ] diff --git a/mmedit/models/backbones/sr_backbones/__init__.py b/mmedit/models/backbones/sr_backbones/__init__.py index 24f181f583..b54707fd1a 100644 --- a/mmedit/models/backbones/sr_backbones/__init__.py +++ b/mmedit/models/backbones/sr_backbones/__init__.py @@ -1,3 +1,4 @@ +from .basicvsr_net import BasicVSRNet from .edsr import EDSR from .edvr_net import EDVRNet from .rrdb_net import RRDBNet @@ -5,4 +6,6 @@ from .srcnn import SRCNN from .tof import TOFlow -__all__ = ['MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN'] +__all__ = [ + 'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', 'BasicVSRNet' +] diff --git a/mmedit/models/backbones/sr_backbones/basicvsr_net.py b/mmedit/models/backbones/sr_backbones/basicvsr_net.py new file mode 100644 index 0000000000..371e18c700 --- /dev/null +++ b/mmedit/models/backbones/sr_backbones/basicvsr_net.py @@ -0,0 +1,418 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.runner import load_checkpoint + +from mmedit.models.common import (PixelShufflePack, ResidualBlockNoBN, + flow_warp, make_layer) +from mmedit.models.registry import BACKBONES +from mmedit.utils import get_root_logger + + +@BACKBONES.register_module() +class BasicVSRNet(nn.Module): + """BasicVSR network structure for video super-resolution. + + Support only x4 upsampling. + Paper: + BasicVSR: The Search for Essential Components in Video Super-Resolution + and Beyond, CVPR, 2021 + + Args: + mid_channels (int): Channel number of the intermediate features. + Default: 64. + num_blocks (int): Number of residual blocks in each propagation branch. + Default: 30. + spynet_pretrained (str): Pre-trained model path of SPyNet. + Default: None. + """ + + def __init__(self, mid_channels=64, num_blocks=30, spynet_pretrained=None): + + super().__init__() + + self.mid_channels = mid_channels + + # optical flow network for feature alignment + self.spynet = SPyNet(pretrained=spynet_pretrained) + + # propagation branches + self.backward_resblocks = ResidualBlocksWithInputConv( + mid_channels + 3, mid_channels, num_blocks) + self.forward_resblocks = ResidualBlocksWithInputConv( + mid_channels + 3, mid_channels, num_blocks) + + # upsample + self.fusion = nn.Conv2d( + mid_channels * 2, mid_channels, 1, 1, 0, bias=True) + self.upsample1 = PixelShufflePack( + mid_channels, mid_channels, 2, upsample_kernel=3) + self.upsample2 = PixelShufflePack( + mid_channels, 64, 2, upsample_kernel=3) + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + self.img_upsample = nn.Upsample( + scale_factor=4, mode='bilinear', align_corners=False) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def check_if_mirror_extended(self, lrs): + """Check whether the input is a mirror-extended sequence. + + If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the + (t-1-i)-th frame. + + Args: + lrs (tensor): Input LR images with shape (n, t, c, h, w) + """ + + self.is_mirror_extended = False + if lrs.size(1) % 2 == 0: + lrs_1, lrs_2 = torch.chunk(lrs, 2, dim=1) + if torch.norm(lrs_1 - lrs_2.flip(1)) == 0: + self.is_mirror_extended = True + + def compute_flow(self, lrs): + """Compute optical flow using SPyNet for feature warping. + + Note that if the input is an mirror-extended sequence, 'flows_forward' + is not needed, since it is equal to 'flows_backward.flip(1)'. + + Args: + lrs (tensor): Input LR images with shape (n, t, c, h, w) + + Return: + tuple(Tensor): Optical flow. 'flows_forward' corresponds to the + flows used for forward-time propagation (current to previous). + 'flows_backward' corresponds to the flows used for + backward-time propagation (current to next). + """ + + n, t, c, h, w = lrs.size() + lrs_1 = lrs[:, :-1, :, :, :].reshape(-1, c, h, w) + lrs_2 = lrs[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(lrs_1, lrs_2).view(n, t - 1, 2, h, w) + + if self.is_mirror_extended: # flows_forward = flows_backward.flip(1) + flows_forward = None + else: + flows_forward = self.spynet(lrs_2, lrs_1).view(n, t - 1, 2, h, w) + + return flows_forward, flows_backward + + def forward(self, lrs): + """Forward function for BasicVSR. + + Args: + lrs (Tensor): Input LR sequence with shape (n, t, c, h, w). + + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + n, t, c, h, w = lrs.size() + assert h >= 64 and w >= 64, ( + 'The height and width of inputs should be at least 64, ' + f'but got {h} and {w}.') + + # check whether the input is an extended sequence + self.check_if_mirror_extended(lrs) + + # compute optical flow + flows_forward, flows_backward = self.compute_flow(lrs) + + # backward-time propgation + outputs = [] + feat_prop = lrs.new_zeros(n, self.mid_channels, h, w) + for i in range(t - 1, -1, -1): + if i < t - 1: # no warping required for the last timestep + flow = flows_backward[:, i, :, :, :] + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) + + feat_prop = torch.cat([lrs[:, i, :, :, :], feat_prop], dim=1) + feat_prop = self.backward_resblocks(feat_prop) + + outputs.append(feat_prop) + outputs = outputs[::-1] + + # forward-time propagation and upsampling + feat_prop = torch.zeros_like(feat_prop) + for i in range(0, t): + lr_curr = lrs[:, i, :, :, :] + if i > 0: # no warping required for the first timestep + if flows_forward is not None: + flow = flows_forward[:, i - 1, :, :, :] + else: + flow = flows_backward[:, -i, :, :, :] + feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1)) + + feat_prop = torch.cat([lr_curr, feat_prop], dim=1) + feat_prop = self.forward_resblocks(feat_prop) + + # upsampling given the backward and forward features + out = torch.cat([outputs[i], feat_prop], dim=1) + out = self.lrelu(self.fusion(out)) + out = self.lrelu(self.upsample1(out)) + out = self.lrelu(self.upsample2(out)) + out = self.lrelu(self.conv_hr(out)) + out = self.conv_last(out) + base = self.img_upsample(lr_curr) + out += base + outputs[i] = out + + return torch.stack(outputs, dim=1) + + def init_weights(self, pretrained=None, strict=True): + """Init weights for models. + + Args: + pretrained (str, optional): Path for pretrained weights. If given + None, pretrained weights will not be loaded. Defaults: None. + strict (boo, optional): Whether strictly load the pretrained model. + Defaults to True. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=strict, logger=logger) + elif pretrained is not None: + raise TypeError(f'"pretrained" must be a str or None. ' + f'But received {type(pretrained)}.') + + +class ResidualBlocksWithInputConv(nn.Module): + """Residual blocks with a convolution in front. + + Args: + in_channels (int): Number of input channels of the first conv. + out_channels (int): Number of channels of the residual blocks. + Default: 64. + num_blocks (int): Number of residual blocks. Default: 30. + """ + + def __init__(self, in_channels, out_channels=64, num_blocks=30): + super().__init__() + + main = [] + + # a convolution used to match the channels of the residual blocks + main.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True)) + main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True)) + + # residual blocks + main.append( + make_layer( + ResidualBlockNoBN, num_blocks, mid_channels=out_channels)) + + self.main = nn.Sequential(*main) + + def forward(self, feat): + """ + Forward function for ResidualBlocksWithInputConv. + + Args: + feat (Tensor): Input feature with shape (n, in_channels, h, w) + + Returns: + Tensor: Output feature with shape (n, out_channels, h, w) + """ + return self.main(feat) + + +class SPyNet(nn.Module): + """SPyNet network structure. + + The difference to the SPyNet in [tof.py] is that + 1. more SPyNetBasicModule is used in this version, and + 2. no batch normalization is used in this version. + + Paper: + Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 + + Args: + pretrained (str): path for pre-trained SPyNet. Default: None. + """ + + def __init__(self, pretrained): + super().__init__() + + self.basic_module = nn.ModuleList( + [SPyNetBasicModule() for _ in range(6)]) + + if isinstance(pretrained, str): + self.load_state_dict(torch.load(pretrained), strict=True) + elif pretrained is not None: + raise TypeError('[pretrained] should be str or None, ' + f'but got {type(pretrained)}.') + + self.register_buffer( + 'mean', + torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer( + 'std', + torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def compute_flow(self, ref, supp): + """Compute flow from ref to supp. + + Note that in this function, the images are already resized to a + multiple of 32. + + Args: + ref (Tensor): Reference image with shape of (n, 3, h, w). + supp (Tensor): Supporting image with shape of (n, 3, h, w). + + Returns: + Tensor: Estimated optical flow: (n, 2, h, w). + """ + n, _, h, w = ref.size() + + # normalize the input images + ref = [(ref - self.mean) / self.std] + supp = [(supp - self.mean) / self.std] + + # generate downsampled frames + for level in range(5): + ref.append( + F.avg_pool2d( + input=ref[-1], + kernel_size=2, + stride=2, + count_include_pad=False)) + supp.append( + F.avg_pool2d( + input=supp[-1], + kernel_size=2, + stride=2, + count_include_pad=False)) + ref = ref[::-1] + supp = supp[::-1] + + # flow computation + flow = ref[0].new_zeros(n, 2, h // 32, w // 32) + for level in range(len(ref)): + if level == 0: + flow_up = flow + else: + flow_up = F.interpolate( + input=flow, + scale_factor=2, + mode='bilinear', + align_corners=True) * 2.0 + + # add the residue to the upsampled flow + flow = flow_up + self.basic_module[level]( + torch.cat([ + ref[level], + flow_warp( + supp[level], + flow_up.permute(0, 2, 3, 1), + padding_mode='border'), flow_up + ], 1)) + + return flow + + def forward(self, ref, supp): + """Forward function of SPyNet. + + This function computes the optical flow from ref to supp. + + Args: + ref (Tensor): Reference image with shape of (n, 3, h, w). + supp (Tensor): Supporting image with shape of (n, 3, h, w). + + Returns: + Tensor: Estimated optical flow: (n, 2, h, w). + """ + + # upsize to a multiple of 32 + h, w = ref.shape[2:4] + w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1) + h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1) + ref = F.interpolate( + input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False) + supp = F.interpolate( + input=supp, + size=(h_up, w_up), + mode='bilinear', + align_corners=False) + + # compute flow, and resize back to the original resolution + flow = F.interpolate( + input=self.compute_flow(ref, supp), + size=(h, w), + mode='bilinear', + align_corners=False) + + # adjust the flow values + flow[:, 0, :, :] *= float(w) / float(w_up) + flow[:, 1, :, :] *= float(h) / float(h_up) + + return flow + + +class SPyNetBasicModule(nn.Module): + """Basic Module for SPyNet. + + Paper: + Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 + """ + + def __init__(self): + super().__init__() + + self.basic_module = nn.Sequential( + ConvModule( + in_channels=8, + out_channels=32, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule( + in_channels=32, + out_channels=64, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule( + in_channels=64, + out_channels=32, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule( + in_channels=32, + out_channels=16, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule( + in_channels=16, + out_channels=2, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=None)) + + def forward(self, tensor_input): + """ + Args: + tensor_input (Tensor): Input tensor with shape (b, 8, h, w). + 8 channels contain: + [reference image (3), neighbor image (3), initial flow (2)]. + + Returns: + Tensor: Refined flow with shape (b, 2, h, w) + """ + return self.basic_module(tensor_input) diff --git a/mmedit/models/restorers/__init__.py b/mmedit/models/restorers/__init__.py index 4caa1b7519..40eccbc07d 100644 --- a/mmedit/models/restorers/__init__.py +++ b/mmedit/models/restorers/__init__.py @@ -1,7 +1,8 @@ from .basic_restorer import BasicRestorer +from .basicvsr import BasicVSR from .edvr import EDVR from .esrgan import ESRGAN from .liif import LIIF from .srgan import SRGAN -__all__ = ['BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF'] +__all__ = ['BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR'] diff --git a/mmedit/models/restorers/basicvsr.py b/mmedit/models/restorers/basicvsr.py new file mode 100644 index 0000000000..9e0004421c --- /dev/null +++ b/mmedit/models/restorers/basicvsr.py @@ -0,0 +1,185 @@ +import numbers +import os.path as osp + +import mmcv +import numpy as np +import torch + +from mmedit.core import tensor2img +from ..registry import MODELS +from .basic_restorer import BasicRestorer + + +@MODELS.register_module() +class BasicVSR(BasicRestorer): + """BasicVSR model for video super-resolution. + + Note that this model is used for IconVSR. + + Paper: + BasicVSR: The Search for Essential Components in Video Super-Resolution + and Beyond, CVPR, 2021 + + Args: + generator (dict): Config for the generator structure. + pixel_loss (dict): Config for pixel-wise loss. + train_cfg (dict): Config for training. Default: None. + test_cfg (dict): Config for testing. Default: None. + pretrained (str): Path for pretrained model. Default: None. + """ + + def __init__(self, + generator, + pixel_loss, + train_cfg=None, + test_cfg=None, + pretrained=None): + super().__init__(generator, pixel_loss, train_cfg, test_cfg, + pretrained) + + # fix pre-trained networks + self.fix_iter = train_cfg.get('fix_iter', 0) if train_cfg else 0 + self.generator.find_unused_parameters = False + + # count training steps + self.register_buffer('step_counter', torch.zeros(1)) + + def train_step(self, data_batch, optimizer): + """Train step. + + Args: + data_batch (dict): A batch of data. + optimizer (obj): Optimizer. + + Returns: + dict: Returned output. + """ + # fix SPyNet and EDVR at the beginning + if self.step_counter < self.fix_iter: + if not self.generator.find_unused_parameters: + self.generator.find_unused_parameters = True + for k, v in self.generator.named_parameters(): + if 'spynet' in k or 'edvr' in k: + v.requires_grad_(False) + elif self.step_counter == self.fix_iter: + # train all the parameters + self.generator.find_unused_parameters = False + self.generator.requires_grad_(True) + + outputs = self(**data_batch, test_mode=False) + loss, log_vars = self.parse_losses(outputs.pop('losses')) + + # optimize + optimizer['generator'].zero_grad() + loss.backward() + optimizer['generator'].step() + + self.step_counter += 1 + + outputs.update({'log_vars': log_vars}) + return outputs + + def evaluate(self, output, gt): + """Evaluation function. + + If the output contains multiple frames, we compute the metric + one by one and take an average. + + Args: + output (Tensor): Model output with shape (n, t, c, h, w). + gt (Tensor): GT Tensor with shape (n, t, c, h, w). + + Returns: + dict: Evaluation results. + """ + crop_border = self.test_cfg.crop_border + eval_result = dict() + for metric in self.test_cfg.metrics: + if output.ndim == 5: # a sequence: (n, t, c, h, w) + avg = [] + for i in range(0, output.size(1)): + output_i = tensor2img(output[:, i, :, :, :]) + gt_i = tensor2img(gt[:, i, :, :, :]) + avg.append(self.allowed_metrics[metric](output_i, gt_i, + crop_border)) + eval_result[metric] = np.mean(avg) + elif output.ndim == 4: # an image: (n, c, t, w), for Vimeo-90K-T + output_img = tensor2img(output) + gt_img = tensor2img(gt) + value = self.allowed_metrics[metric](output_img, gt_img, + crop_border) + eval_result[metric] = value + + return eval_result + + def forward_test(self, + lq, + gt=None, + meta=None, + save_image=False, + save_path=None, + iteration=None): + """Testing forward function. + + Args: + lq (Tensor): LQ Tensor with shape (n, t, c, h, w). + gt (Tensor): GT Tensor with shape (n, t, c, h, w). Default: None. + save_image (bool): Whether to save image. Default: False. + save_path (str): Path to save image. Default: None. + iteration (int): Iteration for the saving image name. + Default: None. + + Returns: + dict: Output results. + """ + with torch.no_grad(): + output = self.generator(lq) + + # If the GT is an image (i.e. the cetner frame), the output sequence is + # turned to an image. + if gt is not None and gt.ndim == 4: + t = output.size(1) + if self.generator.is_mirror_extended: # with mirror extension + output = 0.5 * (output[:, t // 4] + output[:, -1 - t // 4]) + else: # without mirror extension + output = output[:, t // 2] + + if self.test_cfg is not None and self.test_cfg.get('metrics', None): + assert gt is not None, ( + 'evaluation with metrics must have gt images.') + results = dict(eval_result=self.evaluate(output, gt)) + else: + results = dict(lq=lq.cpu(), output=output.cpu()) + if gt is not None: + results['gt'] = gt.cpu() + + # save image + if save_image: + if output.ndim == 4: # an image, key = 000001/0000 (Vimeo-90K) + img_name = meta[0]['key'].replace('/', '_') + if isinstance(iteration, numbers.Number): + save_path = osp.join( + save_path, f'{img_name}-{iteration + 1:06d}.png') + elif iteration is None: + save_path = osp.join(save_path, f'{img_name}.png') + else: + raise ValueError('iteration should be number or None, ' + f'but got {type(iteration)}') + mmcv.imwrite(tensor2img(output), save_path) + elif output.ndim == 5: # a sequence, key = 000 + folder_name = meta[0]['key'].split('/')[0] + for i in range(0, output.size(1)): + if isinstance(iteration, numbers.Number): + save_path_i = osp.join( + save_path, folder_name, + f'{i:08d}-{iteration + 1:06d}.png') + elif iteration is None: + save_path_i = osp.join(save_path, folder_name, + f'{i:08d}.png') + else: + raise ValueError('iteration should be number or None, ' + f'but got {type(iteration)}') + mmcv.imwrite( + tensor2img(output[:, i, :, :, :]), save_path_i) + + return results diff --git a/tests/test_basicvsr_model.py b/tests/test_basicvsr_model.py new file mode 100644 index 0000000000..b6b0ec117f --- /dev/null +++ b/tests/test_basicvsr_model.py @@ -0,0 +1,147 @@ +import tempfile + +import mmcv +import pytest +import torch +from mmcv.runner import obj_from_dict + +from mmedit.models import build_model +from mmedit.models.backbones.sr_backbones import BasicVSRNet +from mmedit.models.losses import MSELoss + + +def test_basicvsr_model(): + + model_cfg = dict( + type='BasicVSR', + generator=dict( + type='BasicVSRNet', + mid_channels=64, + num_blocks=30, + spynet_pretrained=None), + pixel_loss=dict(type='MSELoss', loss_weight=1.0, reduction='sum'), + ) + + train_cfg = dict(fix_iter=1) + train_cfg = mmcv.Config(train_cfg) + test_cfg = None + + # build restorer + restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) + + # test attributes + assert restorer.__class__.__name__ == 'BasicVSR' + assert isinstance(restorer.generator, BasicVSRNet) + assert isinstance(restorer.pixel_loss, MSELoss) + + # prepare data + inputs = torch.rand(1, 5, 3, 64, 64) + targets = torch.rand(1, 5, 3, 256, 256) + + if torch.cuda.is_available(): + inputs = inputs.cuda() + targets = targets.cuda() + restorer = restorer.cuda() + + # prepare data and optimizer + data_batch = {'lq': inputs, 'gt': targets} + optim_cfg = dict(type='Adam', lr=2e-4, betas=(0.9, 0.999)) + optimizer = { + 'generator': + obj_from_dict(optim_cfg, torch.optim, + dict(params=getattr(restorer, 'generator').parameters())) + } + + # train_step (wihout updating spynet) + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + assert isinstance(outputs['log_vars']['loss_pix'], float) + assert outputs['num_samples'] == 1 + assert torch.equal(outputs['results']['lq'], data_batch['lq'].cpu()) + assert torch.equal(outputs['results']['gt'], data_batch['gt'].cpu()) + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 5, 3, 256, 256) + + # train with spynet updated + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + assert isinstance(outputs['log_vars']['loss_pix'], float) + assert outputs['num_samples'] == 1 + assert torch.equal(outputs['results']['lq'], data_batch['lq'].cpu()) + assert torch.equal(outputs['results']['gt'], data_batch['gt'].cpu()) + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 5, 3, 256, 256) + + # test forward_dummy + with torch.no_grad(): + output = restorer.forward_dummy(data_batch['lq']) + assert torch.is_tensor(output) + assert output.size() == (1, 5, 3, 256, 256) + + # forward_test + with torch.no_grad(): + outputs = restorer(**data_batch, test_mode=True) + assert torch.equal(outputs['lq'], data_batch['lq'].cpu()) + assert torch.equal(outputs['gt'], data_batch['gt'].cpu()) + assert torch.is_tensor(outputs['output']) + assert outputs['output'].size() == (1, 5, 3, 256, 256) + + with torch.no_grad(): + outputs = restorer(inputs, test_mode=True) + assert torch.equal(outputs['lq'], data_batch['lq'].cpu()) + assert torch.is_tensor(outputs['output']) + assert outputs['output'].size() == (1, 5, 3, 256, 256) + + # test with metric and save image + train_cfg = mmcv.ConfigDict(fix_iter=1) + test_cfg = dict(metrics=('PSNR', 'SSIM'), crop_border=0) + test_cfg = mmcv.Config(test_cfg) + + data_batch = { + 'lq': inputs, + 'gt': targets, + 'meta': [{ + 'gt_path': 'fake_path/fake_name.png', + 'key': '000' + }] + } + + restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) + + with pytest.raises(AssertionError): + # evaluation with metrics must have gt images + restorer(lq=inputs, test_mode=True) + + with tempfile.TemporaryDirectory() as tmpdir: + outputs = restorer( + **data_batch, + test_mode=True, + save_image=True, + save_path=tmpdir, + iteration=None) + assert isinstance(outputs, dict) + assert isinstance(outputs['eval_result'], dict) + assert isinstance(outputs['eval_result']['PSNR'], float) + assert isinstance(outputs['eval_result']['SSIM'], float) + + outputs = restorer( + **data_batch, + test_mode=True, + save_image=True, + save_path=tmpdir, + iteration=100) + assert isinstance(outputs, dict) + assert isinstance(outputs['eval_result'], dict) + assert isinstance(outputs['eval_result']['PSNR'], float) + assert isinstance(outputs['eval_result']['SSIM'], float) + + with pytest.raises(ValueError): + # iteration should be number or None + restorer( + **data_batch, + test_mode=True, + save_image=True, + save_path=tmpdir, + iteration='100') diff --git a/tests/test_basicvsr_net.py b/tests/test_basicvsr_net.py new file mode 100644 index 0000000000..95458e55ec --- /dev/null +++ b/tests/test_basicvsr_net.py @@ -0,0 +1,34 @@ +import pytest +import torch + +from mmedit.models.backbones.sr_backbones.basicvsr_net import BasicVSRNet + + +def test_basicvsr_net(): + """Test BasicVSR.""" + + # cpu + basicvsr = BasicVSRNet( + mid_channels=64, num_blocks=30, spynet_pretrained=None) + input_tensor = torch.rand(1, 5, 3, 64, 64) + basicvsr.init_weights(pretrained=None) + output = basicvsr(input_tensor) + assert output.shape == (1, 5, 3, 256, 256) + + # gpu + if torch.cuda.is_available(): + basicvsr = BasicVSRNet( + mid_channels=64, num_blocks=30, spynet_pretrained=None).cuda() + input_tensor = torch.rand(1, 5, 3, 64, 64).cuda() + basicvsr.init_weights(pretrained=None) + output = basicvsr(input_tensor) + assert output.shape == (1, 5, 3, 256, 256) + + with pytest.raises(AssertionError): + # The height and width of inputs should be at least 64 + input_tensor = torch.rand(1, 5, 3, 61, 61) + basicvsr(input_tensor) + + with pytest.raises(TypeError): + # pretrained should be str or None + basicvsr.init_weights(pretrained=[1])