-
Notifications
You must be signed in to change notification settings - Fork 18
/
eval_tf2.py
130 lines (118 loc) · 4.67 KB
/
eval_tf2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
'''
Code derived from https://github.com/tsc2017/Inception-Score
Arguments:
--model: generator architecture
--dataset: the dataset to compare with
--z_dim: latent dimensionality
--dim: dim of latent tensor
--begin: the iteration number to begin from
--end: the iteration number to end with
--step: the step of iteration
--model_dir: the directory of model checkpoints
--log_dir: the directory to save the results
--eval_is: True or False to evaluate IS
--eval_fid: True or False to evaluate FID
Outputs:
csv file that records the scores of checkpoints
'''
import os
import torch
import csv
import numpy as np
from metrics.InceptionScore_tf2 import get_inception_score
from GANs.models import GoodGenerator, DC_generator, dc_G
from GANs import ResNet32Generator
from utils import eval_parser
BATCH_SIZE = 100
N_CHANNEL = 3
RESOLUTION = 32
NUM_SAMPLES = 50000
def cal_inception_score(G, device, z_dim):
all_samples = []
samples = torch.randn(NUM_SAMPLES, z_dim)
for i in range(0, NUM_SAMPLES, BATCH_SIZE):
samples_100 = samples[i:i + BATCH_SIZE]
samples_100 = samples_100.to(device=device)
all_samples.append(G(samples_100).cpu().data.numpy())
all_samples = np.concatenate(all_samples, axis=0)
all_samples = ((all_samples + 1)/ 2 * 255).astype(np.uint8)
all_samples = all_samples.reshape((-1, N_CHANNEL, RESOLUTION, RESOLUTION))
return get_inception_score(all_samples)
class Evalor(object):
def __init__(self, G, z_dim, dataset,
model_dir,
log_path, device):
self.is_flag = False
self.fid_flag = False
self.log_path = log_path
self.device = device
self.G = G
self.z_dim = z_dim
self.model_dir = model_dir
self.dataset = dataset
self.init_writer()
def init_writer(self):
if not os.path.exists(self.log_path):
os.makedirs(self.log_path)
self.f = open(self.log_path + '%s_metrics.csv' % self.dataset, 'w')
fieldnames = ['iter',
'is_mean', 'is_std',
'FID score']
self.writer = csv.DictWriter(self.f, fieldnames=fieldnames)
self.writer.writeheader()
def load_model(self, model_path):
print('loading model from %s' % model_path)
chkpoint = torch.load(model_path)
self.G.load_state_dict(chkpoint['G'])
def get_metrics(self, count):
print('===Iter %d===' % count)
content = {'iter': count}
if self.is_flag:
is_score = cal_inception_score(G=self.G, device=self.device, z_dim=self.z_dim)
np.set_printoptions(precision=5)
print('Inception score mean: {}, std: {}'.format(is_score[0], is_score[1]))
content.update({'is_mean': is_score[0],
'is_std': is_score[1]})
# if self.fid_flag:
# if self.dataset == 'lsun-bedroom':
# fid_score = lsun_fid_score(G=self.G, device=device, z_dim=self.z_dim)
# elif self.dataset == 'cifar10':
# fid_score = cal_fid_score(G=self.G, device=self.device, z_dim=self.z_dim)
# np.set_printoptions(precision=5)
# print('FID score : {}'.format(fid_score))
# content.update({'FID score': fid_score})
self.writer.writerow(content)
self.f.flush()
def eval_metrics(self, begin, end, step, is_flag=True, fid_flag=False, dataname='CIFAR10'):
print('%d ==> %d, step: %d' % (begin, end, step))
self.is_flag = is_flag
self.fid_flag = fid_flag
for i in range(begin, end + step, step):
self.load_model(model_path=self.model_dir + '%d.pth' % i)
self.get_metrics(i)
self.f.close()
if __name__ == '__main__':
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
parser = eval_parser()
config = vars(parser.parse_args())
print(config)
model = config['model']
z_dim = config['z_dim']
if model == 'dc':
G = GoodGenerator()
elif model == 'ResGAN':
G = ResNet32Generator(z_dim=z_dim, num_filters=128, batchnorm=True)
elif model == 'DCGAN':
G = DC_generator(z_dim=z_dim)
elif model == 'mnist':
G = dc_G(z_dim=z_dim)
G.to(device)
G.eval()
evalor = Evalor(G=G, z_dim=128, dataset='cifar10',
model_dir=config['model_dir'],
log_path=config['logdir'],
device=device)
evalor.eval_metrics(begin=config['begin'], end=config['end'], step=config['step'],
is_flag=config['eval_is'], fid_flag=config['eval_fid'],
dataname=config['dataset'])