Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Update CuGraphRelGraphConv to use pylibcugraphops=23.02 #5217

Merged
merged 8 commits into from
Feb 15, 2023
72 changes: 37 additions & 35 deletions examples/advanced/cugraph/rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
code changes from the current `entity_sample.py` example.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you see similar performance numbers after running this script?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, performance is the same as before.

"""

import argparse

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from dgl.dataloading import DataLoader, MultiLayerNeighborSampler
from dgl.nn import CuGraphRelGraphConv
import argparse
from torchmetrics.functional import accuracy


class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, fanouts):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases):
super().__init__()
self.emb = nn.Embedding(num_nodes, h_dim)
# two-layer RGCN
Expand All @@ -30,46 +31,45 @@ def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, fanouts):
num_rels,
regularizer="basis",
num_bases=num_bases,
self_loop=False,
max_in_degree=fanouts[0]
self_loop=True,
apply_norm=True,
)
self.conv2 = CuGraphRelGraphConv(
h_dim,
out_dim,
num_rels,
regularizer="basis",
num_bases=num_bases,
self_loop=False,
max_in_degree=fanouts[1]
self_loop=True,
apply_norm=True,
)

def forward(self, g):
def forward(self, g, fanouts=[None, None]):
x = self.emb(g[0].srcdata[dgl.NID])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE],
norm=g[0].edata["norm"]))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], norm=g[1].edata["norm"])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], fanouts[0]))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], fanouts[1])
return h

def update_max_in_degree(self, fanouts):
self.conv1.max_in_degree = fanouts[0]
self.conv2.max_in_degree = fanouts[1]


def evaluate(model, labels, dataloader, inv_target):
model.eval()
eval_logits = []
eval_seeds = []
with torch.no_grad():
for input_nodes, output_nodes, blocks in dataloader:
for _, output_nodes, blocks in dataloader:
output_nodes = inv_target[output_nodes.type(torch.int64)]
for block in blocks:
block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks)
eval_logits.append(logits.cpu().detach())
eval_seeds.append(output_nodes.cpu().detach())
num_classes = eval_logits[0].shape[1]
eval_logits = torch.cat(eval_logits)
eval_seeds = torch.cat(eval_seeds)
return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item()
return accuracy(
eval_logits.argmax(dim=1),
labels[eval_seeds].cpu(),
task="multiclass",
num_classes=num_classes,
).item()


def train(device, g, target_idx, labels, train_mask, model, fanouts):
Expand All @@ -96,14 +96,12 @@ def train(device, g, target_idx, labels, train_mask, model, fanouts):
batch_size=100,
shuffle=False,
)
for epoch in range(100):
for epoch in range(50):
model.train()
total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
for it, (_, output_nodes, blocks) in enumerate(train_loader):
output_nodes = inv_target[output_nodes.type(torch.int64)]
for block in blocks:
block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks)
logits = model(blocks, fanouts=fanouts)
loss = loss_fcn(logits, labels[output_nodes])
optimizer.zero_grad()
loss.backward()
Expand All @@ -124,7 +122,7 @@ def train(device, g, target_idx, labels, train_mask, model, fanouts):
"--dataset",
type=str,
default="aifb",
choices=['aifb', 'mutag', 'bgs', 'am'],
choices=["aifb", "mutag", "bgs", "am"],
)
args = parser.parse_args()
device = torch.device("cuda")
Expand Down Expand Up @@ -168,15 +166,19 @@ def train(device, g, target_idx, labels, train_mask, model, fanouts):
out_size = data.num_classes
num_bases = 20
fanouts = [4, 4]
model = RGCN(in_size, 16, out_size, num_rels, num_bases, fanouts).to(device)
model = RGCN(in_size, 16, out_size, num_rels, num_bases).to(device)

train(device, g, target_idx, labels, train_mask, model, fanouts)
train(
device,
g,
target_idx,
labels,
train_mask,
model,
fanouts,
)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
# Note: cugraph-ops aggregators are designed for sampled graphs (MFGs) and
# expect max_in_degree as input for performance considerations. Hence, we
# have to update max_in_degree with the fanouts of test_sampler.
test_sampler = MultiLayerNeighborSampler([500, 500])
model.update_max_in_degree(test_sampler.fanouts)
test_sampler = MultiLayerNeighborSampler([-1, -1])
test_loader = DataLoader(
g,
target_idx[test_idx].type(g.idtype),
Expand Down
Loading