-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add demo for video super-resolution methods (#275)
* Add demo for video super-resolution methods * Add video restoration to README * Update demo
- Loading branch information
1 parent
a0eaf22
commit 764e606
Showing
4 changed files
with
143 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import argparse | ||
|
||
import mmcv | ||
import torch | ||
|
||
from mmedit.apis import init_model, restoration_video_inference | ||
from mmedit.core import tensor2img | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Restoration demo') | ||
parser.add_argument('config', help='test config file path') | ||
parser.add_argument('checkpoint', help='checkpoint file') | ||
parser.add_argument('input_dir', help='directory of the input video') | ||
parser.add_argument('output_dir', help='directory of the output video') | ||
parser.add_argument( | ||
'--window_size', | ||
type=int, | ||
default=0, | ||
help='window size if sliding-window framework is used') | ||
parser.add_argument('--device', type=int, default=0, help='CUDA device id') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
model = init_model( | ||
args.config, args.checkpoint, device=torch.device('cuda', args.device)) | ||
|
||
output = restoration_video_inference(model, args.input_dir, | ||
args.window_size) | ||
for i in range(0, output.size(1)): | ||
output_i = output[:, i, :, :, :] | ||
output_i = tensor2img(output_i) | ||
save_path_i = f'{args.output_dir}/{i:08d}.png' | ||
|
||
mmcv.imwrite(output_i, save_path_i) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import glob | ||
|
||
import torch | ||
from mmcv.parallel import collate, scatter | ||
|
||
from mmedit.datasets.pipelines import Compose | ||
|
||
|
||
def pad_sequence(data, window_size): | ||
padding = window_size // 2 | ||
|
||
data = torch.cat([ | ||
data[:, 1 + padding:1 + 2 * padding].flip(1), data, | ||
data[:, -1 - 2 * padding:-1 - padding].flip(1) | ||
], | ||
dim=1) | ||
|
||
return data | ||
|
||
|
||
def restoration_video_inference(model, img_dir, window_size): | ||
"""Inference image with the model. | ||
Args: | ||
model (nn.Module): The loaded model. | ||
img_dir (str): Directory of the input video. | ||
window_size (int): The window size used in sliding-window framework. | ||
This value should be set according to the settings of the network. | ||
A value smaller than 0 means using recurrent framework. | ||
Returns: | ||
Tensor: The predicted restoration result. | ||
""" | ||
device = next(model.parameters()).device # model device | ||
|
||
# pipeline | ||
test_pipeline = [ | ||
dict(type='GenerateSegmentIndices', interval_list=[1]), | ||
dict( | ||
type='LoadImageFromFileList', | ||
io_backend='disk', | ||
key='lq', | ||
channel_order='rgb'), | ||
dict(type='RescaleToZeroOne', keys=['lq']), | ||
dict(type='FramesToTensor', keys=['lq']), | ||
dict(type='Collect', keys=['lq'], meta_keys=['lq_path', 'key']) | ||
] | ||
|
||
# build the data pipeline | ||
test_pipeline = Compose(test_pipeline) | ||
|
||
# prepare data | ||
sequence_length = len(glob.glob(f'{img_dir}/*')) | ||
key = img_dir.split('/')[-1] | ||
lq_folder = '/'.join(img_dir.split('/')[:-1]) | ||
data = dict( | ||
lq_path=lq_folder, | ||
gt_path='', | ||
key=key, | ||
sequence_length=sequence_length) | ||
data = test_pipeline(data) | ||
data = scatter(collate([data], samples_per_gpu=1), [device])[0]['lq'] | ||
|
||
# forward the model | ||
with torch.no_grad(): | ||
if window_size > 0: # sliding window framework | ||
data = pad_sequence(data, window_size) | ||
result = [] | ||
for i in range(0, data.size(1) - 2 * window_size): | ||
data_i = data[:, i:i + window_size] | ||
result.append(model(lq=data_i, test_mode=True)['output']) | ||
result = torch.stack(result, dim=1) | ||
else: # recurrent framework | ||
result = model(lq=data, test_mode=True)['output'] | ||
|
||
return result |