Skip to content

Commit

Permalink
Add Index.narrow() (#9287)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored May 4, 2024
1 parent c0e1459 commit 4839cb8
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287))
- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))
- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))
- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))
Expand Down
13 changes: 13 additions & 0 deletions test/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,16 @@ def test_index_select(dtype, device):
assert out.data_ptr() == inplace.data_ptr()
assert not isinstance(out, Index)
assert not isinstance(inplace, Index)


@withCUDA
@pytest.mark.parametrize('dtype', DTYPES)
def test_narrow(dtype, device):
kwargs = dict(dtype=dtype, device=device)
index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)

out = index.narrow(0, start=1, length=2)
assert isinstance(out, Index)
assert out.equal(tensor([1, 1], device=device))
assert out.dim_size == 3
assert out.is_sorted
39 changes: 39 additions & 0 deletions torch_geometric/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,16 @@ def __repr__(self) -> str: # type: ignore
return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
indent, force_newline=False)

# Helpers #################################################################

def _shallow_copy(self) -> 'Index':
out = Index(self._data)
out._dim_size = self._dim_size
out._is_sorted = self._is_sorted
out._indptr = self._indptr
out._cat_metadata = self._cat_metadata
return out


def apply_(
tensor: Index,
Expand Down Expand Up @@ -526,3 +536,32 @@ def _index_select(
out._dim_size = input.dim_size

return out


@implements(aten.slice.Tensor)
def _slice(
input: Index,
dim: int,
start: Optional[int] = None,
end: Optional[int] = None,
step: int = 1,
) -> Index:

if ((start is None or start <= 0)
and (end is None or end > input.size(dim)) and step == 1):
return input._shallow_copy() # No-op.

data = aten.slice.Tensor(input._data, dim, start, end, step)

if step != 1:
data = data.contiguous()

out = Index(data)
out._dim_size = input.dim_size
# NOTE We could potentially maintain the `indptr` attribute here,
# but it is not really clear if this is worth it. The most important
# information `is_sorted` needs to be maintained though:
if step >= 0:
out._is_sorted = input.is_sorted

return out

0 comments on commit 4839cb8

Please sign in to comment.