-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun_adamf_mat.py
77 lines (72 loc) · 2.25 KB
/
run_adamf_mat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from email.generator import Generator
import torch
import mmkgc
from mmkgc.config import Tester, AdvMixTrainer
from mmkgc.module.model import AdvMixRotatE
from mmkgc.module.loss import SigmoidLoss
from mmkgc.module.strategy import NegativeSampling
from mmkgc.data import TrainDataLoader, TestDataLoader
from mmkgc.adv.modules import MultiGenerator
from args import get_args
if __name__ == "__main__":
args = get_args()
print(args)
# set the seed
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# dataloader for training
train_dataloader = TrainDataLoader(
in_path="./benchmarks/" + args.dataset + '/',
batch_size=args.batch_size,
threads=8,
sampling_mode="normal",
bern_flag=1,
filter_flag=1,
neg_ent=args.neg_num,
neg_rel=0
)
# dataloader for test
test_dataloader = TestDataLoader(
"./benchmarks/" + args.dataset + '/', "link")
img_emb = torch.load('./embeddings/' + args.dataset + '-visual.pth')
text_emb = torch.load('./embeddings/' + args.dataset + '-textual.pth')
# define the model
kge_score = AdvMixRotatE(
ent_tot=train_dataloader.get_ent_tot(),
rel_tot=train_dataloader.get_rel_tot(),
dim=args.dim,
margin=args.margin,
epsilon=2.0,
img_emb=img_emb,
text_emb=text_emb
)
print(kge_score)
# define the loss function
model = NegativeSampling(
model=kge_score,
loss=SigmoidLoss(adv_temperature=args.adv_temp),
batch_size=train_dataloader.get_batch_size(),
)
adv_generator = MultiGenerator(
noise_dim=64,
structure_dim=2*args.dim,
img_dim=2*args.dim
)
# train the model
trainer = AdvMixTrainer(
model=model,
data_loader=train_dataloader,
train_times=args.epoch,
alpha=args.learning_rate,
use_gpu=True,
opt_method='Adam',
generator=adv_generator,
lrg=args.lrg,
mu=args.mu
)
trainer.run()
kge_score.save_checkpoint(args.save)
# test the model
kge_score.load_checkpoint(args.save)
tester = Tester(model=kge_score, data_loader=test_dataloader, use_gpu=True)
tester.run_link_prediction(type_constrain=False)