diff --git a/CHANGELOG.md b/CHANGELOG.md index 79a055c8fae7..6148d3fbc1ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,7 +56,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908)) - Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873)) - Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815), [#4862](https://github.com/pyg-team/pytorch_geometric/pull/4862/files)) -- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037), [#5088](https://github.com/pyg-team/pytorch_geometric/pull/5088), [#5270](https://github.com/pyg-team/pytorch_geometric/pull/5270)) +- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037), [#5088](https://github.com/pyg-team/pytorch_geometric/pull/5088), [#5270](https://github.com/pyg-team/pytorch_geometric/pull/5270), [#5307](https://github.com/pyg-team/pytorch_geometric/pull/5307)) - Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) - Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850)) - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838)) diff --git a/test/data/test_data.py b/test/data/test_data.py index 9e8ac4e8cab6..faaf0e21d390 100644 --- a/test/data/test_data.py +++ b/test/data/test_data.py @@ -315,7 +315,3 @@ def assert_equal_tensor_tuple(expected, actual): # Get attrs: edge_attrs = data.get_all_edge_attrs() assert len(edge_attrs) == 3 - - # Get num nodes: - assert data.num_src_nodes() == 3 - assert data.num_dst_nodes() == 3 diff --git a/test/data/test_graph_store.py b/test/data/test_graph_store.py index 266f945fa2a1..c59120f7b4da 100644 --- a/test/data/test_graph_store.py +++ b/test/data/test_graph_store.py @@ -63,18 +63,10 @@ def test_graph_store_conversion(): # Put all edge indices: graph_store.put_edge_index(edge_index=coo, edge_type=('v', '1', 'v'), layout='coo', size=(100, 100), is_sorted=True) - assert graph_store.num_src_nodes(edge_type=('v', '1', 'v')) == 100 - assert graph_store.num_dst_nodes(edge_type=('v', '1', 'v')) == 100 - graph_store.put_edge_index(edge_index=csr, edge_type=('v', '2', 'v'), layout='csr', size=(100, 100)) - assert graph_store.num_src_nodes(edge_type=('v', '2', 'v')) == 100 - assert graph_store.num_dst_nodes(edge_type=('v', '2', 'v')) == 100 - graph_store.put_edge_index(edge_index=csc, edge_type=('v', '3', 'v'), layout='csc', size=(100, 100)) - assert graph_store.num_src_nodes(edge_type=('v', '3', 'v')) == 100 - assert graph_store.num_dst_nodes(edge_type=('v', '3', 'v')) == 100 def assert_edge_index_equal(expected: torch.Tensor, actual: torch.Tensor): assert torch.equal(sort_edge_index(expected), sort_edge_index(actual)) diff --git a/test/data/test_hetero_data.py b/test/data/test_hetero_data.py index 98bd13084fe1..3f61bbb678ae 100644 --- a/test/data/test_hetero_data.py +++ b/test/data/test_hetero_data.py @@ -493,7 +493,3 @@ def assert_equal_tensor_tuple(expected, actual): # Get attrs: edge_attrs = data.get_all_edge_attrs() assert len(edge_attrs) == 3 - - # Get num nodes: - assert data.num_src_nodes(edge_type=('a', 'to', 'b')) == 3 - assert data.num_dst_nodes(edge_type=('a', 'to', 'c')) == 3 diff --git a/test/data/test_remote_backend_utils.py b/test/data/test_remote_backend_utils.py new file mode 100644 index 000000000000..a7f217d102ec --- /dev/null +++ b/test/data/test_remote_backend_utils.py @@ -0,0 +1,36 @@ +import pytest +import torch + +from torch_geometric.data import HeteroData +from torch_geometric.data.remote_backend_utils import num_nodes, size +from torch_geometric.testing.feature_store import MyFeatureStore +from torch_geometric.testing.graph_store import MyGraphStore + + +def get_edge_index(num_src_nodes, num_dst_nodes, num_edges): + row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long) + col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long) + return torch.stack([row, col], dim=0) + + +@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData]) +@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData]) +def test_num_nodes_size(FeatureStore, GraphStore): + feature_store = FeatureStore() + graph_store = GraphStore() + + # Infer num nodes from features: + x = torch.arange(100) + feature_store.put_tensor(x, group_name='x', attr_name='x', index=None) + assert num_nodes(feature_store, graph_store, 'x') == 100 + + # Infer num nodes and size from edges: + xy = get_edge_index(100, 50, 20) + graph_store.put_edge_index(xy, edge_type=('x', 'to', 'y'), layout='coo', + size=(100, 50)) + assert num_nodes(feature_store, graph_store, 'y') == 50 + assert size(feature_store, graph_store, ('x', 'to', 'y')) == (100, 50) + + # Throw an error if we cannot infer for an unknown node type: + with pytest.raises(ValueError, match="Unable to accurately infer"): + _ = num_nodes(feature_store, graph_store, 'z') diff --git a/test/loader/test_neighbor_loader.py b/test/loader/test_neighbor_loader.py index dca87b3fbaf2..0a4cb0a099d7 100644 --- a/test/loader/test_neighbor_loader.py +++ b/test/loader/test_neighbor_loader.py @@ -4,7 +4,6 @@ from torch_sparse import SparseTensor from torch_geometric.data import Data, HeteroData -from torch_geometric.data.feature_store import TensorAttr from torch_geometric.loader import NeighborLoader from torch_geometric.nn import GraphConv, to_hetero from torch_geometric.testing import withPackage @@ -397,7 +396,7 @@ def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore, loader2 = NeighborLoader( (feature_store, graph_store), num_neighbors=[-1, -1], - input_nodes=TensorAttr(group_name='paper', attr_name='x'), + input_nodes='paper', time_attr='time', batch_size=128, ) diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 1861cc523058..777bed2f3b30 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -893,12 +893,6 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]: return edge_attrs - def _num_src_nodes(self, edge_attr: EdgeAttr) -> int: - return self.num_nodes - - def _num_dst_nodes(self, edge_attr: EdgeAttr) -> int: - return self.num_nodes - ############################################################################### diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index 0a8768f3ff75..dd2a283eee8e 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -29,7 +29,7 @@ import numpy as np import torch -from torch_geometric.typing import FeatureTensorType +from torch_geometric.typing import FeatureTensorType, NodeType from torch_geometric.utils.mixin import CastMixin _field_status = Enum("FieldStatus", "UNSET") @@ -52,7 +52,7 @@ class TensorAttr(CastMixin): """ # The group name that the tensor corresponds to. Defaults to UNSET. - group_name: Optional[str] = _field_status.UNSET + group_name: Optional[NodeType] = _field_status.UNSET # The name of the tensor within its group. Defaults to UNSET. attr_name: Optional[str] = _field_status.UNSET diff --git a/torch_geometric/data/graph_store.py b/torch_geometric/data/graph_store.py index aaeb30f00845..8d015278f670 100644 --- a/torch_geometric/data/graph_store.py +++ b/torch_geometric/data/graph_store.py @@ -32,7 +32,7 @@ from torch import Tensor from torch_sparse import SparseTensor -from torch_geometric.typing import Adj, EdgeTensorType, OptTensor +from torch_geometric.typing import Adj, EdgeTensorType, EdgeType, OptTensor from torch_geometric.utils.mixin import CastMixin # The output of converting between two types in the GraphStore is a Tuple of @@ -59,7 +59,7 @@ class EdgeAttr(CastMixin): r"""Defines the attributes of an :obj:`GraphStore` edge.""" # The type of the edge - edge_type: Optional[Any] + edge_type: Optional[EdgeType] # The layout of the edge representation layout: Optional[EdgeLayout] = None @@ -154,22 +154,6 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType: f"found") return edge_index - @abstractmethod - def _num_src_nodes(self, edge_attr: EdgeAttr) -> int: - pass - - def num_src_nodes(self, *args, **kwargs) -> int: - edge_attr = self._edge_attr_cls.cast(*args, **kwargs) - return self._num_src_nodes(edge_attr) - - @abstractmethod - def _num_dst_nodes(self, edge_attr: EdgeAttr) -> int: - pass - - def num_dst_nodes(self, *args, **kwargs) -> int: - edge_attr = self._edge_attr_cls.cast(*args, **kwargs) - return self._num_dst_nodes(edge_attr) - # Layout Conversion ####################################################### def _edge_to_layout( @@ -347,6 +331,9 @@ def __getitem__(self, key: EdgeAttr) -> Optional[EdgeTensorType]: key = self._edge_attr_cls.cast(key) return self.get_edge_index(key) + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + # Data and HeteroData utilities ############################################### diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index 944495849ab5..9b023a3f6b2a 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -846,14 +846,6 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]: return out - def _num_src_nodes(self, edge_attr: EdgeAttr) -> int: - src, _, _ = self._to_canonical(edge_attr.edge_type) - return self[src].num_nodes - - def _num_dst_nodes(self, edge_attr: EdgeAttr) -> int: - _, _, dst = self._to_canonical(edge_attr.edge_type) - return self[dst].num_nodes - # Helper functions ############################################################ diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 71f3bc96a740..83730c828124 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -496,8 +496,6 @@ def __init__( time_attr=kwargs.get('time_attr', None), is_sorted=kwargs.get('is_sorted', False), neg_sampling_ratio=kwargs.get('neg_sampling_ratio', 0.0), - num_src_nodes=kwargs.get('num_src_nodes', None), - num_dst_nodes=kwargs.get('num_dst_nodes', None), share_memory=num_workers > 0, ) diff --git a/torch_geometric/data/remote_backend_utils.py b/torch_geometric/data/remote_backend_utils.py new file mode 100644 index 000000000000..1955087f1083 --- /dev/null +++ b/torch_geometric/data/remote_backend_utils.py @@ -0,0 +1,97 @@ +# This file defines a set of utilities for remote backends (backends that are +# characterize as Tuple[FeatureStore, GraphStore]). +from typing import Tuple, Union + +from torch_geometric.data.feature_store import FeatureStore +from torch_geometric.data.graph_store import GraphStore +from torch_geometric.typing import EdgeType, NodeType + + +# NOTE PyG also supports querying by a relation type `rel` in an edge type +# (src, rel, dst). It may be worth supporting this in remote backends as well. +def _internal_num_nodes( + feature_store: FeatureStore, + graph_store: GraphStore, + query: Union[NodeType, EdgeType], +) -> Union[int, Tuple[int, int]]: + r"""Returns the number of nodes in the node type or the number of source + and destination nodes in an edge type by sequentially accessing attributes + in the feature and graph stores that reveal this number.""" + def _matches_edge_type(query: Union[NodeType, EdgeType], + edge_type: EdgeType) -> bool: + if isinstance(query, (list, tuple)): # EdgeType + return query == edge_type + else: + return query == edge_type[0] or query == edge_type[-1] + + def _matches_node_type(query: Union[NodeType, EdgeType], + node_type: NodeType) -> bool: + if isinstance(query, (list, tuple)): # EdgeType + return query[0] == node_type or query[-1] == node_type + else: + return query == node_type + + node_query = isinstance(query, NodeType) + + # TODO: In general, a feature store and graph store should be able to + # expose methods that allow for easy access to individual attributes, + # instead of requiring iteration to identify a particular attribute. + # Implementing this should reduce the iteration below. + + # 1. Check GraphStore: + edge_attrs = graph_store.get_all_edge_attrs() + for edge_attr in edge_attrs: + if (_matches_edge_type(query, edge_attr.edge_type) + and edge_attr.size is not None): + if node_query: + return edge_attr.size[0] if query == edge_attr.edge_type[ + 0] else edge_attr.size[1] + else: + return edge_attr.size + + # 2. Check FeatureStore: + tensor_attrs = feature_store.get_all_tensor_attrs() + matching_attrs = [ + attr for attr in tensor_attrs + if _matches_node_type(query, attr.group_name) + ] + if node_query: + if len(matching_attrs) > 0: + return feature_store.get_tensor_size(matching_attrs[0])[0] + else: + matching_src_attrs = [ + attr for attr in matching_attrs if attr.group_name == query[0] + ] + matching_dst_attrs = [ + attr for attr in matching_attrs if attr.group_name == query[-1] + ] + if len(matching_src_attrs) > 0 and len(matching_dst_attrs) > 0: + return (feature_store.get_tensor_size(matching_src_attrs[0])[0], + feature_store.get_tensor_size(matching_dst_attrs[0])[0]) + + raise ValueError( + f"Unable to accurately infer the number of nodes corresponding to " + f"query {query} from feature store {feature_store} and graph store " + f"{graph_store}. Please consider either adding an edge containing " + f"the nodes in this query or feature tensors for the nodes in this " + f"query.") + + +def num_nodes( + feature_store: FeatureStore, + graph_store: GraphStore, + query: NodeType, +) -> int: + r"""Returns the number of nodes in a given node type stored in a remote + backend.""" + return _internal_num_nodes(feature_store, graph_store, query) + + +def size( + feature_store: FeatureStore, + graph_store: GraphStore, + query: EdgeType, +) -> Tuple[int, int]: + r"""Returns the size of an edge (number of source nodes, number of + destination nodes) in an edge stored in a remote backend.""" + return _internal_num_nodes(feature_store, graph_store, query) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 975428af03f2..0cb0817d1120 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -5,7 +5,7 @@ from torch import Tensor from torch_scatter import scatter_min -from torch_geometric.data import Data, HeteroData +from torch_geometric.data import Data, HeteroData, remote_backend_utils from torch_geometric.data.feature_store import FeatureStore from torch_geometric.data.graph_store import GraphStore from torch_geometric.loader.base import DataLoaderIterator @@ -24,8 +24,6 @@ def __init__( data, *args, neg_sampling_ratio: float = 0.0, - num_src_nodes: Optional[int] = None, - num_dst_nodes: Optional[int] = None, **kwargs, ): super().__init__(data, *args, **kwargs) @@ -34,37 +32,16 @@ def __init__( # TODO if self.edge_time is not None and # `src` or `dst` nodes don't have time attribute # i.e node_time_dict[input_type[0/-1]] doesn't exist - # set it to largest representabel torch.long. - self.num_src_nodes = num_src_nodes - self.num_dst_nodes = num_dst_nodes - - if self.num_src_nodes is None or self.num_dst_nodes is None: - if self.data_cls == 'custom': - _, graph_store = data - edge_attrs = graph_store.get_all_edge_attrs() - edge_types = [attr.edge_type for attr in edge_attrs] - - # Edge label index is part of the graph: - if self.input_type in edge_types: - self.num_src_nodes = graph_store.num_src_nodes( - edge_type=self.input_type) - self.num_dst_nodes = graph_store.num_dst_nodes( - edge_type=self.input_type) - else: - # We do not support querying the number of nodes by the - # feature store, so we throw an error here: - raise ValueError( - f"Use of a remote backend with " - f"{self.__class__.__name__} requires the " - f"specification of source and destination nodes, as " - f"the edge label index {self.input_type} is not part " - f"of the specified graph.") - - elif issubclass(self.data_cls, Data): - self.num_src_nodes = self.num_dst_nodes = data.num_nodes - else: # issubclass(self.data_cls, HeteroData): - self.num_src_nodes = data[self.input_type[0]].num_nodes - self.num_dst_nodes = data[self.input_type[-1]].num_nodes + # set it to largest representable torch.long. + if self.data_cls == 'custom': + self.num_src_nodes, self.num_dst_nodes = \ + remote_backend_utils.num_nodes(*data, self.input_type) + + elif issubclass(self.data_cls, Data): + self.num_src_nodes = self.num_dst_nodes = data.num_nodes + else: # issubclass(self.data_cls, HeteroData): + self.num_src_nodes = data[self.input_type[0]].num_nodes + self.num_dst_nodes = data[self.input_type[-1]].num_nodes def _add_negative_samples(self, edge_label_index, edge_label, edge_label_time): @@ -249,10 +226,6 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): constraints, *i.e.*, neighbors have an earlier timestamp than the ouput edge. The :obj:`time_attr` needs to be set for this to work. (default: :obj:`None`) - num_src_nodes (optional, int): The number of source nodes in the edge - label index. Inferred if not provided. - num_dst_nodes (optional, int): The number of destination nodes in the - edge label index. Inferred if not provided. replace (bool, optional): If set to :obj:`True`, will sample with replacement. (default: :obj:`False`) directed (bool, optional): If set to :obj:`False`, will include all @@ -302,8 +275,6 @@ def __init__( edge_label_index: InputEdges = None, edge_label: OptTensor = None, edge_label_time: OptTensor = None, - num_src_nodes: Optional[int] = None, - num_dst_nodes: Optional[int] = None, replace: bool = False, directed: bool = True, neg_sampling_ratio: float = 0.0, @@ -354,8 +325,6 @@ def __init__( input_type=edge_type, is_sorted=is_sorted, neg_sampling_ratio=self.neg_sampling_ratio, - num_src_nodes=num_src_nodes, - num_dst_nodes=num_dst_nodes, time_attr=time_attr, share_memory=kwargs.get('num_workers', 0) > 0, ) diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index 849045f21840..41b6e282cf9d 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -4,7 +4,7 @@ import torch from torch import Tensor -from torch_geometric.data import Data, HeteroData +from torch_geometric.data import Data, HeteroData, remote_backend_utils from torch_geometric.data.feature_store import FeatureStore, TensorAttr from torch_geometric.data.graph_store import GraphStore from torch_geometric.loader.base import DataLoaderIterator @@ -473,24 +473,16 @@ def to_index(tensor): return node_type, to_index(input_nodes) else: # Tuple[FeatureStore, GraphStore] - # NOTE FeatureStore and GraphStore are treated as separate - # entities, so we cannot leverage the custom structure in Data and - # HeteroData to infer the number of nodes. As a result, here we expect - # that the input nodes are either explicitly provided or can be - # directly inferred from the feature store. - feature_store, _ = data - + feature_store, graph_store = data assert input_nodes is not None if isinstance(input_nodes, Tensor): return None, to_index(input_nodes) - # Can't infer number of nodes from a group_name; need an attr_name if isinstance(input_nodes, str): - raise NotImplementedError( - f"Cannot infer the number of nodes from a single string " - f"(got '{input_nodes}'). Please pass a more explicit " - f"representation. ") + return input_nodes, range( + remote_backend_utils.num_nodes(feature_store, graph_store, + input_nodes)) if isinstance(input_nodes, (list, tuple)): assert len(input_nodes) == 2 @@ -498,17 +490,7 @@ def to_index(tensor): node_type, input_nodes = input_nodes if input_nodes is None: - raise NotImplementedError( - f"Cannot infer the number of nodes from a node type alone " - f"(got '{input_nodes}'). Please pass a more explicit " - f"representation. ") + return node_type, range( + remote_backend_utils.num_nodes(feature_store, graph_store, + input_nodes)) return node_type, to_index(input_nodes) - - assert isinstance(input_nodes, TensorAttr) - assert input_nodes.is_set('attr_name') - - node_type = getattr(input_nodes, 'group_name', None) - if not input_nodes.is_set('index') or input_nodes.index is None: - num_nodes = feature_store.get_tensor_size(input_nodes)[0] - return node_type, range(num_nodes) - return node_type, input_nodes.index diff --git a/torch_geometric/testing/graph_store.py b/torch_geometric/testing/graph_store.py index 221c6a769a06..ab49938014e5 100644 --- a/torch_geometric/testing/graph_store.py +++ b/torch_geometric/testing/graph_store.py @@ -27,13 +27,3 @@ def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: def get_all_edge_attrs(self): return [EdgeAttr(*key) for key in self.store] - - def _num_src_nodes(self, edge_attr: EdgeAttr) -> int: - for k in self.store: - if k[0] == edge_attr.edge_type: - return k[3][0] - - def _num_dst_nodes(self, edge_attr: EdgeAttr) -> int: - for k in self.store: - if k[0] == edge_attr.edge_type: - return k[3][1]