From 6982512d813958bd537fc98296bc884bb867e402 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 4 May 2024 09:58:17 +0000 Subject: [PATCH 1/7] update --- test/test_edge_index.py | 1 + torch_geometric/edge_index.py | 29 ++--- torch_geometric/index.py | 192 +++++++++++++++++++++++++++++++++- 3 files changed, 207 insertions(+), 15 deletions(-) diff --git a/test/test_edge_index.py b/test/test_edge_index.py index 37fe4609c5b7..7a396c6a8f6e 100644 --- a/test/test_edge_index.py +++ b/test/test_edge_index.py @@ -88,6 +88,7 @@ def test_identity(dtype, device, is_undirected): adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **kwargs) out = EdgeIndex(adj) + print(out._data) assert out.data_ptr() == adj.data_ptr() assert out.dtype == adj.dtype assert out.device == adj.device diff --git a/torch_geometric/edge_index.py b/torch_geometric/edge_index.py index dfe72ba7fac7..d418a51a5731 100644 --- a/torch_geometric/edge_index.py +++ b/torch_geometric/edge_index.py @@ -138,8 +138,8 @@ def assert_symmetric(size: Tuple[Optional[int], Optional[int]]) -> None: def assert_sorted(func: Callable) -> Callable: @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - if not args[0].is_sorted: + def wrapper(self: 'EdgeIndex', *args: Any, **kwargs: Any) -> Any: + if not self.is_sorted: cls_name = args[0].__class__.__name__ raise ValueError( f"Cannot call '{func.__name__}' since '{cls_name}' is not " @@ -222,7 +222,7 @@ class EdgeIndex(Tensor): # The size of the underlying sparse matrix: _sparse_size: Tuple[Optional[int], Optional[int]] = (None, None) - # Whether the `edge_index` represented is non-sorted (`None`), or sorted + # Whether the `edge_index` representation is non-sorted (`None`), or sorted # based on row or column values. _sort_order: Optional[SortOrder] = None @@ -344,6 +344,7 @@ def __new__( out._indptr = indptr if isinstance(data, cls): # If passed `EdgeIndex`, inherit metadata: + out._data = data._data out._T_perm = data._T_perm out._T_index = data._T_index out._T_indptr = data._T_indptr @@ -1102,6 +1103,17 @@ def sparse_narrow( edge_index._indptr = colptr return edge_index + def to_vector(self) -> Tensor: + r"""Converts :class:`EdgeIndex` into a one-dimensional index + vector representation. + """ + num_rows, num_cols = self.get_sparse_size() + + if num_rows * num_cols > torch_geometric.typing.MAX_INT64: + raise ValueError("'to_vector()' will result in an overflow") + + return self._data[0] * num_rows + self._data[1] + # PyTorch/Python builtins ################################################# def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]: @@ -1228,17 +1240,6 @@ def _clear_metadata(self) -> 'EdgeIndex': self._cat_metadata = None return self - def to_vector(self) -> Tensor: - r"""Converts :class:`EdgeIndex` into a one-dimensional index - vector representation. - """ - num_rows, num_cols = self.get_sparse_size() - - if num_rows * num_cols > torch_geometric.typing.MAX_INT64: - raise ValueError("'to_vector()' will result in an overflow") - - return self._data[0] * num_rows + self._data[1] - class SortReturnType(NamedTuple): values: EdgeIndex diff --git a/torch_geometric/index.py b/torch_geometric/index.py index 6c05e330ac3c..e5cf6b2d65b6 100644 --- a/torch_geometric/index.py +++ b/torch_geometric/index.py @@ -1,8 +1,15 @@ -from typing import Optional +import functools +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union import torch from torch import Tensor +from torch_geometric.typing import INDEX_DTYPES + +aten = torch.ops.aten + +HANDLED_FUNCTIONS: Dict[Callable, Callable] = {} + def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor: index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device) @@ -15,3 +22,186 @@ def index2ptr(index: Tensor, size: Optional[int] = None) -> Tensor: return torch._convert_indices_from_coo_to_csr( index, size, out_int32=index.dtype != torch.int64) + + +class CatMetadata(NamedTuple): + nnz: List[int] + dim_size: List[Optional[int]] + is_sorted: List[bool] + + +def implements(torch_function: Callable) -> Callable: + r"""Registers a :pytorch:`PyTorch` function override.""" + @functools.wraps(torch_function) + def decorator(my_function: Callable) -> Callable: + HANDLED_FUNCTIONS[torch_function] = my_function + return my_function + + return decorator + + +def assert_valid_dtype(tensor: Tensor) -> None: + if tensor.dtype not in INDEX_DTYPES: + raise ValueError(f"'Index' holds an unsupported data type " + f"(got '{tensor.dtype}', but expected one of " + f"{INDEX_DTYPES})") + + +def assert_one_dimensional(tensor: Tensor) -> None: + if tensor.dim() != 1: + raise ValueError(f"'Index' needs to be one-dimensional " + f"(got {tensor.dim()} dimensions)") + + +def assert_contiguous(tensor: Tensor) -> None: + if not tensor.is_contiguous(): + raise ValueError("'Index' needs to be contiguous. Please call " + "`index.contiguous()` before proceeding.") + + +def assert_sorted(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any: + if not self.is_sorted: + cls_name = self.__class__.__name__ + raise ValueError( + f"Cannot call '{func.__name__}' since '{cls_name}' is not " + f"sorted. Please call `{cls_name}.sort()` first.") + return func(*args, **kwargs) + + return wrapper + + +class Index(Tensor): + r"""TODO.""" + # See "https://pytorch.org/docs/stable/notes/extending.html" + # for a basic tutorial on how to subclass `torch.Tensor`. + + # The underlying tensor representation: + _data: Tensor + + # The size of the underlying sparse vector, e.g. `_data.max() + 1` : + _dim_size: Optional[int] = None + + # Whether the `index` representation is sorted: + _is_sorted: bool = False + + # A cache for its compressed representation: + _indptr: Optional[Tensor] = None + + # Whenever we perform a concatenation of indices, we cache the original + # metadata to be able to reconstruct individual indices: + _cat_metadata: Optional[CatMetadata] = None + + @staticmethod + def __new__( + cls: Type, + data: Any, + *args: Any, + dim_size: Optional[int] = None, + is_sorted: bool = False, + **kwargs: Any, + ) -> 'Index': + if not isinstance(data, Tensor): + data = torch.tensor(data, *args, **kwargs) + elif len(args) > 0: + raise TypeError( + f"new() received an invalid combination of arguments - got " + f"(Tensor, {', '.join(str(type(arg)) for arg in args)})") + elif len(kwargs) > 0: + raise TypeError(f"new() received invalid keyword arguments - got " + f"{set(kwargs.keys())})") + + assert isinstance(data, Tensor) + + indptr: Optional[Tensor] = None + + if isinstance(data, cls): # If passed `Index`, inherit metadata: + indptr = data._indptr + dim_size = dim_size or data.dim_size() + is_sorted = is_sorted or data.is_sorted + + assert_valid_dtype(data) + assert_one_dimensional(data) + assert_contiguous(data) + + out = Tensor._make_wrapper_subclass( # type: ignore + cls, + size=data.size(), + strides=data.stride(), + dtype=data.dtype, + device=data.device, + layout=data.layout, + requires_grad=False, + ) + assert isinstance(out, Index) + + # Attach metadata: + out._data = data + out._dim_size = dim_size + out._is_sorted = is_sorted + out._indptr = indptr + + if isinstance(data, cls): + out._data = data._data + + # Reset metadata if cache is invalidated: + if dim_size is not None and dim_size != data.dim_size(): + out._indptr = None + + return out + + # Validation ############################################################## + + def validate(self) -> 'Index': + raise NotImplementedError + + # Properties ############################################################## + + @property + def dim_size(self) -> Optional[int]: + raise NotImplementedError + + @property + def is_sorted(self) -> bool: + raise NotImplementedError + + # Cache Interface ######################################################### + + def get_dim_size(self) -> int: + raise NotImplementedError + + def dim_resize_(self, dim_size: Optional[int]) -> 'Index': + raise NotImplementedError + + @assert_sorted + def get_indptr(self) -> Tensor: + raise NotImplementedError + + def fill_cache_(self) -> 'Index': + raise NotImplementedError + + # Methods ################################################################# + + def share_memory_(self) -> 'Index': + self._data.share_memory_() + if self._indptr is not None: + self._indptr.share_memory_() + return self + + def is_shared(self) -> bool: + return self._data.is_shared() + + def as_tensor(self) -> Tensor: + r"""Zero-copies the :class:`Index` representation back to a + :class:`torch.Tensor` representation. + """ + return self._data + + def dim_narrow(self, start: Union[int, Tensor], length: int) -> 'Index': + raise NotImplementedError + + +def sort(self) -> None: + # TODO MOVE BEHIND TORCH DISPATCH + raise NotImplementedError From 8230a4f0791009f77382d5394b30cca03dd20ac1 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 4 May 2024 10:03:34 +0000 Subject: [PATCH 2/7] update --- CHANGELOG.md | 1 + test/test_edge_index.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 104ebfa3b8f1..75826e1ccda6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276)) - Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240)) - Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131)) - Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090)) diff --git a/test/test_edge_index.py b/test/test_edge_index.py index 7a396c6a8f6e..2353062bf4ff 100644 --- a/test/test_edge_index.py +++ b/test/test_edge_index.py @@ -88,7 +88,7 @@ def test_identity(dtype, device, is_undirected): adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **kwargs) out = EdgeIndex(adj) - print(out._data) + assert not isinstance(out.to_tensor(), EdgeIndex) assert out.data_ptr() == adj.data_ptr() assert out.dtype == adj.dtype assert out.device == adj.device From c02e2fbd30dd725eb4286d71301f162f94200ad3 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 4 May 2024 10:13:04 +0000 Subject: [PATCH 3/7] update --- test/test_edge_index.py | 2 +- torch_geometric/edge_index.py | 4 ++-- torch_geometric/index.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_edge_index.py b/test/test_edge_index.py index 2353062bf4ff..d9fdb404a2a4 100644 --- a/test/test_edge_index.py +++ b/test/test_edge_index.py @@ -88,7 +88,7 @@ def test_identity(dtype, device, is_undirected): adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **kwargs) out = EdgeIndex(adj) - assert not isinstance(out.to_tensor(), EdgeIndex) + assert not isinstance(out.as_tensor(), EdgeIndex) assert out.data_ptr() == adj.data_ptr() assert out.dtype == adj.dtype assert out.device == adj.device diff --git a/torch_geometric/edge_index.py b/torch_geometric/edge_index.py index d418a51a5731..aef3777115dd 100644 --- a/torch_geometric/edge_index.py +++ b/torch_geometric/edge_index.py @@ -140,11 +140,11 @@ def assert_sorted(func: Callable) -> Callable: @functools.wraps(func) def wrapper(self: 'EdgeIndex', *args: Any, **kwargs: Any) -> Any: if not self.is_sorted: - cls_name = args[0].__class__.__name__ + cls_name = self.__class__.__name__ raise ValueError( f"Cannot call '{func.__name__}' since '{cls_name}' is not " f"sorted. Please call `{cls_name}.sort_by(...)` first.") - return func(*args, **kwargs) + return func(self, *args, **kwargs) return wrapper diff --git a/torch_geometric/index.py b/torch_geometric/index.py index e5cf6b2d65b6..a0b013186825 100644 --- a/torch_geometric/index.py +++ b/torch_geometric/index.py @@ -67,7 +67,7 @@ def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any: raise ValueError( f"Cannot call '{func.__name__}' since '{cls_name}' is not " f"sorted. Please call `{cls_name}.sort()` first.") - return func(*args, **kwargs) + return func(self, *args, **kwargs) return wrapper @@ -202,6 +202,6 @@ def dim_narrow(self, start: Union[int, Tensor], length: int) -> 'Index': raise NotImplementedError -def sort(self) -> None: - # TODO MOVE BEHIND TORCH DISPATCH - raise NotImplementedError +# def sort(self) -> None: +# # TODO MOVE BEHIND TORCH DISPATCH +# raise NotImplementedError From 2034e9a413917d2a631421ff2a41e8056bd40b4b Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 4 May 2024 10:34:15 +0000 Subject: [PATCH 4/7] update --- CHANGELOG.md | 2 +- test/test_edge_index.py | 8 +-- test/test_index.py | 56 +++++++++++++++++ torch_geometric/__init__.py | 2 + torch_geometric/index.py | 116 ++++++++++++++++++++++++++++++++++-- 5 files changed, 171 insertions(+), 13 deletions(-) create mode 100644 test/test_index.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 75826e1ccda6..38edaa119be5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276)) +- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277)) - Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240)) - Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131)) - Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090)) diff --git a/test/test_edge_index.py b/test/test_edge_index.py index d9fdb404a2a4..e7f82fe92596 100644 --- a/test/test_edge_index.py +++ b/test/test_edge_index.py @@ -46,12 +46,8 @@ def test_basic(dtype, device): adj.validate() assert isinstance(adj, EdgeIndex) - if torch_geometric.typing.WITH_PT112: - assert str(adj).startswith('EdgeIndex([[0, 1, 1, 2],\n' - ' [1, 0, 2, 1]], ') - else: - assert str(adj).startswith('tensor([[0, 1, 1, 2],\n' - ' [1, 0, 2, 1]], ') + assert str(adj).startswith('EdgeIndex([[0, 1, 1, 2],\n' + ' [1, 0, 2, 1]], ') assert 'sparse_size=(3, 3), nnz=4' in str(adj) assert (f"device='{device}'" in str(adj)) == adj.is_cuda assert (f'dtype={dtype}' in str(adj)) == (dtype != torch.long) diff --git a/test/test_index.py b/test/test_index.py new file mode 100644 index 000000000000..c9928a4a76a7 --- /dev/null +++ b/test/test_index.py @@ -0,0 +1,56 @@ +import pytest +import torch + +from torch_geometric import Index +from torch_geometric.testing import withCUDA +from torch_geometric.typing import INDEX_DTYPES + +DTYPES = [pytest.param(dtype, id=str(dtype)[6:]) for dtype in INDEX_DTYPES] + + +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +def test_basic(dtype, device): + kwargs = dict(dtype=dtype, device=device, dim_size=3) + index = Index([0, 1, 1, 2], **kwargs) + index.validate() + assert isinstance(index, Index) + + assert str(index).startswith('Index([0, 1, 1, 2], ') + assert 'dim_size=3' in str(index) + assert (f"device='{device}'" in str(index)) == index.is_cuda + assert (f'dtype={dtype}' in str(index)) == (dtype != torch.long) + + assert index.dtype == dtype + assert index.device == device + assert index.dim_size == 3 + assert not index.is_sorted + + out = index.as_tensor() + assert not isinstance(out, Index) + assert out.dtype == dtype + assert out.device == device + + out = index * 1 + assert not isinstance(out, Index) + assert out.dtype == dtype + assert out.device == device + + +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +def test_identity(dtype, device): + kwargs = dict(dtype=dtype, device=device, dim_size=3, is_sorted=True) + index = Index([0, 1, 1, 2], **kwargs) + + out = Index(index) + assert not isinstance(out.as_tensor(), Index) + assert out.data_ptr() == index.data_ptr() + assert out.dtype == index.dtype + assert out.device == index.device + assert out.dim_size == index.dim_size + assert out.is_sorted == index.is_sorted + + out = Index(index, dim_size=4, is_sorted=False) + assert out.dim_size == 4 + assert out.is_sorted == index.is_sorted diff --git a/torch_geometric/__init__.py b/torch_geometric/__init__.py index 29a636833e4b..381c3524e168 100644 --- a/torch_geometric/__init__.py +++ b/torch_geometric/__init__.py @@ -1,4 +1,5 @@ from ._compile import compile, is_compiling +from .index import Index from .edge_index import EdgeIndex from .seed import seed_everything from .home import get_home_dir, set_home_dir @@ -25,6 +26,7 @@ __version__ = '2.6.0' __all__ = [ + 'Index', 'EdgeIndex', 'seed_everything', 'get_home_dir', diff --git a/torch_geometric/index.py b/torch_geometric/index.py index a0b013186825..2382ec2ee33c 100644 --- a/torch_geometric/index.py +++ b/torch_geometric/index.py @@ -1,7 +1,19 @@ import functools -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + Type, + Union, +) import torch +import torch.utils._pytree as pytree from torch import Tensor from torch_geometric.typing import INDEX_DTYPES @@ -118,7 +130,7 @@ def __new__( if isinstance(data, cls): # If passed `Index`, inherit metadata: indptr = data._indptr - dim_size = dim_size or data.dim_size() + dim_size = dim_size or data.dim_size is_sorted = is_sorted or data.is_sorted assert_valid_dtype(data) @@ -146,7 +158,7 @@ def __new__( out._data = data._data # Reset metadata if cache is invalidated: - if dim_size is not None and dim_size != data.dim_size(): + if dim_size is not None and dim_size != data.dim_size: out._indptr = None return out @@ -154,31 +166,38 @@ def __new__( # Validation ############################################################## def validate(self) -> 'Index': - raise NotImplementedError + r"""TODO.""" + return self # Properties ############################################################## @property def dim_size(self) -> Optional[int]: - raise NotImplementedError + r"""TODO.""" + return self._dim_size @property def is_sorted(self) -> bool: - raise NotImplementedError + r"""TODO.""" + return self._is_sorted # Cache Interface ######################################################### def get_dim_size(self) -> int: + r"""TODO.""" raise NotImplementedError def dim_resize_(self, dim_size: Optional[int]) -> 'Index': + r"""TODO.""" raise NotImplementedError @assert_sorted def get_indptr(self) -> Tensor: + r"""TODO.""" raise NotImplementedError def fill_cache_(self) -> 'Index': + r"""TODO.""" raise NotImplementedError # Methods ################################################################# @@ -199,8 +218,93 @@ def as_tensor(self) -> Tensor: return self._data def dim_narrow(self, start: Union[int, Tensor], length: int) -> 'Index': + r"""TODO.""" raise NotImplementedError + # PyTorch/Python builtins ################################################# + + def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]: + attrs = ['_data'] + if self._indptr is not None: + attrs.append('_indptr') + + ctx = ( + self._dim_size, + self._is_sorted, + self._cat_metadata, + ) + + return attrs, ctx + + @staticmethod + def __tensor_unflatten__( + inner_tensors: Dict[str, Any], + ctx: Tuple[Any, ...], + outer_size: Tuple[int, ...], + outer_stride: Tuple[int, ...], + ) -> 'Index': + index = Index( + inner_tensors['_data'], + dim_size=ctx[0], + is_sorted=ctx[1], + ) + + index._indptr = inner_tensors.get('_indptr', None) + index._cat_metadata = ctx[2] + + return index + + # Prevent auto-wrapping outputs back into the proper subclass type: + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__( + cls: Type, + func: Callable[..., Any], + types: Iterable[Type[Any]], + args: Iterable[Tuple[Any, ...]] = (), + kwargs: Optional[Dict[Any, Any]] = None, + ) -> Any: + # `Index` should be treated as a regular PyTorch tensor for all + # standard PyTorch functionalities. However, + # * some of its metadata can be transferred to new functions, e.g., + # `torch.narrow()` can inherit the `is_sorted` property. + # * not all operations lead to valid `Index` tensors again, e.g., + # `torch.sum()` does not yield a `Index` as its output, or + # `torch.stack() violates the [*] shape assumption. + + # To account for this, we hold a number of `HANDLED_FUNCTIONS` that + # implement specific functions for valid `Index` routines. + if func in HANDLED_FUNCTIONS: + return HANDLED_FUNCTIONS[func](*args, **(kwargs or {})) + + # For all other PyTorch functions, we treat them as vanilla tensors. + args = pytree.tree_map_only(Index, lambda x: x._data, args) + if kwargs is not None: + kwargs = pytree.tree_map_only(Index, lambda x: x._data, kwargs) + return func(*args, **(kwargs or {})) + + def __repr__(self) -> str: # type: ignore + prefix = f'{self.__class__.__name__}(' + indent = len(prefix) + tensor_str = torch._tensor_str._tensor_str(self._data, indent) + + suffixes = [] + if self.dim_size is not None: + suffixes.append(f'dim_size={self.dim_size}') + if (self.device.type != torch._C._get_default_device() + or (self.device.type == 'cuda' + and torch.cuda.current_device() != self.device.index) + or (self.device.type == 'mps')): + suffixes.append(f"device='{self.device}'") + if self.dtype != torch.int64: + suffixes.append(f'dtype={self.dtype}') + if self.is_sorted: + suffixes.append('is_sorted=True') + + return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes, + indent, force_newline=False) + # def sort(self) -> None: # # TODO MOVE BEHIND TORCH DISPATCH From 4b90896bf42461d860d33c7ac64fe61e4dd5df74 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 4 May 2024 10:53:24 +0000 Subject: [PATCH 5/7] update --- test/test_index.py | 34 ++++++++++++++++++++++++++++++++ torch_geometric/index.py | 42 +++++++++++++++++++++++++++++++++++----- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/test/test_index.py b/test/test_index.py index c9928a4a76a7..8b2b34cbb46b 100644 --- a/test/test_index.py +++ b/test/test_index.py @@ -1,5 +1,6 @@ import pytest import torch +from torch import tensor from torch_geometric import Index from torch_geometric.testing import withCUDA @@ -54,3 +55,36 @@ def test_identity(dtype, device): out = Index(index, dim_size=4, is_sorted=False) assert out.dim_size == 4 assert out.is_sorted == index.is_sorted + + +def test_validate(): + with pytest.raises(ValueError, match="unsupported data type"): + Index([0.0, 1.0]) + with pytest.raises(ValueError, match="needs to be one-dimensional"): + Index([[0], [1]]) + with pytest.raises(TypeError, match="invalid combination of arguments"): + Index(torch.tensor([0, 1]), torch.long) + with pytest.raises(TypeError, match="invalid keyword arguments"): + Index(torch.tensor([0, 1]), dtype=torch.long) + with pytest.raises(ValueError, match="contains negative indices"): + Index([-1, 0]).validate() + with pytest.raises(ValueError, match="than its registered size"): + Index([0, 10], dim_size=2).validate() + with pytest.raises(ValueError, match="not sorted"): + Index([1, 0], is_sorted=True).validate() + + +@withCUDA +@pytest.mark.parametrize('dtype', DTYPES) +def test_fill_cache_(dtype, device): + kwargs = dict(dtype=dtype, device=device) + index = Index([0, 1, 1, 2], is_sorted=True, **kwargs) + index.validate().fill_cache_() + assert index.dim_size == 3 + assert index._indptr.dtype == dtype + assert index._indptr.equal(tensor([0, 1, 3, 4], device=device)) + + index = Index([1, 0, 2, 1], **kwargs) + index.validate().fill_cache_() + assert index.dim_size == 3 + assert index._indptr is None diff --git a/torch_geometric/index.py b/torch_geometric/index.py index 2382ec2ee33c..328257b7d4aa 100644 --- a/torch_geometric/index.py +++ b/torch_geometric/index.py @@ -167,6 +167,24 @@ def __new__( def validate(self) -> 'Index': r"""TODO.""" + assert_valid_dtype(self._data) + assert_one_dimensional(self._data) + assert_contiguous(self._data) + + if self.numel() > 0 and self._data.min() < 0: + raise ValueError(f"'{self.__class__.__name__}' contains negative " + f"indices (got {int(self.min())})") + + if (self.numel() > 0 and self.dim_size is not None + and self._data.max() >= self.dim_size): + raise ValueError(f"'{self.__class__.__name__}' contains larger " + f"indices than its registered size " + f"(got {int(self._data.max())}, but expected " + f"values smaller than {self.dim_size})") + + if self.is_sorted and (self._data.diff() < 0).any(): + raise ValueError(f"'{self.__class__.__name__}' is not sorted") + return self # Properties ############################################################## @@ -185,20 +203,34 @@ def is_sorted(self) -> bool: def get_dim_size(self) -> int: r"""TODO.""" - raise NotImplementedError + if self._dim_size is None: + dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0 + self._dim_size = dim_size + + assert isinstance(self._dim_size, int) + return self._dim_size def dim_resize_(self, dim_size: Optional[int]) -> 'Index': r"""TODO.""" - raise NotImplementedError + raise NotImplementedError # TODO @assert_sorted def get_indptr(self) -> Tensor: r"""TODO.""" - raise NotImplementedError + if self._indptr is None: + self._indptr = index2ptr(self._data, self.get_dim_size()) + + assert isinstance(self._indptr, Tensor) + return self._indptr def fill_cache_(self) -> 'Index': r"""TODO.""" - raise NotImplementedError + self.get_dim_size() + + if self.is_sorted: + self.get_indptr() + + return self # Methods ################################################################# @@ -219,7 +251,7 @@ def as_tensor(self) -> Tensor: def dim_narrow(self, start: Union[int, Tensor], length: int) -> 'Index': r"""TODO.""" - raise NotImplementedError + raise NotImplementedError # TODO # PyTorch/Python builtins ################################################# From e87df9b4d648637451a59d22370ecea728b874f9 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 4 May 2024 10:54:03 +0000 Subject: [PATCH 6/7] update --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38edaa119be5..901bb4c99644 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277)) +- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278)) - Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240)) - Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131)) - Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090)) From c0fdeef6364e3fe8a9a2cd4e77bd3436941fdec9 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 4 May 2024 10:56:19 +0000 Subject: [PATCH 7/7] update --- torch_geometric/index.py | 84 ---------------------------------------- 1 file changed, 84 deletions(-) diff --git a/torch_geometric/index.py b/torch_geometric/index.py index b5a130009fc5..328257b7d4aa 100644 --- a/torch_geometric/index.py +++ b/torch_geometric/index.py @@ -337,90 +337,6 @@ def __repr__(self) -> str: # type: ignore return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=False) - # PyTorch/Python builtins ################################################# - - def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]: - attrs = ['_data'] - if self._indptr is not None: - attrs.append('_indptr') - - ctx = ( - self._dim_size, - self._is_sorted, - self._cat_metadata, - ) - - return attrs, ctx - - @staticmethod - def __tensor_unflatten__( - inner_tensors: Dict[str, Any], - ctx: Tuple[Any, ...], - outer_size: Tuple[int, ...], - outer_stride: Tuple[int, ...], - ) -> 'Index': - index = Index( - inner_tensors['_data'], - dim_size=ctx[0], - is_sorted=ctx[1], - ) - - index._indptr = inner_tensors.get('_indptr', None) - index._cat_metadata = ctx[2] - - return index - - # Prevent auto-wrapping outputs back into the proper subclass type: - __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_dispatch__( - cls: Type, - func: Callable[..., Any], - types: Iterable[Type[Any]], - args: Iterable[Tuple[Any, ...]] = (), - kwargs: Optional[Dict[Any, Any]] = None, - ) -> Any: - # `Index` should be treated as a regular PyTorch tensor for all - # standard PyTorch functionalities. However, - # * some of its metadata can be transferred to new functions, e.g., - # `torch.narrow()` can inherit the `is_sorted` property. - # * not all operations lead to valid `Index` tensors again, e.g., - # `torch.sum()` does not yield a `Index` as its output, or - # `torch.stack() violates the [*] shape assumption. - - # To account for this, we hold a number of `HANDLED_FUNCTIONS` that - # implement specific functions for valid `Index` routines. - if func in HANDLED_FUNCTIONS: - return HANDLED_FUNCTIONS[func](*args, **(kwargs or {})) - - # For all other PyTorch functions, we treat them as vanilla tensors. - args = pytree.tree_map_only(Index, lambda x: x._data, args) - if kwargs is not None: - kwargs = pytree.tree_map_only(Index, lambda x: x._data, kwargs) - return func(*args, **(kwargs or {})) - - def __repr__(self) -> str: # type: ignore - prefix = f'{self.__class__.__name__}(' - indent = len(prefix) - tensor_str = torch._tensor_str._tensor_str(self._data, indent) - - suffixes = [] - if self.dim_size is not None: - suffixes.append(f'dim_size={self.dim_size}') - if (self.device.type != torch._C._get_default_device() - or (self.device.type == 'cuda' - and torch.cuda.current_device() != self.device.index) - or (self.device.type == 'mps')): - suffixes.append(f"device='{self.device}'") - if self.dtype != torch.int64: - suffixes.append(f'dtype={self.dtype}') - if self.is_sorted: - suffixes.append('is_sorted=True') - - return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes, - indent, force_newline=False) - # def sort(self) -> None: # # TODO MOVE BEHIND TORCH DISPATCH