Skip to content

Commit

Permalink
update some interfaces for test
Browse files Browse the repository at this point in the history
  • Loading branch information
ShulinCao committed Mar 11, 2018
1 parent 45355e0 commit 99032f5
Show file tree
Hide file tree
Showing 14 changed files with 158 additions and 218,237 deletions.
118,142 changes: 0 additions & 118,142 deletions benchmarks/FB15K/test_neg.txt

This file was deleted.

100,000 changes: 0 additions & 100,000 deletions benchmarks/FB15K/valid_neg.txt

This file was deleted.

86 changes: 69 additions & 17 deletions config/Config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def __init__(self):
self.lib.getTailBatch.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
self.lib.testHead.argtypes = [ctypes.c_void_p]
self.lib.testTail.argtypes = [ctypes.c_void_p]
self.lib.getTestBatch.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
self.lib.getValidBatch.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
self.lib.getBestThreshold.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.lib.test_triple_classification.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.test_flag = False
self.in_path = "./"
self.out_path = "./"
Expand All @@ -42,7 +46,8 @@ def __init__(self):
self.export_steps = 0
self.opt_method = "SGD"
self.optimizer = None

self.test_link_prediction = False
self.test_triple_classification = False
def init(self):
self.trainModel = None
if self.in_path != None:
Expand All @@ -54,6 +59,8 @@ def init(self):
self.relTotal = self.lib.getRelationTotal()
self.entTotal = self.lib.getEntityTotal()
self.trainTotal = self.lib.getTrainTotal()
self.testTotal = self.lib.getTestTotal()
self.validTotal = self.lib.getValidTotal()
self.batch_size = self.lib.getTrainTotal() / self.nbatches
self.batch_seq_size = self.batch_size * (1 + self.negative_ent + self.negative_rel)
self.batch_h = np.zeros(self.batch_size * (1 + self.negative_ent + self.negative_rel), dtype = np.int64)
Expand All @@ -64,14 +71,43 @@ def init(self):
self.batch_t_addr = self.batch_t.__array_interface__['data'][0]
self.batch_r_addr = self.batch_r.__array_interface__['data'][0]
self.batch_y_addr = self.batch_y.__array_interface__['data'][0]
if self.test_flag:
if self.test_link_prediction:
self.lib.importTestFiles()
self.test_h = np.zeros(self.lib.getEntityTotal(), dtype = np.int64)
self.test_t = np.zeros(self.lib.getEntityTotal(), dtype = np.int64)
self.test_r = np.zeros(self.lib.getEntityTotal(), dtype = np.int64)
self.test_h_addr = self.test_h.__array_interface__['data'][0]
self.test_t_addr = self.test_t.__array_interface__['data'][0]
self.test_r_addr = self.test_r.__array_interface__['data'][0]
if self.test_triple_classification:
self.lib.importTestFiles()
self.lib.importTypeFiles()

self.test_pos_h = np.zeros(self.lib.getTestTotal(), dtype = np.int64)
self.test_pos_t = np.zeros(self.lib.getTestTotal(), dtype = np.int64)
self.test_pos_r = np.zeros(self.lib.getTestTotal(), dtype = np.int64)
self.test_neg_h = np.zeros(self.lib.getTestTotal(), dtype = np.int64)
self.test_neg_t = np.zeros(self.lib.getTestTotal(), dtype = np.int64)
self.test_neg_r = np.zeros(self.lib.getTestTotal(), dtype = np.int64)
self.test_pos_h_addr = self.test_pos_h.__array_interface__['data'][0]
self.test_pos_t_addr = self.test_pos_t.__array_interface__['data'][0]
self.test_pos_r_addr = self.test_pos_r.__array_interface__['data'][0]
self.test_neg_h_addr = self.test_neg_h.__array_interface__['data'][0]
self.test_neg_t_addr = self.test_neg_t.__array_interface__['data'][0]
self.test_neg_r_addr = self.test_neg_r.__array_interface__['data'][0]

self.valid_pos_h = np.zeros(self.lib.getValidTotal(), dtype = np.int64)
self.valid_pos_t = np.zeros(self.lib.getValidTotal(), dtype = np.int64)
self.valid_pos_r = np.zeros(self.lib.getValidTotal(), dtype = np.int64)
self.valid_neg_h = np.zeros(self.lib.getValidTotal(), dtype = np.int64)
self.valid_neg_t = np.zeros(self.lib.getValidTotal(), dtype = np.int64)
self.valid_neg_r = np.zeros(self.lib.getValidTotal(), dtype = np.int64)
self.valid_pos_h_addr = self.valid_pos_h.__array_interface__['data'][0]
self.valid_pos_t_addr = self.valid_pos_t.__array_interface__['data'][0]
self.valid_pos_r_addr = self.valid_pos_r.__array_interface__['data'][0]
self.valid_neg_h_addr = self.valid_neg_h.__array_interface__['data'][0]
self.valid_neg_t_addr = self.valid_neg_t.__array_interface__['data'][0]
self.valid_neg_r_addr = self.valid_neg_r.__array_interface__['data'][0]

def get_ent_total(self):
return self.entTotal
Expand All @@ -88,8 +124,11 @@ def set_optimizer(self, optimizer):
def set_opt_method(self, method):
self.opt_method = method

def set_test_flag(self, flag):
self.test_flag = flag
def set_test_link_prediction(self, flag):
self.test_link_prediction = flag

def set_test_triple_classification(self, flag):
self.test_triple_classification = flag

def set_log_on(self, flag):
self.log_on = flag
Expand Down Expand Up @@ -244,16 +283,29 @@ def run(self):
def test(self):
if self.importName != None:
self.restore_pytorch()
#self.trainModel.cuda()
total = self.lib.getTestTotal()
for epoch in range(total):
self.lib.getHeadBatch(self.test_h_addr, self.test_t_addr, self.test_r_addr)
res = self.trainModel.predict(self.test_h, self.test_t, self.test_r)
self.lib.testHead(res.data.numpy().__array_interface__['data'][0])

self.lib.getTailBatch(self.test_h_addr, self.test_t_addr, self.test_r_addr)
res = self.trainModel.predict(self.test_h, self.test_t, self.test_r)
self.lib.testTail(res.data.numpy().__array_interface__['data'][0])
if self.log_on:
print epoch
self.lib.test()
if self.test_link_prediction:
total = self.lib.getTestTotal()
for epoch in range(total):
self.lib.getHeadBatch(self.test_h_addr, self.test_t_addr, self.test_r_addr)
res = self.trainModel.predict(self.test_h, self.test_t, self.test_r)
self.lib.testHead(res.data.numpy().__array_interface__['data'][0])

self.lib.getTailBatch(self.test_h_addr, self.test_t_addr, self.test_r_addr)
res = self.trainModel.predict(self.test_h, self.test_t, self.test_r)
self.lib.testTail(res.data.numpy().__array_interface__['data'][0])
if self.log_on:
print epoch
self.lib.test_link_prediction()
if self.test_triple_classification:
self.lib.getValidBatch(self.valid_pos_h_addr, self.valid_pos_t_addr, self.valid_pos_r_addr, self.valid_neg_h_addr, self.valid_neg_t_addr, self.valid_neg_r_addr)
res_pos = self.trainModel.predict(self.valid_pos_h, self.valid_pos_t, self.valid_pos_r)
res_neg = self.trainModel.predict(self.valid_neg_h, self.valid_neg_t, self.valid_neg_r)
print "res_pos",res_pos
print "res_neg",res_neg
self.lib.getBestThreshold(res_pos.data.numpy().__array_interface__['data'][0], res_neg.data.numpy().__array_interface__['data'][0])

self.lib.getTestBatch(self.test_pos_h_addr, self.test_pos_t_addr, self.test_pos_r_addr, self.test_neg_h_addr, self.test_neg_t_addr, self.test_neg_r_addr)

res_pos = self.trainModel.predict(self.test_pos_h, self.test_pos_t, self.test_pos_r)
res_neg = self.trainModel.predict(self.test_neg_h, self.test_neg_t, self.test_neg_r)
self.lib.test_triple_classification(res_pos.data.numpy().__array_interface__['data'][0], res_neg.data.numpy().__array_interface__['data'][0])
4 changes: 2 additions & 2 deletions example_test_transe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
con = config.Config()
con.set_in_path("./benchmarks/FB15K/")
con.set_test_link_prediction(True)
con.set_test_triple_classification
con.set_test_triple_classification(True)
con.set_work_threads(4)
con.set_dimension(100)
con.init()
Expand All @@ -38,6 +38,6 @@
con.set_dimension(100)
con.init()
con.set_model(models.TransE)
con.import_variables("./res/model.vec.pt")
con.import_variables("./res/transe.pt")
con.test()
'''
6 changes: 3 additions & 3 deletions example_train_transe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#True: Input test files from the same folder.
con.set_log_on(1)
con.set_work_threads(8)
con.set_train_times(10)
con.set_train_times(1000)
con.set_nbatches(100)
con.set_alpha(0.001)
con.set_bern(0)
Expand All @@ -18,9 +18,9 @@
con.set_rel_neg_rate(0)
con.set_opt_method("SGD")
#Model parameters will be exported via torch.save() automatically.
con.set_export_files("./res/model.vec.pt")
con.set_export_files("./res/transe.pt")
#Model parameters will be exported to json files automatically.
con.set_out_files("./res/embedding.vec.json")
con.init()
con.set_model(models.TransE)
con.run()
con.run()
17 changes: 7 additions & 10 deletions examples/train_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
import json

con = config.Config()
con.set_test_flag(True)
#Input training files from benchmarks/FB15K/ folder.
con.set_in_path("./benchmarks/FB15K/")
con.set_out_path("./benchmarks/FB15K/a.vec")
# con.set_export_files("res.vec")
#con.set_import_files("res.vec")
# con.set_export_steps(10)
#True: Input test files from the same folder.
con.set_log_on(1)
con.set_work_threads(8)
con.set_train_times(1000)
Expand All @@ -20,11 +17,11 @@
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)
con.set_opt_method("adagrad")
#Model parameters will be exported via torch.save() automatically.
con.set_export_files("./res/complex.pt")
#Model parameters will be exported to json files automatically.
con.set_out_files("./res/complex.vec.json")
con.init()
con.set_model(models.ComplEx)
# f = open("a.vec", "r")
# content = json.loads(f.read())
# f.close()
# con.set_parameters(content)
con.run()
con.test()

19 changes: 8 additions & 11 deletions examples/train_distmult.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,25 @@
import json

con = config.Config()
con.set_test_flag(True)
#Input training files from benchmarks/FB15K/ folder.
con.set_in_path("./benchmarks/FB15K/")
con.set_out_path("./benchmarks/FB15K/a.vec")
# con.set_export_files("res.vec")
#con.set_import_files("res.vec")
# con.set_export_steps(10)
#True: Input test files from the same folder.
con.set_log_on(1)
con.set_work_threads(8)
con.set_train_times(1000)
con.set_nbatches(100)
con.set_alpha(0.1)
con.set_bern(1)
con.set_dimension(50)
con.set_dimension(100)
con.set_margin(1.0)
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)
con.set_opt_method("adagrad")
#Model parameters will be exported via torch.save() automatically.
con.set_export_files("./res/distmult.pt")
#Model parameters will be exported to json files automatically.
con.set_out_files("./res/distmult.vec.json")
con.init()
con.set_model(models.DistMult)
# f = open("a.vec", "r")
# content = json.loads(f.read())
# f.close()
# con.set_parameters(content)
con.run()
con.test()

17 changes: 7 additions & 10 deletions examples/train_rescal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
import json

con = config.Config()
con.set_test_flag(True)
#Input training files from benchmarks/FB15K/ folder.
con.set_in_path("./benchmarks/FB15K/")
con.set_out_path("./benchmarks/FB15K/a.vec")
# con.set_export_files("res.vec")
#con.set_import_files("res.vec")
# con.set_export_steps(10)
#True: Input test files from the same folder.
con.set_log_on(1)
con.set_work_threads(8)
con.set_train_times(1000)
Expand All @@ -20,11 +17,11 @@
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)
con.set_opt_method("adagrad")
#Model parameters will be exported via torch.save() automatically.
con.set_export_files("./res/rescal.pt")
#Model parameters will be exported to json files automatically.
con.set_out_files("./res/rescal.vec.json")
con.init()
con.set_model(models.RESCAL)
# f = open("a.vec", "r")
# content = json.loads(f.read())
# f.close()
# con.set_parameters(content)
con.run()
con.test()

16 changes: 6 additions & 10 deletions examples/train_transd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
import json

con = config.Config()
con.set_test_flag(True)
#Input training files from benchmarks/FB15K/ folder.
con.set_in_path("./benchmarks/FB15K/")
con.set_out_path("./benchmarks/FB15K/a.vec")
# con.set_export_files("res.vec")
#con.set_import_files("res.vec")
# con.set_export_steps(10)
#True: Input test files from the same folder.
con.set_log_on(1)
con.set_work_threads(8)
con.set_train_times(1000)
Expand All @@ -20,11 +17,10 @@
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)
con.set_opt_method("SGD")
#Model parameters will be exported via torch.save() automatically.
con.set_export_files("./res/transd.pt")
#Model parameters will be exported to json files automatically.
con.set_out_files("./res/transd.vec.json")
con.init()
con.set_model(models.TransD)
# f = open("a.vec", "r")
# content = json.loads(f.read())
# f.close()
# con.set_parameters(content)
con.run()
con.test()
16 changes: 6 additions & 10 deletions examples/train_transe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
import json

con = config.Config()
con.set_test_flag(True)
#Input training files from benchmarks/FB15K/ folder.
con.set_in_path("./benchmarks/FB15K/")
con.set_out_path("./benchmarks/FB15K/a.vec")
# con.set_export_files("res.vec")
#con.set_import_files("res.vec")
# con.set_export_steps(10)
#True: Input test files from the same folder.
con.set_log_on(1)
con.set_work_threads(8)
con.set_train_times(1000)
Expand All @@ -20,11 +17,10 @@
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)
con.set_opt_method("SGD")
#Model parameters will be exported via torch.save() automatically.
con.set_export_files("./res/transe.pt")
#Model parameters will be exported to json files automatically.
con.set_out_files("./res/transe.vec.json")
con.init()
con.set_model(models.TransE)
# f = open("a.vec", "r")
# content = json.loads(f.read())
# f.close()
# con.set_parameters(content)
con.run()
con.test()
16 changes: 6 additions & 10 deletions examples/train_transh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
import json

con = config.Config()
con.set_test_flag(True)
#Input training files from benchmarks/FB15K/ folder.
con.set_in_path("./benchmarks/FB15K/")
con.set_out_path("./benchmarks/FB15K/a.vec")
# con.set_export_files("res.vec")
#con.set_import_files("res.vec")
# con.set_export_steps(10)
#True: Input test files from the same folder.
con.set_log_on(1)
con.set_work_threads(8)
con.set_train_times(1000)
Expand All @@ -20,11 +17,10 @@
con.set_ent_neg_rate(1)
con.set_rel_neg_rate(0)
con.set_opt_method("SGD")
#Model parameters will be exported via torch.save() automatically.
con.set_export_files("./res/transh.pt")
#Model parameters will be exported to json files automatically.
con.set_out_files("./res/transh.vec.json")
con.init()
con.set_model(models.TransH)
# f = open("a.vec", "r")
# content = json.loads(f.read())
# f.close()
# con.set_parameters(content)
con.run()
con.test()
Loading

0 comments on commit 99032f5

Please sign in to comment.