Skip to content

Commit

Permalink
Support EdgeIndex in spmm (#9026)
Browse files Browse the repository at this point in the history
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
wsad1 and rusty1s authored Mar 12, 2024
1 parent e697d26 commit d6b12b0
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for `EdgeIndex` in `spmm` ([#9026](https://github.com/pyg-team/pytorch_geometric/pull/9026))
- Added option to pre-allocate memory in GPU-based `ApproxKNN` ([#9046](https://github.com/pyg-team/pytorch_geometric/pull/9046))
- Added support for `EdgeIndex` in `MessagePassing` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))
- Added support for `torch.compile` in combination with `EdgeIndex` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))
Expand Down
19 changes: 19 additions & 0 deletions test/utils/test_spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import Tensor

import torch_geometric.typing
from torch_geometric import EdgeIndex
from torch_geometric.profile import benchmark
from torch_geometric.testing import withCUDA, withPackage
from torch_geometric.typing import SparseTensor
Expand Down Expand Up @@ -106,6 +107,24 @@ def jit_torch(src: Tensor, other: Tensor, reduce: str) -> Tensor:
assert torch.allclose(out2, out3, atol=1e-6)


@withCUDA
@withPackage('torch>=2.0.0')
@pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max'])
def test_spmm_edge_index(device, reduce):
src = EdgeIndex(
[[0, 1, 1, 2], [1, 0, 2, 1]],
sparse_size=(4, 3),
sort_order='row',
device=device,
)
other = torch.rand(3, 4, device=device)
out = spmm(src, other, reduce=reduce)
assert out.size() == (4, 4)

out2 = spmm(src.to_sparse_coo(), other, reduce=reduce)
assert torch.allclose(out, out2)


if __name__ == '__main__':
import argparse

Expand Down
12 changes: 9 additions & 3 deletions torch_geometric/utils/_spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor

import torch_geometric.typing
from torch_geometric import EdgeIndex
from torch_geometric.typing import Adj, SparseTensor, torch_sparse
from torch_geometric.utils import is_torch_sparse_tensor, scatter

Expand All @@ -16,9 +17,11 @@ def spmm(
r"""Matrix product of sparse matrix with dense matrix.
Args:
src (torch.Tensor or torch_sparse.SparseTensor): The input sparse
matrix, either a :pyg:`PyG` :class:`torch_sparse.SparseTensor` or a
:pytorch:`PyTorch` :class:`torch.sparse.Tensor`.
src (torch.Tensor or torch_sparse.SparseTensor or EdgeIndex):
The input sparse matrix which can be a
:pyg:`PyG` :class:`torch_sparse.SparseTensor`,
a :pytorch:`PyTorch` :class:`torch.sparse.Tensor` or
a :pyg:`PyG` :class:`EdgeIndex`.
other (torch.Tensor): The input dense matrix.
reduce (str, optional): The reduce operation to use
(:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`).
Expand All @@ -31,6 +34,9 @@ def spmm(
if reduce not in ['sum', 'mean', 'min', 'max']:
raise ValueError(f"`reduce` argument '{reduce}' not supported")

if not torch.jit.is_scripting() and isinstance(src, EdgeIndex):
return src.matmul(other=other, reduce=reduce) # type: ignore

if isinstance(src, SparseTensor):
if src.nnz() == 0:
return other.new_zeros(src.size(0), other.size(1))
Expand Down

0 comments on commit d6b12b0

Please sign in to comment.