Skip to content

Commit

Permalink
refactor(loader): NodeLoader and LinkLoader as base implementatio…
Browse files Browse the repository at this point in the history
…n classes, part 4 (#5404)

This PR continues the effort to consolidate PyG's sampling interface in preparation for moving `sample(...)` behind the `GraphStore` interface. This effort is somewhat large in scope and will be broken into multiple PRs for ease of review. It builds off of #5402, and makes a significant move to abstract data loading behind a `data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]]` and a `sampler: BaseSampler`.

It does so by introducing two base implementation classes: `NodeLoader` and `LinkLoader`. `NodeLoader` performs sampling from nodes (using `sample_from_nodes`), and `LinkLoader` does the same from edges (using `sample_from_edges`). They both expose parameters in their initializers that are intended for **loading** (that is, the process of using a sampler to get subgraphs, using a feature fetcher to get features, and joining these together to construct a `HeteroData` object to pass downstream). Samplers are intended to expose parameters that are used for **sampling** (that are particular to the sampling method).

The implementations of `NeighborLoader` and `LinkNeighborLoader` are now very simple: they pass the `NeighborSampler` and any necessary initialization parameters directly in `__init__`, with no other change.
  • Loading branch information
mananshah99 authored Sep 13, 2022
1 parent 373c0ef commit 69f85c2
Show file tree
Hide file tree
Showing 11 changed files with 546 additions and 350 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [2.2.0] - 2022-MM-DD
### Added
- Added `PositionalEncoding` ([#5381](https://github.com/pyg-team/pytorch_geometric/pull/5381))
- Consolidated sampler routines behind `torch_geometric.sampler`, enabling ease of extensibility in the future ([#5312](https://github.com/pyg-team/pytorch_geometric/pull/5312), [#5365](https://github.com/pyg-team/pytorch_geometric/pull/5365), [#5402](https://github.com/pyg-team/pytorch_geometric/pull/5402)))
- Consolidated sampler routines behind `torch_geometric.sampler`, enabling ease of extensibility in the future ([#5312](https://github.com/pyg-team/pytorch_geometric/pull/5312), [#5365](https://github.com/pyg-team/pytorch_geometric/pull/5365), [#5402](https://github.com/pyg-team/pytorch_geometric/pull/5402), [#5404](https://github.com/pyg-team/pytorch_geometric/pull/5404)))
- Added `pyg-lib` neighbor sampling ([#5384](https://github.com/pyg-team/pytorch_geometric/pull/5384), [#5388](https://github.com/pyg-team/pytorch_geometric/pull/5388))
- Added `pyg_lib.segment_matmul` integration within `HeteroLinear` ([#5330](https://github.com/pyg-team/pytorch_geometric/pull/5330), [#5347](https://github.com/pyg-team/pytorch_geometric/pull/5347)))
- Enabled `bf16` support in benchmark scripts ([#5293](https://github.com/pyg-team/pytorch_geometric/pull/5293), [#5341](https://github.com/pyg-team/pytorch_geometric/pull/5341))
Expand All @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Breaking change: removed `num_neighbors` as an attribute of loader ([#5404](https://github.com/pyg-team/pytorch_geometric/pull/5404))
- `ASAPooling` is now jittable ([#5395](https://github.com/pyg-team/pytorch_geometric/pull/5395))
- Updated unsupervised `GraphSAGE` example to leverage `LinkNeighborLoader` ([#5317](https://github.com/pyg-team/pytorch_geometric/pull/5317))
- Replace in-place operations with out-of-place ones to align with `torch.scatter_reduce` API ([#5353](https://github.com/pyg-team/pytorch_geometric/pull/5353))
Expand Down
25 changes: 13 additions & 12 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ def run(args: argparse.ArgumentParser) -> None:
)

for layers in args.num_layers:
num_neighbors = [args.hetero_num_neighbors] * layers
if hetero:
# batch-wise inference
subgraph_loader = NeighborLoader(
data,
num_neighbors=[args.hetero_num_neighbors] *
layers, # batch-wise inference
num_neighbors=num_neighbors,
input_nodes=mask,
batch_size=batch_size,
shuffle=False,
Expand All @@ -71,12 +72,11 @@ def run(args: argparse.ArgumentParser) -> None:

for hidden_channels in args.num_hidden_channels:
print('----------------------------------------------')
print(
f'Batch size={batch_size}, '
f'Layers amount={layers}, '
f'Num_neighbors={subgraph_loader.num_neighbors}, '
f'Hidden features size={hidden_channels}, '
f'Sparse tensor={args.use_sparse_tensor}')
print(f'Batch size={batch_size}, '
f'Layers amount={layers}, '
f'Num_neighbors={num_neighbors}, '
f'Hidden features size={hidden_channels}, '
f'Sparse tensor={args.use_sparse_tensor}')
params = {
'inputs_channels': inputs_channels,
'hidden_channels': hidden_channels,
Expand Down Expand Up @@ -111,10 +111,11 @@ def run(args: argparse.ArgumentParser) -> None:
with torch_profile():
model.inference(subgraph_loader, device,
progress_bar=True)
rename_profile_file(
model_name, dataset_name, str(batch_size),
str(layers), str(hidden_channels),
str(subgraph_loader.num_neighbors))
rename_profile_file(model_name, dataset_name,
str(batch_size),
str(layers),
str(hidden_channels),
str(num_neighbors))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_temporal_heterogeneous_link_neighbor_loader():
data['paper', 'author'].edge_index = get_edge_index(100, 200, 1000)
data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000)

with pytest.raises(ValueError, match=r'`edge_label_time` is specified .*'):
with pytest.raises(ValueError, match=r"'edge_label_time' was not set.*"):
loader = LinkNeighborLoader(data, num_neighbors=[-1] * 2,
edge_label_index=('paper', 'paper'),
batch_size=32, time_attr='time')
Expand Down
4 changes: 4 additions & 0 deletions torch_geometric/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
from .neighbor_sampler import NeighborSampler
from .imbalanced_sampler import ImbalancedSampler
from .dynamic_batch_sampler import DynamicBatchSampler
from .node_loader import NodeLoader
from .link_loader import LinkLoader

__all__ = classes = [
'DataLoader',
'NodeLoader',
'LinkLoader',
'NeighborLoader',
'LinkNeighborLoader',
'HGTLoader',
Expand Down
214 changes: 214 additions & 0 deletions torch_geometric/loader/link_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from typing import Any, Callable, Iterator, Tuple, Union

import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.utils import (
filter_custom_store,
filter_data,
filter_hetero_data,
get_edge_label_index,
)
from torch_geometric.sampler.base import (
BaseSampler,
EdgeSamplerInput,
HeteroSamplerOutput,
SamplerOutput,
)
from torch_geometric.typing import InputEdges, OptTensor


class LinkLoader(torch.utils.data.DataLoader):
r"""A data loader that performs neighbor sampling from link information,
using a generic :class:`~torch_geometric.sampler.BaseSampler`
implementation that defines a :meth:`sample_from_edges` function and is
supported on the provided input :obj:`data` object.
Args:
data (torch_geometric.data.Data or torch_geometric.data.HeteroData):
The :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` graph object.
link_sampler (torch_geometric.sampler.BaseSampler): The sampler
implementation to be used with this loader. Note that the
sampler implementation must be compatible with the input data
object.
edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
The edge indices for which neighbors are sampled to create
mini-batches.
If set to :obj:`None`, all edges will be considered.
In heterogeneous graphs, needs to be passed as a tuple that holds
the edge type and corresponding edge indices.
(default: :obj:`None`)
edge_label (Tensor, optional): The labels of edge indices for
which neighbors are sampled. Must be the same length as
the :obj:`edge_label_index`. If set to :obj:`None` its set to
`torch.zeros(...)` internally. (default: :obj:`None`)
edge_label_time (Tensor, optional): The timestamps for edge indices
for which neighbors are sampled. Must be the same length as
:obj:`edge_label_index`. If set, temporal sampling will be
used such that neighbors are guaranteed to fulfill temporal
constraints, *i.e.*, neighbors have an earlier timestamp than
the ouput edge. The :obj:`time_attr` needs to be set for this
to work. (default: :obj:`None`)
neg_sampling_ratio (float, optional): the number of negative samples
to include as a ratio of the number of positive examples
(default: 0).
transform (Callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
filter_per_worker (bool, optional): If set to :obj:`True`, will filter
the returning data in each worker's subprocess rather than in the
main process.
Setting this to :obj:`True` is generally not recommended:
(1) it may result in too many open file handles,
(2) it may slown down data loading,
(3) it requires operating on CPU tensors.
(default: :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
"""
def __init__(
self,
data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
link_sampler: BaseSampler,
edge_label_index: InputEdges = None,
edge_label: OptTensor = None,
edge_label_time: OptTensor = None,
neg_sampling_ratio: float = 0.0,
transform: Callable = None,
filter_per_worker: bool = False,
**kwargs,
):
# Remove for PyTorch Lightning:
if 'dataset' in kwargs:
del kwargs['dataset']
if 'collate_fn' in kwargs:
del kwargs['collate_fn']

self.data = data

# Initialize sampler with keyword arguments:
# NOTE sampler is an attribute of 'DataLoader', so we use link_sampler
# here:
self.link_sampler = link_sampler

# Store additional arguments:
self.edge_label = edge_label
self.edge_label_index = edge_label_index
self.edge_label_time = edge_label_time
self.transform = transform
self.filter_per_worker = filter_per_worker
self.neg_sampling_ratio = neg_sampling_ratio

# Get input type, or None for homogeneous graphs:
edge_type, edge_label_index = get_edge_label_index(
data, edge_label_index)
if edge_label is None:
edge_label = torch.zeros(edge_label_index.size(1),
device=edge_label_index.device)
self.input_type = edge_type

super().__init__(
Dataset(edge_label_index, edge_label, edge_label_time),
collate_fn=self.collate_fn,
**kwargs,
)

def filter_fn(
self,
out: Union[SamplerOutput, HeteroSamplerOutput],
) -> Union[Data, HeteroData]:
r"""Joins the sampled nodes with their corresponding features,
returning the resulting (Data or HeteroData) object to be used
downstream."""
if isinstance(out, SamplerOutput):
edge_label_index, edge_label = out.metadata
data = filter_data(self.data, out.node, out.row, out.col, out.edge,
self.link_sampler.edge_permutation)
data.edge_label_index = edge_label_index
data.edge_label = edge_label

elif isinstance(out, HeteroSamplerOutput):
edge_label_index, edge_label, edge_label_time = out.metadata
if isinstance(self.data, HeteroData):
data = filter_hetero_data(self.data, out.node, out.row,
out.col, out.edge,
self.link_sampler.edge_permutation)
else: # Tuple[FeatureStore, GraphStore]
data = filter_custom_store(*self.data, out.node, out.row,
out.col, out.edge)

edge_type = self.input_type
data[edge_type].edge_label_index = edge_label_index
data[edge_type].edge_label = edge_label
if edge_label_time is not None:
data[edge_type].edge_label_time = edge_label_time

else:
raise TypeError(f"'{self.__class__.__name__}'' found invalid "
f"type: '{type(out)}'")

return data if self.transform is None else self.transform(data)

def collate_fn(self, index: EdgeSamplerInput) -> Any:
r"""Samples a subgraph from a batch of input nodes."""
out = self.link_sampler.sample_from_edges(
index,
negative_sampling_ratio=self.neg_sampling_ratio,
)
if self.filter_per_worker:
# We execute `filter_fn` in the worker process.
out = self.filter_fn(out)
return out

def _get_iterator(self) -> Iterator:
if self.filter_per_worker:
return super()._get_iterator()
# We execute `filter_fn` in the main process.
return DataLoaderIterator(super()._get_iterator(), self.filter_fn)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'


###############################################################################


class Dataset(torch.utils.data.Dataset):
def __init__(
self,
edge_label_index: torch.Tensor,
edge_label: torch.Tensor,
edge_label_time: OptTensor = None,
):
# NOTE see documentation of LinkLoader for details on these three
# input parameters:
self.edge_label_index = edge_label_index
self.edge_label = edge_label
self.edge_label_time = edge_label_time

def __getitem__(
self,
idx: int,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if self.edge_label_time is None:
return (
self.edge_label_index[0, idx],
self.edge_label_index[1, idx],
self.edge_label[idx],
)
else:
return (
self.edge_label_index[0, idx],
self.edge_label_index[1, idx],
self.edge_label[idx],
self.edge_label_time[idx],
)

def __len__(self) -> int:
return self.edge_label_index.size(1)
Loading

0 comments on commit 69f85c2

Please sign in to comment.