-
Notifications
You must be signed in to change notification settings - Fork 6
/
feature_loader.py
38 lines (34 loc) · 1.13 KB
/
feature_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import numpy as np
import h5py
class SimpleHDF5Dataset:
def __init__(self, file_handle = None):
if file_handle == None:
self.f = ''
self.all_feats_dset = []
self.all_labels = []
self.total = 0
else:
self.f = file_handle
self.all_feats_dset = self.f['all_feats'][...]
self.all_labels = self.f['all_labels'][...]
def __getitem__(self, i):
return torch.Tensor(self.all_feats_dset[i,:]), int(self.all_labels[i])
def __len__(self):
return self.total
def init_loader(filename):
with h5py.File(filename, 'r') as f:
fileset = SimpleHDF5Dataset(f)
feats = fileset.all_feats_dset
labels = fileset.all_labels
while np.sum(feats[-1]) == 0:
feats = np.delete(feats,-1,axis = 0)
labels = np.delete(labels,-1,axis = 0)
class_list = np.unique(np.array(labels)).tolist()
inds = range(len(labels))
cl_data_file = {}
for cl in class_list:
cl_data_file[cl] = []
for ind in inds:
cl_data_file[labels[ind]].append( feats[ind])
return cl_data_file