Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce torch_geometric.Index #9276

Merged
merged 3 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276))
- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))
- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))
- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))
Expand Down
1 change: 1 addition & 0 deletions test/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_identity(dtype, device, is_undirected):
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **kwargs)

out = EdgeIndex(adj)
assert not isinstance(out.as_tensor(), EdgeIndex)
assert out.data_ptr() == adj.data_ptr()
assert out.dtype == adj.dtype
assert out.device == adj.device
Expand Down
33 changes: 17 additions & 16 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@

def assert_sorted(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if not args[0].is_sorted:
cls_name = args[0].__class__.__name__
def wrapper(self: 'EdgeIndex', *args: Any, **kwargs: Any) -> Any:
if not self.is_sorted:
cls_name = self.__class__.__name__
raise ValueError(
f"Cannot call '{func.__name__}' since '{cls_name}' is not "
f"sorted. Please call `{cls_name}.sort_by(...)` first.")
return func(*args, **kwargs)
return func(self, *args, **kwargs)

return wrapper

Expand Down Expand Up @@ -222,7 +222,7 @@
# The size of the underlying sparse matrix:
_sparse_size: Tuple[Optional[int], Optional[int]] = (None, None)

# Whether the `edge_index` represented is non-sorted (`None`), or sorted
# Whether the `edge_index` representation is non-sorted (`None`), or sorted
# based on row or column values.
_sort_order: Optional[SortOrder] = None

Expand Down Expand Up @@ -344,6 +344,7 @@
out._indptr = indptr

if isinstance(data, cls): # If passed `EdgeIndex`, inherit metadata:
out._data = data._data
out._T_perm = data._T_perm
out._T_index = data._T_index
out._T_indptr = data._T_indptr
Expand Down Expand Up @@ -1102,6 +1103,17 @@
edge_index._indptr = colptr
return edge_index

def to_vector(self) -> Tensor:
r"""Converts :class:`EdgeIndex` into a one-dimensional index
vector representation.
"""
num_rows, num_cols = self.get_sparse_size()

if num_rows * num_cols > torch_geometric.typing.MAX_INT64:
raise ValueError("'to_vector()' will result in an overflow")

Check warning on line 1113 in torch_geometric/edge_index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/edge_index.py#L1113

Added line #L1113 was not covered by tests

return self._data[0] * num_rows + self._data[1]

# PyTorch/Python builtins #################################################

def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
Expand Down Expand Up @@ -1228,17 +1240,6 @@
self._cat_metadata = None
return self

def to_vector(self) -> Tensor:
r"""Converts :class:`EdgeIndex` into a one-dimensional index
vector representation.
"""
num_rows, num_cols = self.get_sparse_size()

if num_rows * num_cols > torch_geometric.typing.MAX_INT64:
raise ValueError("'to_vector()' will result in an overflow")

return self._data[0] * num_rows + self._data[1]


class SortReturnType(NamedTuple):
values: EdgeIndex
Expand Down
192 changes: 191 additions & 1 deletion torch_geometric/index.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Optional
import functools
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union

import torch
from torch import Tensor

from torch_geometric.typing import INDEX_DTYPES

aten = torch.ops.aten

HANDLED_FUNCTIONS: Dict[Callable, Callable] = {}


def ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:
index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
Expand All @@ -15,3 +22,186 @@

return torch._convert_indices_from_coo_to_csr(
index, size, out_int32=index.dtype != torch.int64)


class CatMetadata(NamedTuple):
nnz: List[int]
dim_size: List[Optional[int]]
is_sorted: List[bool]


def implements(torch_function: Callable) -> Callable:
r"""Registers a :pytorch:`PyTorch` function override."""
@functools.wraps(torch_function)
def decorator(my_function: Callable) -> Callable:
HANDLED_FUNCTIONS[torch_function] = my_function
return my_function

Check warning on line 38 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L35-L38

Added lines #L35 - L38 were not covered by tests

return decorator

Check warning on line 40 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L40

Added line #L40 was not covered by tests


def assert_valid_dtype(tensor: Tensor) -> None:
if tensor.dtype not in INDEX_DTYPES:
raise ValueError(f"'Index' holds an unsupported data type "

Check warning on line 45 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L44-L45

Added lines #L44 - L45 were not covered by tests
f"(got '{tensor.dtype}', but expected one of "
f"{INDEX_DTYPES})")


def assert_one_dimensional(tensor: Tensor) -> None:
if tensor.dim() != 1:
raise ValueError(f"'Index' needs to be one-dimensional "

Check warning on line 52 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L51-L52

Added lines #L51 - L52 were not covered by tests
f"(got {tensor.dim()} dimensions)")


def assert_contiguous(tensor: Tensor) -> None:
if not tensor.is_contiguous():
raise ValueError("'Index' needs to be contiguous. Please call "

Check warning on line 58 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L57-L58

Added lines #L57 - L58 were not covered by tests
"`index.contiguous()` before proceeding.")


def assert_sorted(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:
if not self.is_sorted:
cls_name = self.__class__.__name__
raise ValueError(

Check warning on line 67 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L65-L67

Added lines #L65 - L67 were not covered by tests
f"Cannot call '{func.__name__}' since '{cls_name}' is not "
f"sorted. Please call `{cls_name}.sort()` first.")
return func(self, *args, **kwargs)

Check warning on line 70 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L70

Added line #L70 was not covered by tests

return wrapper


class Index(Tensor):
r"""TODO."""
# See "https://pytorch.org/docs/stable/notes/extending.html"
# for a basic tutorial on how to subclass `torch.Tensor`.

# The underlying tensor representation:
_data: Tensor

# The size of the underlying sparse vector, e.g. `_data.max() + 1` :
_dim_size: Optional[int] = None

# Whether the `index` representation is sorted:
_is_sorted: bool = False

# A cache for its compressed representation:
_indptr: Optional[Tensor] = None

# Whenever we perform a concatenation of indices, we cache the original
# metadata to be able to reconstruct individual indices:
_cat_metadata: Optional[CatMetadata] = None

@staticmethod
def __new__(
cls: Type,
data: Any,
*args: Any,
dim_size: Optional[int] = None,
is_sorted: bool = False,
**kwargs: Any,
) -> 'Index':
if not isinstance(data, Tensor):
data = torch.tensor(data, *args, **kwargs)
elif len(args) > 0:
raise TypeError(

Check warning on line 108 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L105-L108

Added lines #L105 - L108 were not covered by tests
f"new() received an invalid combination of arguments - got "
f"(Tensor, {', '.join(str(type(arg)) for arg in args)})")
elif len(kwargs) > 0:
raise TypeError(f"new() received invalid keyword arguments - got "

Check warning on line 112 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L111-L112

Added lines #L111 - L112 were not covered by tests
f"{set(kwargs.keys())})")

assert isinstance(data, Tensor)

Check warning on line 115 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L115

Added line #L115 was not covered by tests

indptr: Optional[Tensor] = None

Check warning on line 117 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L117

Added line #L117 was not covered by tests

if isinstance(data, cls): # If passed `Index`, inherit metadata:
indptr = data._indptr
dim_size = dim_size or data.dim_size()
is_sorted = is_sorted or data.is_sorted

assert_valid_dtype(data)
assert_one_dimensional(data)
assert_contiguous(data)

Check warning on line 126 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L124-L126

Added lines #L124 - L126 were not covered by tests

out = Tensor._make_wrapper_subclass( # type: ignore

Check warning on line 128 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L128

Added line #L128 was not covered by tests
cls,
size=data.size(),
strides=data.stride(),
dtype=data.dtype,
device=data.device,
layout=data.layout,
requires_grad=False,
)
assert isinstance(out, Index)

Check warning on line 137 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L137

Added line #L137 was not covered by tests

# Attach metadata:
out._data = data
out._dim_size = dim_size
out._is_sorted = is_sorted
out._indptr = indptr

Check warning on line 143 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L140-L143

Added lines #L140 - L143 were not covered by tests

if isinstance(data, cls):
out._data = data._data

Check warning on line 146 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L145-L146

Added lines #L145 - L146 were not covered by tests

# Reset metadata if cache is invalidated:
if dim_size is not None and dim_size != data.dim_size():
out._indptr = None

Check warning on line 150 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L149-L150

Added lines #L149 - L150 were not covered by tests

return out

Check warning on line 152 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L152

Added line #L152 was not covered by tests

# Validation ##############################################################

def validate(self) -> 'Index':
raise NotImplementedError

# Properties ##############################################################

@property
def dim_size(self) -> Optional[int]:
raise NotImplementedError

@property
def is_sorted(self) -> bool:
raise NotImplementedError

# Cache Interface #########################################################

def get_dim_size(self) -> int:
raise NotImplementedError

def dim_resize_(self, dim_size: Optional[int]) -> 'Index':
raise NotImplementedError

@assert_sorted
def get_indptr(self) -> Tensor:
raise NotImplementedError

def fill_cache_(self) -> 'Index':
raise NotImplementedError

# Methods #################################################################

def share_memory_(self) -> 'Index':
self._data.share_memory_()
if self._indptr is not None:
self._indptr.share_memory_()
return self

Check warning on line 190 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L187-L190

Added lines #L187 - L190 were not covered by tests

def is_shared(self) -> bool:
return self._data.is_shared()

Check warning on line 193 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L193

Added line #L193 was not covered by tests

def as_tensor(self) -> Tensor:
r"""Zero-copies the :class:`Index` representation back to a
:class:`torch.Tensor` representation.
"""
return self._data

Check warning on line 199 in torch_geometric/index.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/index.py#L199

Added line #L199 was not covered by tests

def dim_narrow(self, start: Union[int, Tensor], length: int) -> 'Index':
raise NotImplementedError


# def sort(self) -> None:
# # TODO MOVE BEHIND TORCH DISPATCH
# raise NotImplementedError
Loading