Skip to content

Commit

Permalink
Cherry pick changes from larger commit that pertain only to large gra…
Browse files Browse the repository at this point in the history
…ph indexer
  • Loading branch information
zaristei committed Oct 25, 2024
1 parent facf0c4 commit 374ba5c
Show file tree
Hide file tree
Showing 7 changed files with 1,098 additions and 89 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `data.LargeGraphIndexer` ([#9597](https://github.com/pyg-team/pytorch_geometric/pull/9597))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))

### Changed
Expand Down
170 changes: 170 additions & 0 deletions test/data/test_large_graph_indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import random
import string
from typing import List

import networkx as nx
import torch

from torch_geometric.data import (
Data,
LargeGraphIndexer,
TripletLike,
get_features_for_triplets,
)
from torch_geometric.data.large_graph_indexer import (
EDGE_PID,
EDGE_RELATION,
NODE_PID,
)

# create possible nodes and edges for graph
strkeys = string.ascii_letters + string.digits
NODE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(1000)})
EDGE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(50)})


def featurize(s: str) -> int:
return int.from_bytes(s.encode(), 'little')


def sample_triplets(amount: int = 1) -> List[TripletLike]:
trips = []
for i in range(amount):
h, t = random.sample(NODE_POOL, k=2)
r = random.sample(EDGE_POOL, k=1)[0]
trips.append(tuple([h, r, t]))
return trips


def preprocess_triplet(triplet: TripletLike) -> TripletLike:
h, r, t = triplet
return h.lower(), r, t.lower()


def test_basic_collate():
graphs = [sample_triplets(1000) for i in range(2)]

indexer_0 = LargeGraphIndexer.from_triplets(
graphs[0], pre_transform=preprocess_triplet)
indexer_1 = LargeGraphIndexer.from_triplets(
graphs[1], pre_transform=preprocess_triplet)

big_indexer = LargeGraphIndexer.collate([indexer_0, indexer_1])

assert len(indexer_0._nodes) + len(
indexer_1._nodes) - len(indexer_0._nodes.keys()
& indexer_1._nodes.keys()) == len(
big_indexer._nodes)
assert len(indexer_0._edges) + len(
indexer_1._edges) - len(indexer_0._edges.keys()
& indexer_1._edges.keys()) == len(
big_indexer._edges)

assert len(set(big_indexer._nodes.values())) == len(big_indexer._nodes)
assert len(set(big_indexer._edges.values())) == len(big_indexer._edges)

for node in (indexer_0._nodes.keys() | indexer_1._nodes.keys()):
assert big_indexer.node_attr[NODE_PID][
big_indexer._nodes[node]] == node


def test_large_graph_index():
graphs = [sample_triplets(1000) for i in range(100)]

# Preprocessing of trips lowercases nodes but not edges
node_feature_vecs = {s.lower(): featurize(s.lower()) for s in NODE_POOL}
edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL}

def encode_graph_from_trips(triplets: List[TripletLike]) -> Data:
seen_nodes = dict()
edge_attrs = list()
edge_idx = []
for trip in triplets:
trip = preprocess_triplet(trip)
h, r, t = trip
seen_nodes[h] = len(
seen_nodes) if h not in seen_nodes else seen_nodes[h]
seen_nodes[t] = len(
seen_nodes) if t not in seen_nodes else seen_nodes[t]
edge_attrs.append(edge_feature_vecs[r])
edge_idx.append((seen_nodes[h], seen_nodes[t]))

x = torch.Tensor([node_feature_vecs[n] for n in seen_nodes.keys()])
edge_idx = torch.LongTensor(edge_idx).T
edge_attrs = torch.Tensor(edge_attrs)
return Data(x=x, edge_index=edge_idx, edge_attr=edge_attrs)

naive_graph_ds = [
encode_graph_from_trips(triplets=trips) for trips in graphs
]

indexer = LargeGraphIndexer.collate([
LargeGraphIndexer.from_triplets(g, pre_transform=preprocess_triplet)
for g in graphs
])
indexer_nodes = indexer.get_unique_node_features()
indexer_node_vals = torch.Tensor(
[node_feature_vecs[n] for n in indexer_nodes])
indexer_edges = indexer.get_unique_edge_features(
feature_name=EDGE_RELATION)
indexer_edge_vals = torch.Tensor(
[edge_feature_vecs[e] for e in indexer_edges])
indexer.add_node_feature('x', indexer_node_vals)
indexer.add_edge_feature('edge_attr', indexer_edge_vals,
map_from_feature=EDGE_RELATION)
large_graph_ds = [
get_features_for_triplets(indexer=indexer, triplets=g,
node_feature_name='x',
edge_feature_name='edge_attr',
pre_transform=preprocess_triplet)
for g in graphs
]

for ds in large_graph_ds:
assert NODE_PID in ds
assert EDGE_PID in ds
assert "node_idx" in ds
assert "edge_idx" in ds

def results_are_close_enough(ground_truth: Data, new_method: Data,
thresh=.99):
def _sorted_tensors_are_close(tensor1, tensor2):
return torch.all(
torch.isclose(tensor1.sort()[0],
tensor2.sort()[0]) > thresh)

def _graphs_are_same(tensor1, tensor2):
return nx.weisfeiler_lehman_graph_hash(nx.Graph(
tensor1.T)) == nx.weisfeiler_lehman_graph_hash(
nx.Graph(tensor2.T))
return _sorted_tensors_are_close(
ground_truth.x, new_method.x) \
and _sorted_tensors_are_close(
ground_truth.edge_attr, new_method.edge_attr) \
and _graphs_are_same(
ground_truth.edge_index, new_method.edge_index)

for dsets in zip(naive_graph_ds, large_graph_ds):
assert results_are_close_enough(*dsets)


def test_save_load(tmp_path):
graph = sample_triplets(1000)

node_feature_vecs = {s: featurize(s) for s in NODE_POOL}
edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL}

indexer = LargeGraphIndexer.from_triplets(graph)
indexer_nodes = indexer.get_unique_node_features()
indexer_node_vals = torch.Tensor(
[node_feature_vecs[n] for n in indexer_nodes])
indexer_edges = indexer.get_unique_edge_features(
feature_name=EDGE_RELATION)
indexer_edge_vals = torch.Tensor(
[edge_feature_vecs[e] for e in indexer_edges])
indexer.add_node_feature('x', indexer_node_vals)
indexer.add_edge_feature('edge_attr', indexer_edge_vals,
map_from_feature=EDGE_RELATION)

indexer.save(str(tmp_path))
assert indexer == LargeGraphIndexer.from_disk(str(tmp_path))
29 changes: 29 additions & 0 deletions test/datasets/test_web_qsp_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from torch_geometric.datasets import WebQSPDataset
from torch_geometric.testing import onlyFullTest, onlyOnline


@pytest.mark.skip(reason="Times out")
@onlyOnline
@onlyFullTest
def test_web_qsp_dataset():
dataset = WebQSPDataset()
assert len(dataset) == 4700
assert str(dataset) == "WebQSPDataset(4700)"


@onlyOnline
@onlyFullTest
def test_web_qsp_dataset_limit(tmp_path):
dataset = WebQSPDataset(root=tmp_path, limit=100)
assert len(dataset) == 100
assert str(dataset) == "WebQSPDataset(100)"


@onlyOnline
@onlyFullTest
def test_web_qsp_dataset_limit_no_pcst(tmp_path):
dataset = WebQSPDataset(root=tmp_path, limit=100, include_pcst=False)
assert len(dataset) == 100
assert str(dataset) == "WebQSPDataset(100)"
5 changes: 5 additions & 0 deletions torch_geometric/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .makedirs import makedirs
from .download import download_url, download_google_url
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups

from torch_geometric.lazy_loader import LazyLoader

Expand All @@ -27,6 +28,8 @@
'Dataset',
'InMemoryDataset',
'OnDiskDataset',
'LargeGraphIndexer',
'TripletLike',
]

remote_backend_classes = [
Expand All @@ -50,6 +53,8 @@
'extract_zip',
'extract_bz2',
'extract_gz',
'get_features_for_triplets',
"get_features_for_triplets_groups",
]

__all__ = data_classes + remote_backend_classes + helper_functions
Expand Down
Loading

0 comments on commit 374ba5c

Please sign in to comment.