-
Notifications
You must be signed in to change notification settings - Fork 13
/
video_loader.py
80 lines (69 loc) · 2.64 KB
/
video_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch as th
from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import ffmpeg
class VideoLoader(Dataset):
"""Pytorch video loader."""
def __init__(
self,
csv,
framerate=1,
size=112,
centercrop=False,
):
"""
Args:
"""
self.csv = pd.read_csv(csv)
self.centercrop = centercrop
self.size = size
self.framerate = framerate
def __len__(self):
return len(self.csv)
def _get_video_dim(self, video_path):
probe = ffmpeg.probe(video_path)
video_stream = next((stream for stream in probe['streams']
if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
return height, width
def _get_output_dim(self, h, w):
if isinstance(self.size, tuple) and len(self.size) == 2:
return self.size
elif h >= w:
return int(h * self.size / w), self.size
else:
return self.size, int(w * self.size / h)
def __getitem__(self, idx):
video_path = self.csv['video_path'].values[idx]
output_file = self.csv['feature_path'].values[idx]
video = th.zeros(1)
if not (os.path.isfile(output_file)) and os.path.isfile(video_path):
print('Decoding video: {}'.format(video_path))
try:
h, w = self._get_video_dim(video_path)
height, width = self._get_output_dim(h, w)
cmd = (
ffmpeg
.input(video_path)
.filter('fps', fps=self.framerate)
.filter('scale', width, height)
)
if self.centercrop:
x = int((width - self.size) / 2.0)
y = int((height - self.size) / 2.0)
cmd = cmd.crop(x, y, self.size, self.size)
out, _ = (
cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
if self.centercrop and isinstance(self.size, int):
height, width = self.size, self.size
video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
video = th.from_numpy(video.astype('float32'))
video = video.permute(0, 3, 1, 2)
except:
print('ffprobe failed at: {}'.format(video_path))
return {'video': video, 'input': video_path, 'output': output_file}