-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy patheval_embedding_progress.py
69 lines (56 loc) · 1.86 KB
/
eval_embedding_progress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from init import *
#from Network import embed
#from Analysis import RatingCorrelator
from evalulate import dir_rating_delta_rmse, l2
## ===================== ##
## ======= Setup ======= ##
## ===================== ##
network = 'dirR'
wRuns = ['011X'] # '011XXX'
wEpchs= [5, 15, 25, 35, 45, 55, 65, 75, 85, 95]
pooling = 'rmac'
# 0 Test
# 1 Validation
# 2 Training
DataSubSet = 2
if DataSubSet == 0:
post = "Test"
elif DataSubSet == 1:
post = "Valid"
elif DataSubSet == 2:
post = "Train"
else:
assert False
print("{} Set Analysis".format(post))
print('='*15)
## ========================= ##
## ======= Evaluate ======= ##
## ========================= ##
start = timer()
try:
e0, e1 = 35, 65
delta = dir_rating_delta_rmse(wRuns[0], post, e0, e1)
image_id = np.argmax(delta)
print("selected #{} with delta-rmse {}".format(image_id, delta[image_id]))
embd = [FileManager.Embed(network).load(run=wRuns[0], epoch=e, dset=post)[1][image_id] for e in wEpchs]
diff = [l2(embd[i], embd[i+1]) for i in range(len(wEpchs)-1)]
plt.figure()
plt.title('nod#{}'.format(image_id))
plt.plot(wEpchs[1:], diff,'-*')
plt.xlabel('epochs')
plt.ylabel('delta-rmse')
#Embd = embed.Embeder(network, pooling=pooling)
#embedding = Embd.generate_timeline_embedding(runs=wRuns, post=post, data_subset_id=DataSubSet, epochs=wEpchs)
#
#W = FileManager.Embed(network)(wRuns[0], e0, post)
# correlation plot
#Reg = RatingCorrelator(W)
#Reg.evaluate_embed_distance_matrix(method='l2')
#Reg.evaluate_rating_space(norm='Scale')
#Reg.evaluate_rating_distance_matrix(method='l2')
#p, s, k = Reg.correlate_retrieval('embed', 'rating')
#delta = embedding[:, :, e1] - embedding[:, :, e0]
finally:
plt.show()
total_time = (timer() - start) / 60 / 60
print("Total runtime is {:.1f} hours".format(total_time))