-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgtsrb_aug.py
176 lines (156 loc) · 4.91 KB
/
gtsrb_aug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
""" Models obtained from https://github.com/poojahira/gtsrb-pytorch """
from __future__ import print_function
import zipfile
import os
import torchvision.transforms as transforms
# data augmentation for training and test time
# Resize all images to 32 * 32 and normalize them to mean = 0 and standard-deviation = 1 based on statistics collected from the training set
data_transforms = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
]
)
# Resize, normalize and jitter image brightness
data_jitter_brightness = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ColorJitter(brightness=5),
transforms.ToTensor(),
]
)
# Resize, normalize and jitter image saturation
data_jitter_saturation = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ColorJitter(saturation=5),
transforms.ToTensor(),
]
)
# Resize, normalize and jitter image contrast
data_jitter_contrast = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ColorJitter(contrast=5),
transforms.ToTensor(),
]
)
# Resize, normalize and jitter image hues
data_jitter_hue = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ColorJitter(hue=0.4),
transforms.ToTensor(),
]
)
# Resize, normalize and rotate image
data_rotate = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.RandomRotation(15),
transforms.ToTensor(),
]
)
# Resize, normalize and flip image horizontally and vertically
data_hvflip = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(1),
transforms.RandomVerticalFlip(1),
transforms.ToTensor(),
]
)
# Resize, normalize and flip image horizontally
data_hflip = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(1),
transforms.ToTensor(),
]
)
# Resize, normalize and flip image vertically
data_vflip = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.RandomVerticalFlip(1),
transforms.ToTensor(),
]
)
# Resize, normalize and shear image
data_shear = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.RandomAffine(degrees=15, shear=2),
transforms.ToTensor(),
]
)
# Resize, normalize and translate image
data_translate = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
transforms.ToTensor(),
]
)
# Resize, normalize and crop image
data_center = transforms.Compose(
[
transforms.Resize((36, 36)),
transforms.CenterCrop(32),
transforms.ToTensor(),
]
)
# Resize, normalize and convert image to grayscale
data_grayscale = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
]
)
def initialize_data(folder):
train_zip = folder + "/train_images.zip"
test_zip = folder + "/test_images.zip"
if not os.path.exists(train_zip) or not os.path.exists(test_zip):
raise (
RuntimeError(
"Could not find "
+ train_zip
+ " and "
+ test_zip
+ ", please download them from https://www.kaggle.com/c/nyu-cv-fall-2017/data "
)
)
# extract train_data.zip to train_data
train_folder = folder + "/train_images"
if not os.path.isdir(train_folder):
print(train_folder + " not found, extracting " + train_zip)
zip_ref = zipfile.ZipFile(train_zip, "r")
zip_ref.extractall(folder)
zip_ref.close()
# extract test_data.zip to test_data
test_folder = folder + "/test_images"
if not os.path.isdir(test_folder):
print(test_folder + " not found, extracting " + test_zip)
zip_ref = zipfile.ZipFile(test_zip, "r")
zip_ref.extractall(folder)
zip_ref.close()
# make validation_data by using images 00000*, 00001* and 00002* in each class
val_folder = folder + "/val_images"
if not os.path.isdir(val_folder):
print(val_folder + " not found, making a validation set")
os.mkdir(val_folder)
for dirs in os.listdir(train_folder):
if dirs.startswith("000"):
os.mkdir(val_folder + "/" + dirs)
for f in os.listdir(train_folder + "/" + dirs):
if (
f.startswith("00000")
or f.startswith("00001")
or f.startswith("00002")
):
# move file to validation folder
os.rename(
train_folder + "/" + dirs + "/" + f,
val_folder + "/" + dirs + "/" + f,
)