-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
Copy pathrenet.py
91 lines (77 loc) · 2.99 KB
/
renet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import os.path as osp
import time
import torch
import torch.nn.functional as F
from torch_geometric.datasets import GDELT, ICEWS18
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models.re_net import RENet
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset',
type=str,
default='GDELT',
choices=['ICEWS18', 'GDELT'],
)
parser.add_argument('--seq_len', type=int, default=10)
args = parser.parse_args()
# Load the dataset and precompute history objects:
path = osp.dirname(osp.realpath(__file__))
path = osp.join(path, '..', 'data', args.dataset)
pre_transform = RENet.pre_transform(args.seq_len)
if args.dataset == 'ICEWS18':
train_dataset = ICEWS18(path, pre_transform=pre_transform)
test_dataset = ICEWS18(path, split='test', pre_transform=pre_transform)
elif args.dataset == 'GDELT':
train_dataset = GDELT(path, pre_transform=pre_transform)
test_dataset = GDELT(path, split='test', pre_transform=pre_transform)
# Create dataloader for training and test dataset.
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True,
follow_batch=['h_sub', 'h_obj'], num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False,
follow_batch=['h_sub', 'h_obj'], num_workers=6)
# Initialize model and optimizer.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RENet(
train_dataset.num_nodes,
train_dataset.num_rels,
hidden_channels=200,
seq_len=args.seq_len,
dropout=0.5,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
weight_decay=0.00001)
def train():
model.train()
# Train model via multi-class classification against the corresponding
# object and subject entities.
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
log_prob_obj, log_prob_sub = model(data)
loss_obj = F.nll_loss(log_prob_obj, data.obj)
loss_sub = F.nll_loss(log_prob_sub, data.sub)
loss = loss_obj + loss_sub
loss.backward()
optimizer.step()
def test(loader):
model.eval()
# Compute Mean Reciprocal Rank (MRR) and Hits@1/3/10.
result = torch.tensor([0, 0, 0, 0], dtype=torch.float)
for data in loader:
data = data.to(device)
with torch.no_grad():
log_prob_obj, log_prob_sub = model(data)
result += model.test(log_prob_obj, data.obj) * data.obj.size(0)
result += model.test(log_prob_sub, data.sub) * data.sub.size(0)
result = result / (2 * len(loader.dataset))
return result.tolist()
times = []
for epoch in range(1, 21):
start = time.time()
train()
mrr, hits1, hits3, hits10 = test(test_loader)
print(f'Epoch: {epoch:02d}, MRR: {mrr:.4f}, Hits@1: {hits1:.4f}, '
f'Hits@3: {hits3:.4f}, Hits@10: {hits10:.4f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")