diff --git a/CHANGELOG.md b/CHANGELOG.md index 15ebe1b9ca5a..de26923625aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added hierarichial heterogeneous GraphSAGE example on OGB-MAG ([#7425](https://github.com/pyg-team/pytorch_geometric/pull/7425)) - Added a `LocalGraphStore` implementation for distributed training ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451)) - Added the `GDELTLite` dataset ([#7442](https://github.com/pyg-team/pytorch_geometric/pull/7442)) - Added the `approx_knn` function for approximated nearest neighbor search ([#7421](https://github.com/pyg-team/pytorch_geometric/pull/7421)) diff --git a/examples/hetero/hierarchical_sage.py b/examples/hetero/hierarchical_sage.py new file mode 100644 index 000000000000..c4ae8ad12d95 --- /dev/null +++ b/examples/hetero/hierarchical_sage.py @@ -0,0 +1,141 @@ +import argparse + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +import torch_geometric.transforms as T +from torch_geometric.datasets import OGB_MAG +from torch_geometric.loader import NeighborLoader +from torch_geometric.nn import HeteroConv, Linear, SAGEConv +from torch_geometric.utils import trim_to_layer + +parser = argparse.ArgumentParser() +parser.add_argument('--device', type=str, default='cuda') +parser.add_argument('--use-sparse-tensor', action='store_true') +args = parser.parse_args() + +device = args.device if torch.cuda.is_available() else 'cpu' + +transforms = [T.ToUndirected(merge=True)] +if args.use_sparse_tensor: + transforms.append(T.ToSparseTensor()) +dataset = OGB_MAG(root='../../data', preprocess='metapath2vec', + transform=T.Compose(transforms)) +data = dataset[0].to(device, 'x', 'y') + + +class HierarchicalHeteroGraphSage(torch.nn.Module): + def __init__(self, edge_types, hidden_channels, out_channels, num_layers): + super().__init__() + + self.convs = torch.nn.ModuleList() + for _ in range(num_layers): + conv = HeteroConv( + { + edge_type: SAGEConv((-1, -1), hidden_channels) + for edge_type in edge_types + }, aggr='sum') + self.convs.append(conv) + + self.lin = Linear(hidden_channels, out_channels) + + def forward(self, x_dict, edge_index_dict, num_sampled_edges_dict, + num_sampled_nodes_dict): + + for i, conv in enumerate(self.convs): + x_dict, edge_index_dict, _ = trim_to_layer( + layer=i, + num_sampled_nodes_per_hop=num_sampled_nodes_dict, + num_sampled_edges_per_hop=num_sampled_edges_dict, + x=x_dict, + edge_index=edge_index_dict, + ) + + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: x.relu() for key, x in x_dict.items()} + + return self.lin(x_dict['paper']) + + +model = HierarchicalHeteroGraphSage( + edge_types=data.edge_types, + hidden_channels=64, + out_channels=dataset.num_classes, + num_layers=2, +).to(args.device) + +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + +kwargs = {'batch_size': 1024, 'num_workers': 0} +train_loader = NeighborLoader( + data, + num_neighbors=[10] * 2, + shuffle=True, + input_nodes=('paper', data['paper'].train_mask), + **kwargs, +) + +val_loader = NeighborLoader( + data, + num_neighbors=[10] * 2, + shuffle=False, + input_nodes=('paper', data['paper'].val_mask), + **kwargs, +) + + +def train(): + model.train() + + total_examples = total_loss = 0 + for batch in tqdm(train_loader): + batch = batch.to(device) + optimizer.zero_grad() + + out = model( + batch.x_dict, + batch.adj_t_dict + if args.use_sparse_tensor else batch.edge_index_dict, + num_sampled_nodes_dict=batch.num_sampled_nodes_dict, + num_sampled_edges_dict=batch.num_sampled_edges_dict, + ) + + batch_size = batch['paper'].batch_size + loss = F.cross_entropy(out[:batch_size], batch['paper'].y[:batch_size]) + loss.backward() + optimizer.step() + + total_examples += batch_size + total_loss += float(loss) * batch_size + + return total_loss / total_examples + + +@torch.no_grad() +def test(loader): + model.eval() + + total_examples = total_correct = 0 + for batch in tqdm(loader): + batch = batch.to(device) + out = model( + batch.x_dict, + batch.adj_t_dict + if args.use_sparse_tensor else batch.edge_index_dict, + num_sampled_nodes_dict=batch.num_sampled_nodes_dict, + num_sampled_edges_dict=batch.num_sampled_edges_dict, + ) + + batch_size = batch['paper'].batch_size + pred = out[:batch_size].argmax(dim=-1) + total_examples += batch_size + total_correct += int((pred == batch['paper'].y[:batch_size]).sum()) + + return total_correct / total_examples + + +for epoch in range(1, 6): + loss = train() + val_acc = test(val_loader) + print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}') diff --git a/torch_geometric/utils/trim_to_layer.py b/torch_geometric/utils/trim_to_layer.py index dd2ce60b8c57..025041d2da38 100644 --- a/torch_geometric/utils/trim_to_layer.py +++ b/torch_geometric/utils/trim_to_layer.py @@ -4,6 +4,7 @@ from torch import Tensor from torch_geometric.typing import ( + Adj, EdgeType, MaybeHeteroEdgeTensor, MaybeHeteroNodeTensor, @@ -17,10 +18,10 @@ def trim_to_layer( layer: int, num_sampled_nodes_per_hop: Union[List[int], Dict[NodeType, List[int]]], num_sampled_edges_per_hop: Union[List[int], Dict[EdgeType, List[int]]], - x: Union[MaybeHeteroNodeTensor], - edge_index: Union[MaybeHeteroEdgeTensor], + x: MaybeHeteroNodeTensor, + edge_index: MaybeHeteroEdgeTensor, edge_attr: Optional[MaybeHeteroEdgeTensor] = None, -) -> Tuple[MaybeHeteroEdgeTensor, MaybeHeteroNodeTensor, +) -> Tuple[MaybeHeteroNodeTensor, MaybeHeteroEdgeTensor, Optional[MaybeHeteroEdgeTensor]]: r"""Trims the :obj:`edge_index` representation, node features :obj:`x` and edge features :obj:`edge_attr` to a minimal-sized representation for the @@ -47,62 +48,31 @@ def trim_to_layer( if layer <= 0: return x, edge_index, edge_attr - # TODO Support `SparseTensor` for heterogeneous graphs. if isinstance(num_sampled_edges_per_hop, dict): x = { - k: v.narrow( - dim=0, - start=0, - length=v.size(0) - num_sampled_nodes_per_hop[k][-layer], - ) + k: trim_feat(v, layer, num_sampled_nodes_per_hop[k]) for k, v in x.items() } edge_index = { - k: v.narrow( - dim=1, - start=0, - length=v.size(1) - num_sampled_edges_per_hop[k][-layer], - ) + k: trim_adj(v, layer, num_sampled_nodes_per_hop[k[-1]], + num_sampled_edges_per_hop[k]) for k, v in edge_index.items() } if edge_attr is not None: edge_attr = { - k: v.narrow( - dim=0, - start=0, - length=v.size(0) - num_sampled_edges_per_hop[k][-layer], - ) + k: trim_feat(v, layer, num_sampled_edges_per_hop[k]) for k, v in edge_attr.items() } return x, edge_index, edge_attr - x = x.narrow( - dim=0, - start=0, - length=x.size(0) - num_sampled_nodes_per_hop[-layer], - ) - if edge_attr is not None: - edge_attr = edge_attr.narrow( - dim=0, - start=0, - length=edge_attr.size(0) - num_sampled_edges_per_hop[-layer], - ) - if isinstance(edge_index, Tensor): - edge_index = edge_index.narrow( - dim=1, - start=0, - length=edge_index.size(1) - num_sampled_edges_per_hop[-layer], - ) - return x, edge_index, edge_attr - - elif isinstance(edge_index, SparseTensor): - num_nodes = edge_index.size(0) - num_sampled_nodes_per_hop[-layer] - num_seed_nodes = num_nodes - num_sampled_nodes_per_hop[-(layer + 1)] - edge_index = trim_sparse_tensor(edge_index, num_nodes, num_seed_nodes) + x = trim_feat(x, layer, num_sampled_nodes_per_hop) + edge_index = trim_adj(edge_index, layer, num_sampled_nodes_per_hop, + num_sampled_edges_per_hop) - return x, edge_index, edge_attr + if edge_attr is not None: + edge_attr = trim_feat(edge_attr, layer, num_sampled_edges_per_hop) - raise NotImplementedError + return x, edge_index, edge_attr class TrimToLayer(torch.nn.Module): @@ -112,7 +82,7 @@ def forward( num_sampled_nodes_per_hop: Optional[List[int]], num_sampled_edges_per_hop: Optional[List[int]], x: Tensor, - edge_index: Union[Tensor, SparseTensor], + edge_index: Adj, edge_attr: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: @@ -141,6 +111,42 @@ def forward( # Helper functions ############################################################ +def trim_feat(x: Tensor, layer: int, num_samples_per_hop: List[int]) -> Tensor: + if layer <= 0: + return x + + return x.narrow( + dim=0, + start=0, + length=x.size(0) - num_samples_per_hop[-layer], + ) + + +def trim_adj( + edge_index: Adj, + layer: int, + num_sampled_nodes_per_hop: List[int], + num_sampled_edges_per_hop: List[int], +) -> Adj: + + if layer <= 0: + return edge_index + + if isinstance(edge_index, Tensor): + return edge_index.narrow( + dim=1, + start=0, + length=edge_index.size(1) - num_sampled_edges_per_hop[-layer], + ) + + elif isinstance(edge_index, SparseTensor): + num_nodes = edge_index.size(0) - num_sampled_nodes_per_hop[-layer] + num_seed_nodes = num_nodes - num_sampled_nodes_per_hop[-(layer + 1)] + return trim_sparse_tensor(edge_index, num_nodes, num_seed_nodes) + + raise ValueError(f"Unsupported 'edge_index' type '{type(edge_index)}'") + + def trim_sparse_tensor(src: SparseTensor, num_nodes: int, num_seed_nodes: None) -> SparseTensor: r"""Trims a :class:`SparseTensor` along both dimensions to only contain