Skip to content

Commit

Permalink
refactor(data): simplify remote backend num_nodes computation (#5307)
Browse files Browse the repository at this point in the history
  • Loading branch information
mananshah99 authored Aug 30, 2022
1 parent 96fbf43 commit be471ee
Show file tree
Hide file tree
Showing 15 changed files with 161 additions and 133 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815), [#4862](https://github.com/pyg-team/pytorch_geometric/pull/4862/files))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037), [#5088](https://github.com/pyg-team/pytorch_geometric/pull/5088), [#5270](https://github.com/pyg-team/pytorch_geometric/pull/5270))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037), [#5088](https://github.com/pyg-team/pytorch_geometric/pull/5088), [#5270](https://github.com/pyg-team/pytorch_geometric/pull/5270), [#5307](https://github.com/pyg-team/pytorch_geometric/pull/5307))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
Expand Down
4 changes: 0 additions & 4 deletions test/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,3 @@ def assert_equal_tensor_tuple(expected, actual):
# Get attrs:
edge_attrs = data.get_all_edge_attrs()
assert len(edge_attrs) == 3

# Get num nodes:
assert data.num_src_nodes() == 3
assert data.num_dst_nodes() == 3
8 changes: 0 additions & 8 deletions test/data/test_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,10 @@ def test_graph_store_conversion():
# Put all edge indices:
graph_store.put_edge_index(edge_index=coo, edge_type=('v', '1', 'v'),
layout='coo', size=(100, 100), is_sorted=True)
assert graph_store.num_src_nodes(edge_type=('v', '1', 'v')) == 100
assert graph_store.num_dst_nodes(edge_type=('v', '1', 'v')) == 100

graph_store.put_edge_index(edge_index=csr, edge_type=('v', '2', 'v'),
layout='csr', size=(100, 100))
assert graph_store.num_src_nodes(edge_type=('v', '2', 'v')) == 100
assert graph_store.num_dst_nodes(edge_type=('v', '2', 'v')) == 100

graph_store.put_edge_index(edge_index=csc, edge_type=('v', '3', 'v'),
layout='csc', size=(100, 100))
assert graph_store.num_src_nodes(edge_type=('v', '3', 'v')) == 100
assert graph_store.num_dst_nodes(edge_type=('v', '3', 'v')) == 100

def assert_edge_index_equal(expected: torch.Tensor, actual: torch.Tensor):
assert torch.equal(sort_edge_index(expected), sort_edge_index(actual))
Expand Down
4 changes: 0 additions & 4 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,3 @@ def assert_equal_tensor_tuple(expected, actual):
# Get attrs:
edge_attrs = data.get_all_edge_attrs()
assert len(edge_attrs) == 3

# Get num nodes:
assert data.num_src_nodes(edge_type=('a', 'to', 'b')) == 3
assert data.num_dst_nodes(edge_type=('a', 'to', 'c')) == 3
36 changes: 36 additions & 0 deletions test/data/test_remote_backend_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
import torch

from torch_geometric.data import HeteroData
from torch_geometric.data.remote_backend_utils import num_nodes, size
from torch_geometric.testing.feature_store import MyFeatureStore
from torch_geometric.testing.graph_store import MyGraphStore


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long)
return torch.stack([row, col], dim=0)


@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData])
@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData])
def test_num_nodes_size(FeatureStore, GraphStore):
feature_store = FeatureStore()
graph_store = GraphStore()

# Infer num nodes from features:
x = torch.arange(100)
feature_store.put_tensor(x, group_name='x', attr_name='x', index=None)
assert num_nodes(feature_store, graph_store, 'x') == 100

# Infer num nodes and size from edges:
xy = get_edge_index(100, 50, 20)
graph_store.put_edge_index(xy, edge_type=('x', 'to', 'y'), layout='coo',
size=(100, 50))
assert num_nodes(feature_store, graph_store, 'y') == 50
assert size(feature_store, graph_store, ('x', 'to', 'y')) == (100, 50)

# Throw an error if we cannot infer for an unknown node type:
with pytest.raises(ValueError, match="Unable to accurately infer"):
_ = num_nodes(feature_store, graph_store, 'z')
3 changes: 1 addition & 2 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch_sparse import SparseTensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.feature_store import TensorAttr
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GraphConv, to_hetero
from torch_geometric.testing import withPackage
Expand Down Expand Up @@ -397,7 +396,7 @@ def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore,
loader2 = NeighborLoader(
(feature_store, graph_store),
num_neighbors=[-1, -1],
input_nodes=TensorAttr(group_name='paper', attr_name='x'),
input_nodes='paper',
time_attr='time',
batch_size=128,
)
Expand Down
6 changes: 0 additions & 6 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,12 +893,6 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]:

return edge_attrs

def _num_src_nodes(self, edge_attr: EdgeAttr) -> int:
return self.num_nodes

def _num_dst_nodes(self, edge_attr: EdgeAttr) -> int:
return self.num_nodes


###############################################################################

Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import numpy as np
import torch

from torch_geometric.typing import FeatureTensorType
from torch_geometric.typing import FeatureTensorType, NodeType
from torch_geometric.utils.mixin import CastMixin

_field_status = Enum("FieldStatus", "UNSET")
Expand All @@ -52,7 +52,7 @@ class TensorAttr(CastMixin):
"""

# The group name that the tensor corresponds to. Defaults to UNSET.
group_name: Optional[str] = _field_status.UNSET
group_name: Optional[NodeType] = _field_status.UNSET

# The name of the tensor within its group. Defaults to UNSET.
attr_name: Optional[str] = _field_status.UNSET
Expand Down
23 changes: 5 additions & 18 deletions torch_geometric/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.typing import Adj, EdgeTensorType, OptTensor
from torch_geometric.typing import Adj, EdgeTensorType, EdgeType, OptTensor
from torch_geometric.utils.mixin import CastMixin

# The output of converting between two types in the GraphStore is a Tuple of
Expand All @@ -59,7 +59,7 @@ class EdgeAttr(CastMixin):
r"""Defines the attributes of an :obj:`GraphStore` edge."""

# The type of the edge
edge_type: Optional[Any]
edge_type: Optional[EdgeType]

# The layout of the edge representation
layout: Optional[EdgeLayout] = None
Expand Down Expand Up @@ -154,22 +154,6 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:
f"found")
return edge_index

@abstractmethod
def _num_src_nodes(self, edge_attr: EdgeAttr) -> int:
pass

def num_src_nodes(self, *args, **kwargs) -> int:
edge_attr = self._edge_attr_cls.cast(*args, **kwargs)
return self._num_src_nodes(edge_attr)

@abstractmethod
def _num_dst_nodes(self, edge_attr: EdgeAttr) -> int:
pass

def num_dst_nodes(self, *args, **kwargs) -> int:
edge_attr = self._edge_attr_cls.cast(*args, **kwargs)
return self._num_dst_nodes(edge_attr)

# Layout Conversion #######################################################

def _edge_to_layout(
Expand Down Expand Up @@ -347,6 +331,9 @@ def __getitem__(self, key: EdgeAttr) -> Optional[EdgeTensorType]:
key = self._edge_attr_cls.cast(key)
return self.get_edge_index(key)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'


# Data and HeteroData utilities ###############################################

Expand Down
8 changes: 0 additions & 8 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,14 +846,6 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]:

return out

def _num_src_nodes(self, edge_attr: EdgeAttr) -> int:
src, _, _ = self._to_canonical(edge_attr.edge_type)
return self[src].num_nodes

def _num_dst_nodes(self, edge_attr: EdgeAttr) -> int:
_, _, dst = self._to_canonical(edge_attr.edge_type)
return self[dst].num_nodes


# Helper functions ############################################################

Expand Down
2 changes: 0 additions & 2 deletions torch_geometric/data/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,6 @@ def __init__(
time_attr=kwargs.get('time_attr', None),
is_sorted=kwargs.get('is_sorted', False),
neg_sampling_ratio=kwargs.get('neg_sampling_ratio', 0.0),
num_src_nodes=kwargs.get('num_src_nodes', None),
num_dst_nodes=kwargs.get('num_dst_nodes', None),
share_memory=num_workers > 0,
)

Expand Down
97 changes: 97 additions & 0 deletions torch_geometric/data/remote_backend_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# This file defines a set of utilities for remote backends (backends that are
# characterize as Tuple[FeatureStore, GraphStore]).
from typing import Tuple, Union

from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.typing import EdgeType, NodeType


# NOTE PyG also supports querying by a relation type `rel` in an edge type
# (src, rel, dst). It may be worth supporting this in remote backends as well.
def _internal_num_nodes(
feature_store: FeatureStore,
graph_store: GraphStore,
query: Union[NodeType, EdgeType],
) -> Union[int, Tuple[int, int]]:
r"""Returns the number of nodes in the node type or the number of source
and destination nodes in an edge type by sequentially accessing attributes
in the feature and graph stores that reveal this number."""
def _matches_edge_type(query: Union[NodeType, EdgeType],
edge_type: EdgeType) -> bool:
if isinstance(query, (list, tuple)): # EdgeType
return query == edge_type
else:
return query == edge_type[0] or query == edge_type[-1]

def _matches_node_type(query: Union[NodeType, EdgeType],
node_type: NodeType) -> bool:
if isinstance(query, (list, tuple)): # EdgeType
return query[0] == node_type or query[-1] == node_type
else:
return query == node_type

node_query = isinstance(query, NodeType)

# TODO: In general, a feature store and graph store should be able to
# expose methods that allow for easy access to individual attributes,
# instead of requiring iteration to identify a particular attribute.
# Implementing this should reduce the iteration below.

# 1. Check GraphStore:
edge_attrs = graph_store.get_all_edge_attrs()
for edge_attr in edge_attrs:
if (_matches_edge_type(query, edge_attr.edge_type)
and edge_attr.size is not None):
if node_query:
return edge_attr.size[0] if query == edge_attr.edge_type[
0] else edge_attr.size[1]
else:
return edge_attr.size

# 2. Check FeatureStore:
tensor_attrs = feature_store.get_all_tensor_attrs()
matching_attrs = [
attr for attr in tensor_attrs
if _matches_node_type(query, attr.group_name)
]
if node_query:
if len(matching_attrs) > 0:
return feature_store.get_tensor_size(matching_attrs[0])[0]
else:
matching_src_attrs = [
attr for attr in matching_attrs if attr.group_name == query[0]
]
matching_dst_attrs = [
attr for attr in matching_attrs if attr.group_name == query[-1]
]
if len(matching_src_attrs) > 0 and len(matching_dst_attrs) > 0:
return (feature_store.get_tensor_size(matching_src_attrs[0])[0],
feature_store.get_tensor_size(matching_dst_attrs[0])[0])

raise ValueError(
f"Unable to accurately infer the number of nodes corresponding to "
f"query {query} from feature store {feature_store} and graph store "
f"{graph_store}. Please consider either adding an edge containing "
f"the nodes in this query or feature tensors for the nodes in this "
f"query.")


def num_nodes(
feature_store: FeatureStore,
graph_store: GraphStore,
query: NodeType,
) -> int:
r"""Returns the number of nodes in a given node type stored in a remote
backend."""
return _internal_num_nodes(feature_store, graph_store, query)


def size(
feature_store: FeatureStore,
graph_store: GraphStore,
query: EdgeType,
) -> Tuple[int, int]:
r"""Returns the size of an edge (number of source nodes, number of
destination nodes) in an edge stored in a remote backend."""
return _internal_num_nodes(feature_store, graph_store, query)
53 changes: 11 additions & 42 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor
from torch_scatter import scatter_min

from torch_geometric.data import Data, HeteroData
from torch_geometric.data import Data, HeteroData, remote_backend_utils
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.loader.base import DataLoaderIterator
Expand All @@ -24,8 +24,6 @@ def __init__(
data,
*args,
neg_sampling_ratio: float = 0.0,
num_src_nodes: Optional[int] = None,
num_dst_nodes: Optional[int] = None,
**kwargs,
):
super().__init__(data, *args, **kwargs)
Expand All @@ -34,37 +32,16 @@ def __init__(
# TODO if self.edge_time is not None and
# `src` or `dst` nodes don't have time attribute
# i.e node_time_dict[input_type[0/-1]] doesn't exist
# set it to largest representabel torch.long.
self.num_src_nodes = num_src_nodes
self.num_dst_nodes = num_dst_nodes

if self.num_src_nodes is None or self.num_dst_nodes is None:
if self.data_cls == 'custom':
_, graph_store = data
edge_attrs = graph_store.get_all_edge_attrs()
edge_types = [attr.edge_type for attr in edge_attrs]

# Edge label index is part of the graph:
if self.input_type in edge_types:
self.num_src_nodes = graph_store.num_src_nodes(
edge_type=self.input_type)
self.num_dst_nodes = graph_store.num_dst_nodes(
edge_type=self.input_type)
else:
# We do not support querying the number of nodes by the
# feature store, so we throw an error here:
raise ValueError(
f"Use of a remote backend with "
f"{self.__class__.__name__} requires the "
f"specification of source and destination nodes, as "
f"the edge label index {self.input_type} is not part "
f"of the specified graph.")

elif issubclass(self.data_cls, Data):
self.num_src_nodes = self.num_dst_nodes = data.num_nodes
else: # issubclass(self.data_cls, HeteroData):
self.num_src_nodes = data[self.input_type[0]].num_nodes
self.num_dst_nodes = data[self.input_type[-1]].num_nodes
# set it to largest representable torch.long.
if self.data_cls == 'custom':
self.num_src_nodes, self.num_dst_nodes = \
remote_backend_utils.num_nodes(*data, self.input_type)

elif issubclass(self.data_cls, Data):
self.num_src_nodes = self.num_dst_nodes = data.num_nodes
else: # issubclass(self.data_cls, HeteroData):
self.num_src_nodes = data[self.input_type[0]].num_nodes
self.num_dst_nodes = data[self.input_type[-1]].num_nodes

def _add_negative_samples(self, edge_label_index, edge_label,
edge_label_time):
Expand Down Expand Up @@ -249,10 +226,6 @@ class LinkNeighborLoader(torch.utils.data.DataLoader):
constraints, *i.e.*, neighbors have an earlier timestamp than
the ouput edge. The :obj:`time_attr` needs to be set for this
to work. (default: :obj:`None`)
num_src_nodes (optional, int): The number of source nodes in the edge
label index. Inferred if not provided.
num_dst_nodes (optional, int): The number of destination nodes in the
edge label index. Inferred if not provided.
replace (bool, optional): If set to :obj:`True`, will sample with
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
Expand Down Expand Up @@ -302,8 +275,6 @@ def __init__(
edge_label_index: InputEdges = None,
edge_label: OptTensor = None,
edge_label_time: OptTensor = None,
num_src_nodes: Optional[int] = None,
num_dst_nodes: Optional[int] = None,
replace: bool = False,
directed: bool = True,
neg_sampling_ratio: float = 0.0,
Expand Down Expand Up @@ -354,8 +325,6 @@ def __init__(
input_type=edge_type,
is_sorted=is_sorted,
neg_sampling_ratio=self.neg_sampling_ratio,
num_src_nodes=num_src_nodes,
num_dst_nodes=num_dst_nodes,
time_attr=time_attr,
share_memory=kwargs.get('num_workers', 0) > 0,
)
Expand Down
Loading

0 comments on commit be471ee

Please sign in to comment.