import os import math import numpy as np import multiprocessing import torch import torch.nn as nn import torch.nn.functional as F import torchvision import subprocess import random import datetime from torchvision import transforms, datasets from utils import * def norm_mean_and_std(args): # https://gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151 if ("resnet" in args.model) or ("DeiT" in args.model): if args.dataset == "imagenet": normalize_mean = [0.485, 0.456, 0.406] normalize_std = [0.229, 0.224, 0.225] else: raise ValueError("transforms.Normalize error 1 !!!") elif ("ViT" in args.model) or ("mlpmixer" in args.model) or ("Beit" in args.model): normalize_mean = [0.5, 0.5, 0.5] normalize_std = [0.5, 0.5, 0.5] else: raise ValueError("transforms.Normalize error 2 !!!") return normalize_mean, normalize_std def setup_data_loader(args, minibatch_size, data, corruption_name=None, severity=None, shuffle=True): # With "fix_seed" function, the data order becomes the same across all the methods within a model. # But if the model (network architecture) is changed, the data order is changed. # The workaround is enabling "strict_fix_of_dataloader_seed_flag" as described below. fix_seed(args.random_seed) # Fix randomness of data loader to strictly ensure reproducibility. # https://pytorch.org/docs/stable/notes/randomness.html if args.strict_fix_of_dataloader_seed_flag: print("strict_fix_of_dataloader_seed") worker_seed = torch.initial_seed() % 2**32 print("worker_seed : {}".format(worker_seed)) def seed_worker(worker_id): np.random.seed(worker_seed) random.seed(worker_seed) g = torch.Generator() g.manual_seed(worker_seed) else: seed_worker = None g = None dataloader_num_workers = multiprocessing.cpu_count() #torch.cuda.device_count() * 4 #multiprocessing.cpu_count() # 5 dataloader_num_workers = min(dataloader_num_workers, args.max_num_worker) print("dataloader_num_workers: " + str(dataloader_num_workers)) normalize_mean, normalize_std = norm_mean_and_std(args) print("transforms.Normalize") print(normalize_mean) print(normalize_std) transform_without_da = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), transforms.Normalize(mean=normalize_mean, std=normalize_std), ]) transform_without_da_imagenet = transforms.Compose([ transforms.Resize(args.image_crop_size), # Imagenet-C dataset : 256 transforms.CenterCrop(args.image_size), # Imagenet-C dataset : 224 transforms.ToTensor(), transforms.Normalize(mean=normalize_mean, std=normalize_std), ]) if data == "imagenet-c": transform_train = transform_without_da transform_test = transform_without_da elif data == "imagenet": transform_train = transform_without_da_imagenet transform_test = transform_without_da_imagenet else: raise ValueError("transforms setting error !!!") if data == "imagenet": imagenet_trainset = torchvision.datasets.ImageNet( root=args.data_path + 'imagenet2012/', split='train', transform=transform_train) imagenet_testset = torchvision.datasets.ImageNet( root=args.data_path + 'imagenet2012/', split='val', transform=transform_test) imagenet_train_loader = torch.utils.data.DataLoader(imagenet_trainset, shuffle=shuffle, batch_size=minibatch_size, drop_last=args.dataloader_drop_last, num_workers=dataloader_num_workers, worker_init_fn=seed_worker, generator=g, pin_memory=True) imagenet_test_loader = torch.utils.data.DataLoader(imagenet_testset, shuffle=shuffle, batch_size=minibatch_size, drop_last=args.dataloader_drop_last, num_workers=dataloader_num_workers, worker_init_fn=seed_worker, generator=g, pin_memory=True) return imagenet_train_loader, imagenet_test_loader if data == "imagenet-c": imagenetc_path = args.data_path + 'imagenet2012/' + 'val_c/' + corruption_name + '/' + str(severity) + '/' # ImageFolder Function... # https://zenn.dev/hidetoshi/articles/20210717_pytorch_dataset_for_imagenet imagenet_c_testset = torchvision.datasets.ImageFolder( \ root = imagenetc_path, \ transform = transform_test) #transform_test_imagenet) imagenet_c_test_loader = torch.utils.data.DataLoader(imagenet_c_testset, shuffle=shuffle, batch_size=minibatch_size, drop_last=args.dataloader_drop_last, num_workers=dataloader_num_workers, worker_init_fn=seed_worker, generator=g, pin_memory=True) return imagenet_c_test_loader