diff --git a/demo/restoration_video_demo.py b/demo/restoration_video_demo.py index af1cc1b1b1..261f65203e 100644 --- a/demo/restoration_video_demo.py +++ b/demo/restoration_video_demo.py @@ -1,12 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import os +import cv2 import mmcv +import numpy as np import torch from mmedit.apis import init_model, restoration_video_inference from mmedit.core import tensor2img +VIDEO_EXTENSIONS = ('.mp4', '.mov') + def parse_args(): parser = argparse.ArgumentParser(description='Restoration demo') @@ -28,12 +33,25 @@ def parse_args(): type=int, default=0, help='window size if sliding-window framework is used') + parser.add_argument( + '--max_seq_len', + type=int, + default=None, + help='maximum sequence length if recurrent framework is used') parser.add_argument('--device', type=int, default=0, help='CUDA device id') args = parser.parse_args() return args def main(): + """ Demo for video restoration models. + + Note that we accept video as input/output, when 'input_dir'/'output_dir' + is set to the path to the video. But using videos introduces video + compression, which lowers the visual quality. If you want actual quality, + please save them as separate images (.png). + """ + args = parse_args() model = init_model( @@ -41,13 +59,25 @@ def main(): output = restoration_video_inference(model, args.input_dir, args.window_size, args.start_idx, - args.filename_tmpl) - for i in range(args.start_idx, args.start_idx + output.size(1)): - output_i = output[:, i - args.start_idx, :, :, :] - output_i = tensor2img(output_i) - save_path_i = f'{args.output_dir}/{args.filename_tmpl.format(i)}' + args.filename_tmpl, args.max_seq_len) + + file_extension = os.path.splitext(args.output_dir)[1] + if file_extension in VIDEO_EXTENSIONS: # save as video + h, w = output.shape[-2:] + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_writer = cv2.VideoWriter(args.output_dir, fourcc, 25, (w, h)) + for i in range(0, output.size(1)): + img = tensor2img(output[:, i, :, :, :]) + video_writer.write(img.astype(np.uint8)) + cv2.destroyAllWindows() + video_writer.release() + else: + for i in range(args.start_idx, args.start_idx + output.size(1)): + output_i = output[:, i - args.start_idx, :, :, :] + output_i = tensor2img(output_i) + save_path_i = f'{args.output_dir}/{args.filename_tmpl.format(i)}' - mmcv.imwrite(output_i, save_path_i) + mmcv.imwrite(output_i, save_path_i) if __name__ == '__main__': diff --git a/mmedit/apis/restoration_video_inference.py b/mmedit/apis/restoration_video_inference.py index 24f0e0aef9..e57a800a49 100644 --- a/mmedit/apis/restoration_video_inference.py +++ b/mmedit/apis/restoration_video_inference.py @@ -1,11 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import glob +import os +import mmcv +import numpy as np import torch from mmcv.parallel import collate, scatter from mmedit.datasets.pipelines import Compose +VIDEO_EXTENSIONS = ('.mp4', '.mov') + def pad_sequence(data, window_size): padding = window_size // 2 @@ -19,8 +24,12 @@ def pad_sequence(data, window_size): return data -def restoration_video_inference(model, img_dir, window_size, start_idx, - filename_tmpl): +def restoration_video_inference(model, + img_dir, + window_size, + start_idx, + filename_tmpl, + max_seq_len=None): """Inference image with the model. Args: @@ -32,6 +41,10 @@ def restoration_video_inference(model, img_dir, window_size, start_idx, start_idx (int): The index corresponds to the first frame in the sequence. filename_tmpl (str): Template for file name. + max_seq_len (int | None): The maximum sequence length that the model + processes. If the sequence length is larger than this number, + the sequence is split into multiple segments. If it is None, + the entire sequence is processed at once. Returns: Tensor: The predicted restoration result. @@ -47,31 +60,48 @@ def restoration_video_inference(model, img_dir, window_size, start_idx, else: test_pipeline = model.cfg.val_pipeline - # the first element in the pipeline must be 'GenerateSegmentIndices' - if test_pipeline[0]['type'] != 'GenerateSegmentIndices': - raise TypeError('The first element in the pipeline must be ' - f'"GenerateSegmentIndices", but got ' - f'"{test_pipeline[0]["type"]}".') - - # specify start_idx and filename_tmpl - test_pipeline[0]['start_idx'] = start_idx - test_pipeline[0]['filename_tmpl'] = filename_tmpl + # check if the input is a video + file_extension = os.path.splitext(img_dir)[1] + if file_extension in VIDEO_EXTENSIONS: + video_reader = mmcv.VideoReader(img_dir) + # load the images + data = dict(lq=[], lq_path=None, key=img_dir) + for frame in video_reader: + data['lq'].append(np.flip(frame, axis=2)) + + # remove the data loading pipeline + tmp_pipeline = [] + for pipeline in test_pipeline: + if pipeline['type'] not in [ + 'GenerateSegmentIndices', 'LoadImageFromFileList' + ]: + tmp_pipeline.append(pipeline) + test_pipeline = tmp_pipeline + else: + # the first element in the pipeline must be 'GenerateSegmentIndices' + if test_pipeline[0]['type'] != 'GenerateSegmentIndices': + raise TypeError('The first element in the pipeline must be ' + f'"GenerateSegmentIndices", but got ' + f'"{test_pipeline[0]["type"]}".') + + # specify start_idx and filename_tmpl + test_pipeline[0]['start_idx'] = start_idx + test_pipeline[0]['filename_tmpl'] = filename_tmpl + + # prepare data + sequence_length = len(glob.glob(f'{img_dir}/*')) + key = img_dir.split('/')[-1] + lq_folder = '/'.join(img_dir.split('/')[:-1]) + data = dict( + lq_path=lq_folder, + gt_path='', + key=key, + sequence_length=sequence_length) # compose the pipeline test_pipeline = Compose(test_pipeline) - - # prepare data - sequence_length = len(glob.glob(f'{img_dir}/*')) - key = img_dir.split('/')[-1] - lq_folder = '/'.join(img_dir.split('/')[:-1]) - data = dict( - lq_path=lq_folder, - gt_path='', - key=key, - sequence_length=sequence_length) data = test_pipeline(data) data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq'] - # forward the model with torch.no_grad(): if window_size > 0: # sliding window framework @@ -79,9 +109,16 @@ def restoration_video_inference(model, img_dir, window_size, start_idx, result = [] for i in range(0, data.size(1) - 2 * (window_size // 2)): data_i = data[:, i:i + window_size] - result.append(model(lq=data_i, test_mode=True)['output']) + result.append(model(lq=data_i, test_mode=True)['output'].cpu()) result = torch.stack(result, dim=1) else: # recurrent framework - result = model(lq=data, test_mode=True)['output'] - + if max_seq_len is None: + result = model(lq=data, test_mode=True)['output'].cpu() + else: + result = [] + for i in range(0, data.size(1), max_seq_len): + result.append( + model(lq=data[:, i:i + max_seq_len], + test_mode=True)['output'].cpu()) + result = torch.cat(result, dim=1) return result diff --git a/tests/data/test_inference.mp4 b/tests/data/test_inference.mp4 new file mode 100644 index 0000000000..1c61f2a2fb Binary files /dev/null and b/tests/data/test_inference.mp4 differ diff --git a/tests/test_inference.py b/tests/test_inference.py index c55e0e103d..2f4273472f 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -50,3 +50,17 @@ def test_restoration_video_inference(): model.cfg.val_pipeline = model.cfg.val_pipeline[1:] output = restoration_video_inference(model, img_dir, window_size, start_idx, filename_tmpl) + + # video (mp4) input + model = init_model( + './configs/restorers/basicvsr/basicvsr_reds4.py', + None, + device='cuda') + img_dir = './tests/data/test_inference.mp4' + window_size = 0 + start_idx = 1 + filename_tmpl = 'im{}.png' + + output = restoration_video_inference(model, img_dir, window_size, + start_idx, filename_tmpl) + assert output.shape == (1, 5, 3, 256, 256)