diff --git a/examples/advanced/cugraph/rgcn.py b/examples/advanced/cugraph/rgcn.py index 535ad0ded639..fda0badb6af9 100644 --- a/examples/advanced/cugraph/rgcn.py +++ b/examples/advanced/cugraph/rgcn.py @@ -8,19 +8,20 @@ code changes from the current `entity_sample.py` example. """ +import argparse + +import dgl import torch import torch.nn as nn import torch.nn.functional as F -from torchmetrics.functional import accuracy -import dgl -from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset -from dgl.dataloading import MultiLayerNeighborSampler, DataLoader +from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset +from dgl.dataloading import DataLoader, MultiLayerNeighborSampler from dgl.nn import CuGraphRelGraphConv -import argparse +from torchmetrics.functional import accuracy class RGCN(nn.Module): - def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, fanouts): + def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases): super().__init__() self.emb = nn.Embedding(num_nodes, h_dim) # two-layer RGCN @@ -30,8 +31,8 @@ def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, fanouts): num_rels, regularizer="basis", num_bases=num_bases, - self_loop=False, - max_in_degree=fanouts[0] + self_loop=True, + apply_norm=True, ) self.conv2 = CuGraphRelGraphConv( h_dim, @@ -39,37 +40,36 @@ def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases, fanouts): num_rels, regularizer="basis", num_bases=num_bases, - self_loop=False, - max_in_degree=fanouts[1] + self_loop=True, + apply_norm=True, ) - def forward(self, g): + def forward(self, g, fanouts=[None, None]): x = self.emb(g[0].srcdata[dgl.NID]) - h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], - norm=g[0].edata["norm"])) - h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], norm=g[1].edata["norm"]) + h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], fanouts[0])) + h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], fanouts[1]) return h - def update_max_in_degree(self, fanouts): - self.conv1.max_in_degree = fanouts[0] - self.conv2.max_in_degree = fanouts[1] - def evaluate(model, labels, dataloader, inv_target): model.eval() eval_logits = [] eval_seeds = [] with torch.no_grad(): - for input_nodes, output_nodes, blocks in dataloader: + for _, output_nodes, blocks in dataloader: output_nodes = inv_target[output_nodes.type(torch.int64)] - for block in blocks: - block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1) logits = model(blocks) eval_logits.append(logits.cpu().detach()) eval_seeds.append(output_nodes.cpu().detach()) + num_classes = eval_logits[0].shape[1] eval_logits = torch.cat(eval_logits) eval_seeds = torch.cat(eval_seeds) - return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item() + return accuracy( + eval_logits.argmax(dim=1), + labels[eval_seeds].cpu(), + task="multiclass", + num_classes=num_classes, + ).item() def train(device, g, target_idx, labels, train_mask, model, fanouts): @@ -96,14 +96,12 @@ def train(device, g, target_idx, labels, train_mask, model, fanouts): batch_size=100, shuffle=False, ) - for epoch in range(100): + for epoch in range(50): model.train() total_loss = 0 - for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader): + for it, (_, output_nodes, blocks) in enumerate(train_loader): output_nodes = inv_target[output_nodes.type(torch.int64)] - for block in blocks: - block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1) - logits = model(blocks) + logits = model(blocks, fanouts=fanouts) loss = loss_fcn(logits, labels[output_nodes]) optimizer.zero_grad() loss.backward() @@ -124,7 +122,7 @@ def train(device, g, target_idx, labels, train_mask, model, fanouts): "--dataset", type=str, default="aifb", - choices=['aifb', 'mutag', 'bgs', 'am'], + choices=["aifb", "mutag", "bgs", "am"], ) args = parser.parse_args() device = torch.device("cuda") @@ -168,15 +166,19 @@ def train(device, g, target_idx, labels, train_mask, model, fanouts): out_size = data.num_classes num_bases = 20 fanouts = [4, 4] - model = RGCN(in_size, 16, out_size, num_rels, num_bases, fanouts).to(device) + model = RGCN(in_size, 16, out_size, num_rels, num_bases).to(device) - train(device, g, target_idx, labels, train_mask, model, fanouts) + train( + device, + g, + target_idx, + labels, + train_mask, + model, + fanouts, + ) test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() - # Note: cugraph-ops aggregators are designed for sampled graphs (MFGs) and - # expect max_in_degree as input for performance considerations. Hence, we - # have to update max_in_degree with the fanouts of test_sampler. - test_sampler = MultiLayerNeighborSampler([500, 500]) - model.update_max_in_degree(test_sampler.fanouts) + test_sampler = MultiLayerNeighborSampler([-1, -1]) test_loader = DataLoader( g, target_idx[test_idx].type(g.idtype), diff --git a/python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py b/python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py index c20f63871577..df92f0cdf3c9 100644 --- a/python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py +++ b/python/dgl/nn/pytorch/conv/cugraph_relgraphconv.py @@ -3,138 +3,20 @@ # pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments import math -import torch as th +import torch from torch import nn try: - from pylibcugraphops import make_mfg_csr_hg - from pylibcugraphops.operators import ( - agg_hg_basis_mfg_n2n_post_bwd as agg_bwd, - ) - from pylibcugraphops.operators import ( - agg_hg_basis_mfg_n2n_post_fwd as agg_fwd, + from pylibcugraphops import make_fg_csr_hg, make_mfg_csr_hg + from pylibcugraphops.torch.autograd import ( + agg_hg_basis_n2n_post as RelGraphConvAgg, ) except ImportError: has_pylibcugraphops = False - - def make_mfg_csr_hg(*args): - r"""A dummy function to help raise error in RelGraphConvAgg when - pylibcugraphops is not found.""" - - raise NotImplementedError( - "RelGraphConvAgg requires pylibcugraphops to be installed." - ) - else: has_pylibcugraphops = True -class RelGraphConvAgg(th.autograd.Function): - r"""Custom autograd function for R-GCN aggregation layer that uses the - aggregation functions in cugraph-ops.""" - - @staticmethod - def forward(ctx, g, num_rels, edge_types, max_in_degree, feat, coeff): - r"""Compute the forward pass of R-GCN aggregation layer. - - Parameters - ---------- - ctx : torch.autograd.function.BackwardCFunction - Context object used to stash information for backward computation. - g : DGLGraph - The graph. - num_rels : int - Number of relations. - edge_types : torch.Tensor - A 1D tensor of edge types. - max_in_degree : int - Maximum number of sampled neighbors of a destination node. - feat : torch.Tensor - A 2D tensor of node features. Shape: (num_src_nodes, in_feat). - coeff : torch.Tensor - A 2D tensor of the coefficient matrix used in basis-decomposition - regularization. Shape: (num_rels, num_bases). It should be set to - ``None`` when no regularization is applied. - - Returns - ------- - agg_output : torch.Tensor - A 2D tensor of aggregation output. Shape: (num_dst_nodes, - num_rels * in_feat) when ``coeff=None``; Shape: (num_dst_nodes, - num_bases * in_feat) otherwise. - """ - - in_feat = feat.shape[-1] - indptr, indices, edge_ids = g.adj_sparse("csc") - # Edge_ids is in a mixed order, need to permutate incoming etypes. - ctx.edge_types_perm = edge_types[edge_ids.long()].int() - - mfg = make_mfg_csr_hg( - g.dstnodes(), - g.srcnodes(), - indptr, - indices, - max_in_degree, - n_node_types=0, - n_edge_types=num_rels, - out_node_types=None, - in_node_types=None, - edge_types=ctx.edge_types_perm, - ) - ctx.mfg = mfg - - if coeff is None: - leading_dimension = num_rels * in_feat - else: - num_bases = coeff.shape[-1] - leading_dimension = num_bases * in_feat - - agg_output = th.empty( - g.num_dst_nodes(), - leading_dimension, - dtype=th.float32, - device=feat.device, - ) - - if coeff is None: - agg_fwd(agg_output, feat.detach(), None, mfg) - else: - agg_fwd(agg_output, feat.detach(), coeff.detach(), mfg) - - ctx.save_for_backward(feat, coeff) - return agg_output - - @staticmethod - def backward(ctx, grad_output): - r"""Compute the backward pass of R-GCN aggregation layer. - - Parameters - ---------- - ctx : torch.autograd.function.BackwardCFunction - Context object used to stash information for backward computation. - grad_output : torch.Tensor - A 2D tensor of the gradient of loss function w.r.t output. - """ - feat, coeff = ctx.saved_tensors - - grad_feat = th.empty_like(feat) - grad_coeff = None if coeff is None else th.empty_like(coeff) - - if coeff is None: - agg_bwd(grad_feat, None, grad_output, feat.detach(), None, ctx.mfg) - else: - agg_bwd( - grad_feat, - grad_coeff, - grad_output, - feat.detach(), - coeff.detach(), - ctx.mfg, - ) - - return None, None, None, None, grad_feat, grad_coeff - - class CuGraphRelGraphConv(nn.Module): r"""An accelerated relational graph convolution layer from `Modeling Relational Data with Graph Convolutional Networks @@ -144,14 +26,10 @@ class CuGraphRelGraphConv(nn.Module): See :class:`dgl.nn.pytorch.conv.RelGraphConv` for mathematical model. This module depends on :code:`pylibcugraphops` package, which can be - installed via :code:`conda install -c nvidia pylibcugraphops>=22.12`. + installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`. .. note:: This is an **experimental** feature. - Compared with :class:`dgl.nn.pytorch.conv.RelGraphConv`, this model: - - * Only works on cuda devices. - * Only supports basis-decomposition regularization. Parameters ---------- @@ -171,31 +49,26 @@ class CuGraphRelGraphConv(nn.Module): Default: ``None``. bias : bool, optional True if bias is added. Default: ``True``. - activation : callable, optional - Activation function. Default: ``None``. self_loop : bool, optional True to include self loop message. Default: ``True``. dropout : float, optional Dropout rate. Default: ``0.0``. - layer_norm : bool, optional - True to add layer norm. Default: ``False``. - max_in_degree : int, optional - Maximum number of sampled neighbors of a destination node, - i.e. maximum in degree of destination nodes. If ``None``, it will be - calculated on the fly during :meth:`forward`. + apply_norm : bool, optional + True to normalize aggregation output by the in-degree of the destination + node per edge type, i.e. :math:`|\mathcal{N}^r_i|`. Default: ``True``. Examples -------- >>> import dgl - >>> import torch as th + >>> import torch >>> from dgl.nn import CuGraphRelGraphConv ... >>> device = 'cuda' >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])).to(device) - >>> feat = th.ones(6, 10).to(device) + >>> feat = torch.ones(6, 10).to(device) >>> conv = CuGraphRelGraphConv( ... 10, 2, 3, regularizer='basis', num_bases=2).to(device) - >>> etype = th.tensor([0,1,2,0,1,2]).to(device) + >>> etype = torch.tensor([0,1,2,0,1,2]).to(device) >>> res = conv(g, feat, etype) >>> res tensor([[-1.7774, -2.0184], @@ -205,6 +78,7 @@ class CuGraphRelGraphConv(nn.Module): [-1.4335, -2.3758], [-1.4331, -2.3295]], device='cuda:0', grad_fn=) """ + MAX_IN_DEGREE_MFG = 500 def __init__( self, @@ -214,87 +88,68 @@ def __init__( regularizer=None, num_bases=None, bias=True, - activation=None, self_loop=True, dropout=0.0, - layer_norm=False, - max_in_degree=None, + apply_norm=False, ): if has_pylibcugraphops is False: raise ModuleNotFoundError( - "dgl.nn.CuGraphRelGraphConv requires pylibcugraphops " - "to be installed." + f"{self.__class__.__name__} requires pylibcugraphops >= 23.02 " + f"to be installed." ) super().__init__() self.in_feat = in_feat self.out_feat = out_feat self.num_rels = num_rels - self.max_in_degree = max_in_degree + self.apply_norm = apply_norm + self.dropout = nn.Dropout(dropout) - # regularizer + dim_self_loop = 1 if self_loop else 0 + self.self_loop = self_loop if regularizer is None: - self.W = nn.Parameter(th.Tensor(num_rels, in_feat, out_feat)) + self.W = nn.Parameter( + torch.Tensor(num_rels + dim_self_loop, in_feat, out_feat) + ) self.coeff = None elif regularizer == "basis": if num_bases is None: raise ValueError( 'Missing "num_bases" for basis regularization.' ) - self.W = nn.Parameter(th.Tensor(num_bases, in_feat, out_feat)) - self.coeff = nn.Parameter(th.Tensor(num_rels, num_bases)) + self.W = nn.Parameter( + torch.Tensor(num_bases + dim_self_loop, in_feat, out_feat) + ) + self.coeff = nn.Parameter(torch.Tensor(num_rels, num_bases)) self.num_bases = num_bases else: raise ValueError( f"Supported regularizer options: 'basis' or None, but got " - f"{regularizer}." + f"'{regularizer}'." ) self.regularizer = regularizer - # Initialize weights. - with th.no_grad(): - if self.regularizer is None: - nn.init.uniform_( - self.W, - -1 / math.sqrt(self.in_feat), - 1 / math.sqrt(self.in_feat), - ) - else: - nn.init.uniform_( - self.W, - -1 / math.sqrt(self.in_feat), - 1 / math.sqrt(self.in_feat), - ) - nn.init.xavier_uniform_( - self.coeff, gain=nn.init.calculate_gain("relu") - ) - - # others - self.bias = bias - self.activation = activation - self.self_loop = self_loop - self.layer_norm = layer_norm - - # bias - if self.bias: - self.h_bias = nn.Parameter(th.Tensor(out_feat)) - nn.init.zeros_(self.h_bias) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_feat)) + else: + self.register_parameter("bias", None) - # layer norm - if self.layer_norm: - self.layer_norm_weight = nn.LayerNorm( - out_feat, elementwise_affine=True - ) + self.reset_parameters() - # weight for self_loop - if self.self_loop: - self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) + def reset_parameters(self): + r"""Reinitialize learnable parameters.""" + bound = 1 / math.sqrt(self.in_feat) + end = -1 if self.self_loop else None + nn.init.uniform_(self.W[:end], -bound, bound) + if self.regularizer == "basis": nn.init.xavier_uniform_( - self.loop_weight, gain=nn.init.calculate_gain("relu") + self.coeff, gain=nn.init.calculate_gain("relu") ) + if self.self_loop: + nn.init.xavier_uniform_(self.W[-1], nn.init.calculate_gain("relu")) + if self.bias is not None: + nn.init.zeros_(self.bias) - self.dropout = nn.Dropout(dropout) - - def forward(self, g, feat, etypes, norm=None): + def forward(self, g, feat, etypes, max_in_degree=None): r"""Forward computation. Parameters @@ -309,57 +164,77 @@ def forward(self, g, feat, etypes, norm=None): so any input of other integer types will be casted into int32, thus introducing some overhead. Pass in int32 tensors directly for best performance. - norm : torch.Tensor, optional - A 1D tensor of edge norm value. Shape: :math:`(|E|,)`. + max_in_degree : int, optional + Maximum in-degree of destination nodes. It is only effective when + :attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When + :attr:`g` is generated from a neighbor sampler, the value should be + set to the corresponding :attr:`fanout`. If not given, + :attr:`max_in_degree` will be calculated on-the-fly. Returns ------- torch.Tensor New node features. Shape: :math:`(|V|, D_{out})`. """ - _device = next(self.parameters()).device - if _device.type != "cuda": - raise RuntimeError( - f"dgl.nn.CuGraphRelGraphConv requires the model to be on " - f"device 'cuda', but got '{_device.type}'." - ) - if _device != g.device: - raise RuntimeError( - f"Expected model and graph on the same device, " - f"but got '{_device}' and '{g.device}'." - ) - if _device != etypes.device: - raise RuntimeError( - f"Expected model and etypes on the same device, " - f"but got '{_device}' and '{etypes.device}'." - ) - if _device != feat.device: - raise RuntimeError( - f"Expected model and feature tensor on the same device, " - f"but got '{_device}' and '{feat.device}'." + # Create csc-representation and cast etypes to int32. + offsets, indices, edge_ids = g.adj_sparse("csc") + edge_types_perm = etypes[edge_ids.long()].int() + + # Create cugraph-ops graph. + if g.is_block: + if max_in_degree is None: + max_in_degree = g.in_degrees().max().item() + + if max_in_degree < self.MAX_IN_DEGREE_MFG: + _graph = make_mfg_csr_hg( + g.dstnodes(), + offsets, + indices, + max_in_degree, + g.num_src_nodes(), + n_node_types=0, + n_edge_types=self.num_rels, + out_node_types=None, + in_node_types=None, + edge_types=edge_types_perm, + ) + else: + offsets_fg = torch.empty( + g.num_src_nodes() + 1, + dtype=offsets.dtype, + device=offsets.device, + ) + offsets_fg[: offsets.numel()] = offsets + offsets_fg[offsets.numel() :] = offsets[-1] + + _graph = make_fg_csr_hg( + offsets_fg, + indices, + n_node_types=0, + n_edge_types=self.num_rels, + node_types=None, + edge_types=edge_types_perm, + ) + else: + _graph = make_fg_csr_hg( + offsets, + indices, + n_node_types=0, + n_edge_types=self.num_rels, + node_types=None, + edge_types=edge_types_perm, ) - # Compute max_in_degree. - max_in_degree = self.max_in_degree - if max_in_degree is None: - max_in_degree = g.in_degrees().max().item() - with g.local_scope(): - g.srcdata["h"] = feat - if norm is not None: - g.edata["norm"] = norm - # Message passing. - h = RelGraphConvAgg.apply( - g, self.num_rels, etypes, max_in_degree, feat, self.coeff - ) - h = h @ self.W.view(-1, self.out_feat) - # Apply bias and activation. - if self.layer_norm: - h = self.layer_norm_weight(h) - if self.bias: - h = h + self.h_bias - if self.self_loop: - h = h + feat[: g.num_dst_nodes()] @ self.loop_weight - if self.activation: - h = self.activation(h) - h = self.dropout(h) - return h + h = RelGraphConvAgg( + feat, + self.coeff, + _graph, + concat_own=self.self_loop, + norm_by_out_degree=self.apply_norm, + )[: g.num_dst_nodes()] + h = h @ self.W.view(-1, self.out_feat) + if self.bias is not None: + h = h + self.bias + h = self.dropout(h) + + return h diff --git a/tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py b/tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py index 2d39a5efdd4d..ddacf8d27484 100644 --- a/tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py +++ b/tests/cugraph/cugraph-ops/test_cugraph_relgraphconv.py @@ -1,97 +1,77 @@ +# pylint: disable=too-many-arguments, too-many-locals +from collections import OrderedDict +from itertools import product + +import dgl import pytest import torch -import dgl -from dgl.nn import CuGraphRelGraphConv -from dgl.nn import RelGraphConv +from dgl.nn import CuGraphRelGraphConv, RelGraphConv # TODO(tingyu66): Re-enable the following tests after updating cuGraph CI image. -use_longs = [False, True] -max_in_degrees = [None, 8] -regularizers = [None, "basis"] -device = "cuda" +options = OrderedDict( + { + "idtype_int": [False, True], + "max_in_degree": [None, 8], + "num_bases": [1, 2, 5], + "regularizer": [None, "basis"], + "self_loop": [False, True], + "to_block": [False, True], + } +) def generate_graph(): u = torch.tensor([0, 1, 0, 2, 3, 0, 4, 0, 5, 0, 6, 7, 0, 8, 9]) v = torch.tensor([1, 9, 2, 9, 9, 4, 9, 5, 9, 6, 9, 9, 8, 9, 0]) g = dgl.graph((u, v)) - num_rels = 3 - g.edata[dgl.ETYPE] = torch.randint(num_rels, (g.num_edges(),)) return g + @pytest.mark.skip() -@pytest.mark.parametrize('use_long', use_longs) -@pytest.mark.parametrize('max_in_degree', max_in_degrees) -@pytest.mark.parametrize("regularizer", regularizers) -def test_full_graph(use_long, max_in_degree, regularizer): - in_feat, out_feat, num_rels, num_bases = 10, 2, 3, 2 +@pytest.mark.parametrize(",".join(options.keys()), product(*options.values())) +def test_relgraphconv_equality( + idtype_int, max_in_degree, num_bases, regularizer, self_loop, to_block +): + device = "cuda:0" + in_feat, out_feat, num_rels = 10, 2, 3 + args = (in_feat, out_feat, num_rels) kwargs = { "num_bases": num_bases, "regularizer": regularizer, "bias": False, - "self_loop": False, + "self_loop": self_loop, } g = generate_graph().to(device) - if use_long: - g = g.long() - else: + g.edata[dgl.ETYPE] = torch.randint(num_rels, (g.num_edges(),)).to(device) + if idtype_int: g = g.int() - feat = torch.ones(g.num_nodes(), in_feat).to(device) + if to_block: + g = dgl.to_block(g) + feat = torch.rand(g.num_src_nodes(), in_feat).to(device) torch.manual_seed(0) - conv1 = RelGraphConv(in_feat, out_feat, num_rels, **kwargs).to(device) + conv1 = RelGraphConv(*args, **kwargs).to(device) torch.manual_seed(0) - conv2 = CuGraphRelGraphConv( - in_feat, out_feat, num_rels, max_in_degree=max_in_degree, **kwargs - ).to(device) + kwargs["apply_norm"] = False + conv2 = CuGraphRelGraphConv(*args, **kwargs).to(device) out1 = conv1(g, feat, g.edata[dgl.ETYPE]) - out2 = conv2(g, feat, g.edata[dgl.ETYPE]) - + out2 = conv2(g, feat, g.edata[dgl.ETYPE], max_in_degree=max_in_degree) assert torch.allclose(out1, out2, atol=1e-06) grad_out = torch.rand_like(out1) out1.backward(grad_out) out2.backward(grad_out) - assert torch.allclose(conv1.linear_r.W.grad, conv2.W.grad, atol=1e-6) - if regularizer is not None: - assert torch.allclose( - conv1.linear_r.coeff.grad, conv2.coeff.grad, atol=1e-6 - ) - -@pytest.mark.skip() -@pytest.mark.parametrize('max_in_degree', max_in_degrees) -@pytest.mark.parametrize("regularizer", regularizers) -def test_mfg(max_in_degree, regularizer): - in_feat, out_feat, num_rels, num_bases = 10, 2, 3, 2 - kwargs = { - "num_bases": num_bases, - "regularizer": regularizer, - "bias": False, - "self_loop": False, - } - g = generate_graph().to(device) - block = dgl.to_block(g) - feat = torch.ones(g.num_nodes(), in_feat).to(device) - - torch.manual_seed(0) - conv1 = RelGraphConv(in_feat, out_feat, num_rels, **kwargs).to(device) - - torch.manual_seed(0) - conv2 = CuGraphRelGraphConv( - in_feat, out_feat, num_rels, max_in_degree=max_in_degree, **kwargs - ).to(device) - out1 = conv1(block, feat[block.srcdata[dgl.NID]], block.edata[dgl.ETYPE]) - out2 = conv2(block, feat[block.srcdata[dgl.NID]], block.edata[dgl.ETYPE]) + end = -1 if self_loop else None + assert torch.allclose(conv1.linear_r.W.grad, conv2.W.grad[:end], atol=1e-6) - assert torch.allclose(out1, out2, atol=1e-06) + if self_loop: + assert torch.allclose( + conv1.loop_weight.grad, conv2.W.grad[-1], atol=1e-6 + ) - grad_out = torch.rand_like(out1) - out1.backward(grad_out) - out2.backward(grad_out) - assert torch.allclose(conv1.linear_r.W.grad, conv2.W.grad, atol=1e-6) if regularizer is not None: assert torch.allclose( conv1.linear_r.coeff.grad, conv2.coeff.grad, atol=1e-6