-
Notifications
You must be signed in to change notification settings - Fork 40
/
prepare_imgnet_val.py
31 lines (24 loc) · 1023 Bytes
/
prepare_imgnet_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser('Get ImageNet validation set for FID/IS evaluation', add_help=False)
parser.add_argument('--data_path', default='./data/imagenet', type=str,
help='imagenet dataset path')
parser.add_argument('--output_dir', default='imagenet-val', type=str,
help='output directory')
args = parser.parse_args()
transform_val = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256)])
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
nsamples = len(dataset_val)
indices = range(nsamples)
for i in tqdm(indices):
sample = dataset_val[i]
img = sample[0]
sample_name = os.path.join(args.output_dir, '{}.png'.format(str(i).zfill(5)))
img.save(sample_name)