Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update conv layers in cugraph-dgl for pylibcugraphops 23.04 #3360

Merged
merged 12 commits into from
Apr 5, 2023
3 changes: 1 addition & 2 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then
--channel "${PYTHON_CHANNEL}" \
--channel pytorch \
--channel pytorch-nightly \
--channel dglteam/label/cu117 \
--channel dglteam/label/cu118 \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, just saw that that they have cu118 builds here: https://anaconda.org/dglteam/dgl

Surprisingly it is missing on getting started page: https://www.dgl.ai/pages/start.html

--channel nvidia \
libcugraph \
pylibcugraph \
Expand All @@ -134,7 +134,6 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then
pytest \
--cache-clear \
--ignore=mg \
--ignore=nn \
--junitxml="${RAPIDS_TESTS_DIR}/junit-cugraph-dgl.xml" \
--cov-config=../../.coveragerc \
--cov=cugraph_dgl \
Expand Down
50 changes: 50 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from cugraph.utilities.utils import import_optional

torch = import_optional("torch")
nn = import_optional("torch.nn")
ops_torch = import_optional("pylibcugraphops.pytorch")


class BaseConv(nn.Module):
r"""An abstract base class for cugraph-ops nn module."""

def __init__(self):
super().__init__()
self._cached_offsets_fg = None

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
raise NotImplementedError

def forward(self, *args):
r"""Runs the forward pass of the module."""
raise NotImplementedError

def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor:
r"""Pad zero-in-degree nodes to the end of offsets to reach size. This
is used to augment offset tensors from DGL blocks (MFGs) to be
compatible with cugraph-ops full-graph primitives."""
if self._cached_offsets_fg is None:
self._cached_offsets_fg = torch.empty(
size, dtype=offsets.dtype, device=offsets.device
)
elif self._cached_offsets_fg.numel() < size:
self._cached_offsets_fg.resize_(size)

self._cached_offsets_fg[: offsets.numel()] = offsets
self._cached_offsets_fg[offsets.numel() : size] = offsets[-1]

return self._cached_offsets_fg[:size]
26 changes: 15 additions & 11 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
from __future__ import annotations
from typing import Optional

from cugraph_dgl.nn.conv.base import BaseConv
from cugraph.utilities.utils import import_optional

dgl = import_optional("dgl")
torch = import_optional("torch")
nn = import_optional("torch.nn")
ops = import_optional("pylibcugraphops")
ops_autograd = import_optional("pylibcugraphops.torch.autograd")
ops_torch = import_optional("pylibcugraphops.pytorch")


class GATConv(nn.Module):
class GATConv(BaseConv):
r"""Graph attention layer from `Graph Attention Network
<https://arxiv.org/pdf/1710.10903.pdf>`__, with the sparse aggregation
accelerated by cugraph-ops.
Expand Down Expand Up @@ -80,6 +80,7 @@ class GATConv(nn.Module):
[ 1.6477, -1.9986],
[ 1.1138, -1.9302]]], device='cuda:0', grad_fn=<ViewBackward0>)
"""
MAX_IN_DEGREE_MFG = 200

def __init__(
self,
Expand Down Expand Up @@ -144,29 +145,32 @@ def forward(
:math:`H` is the number of heads, and :math:`D_{out}` is size of
output feature.
"""

offsets, indices, _ = g.adj_sparse("csc")

if g.is_block:
if max_in_degree is None:
max_in_degree = g.in_degrees().max().item()
_graph = ops.make_mfg_csr(
g.dstnodes(), offsets, indices, max_in_degree, g.num_src_nodes()
)

if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = ops_torch.SampledCSC(
offsets, indices, max_in_degree, g.num_src_nodes()
)
else:
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = ops_torch.StaticCSC(offsets_fg, indices)
else:
_graph = ops.make_fg_csr(offsets, indices)
_graph = ops_torch.StaticCSC(offsets, indices)

feat_transformed = self.fc(feat)
out = ops_autograd.mha_gat_n2n(
out = ops_torch.operators.mha_gat_n2n(
feat_transformed,
self.attn_weights,
_graph,
self.num_heads,
"LeakyReLU",
self.negative_slope,
add_own_node=False,
concat_heads=True,
).view(-1, self.num_heads, self.out_feats)
).view(-1, self.num_heads, self.out_feats)[: g.num_dst_nodes()]

if self.bias is not None:
out = out + self.bias
Expand Down
51 changes: 27 additions & 24 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/relgraphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
import math
from typing import Optional

from cugraph_dgl.nn.conv.base import BaseConv
from cugraph.utilities.utils import import_optional

dgl = import_optional("dgl")
torch = import_optional("torch")
nn = import_optional("torch.nn")
ops = import_optional("pylibcugraphops")
ops_autograd = import_optional("pylibcugraphops.torch.autograd")
ops_torch = import_optional("pylibcugraphops.pytorch")


class RelGraphConv(nn.Module):
class RelGraphConv(BaseConv):
r"""An accelerated relational graph convolution layer from `Modeling
Relational Data with Graph Convolutional Networks
<https://arxiv.org/abs/1703.06103>`__ that leverages the highly-optimized
Expand Down Expand Up @@ -84,6 +84,7 @@ class RelGraphConv(nn.Module):
[-1.4335, -2.3758],
[-1.4331, -2.3295]], device='cuda:0', grad_fn=<AddBackward0>)
"""
MAX_IN_DEGREE_MFG = 500

def __init__(
self,
Expand Down Expand Up @@ -178,43 +179,45 @@ def forward(
torch.Tensor
New node features. Shape: :math:`(|V|, D_{out})`.
"""
# Create csc-representation and cast etypes to int32.
offsets, indices, edge_ids = g.adj_sparse("csc")
edge_types_perm = etypes[edge_ids.long()].int()

# Create cugraph-ops graph.
if g.is_block:
if max_in_degree is None:
max_in_degree = g.in_degrees().max().item()
_graph = ops.make_mfg_csr_hg(
g.dstnodes(),
offsets,
indices,
max_in_degree,
g.num_src_nodes(),
n_node_types=0,
n_edge_types=self.num_rels,
out_node_types=None,
in_node_types=None,
edge_types=edge_types_perm,
)

if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = ops_torch.SampledHeteroCSC(
offsets,
indices,
edge_types_perm,
max_in_degree,
g.num_src_nodes(),
self.num_rels,
)
else:
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = ops_torch.StaticHeteroCSC(
offsets_fg,
indices,
edge_types_perm,
self.num_rels,
)
else:
_graph = ops.make_fg_csr_hg(
_graph = ops_torch.StaticHeteroCSC(
offsets,
indices,
n_node_types=0,
n_edge_types=self.num_rels,
node_types=None,
edge_types=edge_types_perm,
edge_types_perm,
self.num_rels,
)

h = ops_autograd.agg_hg_basis_n2n_post(
h = ops_torch.operators.agg_hg_basis_n2n_post(
feat,
self.coeff,
_graph,
concat_own=self.self_loop,
norm_by_out_degree=self.apply_norm,
)
)[: g.num_dst_nodes()]
h = h @ self.W.view(-1, self.out_feats)
if self.bias is not None:
h = h + self.bias
Expand Down
23 changes: 15 additions & 8 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
from __future__ import annotations
from typing import Optional

from cugraph_dgl.nn.conv.base import BaseConv
from cugraph.utilities.utils import import_optional

dgl = import_optional("dgl")
torch = import_optional("torch")
nn = import_optional("torch.nn")
ops = import_optional("pylibcugraphops")
ops_autograd = import_optional("pylibcugraphops.torch.autograd")
ops_torch = import_optional("pylibcugraphops.pytorch")


class SAGEConv(nn.Module):
class SAGEConv(BaseConv):
r"""An accelerated GraphSAGE layer from `Inductive Representation Learning
on Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__ that leverages the
highly-optimized aggregation primitives in cugraph-ops.
Expand Down Expand Up @@ -68,6 +68,7 @@ class SAGEConv(nn.Module):
[-1.1690, 0.1952],
[-1.1690, 0.1952]], device='cuda:0', grad_fn=<AddmmBackward0>)
"""
MAX_IN_DEGREE_MFG = 500

def __init__(
self,
Expand Down Expand Up @@ -127,14 +128,20 @@ def forward(
if max_in_degree is None:
max_in_degree = g.in_degrees().max().item()

_graph = ops.make_mfg_csr(
g.dstnodes(), offsets, indices, max_in_degree, g.num_src_nodes()
)
if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = ops_torch.SampledCSC(
offsets, indices, max_in_degree, g.num_src_nodes()
)
else:
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = ops_torch.StaticCSC(offsets_fg, indices)
else:
_graph = ops.make_fg_csr(offsets, indices)
_graph = ops_torch.StaticCSC(offsets, indices)

feat = self.feat_drop(feat)
h = ops_autograd.agg_concat_n2n(feat, _graph, self.aggr)
h = ops_torch.operators.agg_concat_n2n(feat, _graph, self.aggr)[
: g.num_dst_nodes()
]
h = self.linear(h)

return h
2 changes: 1 addition & 1 deletion python/cugraph-dgl/tests/nn/test_gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
options = {
"idtype_int": [False, True],
"max_in_degree": [None, 8],
"num_heads": [1, 3],
"num_heads": [1, 2, 3, 7],
"to_block": [False, True],
}

Expand Down
14 changes: 9 additions & 5 deletions python/cugraph-dgl/tests/nn/test_sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@
dgl = import_optional("dgl")

options = {
"bias": [False, True],
"idtype_int": [False, True],
"max_in_degree": [None, 8],
"to_block": [False, True],
}


@pytest.mark.parametrize(",".join(options.keys()), product(*options.values()))
def test_SAGEConv_equality(idtype_int, max_in_degree, to_block):
def test_SAGEConv_equality(bias, idtype_int, max_in_degree, to_block):
SAGEConv = dgl.nn.SAGEConv
CuGraphSAGEConv = cugraph_dgl.nn.SAGEConv
device = "cuda"

in_feat, out_feat = 5, 2
# TODO(tingyu66): re-enable bias after upgrading DGL to 1.0 in conda env
kwargs = {"aggregator_type": "mean", "bias": False}
kwargs = {"aggregator_type": "mean", "bias": bias}
g = create_graph1().to(device)
if idtype_int:
g = g.int()
Expand All @@ -57,7 +57,8 @@ def test_SAGEConv_equality(idtype_int, max_in_degree, to_block):
with torch.no_grad():
conv2.linear.weight.data[:, :in_feat] = conv1.fc_neigh.weight.data
conv2.linear.weight.data[:, in_feat:] = conv1.fc_self.weight.data
# conv2.linear.bias.data[:] = conv1.fc_self.bias.data
if bias:
conv2.linear.bias.data[:] = conv1.fc_self.bias.data

out1 = conv1(g, feat)
out2 = conv2(g, feat, max_in_degree=max_in_degree)
Expand All @@ -76,4 +77,7 @@ def test_SAGEConv_equality(idtype_int, max_in_degree, to_block):
conv2.linear.weight.grad[:, in_feat:],
atol=1e-6,
)
# assert torch.allclose(conv1.fc_self.bias.grad, conv2.linear.bias.grad, atol=1e-6)
if bias:
assert torch.allclose(
conv1.fc_self.bias.grad, conv2.linear.bias.grad, atol=1e-6
)