-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathzipdataset.py
110 lines (84 loc) · 4.28 KB
/
zipdataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
from PIL import Image
import pandas as pd
from zipfile import ZipFile
import logging
import random
from functools import partial
class ImageZipDataset(torch.utils.data.Dataset):
def split_data(self, split='train', n_crossval_split=-1, n_crossval=None, eval_proportion = .2):
if not n_crossval is None:
assert isinstance(n_crossval_split, int)
assert n_crossval_split in range(n_crossval)
if isinstance(n_crossval, int):
if n_crossval_split == -1:
n_crossval_split == random.randint(0,n_crossval)
l = len(self.indices)//n_crossval
logging.info(f"Initializing the {n_crossval_split}th - {n_crossval}-fold crossvalidation {split}-split")
if split != 'train':
self.indices = self.indices[int(n_crossval_split*l):int((n_crossval_split+1)*l)]
else:
self.indices = self.indices[:int(n_crossval_split*l)]+self.indices[int((n_crossval_split+1)*l):]
else:
l = len(self.indices)
if split == 'train':
logging.info(f"Initializing the {(1-eval_proportion)*100:.0f}% {split} split")
self.indices = self.indices[:int((1-eval_proportion)*l)]
else:
logging.info(f"Initializing the {(eval_proportion)*100:.0f}% {split} split")
self.indices = self.indices[int((1-eval_proportion)*l):]
def __init__(self, zip_path, info_path, transform=None, target_transform=None, delimiter='\t',
split='train', n_crossval_split=-1, n_crossval=None, eval_proportion = .2, random_seed=3,
load_fn = pd.read_csv, f_key='filename', l_key='label', load_sample = lambda f: Image.open(f).convert('RGB'), return_indexes=False, custom_split_method=None, **custom_split_kwargs):
assert split in ['train', 'val']
assert isinstance(eval_proportion, float)
self._zip_path = zip_path
self.transform = transform
self.target_transform = target_transform
self.zip_file = None
self.load_sample = load_sample
self.return_indexes = return_indexes
self.metadata = load_fn(info_path,delimiter=delimiter)
# prepare deterministic random index order
self.indices = list(range(len(self.metadata)))
random.Random(random_seed).shuffle(self.indices)
# prepare split
if custom_split_method is None:
self.split_data(split=split, n_crossval_split=n_crossval_split, n_crossval=n_crossval, eval_proportion = eval_proportion)
else:
custom_split_method(self, split=split, n_crossval_split=n_crossval_split, n_crossval=n_crossval, eval_proportion = eval_proportion, **custom_split_kwargs)
self.classes = self.metadata[l_key].unique()
print(f"number classes {max(self.classes)+1}")
self.samples = list(zip(self.metadata[f_key], self.metadata[l_key]))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
#resolve index
index = self.indices[index]
if self.zip_file is None:
self.zip_file = ZipFile(self._zip_path, 'r')
path, target = self.samples[index]
target = torch.as_tensor(target)
with self.zip_file.open(path) as f:
sample = self.load_sample(f)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
if not self.return_indexes: return sample, target
return sample, (target, index)
def __len__(self):
return len(self.indices)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str