-
Notifications
You must be signed in to change notification settings - Fork 1
/
gnn.py
156 lines (138 loc) · 4.9 KB
/
gnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Directly taken from OGB for computing baselines.
"""
import torch
from torch_geometric.nn import (
global_add_pool,
global_mean_pool,
global_max_pool,
GlobalAttention,
Set2Set,
)
from conv import GNN_node, GNN_node_Virtualnode, GNNNodeFlag
class GNN(torch.nn.Module):
def __init__(
self,
num_tasks,
num_layer=5,
emb_dim=300,
gnn_type="gin",
virtual_node=True,
residual=False,
drop_ratio=0.5,
jk="last",
graph_pooling="mean",
):
if num_layer <= 1:
raise ValueError("Number of GNN layers must be greater than 1.")
super(GNN, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.jk = jk
self.emb_dim = emb_dim
self.num_tasks = num_tasks
self.graph_pooling = graph_pooling
# GNN to generate node embeddings
gnn_cls = GNN_node_Virtualnode if virtual_node else GNN_node
self.gnn_node = gnn_cls(
num_layer,
emb_dim,
jk=jk,
drop_ratio=drop_ratio,
residual=residual,
gnn_type=gnn_type,
)
# Pooling function to generate whole-graph embeddings
if self.graph_pooling == "sum":
self.pool = global_add_pool
elif self.graph_pooling == "mean":
self.pool = global_mean_pool
elif self.graph_pooling == "max":
self.pool = global_max_pool
elif self.graph_pooling == "attention":
self.pool = GlobalAttention(
gate_nn=torch.nn.Sequential(
torch.nn.Linear(emb_dim, 2 * emb_dim),
torch.nn.BatchNorm1d(2 * emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * emb_dim, 1),
)
)
elif self.graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
adj = 2 if graph_pooling == "set2set" else 1
self.graph_pred_linear = torch.nn.Linear(adj * self.emb_dim, self.num_tasks)
def forward(self, batched_data):
h_node = self.gnn_node(batched_data)
h_graph = self.pool(h_node, batched_data.batch)
return self.graph_pred_linear(h_graph)
class GNNFlag(torch.nn.Module):
def __init__(
self,
num_tasks,
num_layer=5,
emb_dim=300,
gnn_type="gin",
virtual_node=True,
residual=False,
drop_ratio=0.5,
JK="last",
graph_pooling="mean",
):
"""
num_tasks (int): number of labels to be predicted
virtual_node (bool): whether to add virtual node or not
"""
super(GNNFlag, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
self.graph_pooling = graph_pooling
if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
# GNN to generate node embeddings
if virtual_node:
raise NotImplementedError
# self.gnn_node = GNN_node_Virtualnode(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual,
# gnn_type=gnn_type)
else:
self.gnn_node = GNNNodeFlag(
num_layer,
emb_dim,
JK=JK,
drop_ratio=drop_ratio,
residual=residual,
gnn_type=gnn_type,
)
# Pooling function to generate whole-graph embeddings
if self.graph_pooling == "sum":
self.pool = global_add_pool
elif self.graph_pooling == "mean":
self.pool = global_mean_pool
elif self.graph_pooling == "max":
self.pool = global_max_pool
elif self.graph_pooling == "attention":
self.pool = GlobalAttention(
gate_nn=torch.nn.Sequential(
torch.nn.Linear(emb_dim, 2 * emb_dim),
torch.nn.BatchNorm1d(2 * emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * emb_dim, 1),
)
)
elif self.graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, batched_data, perturb=None):
h_node = self.gnn_node(batched_data, perturb)
h_graph = self.pool(h_node, batched_data.batch)
return self.graph_pred_linear(h_graph)