Skip to content

Commit

Permalink
[Model] Update CuGraphRelGraphConv to use pylibcugraphops=23.02 (#…
Browse files Browse the repository at this point in the history
…5217)

* update cugraph_relgraphconv

* update equality test

* update cugraph rgcn example

* update RelGraphConvAgg based on latest API changes

* enable fallback option to fg when fanout is large

---------

Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
tingyu66 and mufeili authored Feb 15, 2023
1 parent 5f1babf commit 19b3cea
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 330 deletions.
72 changes: 37 additions & 35 deletions examples/advanced/cugraph/rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
code changes from the current `entity_sample.py` example.
"""

import argparse

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from dgl.dataloading import DataLoader, MultiLayerNeighborSampler
from dgl.nn import CuGraphRelGraphConv
import argparse
from torchmetrics.functional import accuracy


class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, fanouts):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases):
super().__init__()
self.emb = nn.Embedding(num_nodes, h_dim)
# two-layer RGCN
Expand All @@ -30,46 +31,45 @@ def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, fanouts):
num_rels,
regularizer="basis",
num_bases=num_bases,
self_loop=False,
max_in_degree=fanouts[0]
self_loop=True,
apply_norm=True,
)
self.conv2 = CuGraphRelGraphConv(
h_dim,
out_dim,
num_rels,
regularizer="basis",
num_bases=num_bases,
self_loop=False,
max_in_degree=fanouts[1]
self_loop=True,
apply_norm=True,
)

def forward(self, g):
def forward(self, g, fanouts=[None, None]):
x = self.emb(g[0].srcdata[dgl.NID])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE],
norm=g[0].edata["norm"]))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], norm=g[1].edata["norm"])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], fanouts[0]))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], fanouts[1])
return h

def update_max_in_degree(self, fanouts):
self.conv1.max_in_degree = fanouts[0]
self.conv2.max_in_degree = fanouts[1]


def evaluate(model, labels, dataloader, inv_target):
model.eval()
eval_logits = []
eval_seeds = []
with torch.no_grad():
for input_nodes, output_nodes, blocks in dataloader:
for _, output_nodes, blocks in dataloader:
output_nodes = inv_target[output_nodes.type(torch.int64)]
for block in blocks:
block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks)
eval_logits.append(logits.cpu().detach())
eval_seeds.append(output_nodes.cpu().detach())
num_classes = eval_logits[0].shape[1]
eval_logits = torch.cat(eval_logits)
eval_seeds = torch.cat(eval_seeds)
return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item()
return accuracy(
eval_logits.argmax(dim=1),
labels[eval_seeds].cpu(),
task="multiclass",
num_classes=num_classes,
).item()


def train(device, g, target_idx, labels, train_mask, model, fanouts):
Expand All @@ -96,14 +96,12 @@ def train(device, g, target_idx, labels, train_mask, model, fanouts):
batch_size=100,
shuffle=False,
)
for epoch in range(100):
for epoch in range(50):
model.train()
total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
for it, (_, output_nodes, blocks) in enumerate(train_loader):
output_nodes = inv_target[output_nodes.type(torch.int64)]
for block in blocks:
block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks)
logits = model(blocks, fanouts=fanouts)
loss = loss_fcn(logits, labels[output_nodes])
optimizer.zero_grad()
loss.backward()
Expand All @@ -124,7 +122,7 @@ def train(device, g, target_idx, labels, train_mask, model, fanouts):
"--dataset",
type=str,
default="aifb",
choices=['aifb', 'mutag', 'bgs', 'am'],
choices=["aifb", "mutag", "bgs", "am"],
)
args = parser.parse_args()
device = torch.device("cuda")
Expand Down Expand Up @@ -168,15 +166,19 @@ def train(device, g, target_idx, labels, train_mask, model, fanouts):
out_size = data.num_classes
num_bases = 20
fanouts = [4, 4]
model = RGCN(in_size, 16, out_size, num_rels, num_bases, fanouts).to(device)
model = RGCN(in_size, 16, out_size, num_rels, num_bases).to(device)

train(device, g, target_idx, labels, train_mask, model, fanouts)
train(
device,
g,
target_idx,
labels,
train_mask,
model,
fanouts,
)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
# Note: cugraph-ops aggregators are designed for sampled graphs (MFGs) and
# expect max_in_degree as input for performance considerations. Hence, we
# have to update max_in_degree with the fanouts of test_sampler.
test_sampler = MultiLayerNeighborSampler([500, 500])
model.update_max_in_degree(test_sampler.fanouts)
test_sampler = MultiLayerNeighborSampler([-1, -1])
test_loader = DataLoader(
g,
target_idx[test_idx].type(g.idtype),
Expand Down
Loading

0 comments on commit 19b3cea

Please sign in to comment.