-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_data.py
134 lines (114 loc) · 5.72 KB
/
utils_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import math
import numpy as np
import multiprocessing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import subprocess
import random
import datetime
from torchvision import transforms, datasets
from utils import *
def norm_mean_and_std(args):
# https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151
if ("resnet" in args.model) or ("DeiT" in args.model):
if args.dataset == "imagenet":
normalize_mean = [0.485, 0.456, 0.406]
normalize_std = [0.229, 0.224, 0.225]
else:
raise ValueError("transforms.Normalize error 1 !!!")
elif ("ViT" in args.model) or ("mlpmixer" in args.model) or ("Beit" in args.model):
normalize_mean = [0.5, 0.5, 0.5]
normalize_std = [0.5, 0.5, 0.5]
else:
raise ValueError("transforms.Normalize error 2 !!!")
return normalize_mean, normalize_std
def setup_data_loader(args, minibatch_size, data, corruption_name=None, severity=None, shuffle=True):
# With "fix_seed" function, the data order becomes the same across all the methods within a model.
# But if the model (network architecture) is changed, the data order is changed.
# The workaround is enabling "strict_fix_of_dataloader_seed_flag" as described below.
fix_seed(args.random_seed)
# Fix randomness of data loader to strictly ensure reproducibility.
# https://pytorch.org/docs/stable/notes/randomness.html
if args.strict_fix_of_dataloader_seed_flag:
print("strict_fix_of_dataloader_seed")
worker_seed = torch.initial_seed() % 2**32
print("worker_seed : {}".format(worker_seed))
def seed_worker(worker_id):
np.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(worker_seed)
else:
seed_worker = None
g = None
dataloader_num_workers = multiprocessing.cpu_count() #torch.cuda.device_count() * 4 #multiprocessing.cpu_count() # 5
dataloader_num_workers = min(dataloader_num_workers, args.max_num_worker)
print("dataloader_num_workers: " + str(dataloader_num_workers))
normalize_mean, normalize_std = norm_mean_and_std(args)
print("transforms.Normalize")
print(normalize_mean)
print(normalize_std)
transform_without_da = transforms.Compose([
transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=normalize_mean, std=normalize_std),
])
transform_without_da_imagenet = transforms.Compose([
transforms.Resize(args.image_crop_size), # Imagenet-C dataset : 256
transforms.CenterCrop(args.image_size), # Imagenet-C dataset : 224
transforms.ToTensor(),
transforms.Normalize(mean=normalize_mean, std=normalize_std),
])
if data == "imagenet-c":
transform_train = transform_without_da
transform_test = transform_without_da
elif data == "imagenet":
transform_train = transform_without_da_imagenet
transform_test = transform_without_da_imagenet
else:
raise ValueError("transforms setting error !!!")
if data == "imagenet":
imagenet_trainset = torchvision.datasets.ImageNet(
root=args.data_path + 'imagenet2012/',
split='train',
transform=transform_train)
imagenet_testset = torchvision.datasets.ImageNet(
root=args.data_path + 'imagenet2012/',
split='val',
transform=transform_test)
imagenet_train_loader = torch.utils.data.DataLoader(imagenet_trainset,
shuffle=shuffle,
batch_size=minibatch_size,
drop_last=args.dataloader_drop_last,
num_workers=dataloader_num_workers,
worker_init_fn=seed_worker,
generator=g,
pin_memory=True)
imagenet_test_loader = torch.utils.data.DataLoader(imagenet_testset,
shuffle=shuffle,
batch_size=minibatch_size,
drop_last=args.dataloader_drop_last,
num_workers=dataloader_num_workers,
worker_init_fn=seed_worker,
generator=g,
pin_memory=True)
return imagenet_train_loader, imagenet_test_loader
if data == "imagenet-c":
imagenetc_path = args.data_path + 'imagenet2012/' + 'val_c/' + corruption_name + '/' + str(severity) + '/'
# ImageFolder Function...
# https://zenn.dev/hidetoshi/articles/20210717_pytorch_dataset_for_imagenet
imagenet_c_testset = torchvision.datasets.ImageFolder( \
root = imagenetc_path, \
transform = transform_test) #transform_test_imagenet)
imagenet_c_test_loader = torch.utils.data.DataLoader(imagenet_c_testset,
shuffle=shuffle,
batch_size=minibatch_size,
drop_last=args.dataloader_drop_last,
num_workers=dataloader_num_workers,
worker_init_fn=seed_worker,
generator=g,
pin_memory=True)
return imagenet_c_test_loader