forked from princeton-vl/RAFT
-
Notifications
You must be signed in to change notification settings - Fork 12
/
run.py
77 lines (58 loc) · 2.05 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import sys
sys.path.append('core')
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder
DEVICE = 'cuda' if torch.cuda.is_available else 'cpu'
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)
def viz(img, flo, i):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
plt.imshow(img_flo / 255.0)
plt.savefig(f'/home/sharif/Documents/RAFT/test_vis/{i}.png')
# clear plt
plt.clf()
plt.cla()
def run(args):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load('models/raft-things.pth'))
model = model.module
model.to(DEVICE)
model.eval()
output_dir = Path(args.output_dir)
images_dir = Path(args.images_dir)
images = list(images_dir.glob('*.png')) + list(images_dir.glob('*.jpg'))
with torch.no_grad():
images = sorted(images)
for i in range(len(images)-1):
im_f1 = str(images[i])
im_f2 = str(images[i+1])
image1 = load_image(im_f1)
image2 = load_image(im_f2)
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
# 2.2 MB
of_f_name = output_dir / f'{i}.npy'
np.save(of_f_name, flow_up.cpu())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--images_dir', help="directory with your images")
parser.add_argument('--output_dir', help="optical flow images will be stored here as .npy files")
args = parser.parse_args()
run(args)