-
Notifications
You must be signed in to change notification settings - Fork 249
/
Copy pathspatial_dataloader.py
167 lines (135 loc) · 5.9 KB
/
spatial_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import pickle
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import random
from split_train_test_video import *
from skimage import io, color, exposure
class spatial_dataset(Dataset):
def __init__(self, dic, root_dir, mode, transform=None):
self.keys = dic.keys()
self.values=dic.values()
self.root_dir = root_dir
self.mode =mode
self.transform = transform
def __len__(self):
return len(self.keys)
def load_ucf_image(self,video_name, index):
if video_name.split('_')[0] == 'HandstandPushups':
n,g = video_name.split('_',1)
name = 'HandStandPushups_'+g
path = self.root_dir + 'HandstandPushups'+'/separated_images/v_'+name+'/v_'+name+'_'
else:
path = self.root_dir + video_name.split('_')[0]+'/separated_images/v_'+video_name+'/v_'+video_name+'_'
img = Image.open(path +str(index)+'.jpg')
transformed_img = self.transform(img)
img.close()
return transformed_img
def __getitem__(self, idx):
if self.mode == 'train':
video_name, nb_clips = self.keys[idx].split(' ')
nb_clips = int(nb_clips)
clips = []
clips.append(random.randint(1, nb_clips/3))
clips.append(random.randint(nb_clips/3, nb_clips*2/3))
clips.append(random.randint(nb_clips*2/3, nb_clips+1))
elif self.mode == 'val':
video_name, index = self.keys[idx].split(' ')
index =abs(int(index))
else:
raise ValueError('There are only train and val mode')
label = self.values[idx]
label = int(label)-1
if self.mode=='train':
data ={}
for i in range(len(clips)):
key = 'img'+str(i)
index = clips[i]
data[key] = self.load_ucf_image(video_name, index)
sample = (data, label)
elif self.mode=='val':
data = self.load_ucf_image(video_name,index)
sample = (video_name, data, label)
else:
raise ValueError('There are only train and val mode')
return sample
class spatial_dataloader():
def __init__(self, BATCH_SIZE, num_workers, path, ucf_list, ucf_split):
self.BATCH_SIZE=BATCH_SIZE
self.num_workers=num_workers
self.data_path=path
self.frame_count ={}
# split the training and testing videos
splitter = UCF101_splitter(path=ucf_list,split=ucf_split)
self.train_video, self.test_video = splitter.split_video()
def load_frame_count(self):
#print '==> Loading frame number of each video'
with open('dic/frame_count.pickle','rb') as file:
dic_frame = pickle.load(file)
file.close()
for line in dic_frame :
videoname = line.split('_',1)[1].split('.',1)[0]
n,g = videoname.split('_',1)
if n == 'HandStandPushups':
videoname = 'HandstandPushups_'+ g
self.frame_count[videoname]=dic_frame[line]
def run(self):
self.load_frame_count()
self.get_training_dic()
self.val_sample20()
train_loader = self.train()
val_loader = self.validate()
return train_loader, val_loader, self.test_video
def get_training_dic(self):
#print '==> Generate frame numbers of each training video'
self.dic_training={}
for video in self.train_video:
#print videoname
nb_frame = self.frame_count[video]-10+1
key = video+' '+ str(nb_frame)
self.dic_training[key] = self.train_video[video]
def val_sample20(self):
print '==> sampling testing frames'
self.dic_testing={}
for video in self.test_video:
nb_frame = self.frame_count[video]-10+1
interval = int(nb_frame/19)
for i in range(19):
frame = i*interval
key = video+ ' '+str(frame+1)
self.dic_testing[key] = self.test_video[video]
def train(self):
training_set = spatial_dataset(dic=self.dic_training, root_dir=self.data_path, mode='train', transform = transforms.Compose([
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
]))
print '==> Training data :',len(training_set),'frames'
print training_set[1][0]['img1'].size()
train_loader = DataLoader(
dataset=training_set,
batch_size=self.BATCH_SIZE,
shuffle=True,
num_workers=self.num_workers)
return train_loader
def validate(self):
validation_set = spatial_dataset(dic=self.dic_testing, root_dir=self.data_path, mode='val', transform = transforms.Compose([
transforms.Scale([224,224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
]))
print '==> Validation data :',len(validation_set),'frames'
print validation_set[1][1].size()
val_loader = DataLoader(
dataset=validation_set,
batch_size=self.BATCH_SIZE,
shuffle=False,
num_workers=self.num_workers)
return val_loader
if __name__ == '__main__':
dataloader = spatial_dataloader(BATCH_SIZE=1, num_workers=1,
path='/home/ubuntu/data/UCF101/spatial_no_sampled/',
ucf_list='/home/ubuntu/cvlab/pytorch/ucf101_two_stream/github/UCF_list/',
ucf_split='01')
train_loader,val_loader,test_video = dataloader.run()