Skip to content

Commit

Permalink
[Feature] HGAM hetero low-level example (#7425)
Browse files Browse the repository at this point in the history
-  enable hetero part for CSR in trim_to_layer
- refactor of trim_to_layer utils to add low-level functions
- add hgam hetero low-level example

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
3 people authored May 27, 2023
1 parent d64fa43 commit d8a651c
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added hierarichial heterogeneous GraphSAGE example on OGB-MAG ([#7425](https://github.com/pyg-team/pytorch_geometric/pull/7425))
- Added a `LocalGraphStore` implementation for distributed training ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451))
- Added the `GDELTLite` dataset ([#7442](https://github.com/pyg-team/pytorch_geometric/pull/7442))
- Added the `approx_knn` function for approximated nearest neighbor search ([#7421](https://github.com/pyg-team/pytorch_geometric/pull/7421))
Expand Down
141 changes: 141 additions & 0 deletions examples/hetero/hierarchical_sage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import argparse

import torch
import torch.nn.functional as F
from tqdm import tqdm

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import HeteroConv, Linear, SAGEConv
from torch_geometric.utils import trim_to_layer

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--use-sparse-tensor', action='store_true')
args = parser.parse_args()

device = args.device if torch.cuda.is_available() else 'cpu'

transforms = [T.ToUndirected(merge=True)]
if args.use_sparse_tensor:
transforms.append(T.ToSparseTensor())
dataset = OGB_MAG(root='../../data', preprocess='metapath2vec',
transform=T.Compose(transforms))
data = dataset[0].to(device, 'x', 'y')


class HierarchicalHeteroGraphSage(torch.nn.Module):
def __init__(self, edge_types, hidden_channels, out_channels, num_layers):
super().__init__()

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv(
{
edge_type: SAGEConv((-1, -1), hidden_channels)
for edge_type in edge_types
}, aggr='sum')
self.convs.append(conv)

self.lin = Linear(hidden_channels, out_channels)

def forward(self, x_dict, edge_index_dict, num_sampled_edges_dict,
num_sampled_nodes_dict):

for i, conv in enumerate(self.convs):
x_dict, edge_index_dict, _ = trim_to_layer(
layer=i,
num_sampled_nodes_per_hop=num_sampled_nodes_dict,
num_sampled_edges_per_hop=num_sampled_edges_dict,
x=x_dict,
edge_index=edge_index_dict,
)

x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}

return self.lin(x_dict['paper'])


model = HierarchicalHeteroGraphSage(
edge_types=data.edge_types,
hidden_channels=64,
out_channels=dataset.num_classes,
num_layers=2,
).to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

kwargs = {'batch_size': 1024, 'num_workers': 0}
train_loader = NeighborLoader(
data,
num_neighbors=[10] * 2,
shuffle=True,
input_nodes=('paper', data['paper'].train_mask),
**kwargs,
)

val_loader = NeighborLoader(
data,
num_neighbors=[10] * 2,
shuffle=False,
input_nodes=('paper', data['paper'].val_mask),
**kwargs,
)


def train():
model.train()

total_examples = total_loss = 0
for batch in tqdm(train_loader):
batch = batch.to(device)
optimizer.zero_grad()

out = model(
batch.x_dict,
batch.adj_t_dict
if args.use_sparse_tensor else batch.edge_index_dict,
num_sampled_nodes_dict=batch.num_sampled_nodes_dict,
num_sampled_edges_dict=batch.num_sampled_edges_dict,
)

batch_size = batch['paper'].batch_size
loss = F.cross_entropy(out[:batch_size], batch['paper'].y[:batch_size])
loss.backward()
optimizer.step()

total_examples += batch_size
total_loss += float(loss) * batch_size

return total_loss / total_examples


@torch.no_grad()
def test(loader):
model.eval()

total_examples = total_correct = 0
for batch in tqdm(loader):
batch = batch.to(device)
out = model(
batch.x_dict,
batch.adj_t_dict
if args.use_sparse_tensor else batch.edge_index_dict,
num_sampled_nodes_dict=batch.num_sampled_nodes_dict,
num_sampled_edges_dict=batch.num_sampled_edges_dict,
)

batch_size = batch['paper'].batch_size
pred = out[:batch_size].argmax(dim=-1)
total_examples += batch_size
total_correct += int((pred == batch['paper'].y[:batch_size]).sum())

return total_correct / total_examples


for epoch in range(1, 6):
loss = train()
val_acc = test(val_loader)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')
96 changes: 51 additions & 45 deletions torch_geometric/utils/trim_to_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor

from torch_geometric.typing import (
Adj,
EdgeType,
MaybeHeteroEdgeTensor,
MaybeHeteroNodeTensor,
Expand All @@ -17,10 +18,10 @@ def trim_to_layer(
layer: int,
num_sampled_nodes_per_hop: Union[List[int], Dict[NodeType, List[int]]],
num_sampled_edges_per_hop: Union[List[int], Dict[EdgeType, List[int]]],
x: Union[MaybeHeteroNodeTensor],
edge_index: Union[MaybeHeteroEdgeTensor],
x: MaybeHeteroNodeTensor,
edge_index: MaybeHeteroEdgeTensor,
edge_attr: Optional[MaybeHeteroEdgeTensor] = None,
) -> Tuple[MaybeHeteroEdgeTensor, MaybeHeteroNodeTensor,
) -> Tuple[MaybeHeteroNodeTensor, MaybeHeteroEdgeTensor,
Optional[MaybeHeteroEdgeTensor]]:
r"""Trims the :obj:`edge_index` representation, node features :obj:`x` and
edge features :obj:`edge_attr` to a minimal-sized representation for the
Expand All @@ -47,62 +48,31 @@ def trim_to_layer(
if layer <= 0:
return x, edge_index, edge_attr

# TODO Support `SparseTensor` for heterogeneous graphs.
if isinstance(num_sampled_edges_per_hop, dict):
x = {
k: v.narrow(
dim=0,
start=0,
length=v.size(0) - num_sampled_nodes_per_hop[k][-layer],
)
k: trim_feat(v, layer, num_sampled_nodes_per_hop[k])
for k, v in x.items()
}
edge_index = {
k: v.narrow(
dim=1,
start=0,
length=v.size(1) - num_sampled_edges_per_hop[k][-layer],
)
k: trim_adj(v, layer, num_sampled_nodes_per_hop[k[-1]],
num_sampled_edges_per_hop[k])
for k, v in edge_index.items()
}
if edge_attr is not None:
edge_attr = {
k: v.narrow(
dim=0,
start=0,
length=v.size(0) - num_sampled_edges_per_hop[k][-layer],
)
k: trim_feat(v, layer, num_sampled_edges_per_hop[k])
for k, v in edge_attr.items()
}
return x, edge_index, edge_attr

x = x.narrow(
dim=0,
start=0,
length=x.size(0) - num_sampled_nodes_per_hop[-layer],
)
if edge_attr is not None:
edge_attr = edge_attr.narrow(
dim=0,
start=0,
length=edge_attr.size(0) - num_sampled_edges_per_hop[-layer],
)
if isinstance(edge_index, Tensor):
edge_index = edge_index.narrow(
dim=1,
start=0,
length=edge_index.size(1) - num_sampled_edges_per_hop[-layer],
)
return x, edge_index, edge_attr

elif isinstance(edge_index, SparseTensor):
num_nodes = edge_index.size(0) - num_sampled_nodes_per_hop[-layer]
num_seed_nodes = num_nodes - num_sampled_nodes_per_hop[-(layer + 1)]
edge_index = trim_sparse_tensor(edge_index, num_nodes, num_seed_nodes)
x = trim_feat(x, layer, num_sampled_nodes_per_hop)
edge_index = trim_adj(edge_index, layer, num_sampled_nodes_per_hop,
num_sampled_edges_per_hop)

return x, edge_index, edge_attr
if edge_attr is not None:
edge_attr = trim_feat(edge_attr, layer, num_sampled_edges_per_hop)

raise NotImplementedError
return x, edge_index, edge_attr


class TrimToLayer(torch.nn.Module):
Expand All @@ -112,7 +82,7 @@ def forward(
num_sampled_nodes_per_hop: Optional[List[int]],
num_sampled_edges_per_hop: Optional[List[int]],
x: Tensor,
edge_index: Union[Tensor, SparseTensor],
edge_index: Adj,
edge_attr: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:

Expand Down Expand Up @@ -141,6 +111,42 @@ def forward(
# Helper functions ############################################################


def trim_feat(x: Tensor, layer: int, num_samples_per_hop: List[int]) -> Tensor:
if layer <= 0:
return x

return x.narrow(
dim=0,
start=0,
length=x.size(0) - num_samples_per_hop[-layer],
)


def trim_adj(
edge_index: Adj,
layer: int,
num_sampled_nodes_per_hop: List[int],
num_sampled_edges_per_hop: List[int],
) -> Adj:

if layer <= 0:
return edge_index

if isinstance(edge_index, Tensor):
return edge_index.narrow(
dim=1,
start=0,
length=edge_index.size(1) - num_sampled_edges_per_hop[-layer],
)

elif isinstance(edge_index, SparseTensor):
num_nodes = edge_index.size(0) - num_sampled_nodes_per_hop[-layer]
num_seed_nodes = num_nodes - num_sampled_nodes_per_hop[-(layer + 1)]
return trim_sparse_tensor(edge_index, num_nodes, num_seed_nodes)

raise ValueError(f"Unsupported 'edge_index' type '{type(edge_index)}'")


def trim_sparse_tensor(src: SparseTensor, num_nodes: int,
num_seed_nodes: None) -> SparseTensor:
r"""Trims a :class:`SparseTensor` along both dimensions to only contain
Expand Down

0 comments on commit d8a651c

Please sign in to comment.