Skip to content

Commit

Permalink
Add PyTorch SparseTensor support for SGConv, SSGConv and TAGConv (
Browse files Browse the repository at this point in the history
#6514)

For the roadmap
#5867
  • Loading branch information
EdisonLeeeee authored Jan 26, 2023
1 parent 239c710 commit 77693c9
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193))
- Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187))
- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514))
- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
Expand Down
5 changes: 5 additions & 0 deletions test/nn/conv/test_sg_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ def test_sg_conv():
value = torch.rand(row.size(0))
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
adj1 = adj2.set_value(None)
adj3 = adj1.to_torch_sparse_coo_tensor()
adj4 = adj2.to_torch_sparse_coo_tensor()

conv = SGConv(16, 32, K=10)
assert conv.__repr__() == 'SGConv(16, 32, K=10)'
out1 = conv(x, edge_index)
assert out1.size() == (4, 32)
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)
out2 = conv(x, edge_index, value)
assert out2.size() == (4, 32)
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, OptTensor) -> Tensor'
Expand All @@ -37,3 +41,4 @@ def test_sg_conv():
conv(x, edge_index)
assert conv(x, edge_index).tolist() == out1.tolist()
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)
5 changes: 5 additions & 0 deletions test/nn/conv/test_ssg_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ def test_ssg_conv():
value = torch.rand(row.size(0))
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
adj1 = adj2.set_value(None)
adj3 = adj1.to_torch_sparse_coo_tensor()
adj4 = adj2.to_torch_sparse_coo_tensor()

conv = SSGConv(16, 32, alpha=0.1, K=10)
assert conv.__repr__() == 'SSGConv(16, 32, K=10, alpha=0.1)'
out1 = conv(x, edge_index)
assert out1.size() == (4, 32)
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)
out2 = conv(x, edge_index, value)
assert out2.size() == (4, 32)
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, OptTensor) -> Tensor'
Expand All @@ -37,3 +41,4 @@ def test_ssg_conv():
conv(x, edge_index)
assert conv(x, edge_index).tolist() == out1.tolist()
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)
4 changes: 4 additions & 0 deletions test/nn/conv/test_tag_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ def test_tag_conv():
value = torch.rand(row.size(0))
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
adj1 = adj2.set_value(None)
adj3 = adj1.to_torch_sparse_coo_tensor()
adj4 = adj2.to_torch_sparse_coo_tensor()

conv = TAGConv(16, 32)
assert conv.__repr__() == 'TAGConv(16, 32, K=3)'
out1 = conv(x, edge_index)
assert out1.size() == (4, 32)
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)
out2 = conv(x, edge_index, value)
assert out2.size() == (4, 32)
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, OptTensor) -> Tensor'
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/nn/conv/sg_conv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional

from torch import Tensor
from torch_sparse import SparseTensor, matmul
from torch_sparse import SparseTensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import spmm


class SGConv(MessagePassing):
Expand Down Expand Up @@ -104,7 +105,7 @@ def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
return edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)
return spmm(adj_t, x, reduce=self.aggr)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/nn/conv/ssg_conv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional

from torch import Tensor
from torch_sparse import SparseTensor, matmul
from torch_sparse import SparseTensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import spmm


class SSGConv(MessagePassing):
Expand Down Expand Up @@ -115,7 +116,7 @@ def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
return edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)
return spmm(adj_t, x, reduce=self.aggr)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/nn/conv/tag_conv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch
from torch import Tensor
from torch_sparse import SparseTensor, matmul
from torch_sparse import SparseTensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import spmm


class TAGConv(MessagePassing):
Expand Down Expand Up @@ -99,7 +100,7 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)
return spmm(adj_t, x, reduce=self.aggr)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
Expand Down

0 comments on commit 77693c9

Please sign in to comment.