-
Notifications
You must be signed in to change notification settings - Fork 844
/
eval2.py
70 lines (52 loc) · 1.93 KB
/
eval2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# -*- coding: utf-8 -*-
# /usr/bin/python2
from __future__ import print_function
import tensorflow as tf
from models import Net2
import argparse
from hparam import hparam as hp
from tensorpack.predict.base import OfflinePredictor
from tensorpack.predict.config import PredictConfig
from tensorpack.tfutils.sessinit import SaverRestore
from tensorpack.tfutils.sessinit import ChainInit
from data_load import Net2DataFlow
def get_eval_input_names():
return ['x_mfccs', 'y_spec']
def get_eval_output_names():
return ['net2/eval/summ_loss']
def eval(logdir1, logdir2):
# Load graph
model = Net2()
# dataflow
df = Net2DataFlow(hp.test2.data_path, hp.test2.batch_size)
ckpt1 = tf.train.latest_checkpoint(logdir1)
ckpt2 = tf.train.latest_checkpoint(logdir2)
session_inits = []
if ckpt2:
session_inits.append(SaverRestore(ckpt2))
if ckpt1:
session_inits.append(SaverRestore(ckpt1, ignore=['global_step']))
pred_conf = PredictConfig(
model=model,
input_names=get_eval_input_names(),
output_names=get_eval_output_names(),
session_init=ChainInit(session_inits))
predictor = OfflinePredictor(pred_conf)
x_mfccs, y_spec, _ = next(df().get_data())
summ_loss, = predictor(x_mfccs, y_spec)
writer = tf.summary.FileWriter(logdir2)
writer.add_summary(summ_loss)
writer.close()
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('case1', type=str, help='experiment case name of train1')
parser.add_argument('case2', type=str, help='experiment case name of train2')
arguments = parser.parse_args()
return arguments
if __name__ == '__main__':
args = get_arguments()
hp.set_hparam_yaml(args.case2)
logdir_train1 = '{}/{}/train1'.format(hp.logdir_path, args.case1)
logdir_train2 = '{}/{}/train2'.format(hp.logdir_path, args.case2)
eval(logdir1=logdir_train1, logdir2=logdir_train2)
print("Done")