Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Index into edge_index.select #9296

Merged
merged 7 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Integrate `torch_geometric.Index` into `torch_geometric.EdgeIndex` ([#9296](https://github.com/pyg-team/pytorch_geometric/pull/9296))
- Support `EdgeIndex.sparse_narrow` for non-sorted edge indices ([#9291](https://github.com/pyg-team/pytorch_geometric/pull/9291))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289), [#9297](https://github.com/pyg-team/pytorch_geometric/pull/9297))
- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))
Expand Down
64 changes: 63 additions & 1 deletion test/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor, tensor

import torch_geometric
from torch_geometric import EdgeIndex
from torch_geometric import EdgeIndex, Index
from torch_geometric.edge_index import (
ReduceType,
SortReturnType,
Expand Down Expand Up @@ -553,6 +553,68 @@ def test_getitem(dtype, device, is_undirected):
assert not isinstance(out, EdgeIndex)


@withCUDA
@pytest.mark.parametrize('dtype', DTYPES)
def test_select(dtype, device):
kwargs = dict(dtype=dtype, device=device)

adj = EdgeIndex(
[[0, 1, 1, 2], [1, 0, 2, 1]],
sort_order='row',
sparse_size=(4, 5),
**kwargs,
).fill_cache_()

out = adj[0]
assert isinstance(out, Index)
assert out.equal(tensor([0, 1, 1, 2], device=device))
assert out.dim_size == 4
assert out.is_sorted
assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device))

out = adj[-1]
assert isinstance(out, Index)
assert out.equal(tensor([1, 0, 2, 1], device=device))
assert out.dim_size == 5
assert not out.is_sorted
assert out._indptr is None

out = adj[-2, 2:4]
assert isinstance(out, Index)
assert out.equal(tensor([1, 2], device=device))
assert out.dim_size == 4
assert out.is_sorted
assert out._indptr is None

adj = EdgeIndex(
[[1, 0, 2, 1], [0, 1, 1, 2]],
sort_order='col',
sparse_size=(5, 4),
**kwargs,
).fill_cache_()

out = adj[1]
assert isinstance(out, Index)
assert out.equal(tensor([0, 1, 1, 2], device=device))
assert out.dim_size == 4
assert out.is_sorted
assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device))

out = adj[-2]
assert isinstance(out, Index)
assert out.equal(tensor([1, 0, 2, 1], device=device))
assert out.dim_size == 5
assert not out.is_sorted
assert out._indptr is None

out = adj[-1, 2:4]
assert isinstance(out, Index)
assert out.equal(tensor([1, 2], device=device))
assert out.dim_size == 4
assert out.is_sorted
assert out._indptr is None


@withCUDA
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('value_dtype', [None, torch.double])
Expand Down
29 changes: 26 additions & 3 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch import Tensor

import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric import Index, is_compiling
from torch_geometric.index import index2ptr, ptr2index
from torch_geometric.typing import INDEX_DTYPES, SparseTensor

Expand Down Expand Up @@ -1519,6 +1519,29 @@ def _index(
return out


@implements(aten.select.int)
def _select(input: EdgeIndex, dim: int, index: int) -> Union[Tensor, Index]:
out = aten.select.int(input._data, dim, index)

if dim == 0 or dim == -2:
out = Index(out)

if index == 0 or index == -2: # Row-select:
out._dim_size = input.sparse_size(0)
out._is_sorted = input.is_sorted_by_row
if input.is_sorted_by_row:
out._indptr = input._indptr

else: # Col-select:
assert index == 1 or index == -1
out._dim_size = input.sparse_size(1)
out._is_sorted = input.is_sorted_by_col
if input.is_sorted_by_col:
out._indptr = input._indptr

return out


@implements(aten.add.Tensor)
def _add(
input: EdgeIndex,
Expand Down Expand Up @@ -1725,13 +1748,13 @@ def _torch_sparse_spmm(
if not transpose:
assert input.is_sorted_by_row
(rowptr, col), _ = input.get_csr()
row = input[0]
row = input._data[0]
if other.requires_grad and reduce in ['sum', 'mean']:
(colptr, _), perm = input.get_csc()
else:
assert input.is_sorted_by_col
(rowptr, col), _ = input.get_csc()
row = input[1]
row = input._data[1]
if other.requires_grad and reduce in ['sum', 'mean']:
(colptr, _), perm = input.get_csr()

Expand Down
20 changes: 14 additions & 6 deletions torch_geometric/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,15 +598,20 @@ def _flip(

@implements(aten.index_select.default)
def _index_select(
input: Index,
input: Union[Index, Tensor],
dim: int,
index: Tensor,
) -> Index:
index: Union[Index, Tensor],
) -> Union[Index, Tensor]:

data = aten.index_select.default(input._data, dim, index)
out = aten.index_select.default(
input._data if isinstance(input, Index) else input,
dim,
index._data if isinstance(index, Index) else index,
)

out = Index(data)
out._dim_size = input.dim_size
if isinstance(input, Index):
out = Index(out)
out._dim_size = input.dim_size

return out

Expand Down Expand Up @@ -652,6 +657,9 @@ def _index(

data = aten.index.Tensor(input._data, indices)

if data.dim() != 1:
return data

assert len(indices) == 1
index = indices[0]
assert index is not None
Expand Down
Loading