-
Notifications
You must be signed in to change notification settings - Fork 16
/
data.py
68 lines (55 loc) · 2.03 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pickle
from glob import glob
from typing import Tuple, List
import torch
import torch.utils.data as data
class AV(data.Dataset):
def __init__(self, path: str):
self.path = path
self.data = []
def __len__(self):
return len(self.data)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
return self.data[idx]
class AudioVideo(AV):
def __init__(self, path: str):
# output format:
# return (
# torch.rand((1, 96, 64)),
# torch.rand((3, 224, 224)),
# np.random.choice([0, 1])
# )
super().__init__(path)
for file_path in glob(f'{path}/*.pkl'):
audios, images, label = pickle.load(open(file_path, 'rb'))
self.data += [(audios[i], images[i], label) for i in range(len(audios))]
class AudioVideo3D(AV):
def __init__(self, path: str):
# output format:
# return (
# torch.rand((1, 96, 64)),
# torch.rand((3, 16, 224, 224)),
# np.random.choice([0, 1])
# )
super().__init__(path)
frames = 16
for file_path in glob(f'{path}/*.pkl'):
audios, images, label = pickle.load(open(file_path, 'rb'))
images_temporal = self._process_temporal_tensor(images, frames)
self.data += [(audios[i], images_temporal[i], label) for i in range(len(audios))]
@staticmethod
def _process_temporal_tensor(images: List[torch.Tensor],
frames: int) -> List[torch.Tensor]:
out = []
for i in range(len(images)):
e = torch.zeros((frames, 3, 224, 224))
e[-1] = images[0]
for j in range(min(i, frames)):
e[-1 - j] = images[j]
# try:
# e[-1 - j] = images[j]
# except:
# raise ValueError(f"trying to get {i} from images with len = {len(images)}")
ee = e.permute((1, 0, 2, 3))
out.append(ee)
return out