Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support video input and output in restoration demo #622

Merged
merged 8 commits into from
Dec 29, 2021
42 changes: 36 additions & 6 deletions demo/restoration_video_demo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os

import cv2
import mmcv
import numpy as np
import torch

from mmedit.apis import init_model, restoration_video_inference
from mmedit.core import tensor2img

VIDEO_EXTENSIONS = ('.mp4', '.mov')


def parse_args():
parser = argparse.ArgumentParser(description='Restoration demo')
Expand All @@ -28,26 +33,51 @@ def parse_args():
type=int,
default=0,
help='window size if sliding-window framework is used')
parser.add_argument(
'--max_seq_len',
type=int,
default=None,
help='maximum sequence length if recurrent framework is used')
parser.add_argument('--device', type=int, default=0, help='CUDA device id')
args = parser.parse_args()
return args


def main():
""" Demo for video restoration models.

Note that we accept video as input/output, when 'input_dir'/'output_dir'
is set to the path to the video. But using videos introduces video
compression, which lowers the visual quality. If you want actual quality,
please save them as separate images (.png).
ckkelvinchan marked this conversation as resolved.
Show resolved Hide resolved
"""

args = parse_args()

model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))

output = restoration_video_inference(model, args.input_dir,
args.window_size, args.start_idx,
args.filename_tmpl)
for i in range(args.start_idx, args.start_idx + output.size(1)):
output_i = output[:, i - args.start_idx, :, :, :]
output_i = tensor2img(output_i)
save_path_i = f'{args.output_dir}/{args.filename_tmpl.format(i)}'
args.filename_tmpl, args.max_seq_len)

file_extension = os.path.splitext(args.output_dir)[1]
if file_extension in VIDEO_EXTENSIONS: # save as video
h, w = output.shape[-2:]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(args.output_dir, fourcc, 25, (w, h))
for i in range(0, output.size(1)):
img = tensor2img(output[:, i, :, :, :])
video_writer.write(img.astype(np.uint8))
cv2.destroyAllWindows()
video_writer.release()
else:
for i in range(args.start_idx, args.start_idx + output.size(1)):
output_i = output[:, i - args.start_idx, :, :, :]
output_i = tensor2img(output_i)
save_path_i = f'{args.output_dir}/{args.filename_tmpl.format(i)}'

mmcv.imwrite(output_i, save_path_i)
mmcv.imwrite(output_i, save_path_i)


if __name__ == '__main__':
Expand Down
87 changes: 62 additions & 25 deletions mmedit/apis/restoration_video_inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os

import mmcv
import numpy as np
import torch
from mmcv.parallel import collate, scatter

from mmedit.datasets.pipelines import Compose

VIDEO_EXTENSIONS = ('.mp4', '.mov')


def pad_sequence(data, window_size):
padding = window_size // 2
Expand All @@ -19,8 +24,12 @@ def pad_sequence(data, window_size):
return data


def restoration_video_inference(model, img_dir, window_size, start_idx,
filename_tmpl):
def restoration_video_inference(model,
img_dir,
window_size,
start_idx,
filename_tmpl,
max_seq_len=None):
"""Inference image with the model.

Args:
Expand All @@ -32,6 +41,10 @@ def restoration_video_inference(model, img_dir, window_size, start_idx,
start_idx (int): The index corresponds to the first frame in the
sequence.
filename_tmpl (str): Template for file name.
max_seq_len (int | None): The maximum sequence length that the model
processes. If the sequence length is larger than this number,
the sequence is split into multiple segments. If it is None,
the entire sequence is processed at once.

Returns:
Tensor: The predicted restoration result.
Expand All @@ -47,41 +60,65 @@ def restoration_video_inference(model, img_dir, window_size, start_idx,
else:
test_pipeline = model.cfg.val_pipeline

# the first element in the pipeline must be 'GenerateSegmentIndices'
if test_pipeline[0]['type'] != 'GenerateSegmentIndices':
raise TypeError('The first element in the pipeline must be '
f'"GenerateSegmentIndices", but got '
f'"{test_pipeline[0]["type"]}".')

# specify start_idx and filename_tmpl
test_pipeline[0]['start_idx'] = start_idx
test_pipeline[0]['filename_tmpl'] = filename_tmpl
# check if the input is a video
file_extension = os.path.splitext(img_dir)[1]
if file_extension in VIDEO_EXTENSIONS:
video_reader = mmcv.VideoReader(img_dir)
# load the images
data = dict(lq=[], lq_path=None, key=img_dir)
for frame in video_reader:
data['lq'].append(np.flip(frame, axis=2))

# remove the data loading pipeline
tmp_pipeline = []
for pipeline in test_pipeline:
if pipeline['type'] not in [
'GenerateSegmentIndices', 'LoadImageFromFileList'
]:
tmp_pipeline.append(pipeline)
test_pipeline = tmp_pipeline
else:
# the first element in the pipeline must be 'GenerateSegmentIndices'
if test_pipeline[0]['type'] != 'GenerateSegmentIndices':
raise TypeError('The first element in the pipeline must be '
f'"GenerateSegmentIndices", but got '
f'"{test_pipeline[0]["type"]}".')

# specify start_idx and filename_tmpl
test_pipeline[0]['start_idx'] = start_idx
test_pipeline[0]['filename_tmpl'] = filename_tmpl

# prepare data
sequence_length = len(glob.glob(f'{img_dir}/*'))
key = img_dir.split('/')[-1]
lq_folder = '/'.join(img_dir.split('/')[:-1])
data = dict(
lq_path=lq_folder,
gt_path='',
key=key,
sequence_length=sequence_length)

# compose the pipeline
test_pipeline = Compose(test_pipeline)

# prepare data
sequence_length = len(glob.glob(f'{img_dir}/*'))
key = img_dir.split('/')[-1]
lq_folder = '/'.join(img_dir.split('/')[:-1])
data = dict(
lq_path=lq_folder,
gt_path='',
key=key,
sequence_length=sequence_length)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq']

# forward the model
with torch.no_grad():
if window_size > 0: # sliding window framework
data = pad_sequence(data, window_size)
result = []
for i in range(0, data.size(1) - 2 * (window_size // 2)):
data_i = data[:, i:i + window_size]
result.append(model(lq=data_i, test_mode=True)['output'])
result.append(model(lq=data_i, test_mode=True)['output'].cpu())
result = torch.stack(result, dim=1)
else: # recurrent framework
result = model(lq=data, test_mode=True)['output']

if max_seq_len is None:
result = model(lq=data, test_mode=True)['output'].cpu()
else:
result = []
for i in range(0, data.size(1), max_seq_len):
result.append(
model(lq=data[:, i:i + max_seq_len],
test_mode=True)['output'].cpu())
result = torch.cat(result, dim=1)
return result
Binary file added tests/data/test_inference.mp4
Binary file not shown.
14 changes: 14 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,17 @@ def test_restoration_video_inference():
model.cfg.val_pipeline = model.cfg.val_pipeline[1:]
output = restoration_video_inference(model, img_dir, window_size,
start_idx, filename_tmpl)

# video (mp4) input
model = init_model(
'./configs/restorers/basicvsr/basicvsr_reds4.py',
None,
device='cuda')
img_dir = './tests/data/test_inference.mp4'
window_size = 0
start_idx = 1
filename_tmpl = 'im{}.png'

output = restoration_video_inference(model, img_dir, window_size,
start_idx, filename_tmpl)
assert output.shape == (1, 5, 3, 256, 256)