-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
rgcn.py
95 lines (76 loc) · 3.04 KB
/
rgcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os.path as osp
import time
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Entities
from torch_geometric.nn import FastRGCNConv, RGCNConv
from torch_geometric.utils import k_hop_subgraph
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='AIFB',
choices=['AIFB', 'MUTAG', 'BGS', 'AM'])
args = parser.parse_args()
# Trade memory consumption for faster computation.
if args.dataset in ['AIFB', 'MUTAG']:
Conv = FastRGCNConv
else:
Conv = RGCNConv
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities')
dataset = Entities(path, args.dataset)
data = dataset[0]
# BGS and AM graphs are too big to process them in a full-batch fashion.
# Since our model does only make use of a rather small receptive field, we
# filter the graph to only contain the nodes that are at most 2-hop neighbors
# away from any training/test node.
node_idx = torch.cat([data.train_idx, data.test_idx], dim=0)
node_idx, edge_index, mapping, edge_mask = k_hop_subgraph(
node_idx, 2, data.edge_index, relabel_nodes=True)
data.num_nodes = node_idx.size(0)
data.edge_index = edge_index
data.edge_type = data.edge_type[edge_mask]
data.train_idx = mapping[:data.train_idx.size(0)]
data.test_idx = mapping[data.train_idx.size(0):]
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = Conv(data.num_nodes, 16, dataset.num_relations,
num_bases=30)
self.conv2 = Conv(16, dataset.num_classes, dataset.num_relations,
num_bases=30)
def forward(self, edge_index, edge_type):
x = F.relu(self.conv1(None, edge_index, edge_type))
x = self.conv2(x, edge_index, edge_type)
return F.log_softmax(x, dim=1)
if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
device = torch.device('cpu') if args.dataset == 'AM' else device
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)
def train():
model.train()
optimizer.zero_grad()
out = model(data.edge_index, data.edge_type)
loss = F.nll_loss(out[data.train_idx], data.train_y)
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test():
model.eval()
pred = model(data.edge_index, data.edge_type).argmax(dim=-1)
train_acc = float((pred[data.train_idx] == data.train_y).float().mean())
test_acc = float((pred[data.test_idx] == data.test_y).float().mean())
return train_acc, test_acc
times = []
for epoch in range(1, 51):
start = time.time()
loss = train()
train_acc, test_acc = test()
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} '
f'Test: {test_acc:.4f}')
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")