Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(sampler): consolidate sampling interface, part 1 #5312

Merged
merged 30 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0b976ee
init
mananshah99 Aug 29, 2022
015743e
Merge branch 'master' of github.com:pyg-team/pytorch_geometric into r…
mananshah99 Aug 29, 2022
b111754
update
mananshah99 Aug 29, 2022
9a36b3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2022
f23586b
update
mananshah99 Aug 29, 2022
5c35cd6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2022
6545ced
update
mananshah99 Aug 29, 2022
aebf0d5
more cleanup
mananshah99 Aug 29, 2022
6741272
update
mananshah99 Aug 29, 2022
aacc361
fix
mananshah99 Aug 29, 2022
4b8d88f
init
mananshah99 Aug 30, 2022
1b0bfc8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 30, 2022
d7ec97c
merge
mananshah99 Aug 30, 2022
5a88c3d
Merge branch 'remote_backend_2' of github.com:pyg-team/pytorch_geomet…
mananshah99 Aug 30, 2022
71b76e3
rm
mananshah99 Aug 30, 2022
dfa34dc
update
mananshah99 Sep 6, 2022
0058f7f
minor
mananshah99 Sep 6, 2022
fc4794f
merge
mananshah99 Sep 6, 2022
d7d1e5e
udpate
mananshah99 Sep 6, 2022
46aacc0
update
mananshah99 Sep 6, 2022
cf79d7b
merge
mananshah99 Sep 7, 2022
d849e63
update
mananshah99 Sep 7, 2022
16bdf4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
97c16ff
Merge branch 'master' into remote_backend_2
mananshah99 Sep 7, 2022
f22ba98
flake8
mananshah99 Sep 7, 2022
4778346
hgt
mananshah99 Sep 7, 2022
ac4a319
update
mananshah99 Sep 7, 2022
e35f499
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2022
884771d
update
mananshah99 Sep 7, 2022
873f403
Merge branch 'remote_backend_2' of github.com:pyg-team/pytorch_geomet…
mananshah99 Sep 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Consolidated sampler routines behind `torch_geometric.sampler`, enabling ease of extensibility in the future ([#5312](https://github.com/pyg-team/pytorch_geometric/pull/5312))
- Enabled `bf16` support in benchmark scripts ([#5293](https://github.com/pyg-team/pytorch_geometric/pull/5293))
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
- Added `pyg_lib.segment_matmul` integration within `HeteroLinear` ([#5330](https://github.com/pyg-team/pytorch_geometric/pull/5330), [#5347](https://github.com/pyg-team/pytorch_geometric/pull/5347)))
- Enabled `bf16` support in benchmark scripts ([#5293](https://github.com/pyg-team/pytorch_geometric/pull/5293), [#5341](https://github.com/pyg-team/pytorch_geometric/pull/5341))
- Added `Aggregation.set_validate_args` option to skip validation of `dim_size` ([#5290](https://github.com/pyg-team/pytorch_geometric/pull/5290))
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def to_csc(
f"Edge {edge_attr.edge_type} cannot be converted "
f"to a different type without specifying 'size' for "
f"the source and destination node types (got {size}). "
f"Please specify these parameters for successful execution. ")
f"Please specify these parameters for successful execution.")
(row, col) = adj
if not is_sorted:
perm = (col * size[0]).add_(row).argsort()
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/loader/hgt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from torch_geometric.data import HeteroData
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.utils import filter_hetero_data, to_hetero_csc
from torch_geometric.loader.utils import filter_hetero_data
from torch_geometric.sampler.utils import to_hetero_csc
from torch_geometric.typing import NodeType


Expand Down
8 changes: 7 additions & 1 deletion torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.neighbor_loader import NeighborSampler
from torch_geometric.loader.utils import (
filter_custom_store,
filter_data,
filter_hetero_data,
)
from torch_geometric.sampler import NeighborSampler
from torch_geometric.typing import InputEdges, NumNeighbors, OptTensor


# TODO(manan) clean this up, align with NeighborSampler interface and
# implementation:
class LinkNeighborSampler(NeighborSampler):
def __init__(
self,
Expand Down Expand Up @@ -93,6 +95,10 @@ def update_time_(node_time_dict, index, node_type, num_nodes):
self.num_dst_nodes)
return node_time_dict

def sample(self, index):
# TODO(manan): remove after proper integration with interface
pass

def __call__(self, query: List[Tuple[Tensor]]):
query = [torch.stack(s, dim=0) for s in zip(*query)]
edge_label_index = torch.stack(query[:2], dim=0)
Expand Down
199 changes: 3 additions & 196 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -9,208 +9,14 @@
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.utils import (
edge_type_to_str,
filter_custom_store,
filter_data,
filter_hetero_data,
to_csc,
to_hetero_csc,
)
from torch_geometric.sampler import NeighborSampler
from torch_geometric.typing import InputNodes, NumNeighbors


class NeighborSampler:
def __init__(
self,
data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
num_neighbors: NumNeighbors,
replace: bool = False,
directed: bool = True,
input_type: Optional[Any] = None,
time_attr: Optional[str] = None,
is_sorted: bool = False,
share_memory: bool = False,
):
self.data_cls = data.__class__ if isinstance(
data, (Data, HeteroData)) else 'custom'
self.num_neighbors = num_neighbors
self.replace = replace
self.directed = directed
self.node_time = None

# TODO Unify the following conditionals behind the `FeatureStore`
# and `GraphStore` API

# If we are working with a `Data` object, convert the edge_index to
# CSC and store it:
if isinstance(data, Data):
if time_attr is not None:
# TODO `time_attr` support for homogeneous graphs
raise ValueError(
f"'time_attr' attribute not yet supported for "
f"'{data.__class__.__name__}' object")

# Convert the graph data into a suitable format for sampling.
out = to_csc(data, device='cpu', share_memory=share_memory,
is_sorted=is_sorted)
self.colptr, self.row, self.perm = out
assert isinstance(num_neighbors, (list, tuple))

# If we are working with a `HeteroData` object, convert each edge
# type's edge_index to CSC and store it:
elif isinstance(data, HeteroData):
if time_attr is not None:
self.node_time_dict = data.collect(time_attr)
else:
self.node_time_dict = None

# Convert the graph data into a suitable format for sampling.
# NOTE: Since C++ cannot take dictionaries with tuples as key as
# input, edge type triplets are converted into single strings.
out = to_hetero_csc(data, device='cpu', share_memory=share_memory,
is_sorted=is_sorted)
self.colptr_dict, self.row_dict, self.perm_dict = out

self.node_types, self.edge_types = data.metadata()
self._set_num_neighbors_and_num_hops(num_neighbors)

assert input_type is not None
self.input_type = input_type

# If we are working with a `Tuple[FeatureStore, GraphStore]` object,
# obtain edges from GraphStore and convert them to CSC if necessary,
# storing the resulting representations:
elif isinstance(data, tuple):
# TODO support `FeatureStore` with no edge types (e.g. `Data`)
feature_store, graph_store = data

# TODO support `collect` on `FeatureStore`
self.node_time_dict = None
if time_attr is not None:
# We need to obtain all features with 'attr_name=time_attr'
# from the feature store and store them in node_time_dict. To
# do so, we make an explicit feature store GET call here with
# the relevant 'TensorAttr's
time_attrs = [
attr for attr in feature_store.get_all_tensor_attrs()
if attr.attr_name == time_attr
]
for attr in time_attrs:
attr.index = None
time_tensors = feature_store.multi_get_tensor(time_attrs)
self.node_time_dict = {
time_attr.group_name: time_tensor
for time_attr, time_tensor in zip(time_attrs, time_tensors)
}

# Obtain all node and edge metadata:
node_attrs = feature_store.get_all_tensor_attrs()
edge_attrs = graph_store.get_all_edge_attrs()

self.node_types = list(
set(node_attr.group_name for node_attr in node_attrs))
self.edge_types = list(
set(edge_attr.edge_type for edge_attr in edge_attrs))

# Set other required parameters:
self._set_num_neighbors_and_num_hops(num_neighbors)

assert input_type is not None
self.input_type = input_type

# Obtain CSC representations for in-memory sampling:
row_dict, colptr_dict, perm_dict = graph_store.csc()
self.row_dict = {
edge_type_to_str(k): v
for k, v in row_dict.items()
}
self.colptr_dict = {
edge_type_to_str(k): v
for k, v in colptr_dict.items()
}
self.perm_dict = {
edge_type_to_str(k): v
for k, v in perm_dict.items()
}

else:
raise TypeError(f'NeighborLoader found invalid type: {type(data)}')

def _set_num_neighbors_and_num_hops(self, num_neighbors):
if isinstance(num_neighbors, (list, tuple)):
num_neighbors = {key: num_neighbors for key in self.edge_types}
assert isinstance(num_neighbors, dict)
self.num_neighbors = {
edge_type_to_str(key): value
for key, value in num_neighbors.items()
}
# Add at least one element to the list to ensure `max` is well-defined
self.num_hops = max([0] + [len(v) for v in num_neighbors.values()])

def _sparse_neighbor_sample(self, index: Tensor):
fn = torch.ops.torch_sparse.neighbor_sample
node, row, col, edge = fn(
self.colptr,
self.row,
index,
self.num_neighbors,
self.replace,
self.directed,
)
return node, row, col, edge

def _hetero_sparse_neighbor_sample(self, index_dict: Dict[str, Tensor],
**kwargs):
if self.node_time_dict is None:
fn = torch.ops.torch_sparse.hetero_neighbor_sample
node_dict, row_dict, col_dict, edge_dict = fn(
self.node_types,
self.edge_types,
self.colptr_dict,
self.row_dict,
index_dict,
self.num_neighbors,
self.num_hops,
self.replace,
self.directed,
)
else:
try:
fn = torch.ops.torch_sparse.hetero_temporal_neighbor_sample
except RuntimeError as e:
raise RuntimeError(
"'torch_sparse' operator "
"'hetero_temporal_neighbor_sample' not "
"found. Currently requires building "
"'torch_sparse' from master.", e)

node_dict, row_dict, col_dict, edge_dict = fn(
self.node_types,
self.edge_types,
self.colptr_dict,
self.row_dict,
index_dict,
self.num_neighbors,
kwargs.get('node_time_dict', self.node_time_dict),
self.num_hops,
self.replace,
self.directed,
)
return node_dict, row_dict, col_dict, edge_dict

def __call__(self, index: Union[List[int], Tensor]):
if not isinstance(index, torch.LongTensor):
index = torch.LongTensor(index)

if self.data_cls != 'custom' and issubclass(self.data_cls, Data):
return self._sparse_neighbor_sample(index) + (index.numel(), )

elif self.data_cls == 'custom' or issubclass(self.data_cls,
HeteroData):
return self._hetero_sparse_neighbor_sample(
{self.input_type: index}) + (index.numel(), )


class NeighborLoader(torch.utils.data.DataLoader):
r"""A data loader that performs neighbor sampling as introduced in the
`"Inductive Representation Learning on Large Graphs"
Expand Down Expand Up @@ -400,6 +206,7 @@ def __init__(
super().__init__(input_nodes, collate_fn=self.collate_fn, **kwargs)

def filter_fn(self, out: Any) -> Union[Data, HeteroData]:
# TODO(manan): remove special access of input_type and perm_dict here:
if isinstance(self.data, Data):
node, row, col, edge, batch_size = out
data = filter_data(self.data, node, row, col, edge,
Expand Down
Loading