Skip to content

Commit

Permalink
Add demo for video super-resolution methods (#275)
Browse files Browse the repository at this point in the history
* Add demo for video super-resolution methods

* Add video restoration to README

* Update demo
  • Loading branch information
ckkelvinchan authored May 20, 2021
1 parent a0eaf22 commit 764e606
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 3 deletions.
43 changes: 43 additions & 0 deletions demo/restoration_video_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import argparse

import mmcv
import torch

from mmedit.apis import init_model, restoration_video_inference
from mmedit.core import tensor2img


def parse_args():
parser = argparse.ArgumentParser(description='Restoration demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('input_dir', help='directory of the input video')
parser.add_argument('output_dir', help='directory of the output video')
parser.add_argument(
'--window_size',
type=int,
default=0,
help='window size if sliding-window framework is used')
parser.add_argument('--device', type=int, default=0, help='CUDA device id')
args = parser.parse_args()
return args


def main():
args = parse_args()

model = init_model(
args.config, args.checkpoint, device=torch.device('cuda', args.device))

output = restoration_video_inference(model, args.input_dir,
args.window_size)
for i in range(0, output.size(1)):
output_i = output[:, i, :, :, :]
output_i = tensor2img(output_i)
save_path_i = f'{args.output_dir}/{i:08d}.png'

mmcv.imwrite(output_i, save_path_i)


if __name__ == '__main__':
main()
24 changes: 22 additions & 2 deletions docs/demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ python demo/matting_demo.py configs/mattors/dim/dim_stage3_v16_pln_1x1_1000k_com

The predicted alpha matte will be save in `tests/data/pred/GT05.png`.

#### Restoration
#### Restoration (Image)

You can use the following commands to test an image for restoration.

Expand All @@ -47,8 +47,28 @@ If `--imshow` is specified, the demo will also show image with opencv. Examples:
```shell
python demo/restoration_demo.py configs/restorer/esrgan/esrgan_x4c64b23g32_1x16_400k_div2k.py work_dirs/esrgan_x4c64b23g32_1x16_400k_div2k/latest.pth tests/data/lq/baboon_x4.png demo/demo_out_baboon.png
```
#### Restoration (Video)

The restored image will be save in `demo/demo_out_baboon.png`.
You can use the following commands to test a video for restoration.

```shell
python demo/restoration_video_demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} ${INPUT_DIR} ${OUTPUT_DIR} [--window_size=$WINDOW_SIZE] [--device ${GPU_ID}]
```

It suppots both the sliding-window framework and the recurrent framework. Examples:


EDVR:
```shell
python demo/restoration_video_demo.py ./configs/restorers/edvr/edvrm_wotsa_x4_g8_600k_reds.py https://download.openmmlab.com/mmediting/restorers/edvr/edvrm_wotsa_x4_8x4_600k_reds_20200522-0570e567.pth data/Vid4/BIx4/calendar/ ./output --window_size=5
```

BasicVSR:
```shell
python demo/restoration_video_demo.py ./configs/restorers/basicvsr/basicvsr_reds4.py https://download.openmmlab.com/mmediting/restorers/basicvsr/basicvsr_reds4_20120409-0e599677.pth data/Vid4/BIx4/calendar/ ./output
```

The restored video will be save in `output/`.

#### Generation

Expand Down
3 changes: 2 additions & 1 deletion mmedit/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from .inpainting_inference import inpainting_inference
from .matting_inference import init_model, matting_inference
from .restoration_inference import restoration_inference
from .restoration_video_inference import restoration_video_inference
from .test import multi_gpu_test, single_gpu_test
from .train import set_random_seed, train_model

__all__ = [
'train_model', 'set_random_seed', 'init_model', 'matting_inference',
'inpainting_inference', 'restoration_inference', 'generation_inference',
'multi_gpu_test', 'single_gpu_test'
'multi_gpu_test', 'single_gpu_test', 'restoration_video_inference'
]
76 changes: 76 additions & 0 deletions mmedit/apis/restoration_video_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import glob

import torch
from mmcv.parallel import collate, scatter

from mmedit.datasets.pipelines import Compose


def pad_sequence(data, window_size):
padding = window_size // 2

data = torch.cat([
data[:, 1 + padding:1 + 2 * padding].flip(1), data,
data[:, -1 - 2 * padding:-1 - padding].flip(1)
],
dim=1)

return data


def restoration_video_inference(model, img_dir, window_size):
"""Inference image with the model.
Args:
model (nn.Module): The loaded model.
img_dir (str): Directory of the input video.
window_size (int): The window size used in sliding-window framework.
This value should be set according to the settings of the network.
A value smaller than 0 means using recurrent framework.
Returns:
Tensor: The predicted restoration result.
"""
device = next(model.parameters()).device # model device

# pipeline
test_pipeline = [
dict(type='GenerateSegmentIndices', interval_list=[1]),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='lq',
channel_order='rgb'),
dict(type='RescaleToZeroOne', keys=['lq']),
dict(type='FramesToTensor', keys=['lq']),
dict(type='Collect', keys=['lq'], meta_keys=['lq_path', 'key'])
]

# build the data pipeline
test_pipeline = Compose(test_pipeline)

# prepare data
sequence_length = len(glob.glob(f'{img_dir}/*'))
key = img_dir.split('/')[-1]
lq_folder = '/'.join(img_dir.split('/')[:-1])
data = dict(
lq_path=lq_folder,
gt_path='',
key=key,
sequence_length=sequence_length)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq']

# forward the model
with torch.no_grad():
if window_size > 0: # sliding window framework
data = pad_sequence(data, window_size)
result = []
for i in range(0, data.size(1) - 2 * window_size):
data_i = data[:, i:i + window_size]
result.append(model(lq=data_i, test_mode=True)['output'])
result = torch.stack(result, dim=1)
else: # recurrent framework
result = model(lq=data, test_mode=True)['output']

return result

0 comments on commit 764e606

Please sign in to comment.