Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spmm and is_torch_sparse_tensor #5906

Merged
merged 5 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906))
- Added `HydroNet` water cluster dataset ([#5537](https://github.com/pyg-team/pytorch_geometric/pull/5537), [#5902](https://github.com/pyg-team/pytorch_geometric/pull/5902), [#5903](https://github.com/pyg-team/pytorch_geometric/pull/5903))
- Added explainability support for heterogeneous GNNs ([#5886](https://github.com/pyg-team/pytorch_geometric/pull/5886))
- Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888))
Expand Down
45 changes: 45 additions & 0 deletions test/utils/test_spmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
import torch
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.utils import spmm


def test_spmm():
src = torch.randn(5, 4)
other = torch.randn(4, 8)

out1 = src @ other
out2 = spmm(src.to_sparse(), other, reduce='sum')
out3 = spmm(SparseTensor.from_dense(src), other, reduce='sum')
assert out1.size() == (5, 8)
assert torch.allclose(out1, out2)
assert torch.allclose(out1, out3)

for reduce in ['mean', 'min', 'max']:
out = spmm(SparseTensor.from_dense(src), other, reduce)
assert out.size() == (5, 8)

with pytest.raises(ValueError, match="not supported"):
spmm(src.to_sparse(), other, reduce)


def test_spmm():
@torch.jit.script
def jit_torch_sparse(src: SparseTensor, other: Tensor) -> Tensor:
return spmm(src, other)

@torch.jit.script
def jit_torch(src: Tensor, other: Tensor) -> Tensor:
return spmm(src, other, reduce='sum')

src = torch.randn(5, 4)
other = torch.randn(4, 8)

out1 = src @ other
out2 = jit_torch_sparse(SparseTensor.from_dense(src), other)
out3 = jit_torch(src.to_sparse(), other)
assert out1.size() == (5, 8)
assert torch.allclose(out1, out2)
assert torch.allclose(out1, out3)
12 changes: 12 additions & 0 deletions test/utils/test_torch_sparse_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.utils import is_torch_sparse_tensor


def test_is_torch_sparse_tensor():
x = torch.randn(5, 5)

assert not is_torch_sparse_tensor(x)
assert not is_torch_sparse_tensor(SparseTensor.from_dense(x))
assert is_torch_sparse_tensor(x.to_sparse())
4 changes: 4 additions & 0 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
structured_negative_sampling_feasible)
from .train_test_split_edges import train_test_split_edges
from .scatter import scatter
from .torch_sparse_tensor import is_torch_sparse_tensor
from .spmm import spmm

__all__ = [
'degree',
Expand Down Expand Up @@ -95,6 +97,8 @@
'structured_negative_sampling_feasible',
'train_test_split_edges',
'scatter',
'is_torch_sparse_tensor',
'spmm',
]

classes = __all__
54 changes: 54 additions & 0 deletions torch_geometric/utils/spmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Union

import torch
from torch import Tensor
from torch_sparse import SparseTensor, matmul

from .torch_sparse_tensor import is_torch_sparse_tensor


@torch.jit._overload
def spmm(src, other, reduce):
# type: (Tensor, Tensor, str) -> Tensor
pass


@torch.jit._overload
def spmm(src, other, reduce):
# type: (SparseTensor, Tensor, str) -> Tensor
pass


def spmm(
src: Union[SparseTensor, Tensor],
other: Tensor,
reduce: str = "sum",
) -> Tensor:
"""Matrix product of sparse matrix with dense matrix.

Args:
src (Tensor or torch_sparse.SparseTensor]): The input sparse matrix,
either a :class:`torch_sparse.SparseTensor` or a
:class:`torch.sparse.Tensor`.
other (Tensor): The input dense matrix.
reduce (str, optional): The reduce operation to use
(:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`).
(default: :obj:`"sum"`)

:rtype: :class:`Tensor`
"""
assert reduce in ['sum', 'add', 'mean', 'min', 'max']

if isinstance(src, SparseTensor):
return matmul(src, other, reduce)

if not is_torch_sparse_tensor(src):
raise ValueError("`src` must be a `torch_sparse.SparseTensor` "
f"or a `torch.sparse.Tensor` (got {type(src)}).")

if reduce in ['sum', 'add']:
return torch.sparse.mm(src, other)

# TODO: Support `mean` reduction for PyTorch SparseTensor
raise ValueError(f"`{reduce}` reduction is not supported for "
f"`torch.sparse.Tensor`.")
13 changes: 13 additions & 0 deletions torch_geometric/utils/torch_sparse_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any

from torch import Tensor


def is_torch_sparse_tensor(src: Any) -> bool:
"""Returns :obj:`True` if the input :obj:`x` is a PyTorch
:obj:`SparseTensor` (in any sparse format).

Args:
src (Any): The input object to be checked.
"""
return isinstance(src, Tensor) and src.is_sparse