Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Nov 9, 2022
1 parent 4fa432b commit f76caa5
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906))
- Added `HydroNet` water cluster dataset ([#5537](https://github.com/pyg-team/pytorch_geometric/pull/5537), [#5902](https://github.com/pyg-team/pytorch_geometric/pull/5902), [#5903](https://github.com/pyg-team/pytorch_geometric/pull/5903))
- Added explainability support for heterogeneous GNNs ([#5886](https://github.com/pyg-team/pytorch_geometric/pull/5886))
- Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888))
Expand Down
45 changes: 45 additions & 0 deletions test/utils/test_spmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
import torch
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.utils import spmm


def test_spmm():
src = torch.randn(5, 4)
other = torch.randn(4, 8)

out1 = src @ other
out2 = spmm(src.to_sparse(), other, reduce='sum')
out3 = spmm(SparseTensor.from_dense(src), other, reduce='sum')
assert out1.size() == (5, 8)
assert torch.allclose(out1, out2)
assert torch.allclose(out1, out3)

for reduce in ['mean', 'min', 'max']:
out = spmm(SparseTensor.from_dense(src), other, reduce)
assert out.size() == (5, 8)

with pytest.raises(ValueError, match="not supported"):
spmm(src.to_sparse(), other, reduce)


def test_spmm():
@torch.jit.script
def jit_torch_sparse(src: SparseTensor, other: Tensor) -> Tensor:
return spmm(src, other)

@torch.jit.script
def jit_torch(src: Tensor, other: Tensor) -> Tensor:
return spmm(src, other, reduce='sum')

src = torch.randn(5, 4)
other = torch.randn(4, 8)

out1 = src @ other
out2 = jit_torch_sparse(SparseTensor.from_dense(src), other)
out3 = jit_torch(src.to_sparse(), other)
assert out1.size() == (5, 8)
assert torch.allclose(out1, out2)
assert torch.allclose(out1, out3)
37 changes: 19 additions & 18 deletions torch_geometric/utils/spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import Tensor
from torch_sparse import SparseTensor
from torch_sparse import SparseTensor, matmul

from .torch_sparse_tensor import is_torch_sparse_tensor

Expand All @@ -27,27 +27,28 @@ def spmm(
"""Matrix product of sparse matrix with dense matrix.
Args:
src (Union[SparseTensor, Tensor]): The input sparse matrix,
either of :obj:`torch_sparse.SparseTensor` or
`torch.sparse.Tensor`.
src (Tensor or torch_sparse.SparseTensor]): The input sparse matrix,
either a :class:`torch_sparse.SparseTensor` or a
:class:`torch.sparse.Tensor`.
other (Tensor): The input dense matrix.
reduce (str, optional): The reduce operation to use for merging edge
features (:obj:`"sum"`, :obj:`"add"`, :obj:`"mean"`, :obj:`"max"`,
:obj:`"min"`). (default: :obj:`"sum"`)
reduce (str, optional): The reduce operation to use
(:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`).
(default: :obj:`"sum"`)
:rtype: :class:`Tensor`
"""
assert reduce in {"sum", "add", "mean", "max", "min"}
assert reduce in ['sum', 'add', 'mean', 'min', 'max']

if isinstance(src, SparseTensor):
return src.spmm(other, reduce)
return matmul(src, other, reduce)

if not is_torch_sparse_tensor(src):
raise ValueError("`src` must be a torch_sparse SparseTensor "
f"or a PyTorch SparseTensor, but got {type(src)}.")
if reduce in {"sum", "add"}:
raise ValueError("`src` must be a `torch_sparse.SparseTensor` "
f"or a `torch.sparse.Tensor` (got {type(src)}).")

if reduce in ['sum', 'add']:
return torch.sparse.mm(src, other)
elif reduce == "mean":
# TODO: Support `mean` reduction for PyTorch SparseTensor
raise NotImplementedError
else:
raise NotImplementedError("`max` and `min` reduction are not supported"
"for PyTorch SparseTensor input.")

# TODO: Support `mean` reduction for PyTorch SparseTensor
raise ValueError(f"`{reduce}` reduction is not supported for "
f"`torch.sparse.Tensor`.")
10 changes: 5 additions & 5 deletions torch_geometric/utils/torch_sparse_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from torch import Tensor


def is_torch_sparse_tensor(x: Any) -> bool:
"""Returns :obj:`True` if the input :obj:`x` is
PyTorch SparseTensor (COO or CSR format).
def is_torch_sparse_tensor(src: Any) -> bool:
"""Returns :obj:`True` if the input :obj:`x` is a PyTorch
:obj:`SparseTensor` (in any sparse format).
Args:
x (Any): The input object to be checked.
src (Any): The input object to be checked.
"""
return isinstance(x, Tensor) and x.is_sparse
return isinstance(src, Tensor) and src.is_sparse

0 comments on commit f76caa5

Please sign in to comment.