Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spmm and is_torch_sparse_tensor #5906

Merged
merged 5 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added test/nn/functional/test_spmm.py
Empty file.
12 changes: 12 additions & 0 deletions test/utils/test_torch_sparse_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.utils import is_torch_sparse_tensor


def test_is_torch_sparse_tensor():
x = torch.randn(5, 5)

assert not is_torch_sparse_tensor(x)
assert not is_torch_sparse_tensor(SparseTensor.from_dense(x))
assert is_torch_sparse_tensor(x.to_sparse())
3 changes: 2 additions & 1 deletion torch_geometric/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .bro import bro
from .gini import gini
from .spmm import spmm

__all__ = ['bro', 'gini']
__all__ = ['bro', 'gini', 'spmm']
EdisonLeeeee marked this conversation as resolved.
Show resolved Hide resolved

classes = __all__
22 changes: 22 additions & 0 deletions torch_geometric/nn/functional/spmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.utils import is_torch_sparse_tensor


def spmm(src: SparseTensor, other: Tensor, reduce: str = "sum") -> Tensor:
EdisonLeeeee marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this TorchScript compatible via the overload decorator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure :)

assert reduce in {"sum", "mean", "max", "min"}
if isinstance(src, SparseTensor):
return src.spmm(other, reduce)
if not is_torch_sparse_tensor(src):
raise ValueError("`src` must be a torch_sparse SparseTensor "
f"or a PyTorch SparseTensor, but got {type(src)}.")
if reduce == "sum":
return torch.sparse.mm(src, other)
elif reduce == "mean":
# TODO: Support `mean` reduction for PyTorch SparseTensor
raise NotImplementedError
else:
raise NotImplementedError("`max` and `min` reduction are not supported"
"for PyTorch SparseTensor input.")
2 changes: 2 additions & 0 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
structured_negative_sampling_feasible)
from .train_test_split_edges import train_test_split_edges
from .scatter import scatter
from .torch_sparse_tensor import is_torch_sparse_tensor

__all__ = [
'degree',
Expand Down Expand Up @@ -95,6 +96,7 @@
'structured_negative_sampling_feasible',
'train_test_split_edges',
'scatter',
'is_torch_sparse_tensor',
]

classes = __all__
11 changes: 11 additions & 0 deletions torch_geometric/utils/torch_sparse_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any


def is_torch_sparse_tensor(x: Any) -> bool:
"""Returns :obj:`True` if the input :obj:`x` is
PyTorch SparseTensor (COO or CSR format).

Args:
x (Any): The input object to be checked.
"""
return getattr(x, 'is_sparse', False)
EdisonLeeeee marked this conversation as resolved.
Show resolved Hide resolved