Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validate() and fill_cache_() to torch_geometric.Index #9278

Merged
merged 9 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277))
- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278))
- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))
- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))
- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))
Expand Down
34 changes: 34 additions & 0 deletions test/test_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
from torch import tensor

from torch_geometric import Index
from torch_geometric.testing import withCUDA
Expand Down Expand Up @@ -54,3 +55,36 @@ def test_identity(dtype, device):
out = Index(index, dim_size=4, is_sorted=False)
assert out.dim_size == 4
assert out.is_sorted == index.is_sorted


def test_validate():
with pytest.raises(ValueError, match="unsupported data type"):
Index([0.0, 1.0])
with pytest.raises(ValueError, match="needs to be one-dimensional"):
Index([[0], [1]])
with pytest.raises(TypeError, match="invalid combination of arguments"):
Index(torch.tensor([0, 1]), torch.long)
with pytest.raises(TypeError, match="invalid keyword arguments"):
Index(torch.tensor([0, 1]), dtype=torch.long)
with pytest.raises(ValueError, match="contains negative indices"):
Index([-1, 0]).validate()
with pytest.raises(ValueError, match="than its registered size"):
Index([0, 10], dim_size=2).validate()
with pytest.raises(ValueError, match="not sorted"):
Index([1, 0], is_sorted=True).validate()


@withCUDA
@pytest.mark.parametrize('dtype', DTYPES)
def test_fill_cache_(dtype, device):
kwargs = dict(dtype=dtype, device=device)
index = Index([0, 1, 1, 2], is_sorted=True, **kwargs)
index.validate().fill_cache_()
assert index.dim_size == 3
assert index._indptr.dtype == dtype
assert index._indptr.equal(tensor([0, 1, 3, 4], device=device))

index = Index([1, 0, 2, 1], **kwargs)
index.validate().fill_cache_()
assert index.dim_size == 3
assert index._indptr is None
42 changes: 37 additions & 5 deletions torch_geometric/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,24 @@ def __new__(

def validate(self) -> 'Index':
r"""TODO."""
assert_valid_dtype(self._data)
assert_one_dimensional(self._data)
assert_contiguous(self._data)

if self.numel() > 0 and self._data.min() < 0:
raise ValueError(f"'{self.__class__.__name__}' contains negative "
f"indices (got {int(self.min())})")

if (self.numel() > 0 and self.dim_size is not None
and self._data.max() >= self.dim_size):
raise ValueError(f"'{self.__class__.__name__}' contains larger "
f"indices than its registered size "
f"(got {int(self._data.max())}, but expected "
f"values smaller than {self.dim_size})")

if self.is_sorted and (self._data.diff() < 0).any():
raise ValueError(f"'{self.__class__.__name__}' is not sorted")

return self

# Properties ##############################################################
Expand All @@ -185,20 +203,34 @@ def is_sorted(self) -> bool:

def get_dim_size(self) -> int:
r"""TODO."""
raise NotImplementedError
if self._dim_size is None:
dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0
self._dim_size = dim_size

assert isinstance(self._dim_size, int)
return self._dim_size

def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
r"""TODO."""
raise NotImplementedError
raise NotImplementedError # TODO

@assert_sorted
def get_indptr(self) -> Tensor:
r"""TODO."""
raise NotImplementedError
if self._indptr is None:
self._indptr = index2ptr(self._data, self.get_dim_size())

assert isinstance(self._indptr, Tensor)
return self._indptr

def fill_cache_(self) -> 'Index':
r"""TODO."""
raise NotImplementedError
self.get_dim_size()

if self.is_sorted:
self.get_indptr()

return self

# Methods #################################################################

Expand All @@ -219,7 +251,7 @@ def as_tensor(self) -> Tensor:

def dim_narrow(self, start: Union[int, Tensor], length: int) -> 'Index':
r"""TODO."""
raise NotImplementedError
raise NotImplementedError # TODO

# PyTorch/Python builtins #################################################

Expand Down
Loading