-
Notifications
You must be signed in to change notification settings - Fork 57
/
pretrain.py
51 lines (47 loc) · 2.04 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import logging
import torch
from corrupter import BernCorrupter, BernCorrupterMulti
from read_data import index_ent_rel, graph_size, read_data
from config import config, overwrite_config_with_args
from logger_init import logger_init
from data_utils import inplace_shuffle, heads_tails
from select_gpu import select_gpu
from trans_e import TransE
from trans_d import TransD
from distmult import DistMult
from compl_ex import ComplEx
logger_init()
torch.cuda.set_device(select_gpu())
overwrite_config_with_args()
task_dir = config().task.dir
kb_index = index_ent_rel(os.path.join(task_dir, 'train.txt'),
os.path.join(task_dir, 'valid.txt'),
os.path.join(task_dir, 'test.txt'))
n_ent, n_rel = graph_size(kb_index)
train_data = read_data(os.path.join(task_dir, 'train.txt'), kb_index)
inplace_shuffle(*train_data)
valid_data = read_data(os.path.join(task_dir, 'valid.txt'), kb_index)
test_data = read_data(os.path.join(task_dir, 'test.txt'), kb_index)
heads, tails = heads_tails(n_ent, train_data, valid_data, test_data)
valid_data = [torch.LongTensor(vec) for vec in valid_data]
test_data = [torch.LongTensor(vec) for vec in test_data]
tester = lambda: gen.test_link(valid_data, n_ent, heads, tails)
train_data = [torch.LongTensor(vec) for vec in train_data]
mdl_type = config().pretrain_config
gen_config = config()[mdl_type]
if mdl_type == 'TransE':
corrupter = BernCorrupter(train_data, n_ent, n_rel)
gen = TransE(n_ent, n_rel, gen_config)
elif mdl_type == 'TransD':
corrupter = BernCorrupter(train_data, n_ent, n_rel)
gen = TransD(n_ent, n_rel, gen_config)
elif mdl_type == 'DistMult':
corrupter = BernCorrupterMulti(train_data, n_ent, n_rel, gen_config.n_sample)
gen = DistMult(n_ent, n_rel, gen_config)
elif mdl_type == 'ComplEx':
corrupter = BernCorrupterMulti(train_data, n_ent, n_rel, gen_config.n_sample)
gen = ComplEx(n_ent, n_rel, gen_config)
gen.pretrain(train_data, corrupter, tester)
gen.load(os.path.join(task_dir, gen_config.model_file))
gen.test_link(test_data, n_ent, heads, tails)