Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ClusterPooling layer #9627

Merged
merged 13 commits into from
Sep 10, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))
- Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554))
Expand Down
32 changes: 32 additions & 0 deletions test/nn/pool/test_cluster_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import torch

from torch_geometric.nn import ClusterPooling
from torch_geometric.testing import withPackage


@withPackage('scipy')
@pytest.mark.parametrize('edge_score_method', [
'tanh',
'sigmoid',
'log_softmax',
])
def test_cluster_pooling(edge_score_method):
x = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [-1.0]])
edge_index = torch.tensor([
[0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6],
[1, 2, 3, 6, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4, 0],
])
batch = torch.tensor([0, 0, 0, 0, 1, 1, 0])

op = ClusterPooling(in_channels=1, edge_score_method=edge_score_method)
assert str(op) == 'ClusterPooling(1)'
op.reset_parameters()

x, edge_index, batch, unpool_info = op(x, edge_index, batch)
assert x.size(0) <= 7
assert edge_index.size(0) == 2
if edge_index.numel() > 0:
assert edge_index.min() >= 0
assert edge_index.max() < x.size(0)
assert batch.size() == (x.size(0), )
12 changes: 7 additions & 5 deletions torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
import torch_geometric.typing
from torch_geometric.typing import OptTensor, torch_cluster

from .asap import ASAPooling
from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x
from .edge_pool import EdgePooling
from .glob import global_add_pool, global_max_pool, global_mean_pool
from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,
ApproxMIPSKNNIndex)
from .graclus import graclus
from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
from .mem_pool import MemPooling
from .pan_pool import PANPooling
from .sag_pool import SAGPooling
from .topk_pool import TopKPooling
from .sag_pool import SAGPooling
from .edge_pool import EdgePooling
from .cluster_pool import ClusterPooling
from .asap import ASAPooling
from .pan_pool import PANPooling
from .mem_pool import MemPooling
from .voxel_grid import voxel_grid
from .approx_knn import approx_knn, approx_knn_graph

Expand Down Expand Up @@ -344,6 +345,7 @@ def nearest(
'TopKPooling',
'SAGPooling',
'EdgePooling',
'ClusterPooling',
'ASAPooling',
'PANPooling',
'MemPooling',
Expand Down
145 changes: 145 additions & 0 deletions torch_geometric/nn/pool/cluster_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from typing import NamedTuple, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.utils import (
dense_to_sparse,
one_hot,
to_dense_adj,
to_scipy_sparse_matrix,
)


class UnpoolInfo(NamedTuple):
edge_index: Tensor
cluster: Tensor
batch: Tensor


class ClusterPooling(torch.nn.Module):
r"""The cluster pooling operator from the `"Edge-Based Graph Component
Pooling" <paper url>`_ paper.

:class:`ClusterPooling` computes a score for each edge.
Based on the selected edges, graph clusters are calculated and compressed
to one node using the injective :obj:`"sum"` aggregation function.
Edges are remapped based on the nodes created by each cluster and the
original edges.

Args:
in_channels (int): Size of each input sample.
edge_score_method (str, optional): The function to apply
to compute the edge score from raw edge scores (:obj:`"tanh"`,
:obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`)
dropout (float, optional): The probability with
which to drop edge scores during training. (default: :obj:`0.0`)
threshold (float, optional): The threshold of edge scores. If set to
:obj:`None`, will be automatically inferred depending on
:obj:`edge_score_method`. (default: :obj:`None`)
"""
def __init__(
self,
in_channels: int,
edge_score_method: str = 'tanh',
dropout: float = 0.0,
threshold: Optional[float] = None,
):
super().__init__()
assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax']

if threshold is None:
threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0

self.in_channels = in_channels
self.edge_score_method = edge_score_method
self.dropout = dropout
self.threshhold = threshold

self.lin = torch.nn.Linear(2 * in_channels, 1)

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.lin.reset_parameters()

def forward(
self,
x: Tensor,
edge_index: Tensor,
batch: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
r"""Forward pass.

Args:
x (torch.Tensor): The node features.
edge_index (torch.Tensor): The edge indices.
batch (torch.Tensor): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each node to a specific example.

Return types:
* **x** *(torch.Tensor)* - The pooled node features.
* **edge_index** *(torch.Tensor)* - The coarsened edge indices.
* **batch** *(torch.Tensor)* - The coarsened batch vector.
* **unpool_info** *(UnpoolInfo)* - Information that can be consumed
for unpooling.
"""
mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]

edge_attr = torch.cat(
[x[edge_index[0]], x[edge_index[1]]],
dim=-1,
)
edge_score = self.lin(edge_attr).view(-1)
edge_score = F.dropout(edge_score, p=self.dropout,
training=self.training)

if self.edge_score_method == 'tanh':
edge_score = edge_score.tanh()
elif self.edge_score_method == 'sigmoid':
edge_score = edge_score.sigmoid()
else:
assert self.edge_score_method == 'log_softmax'
edge_score = F.log_softmax(edge_score, dim=0)

return self._merge_edges(x, edge_index, batch, edge_score)

def _merge_edges(
self,
x: Tensor,
edge_index: Tensor,
batch: Tensor,
edge_score: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:

from scipy.sparse.csgraph import connected_components

edge_contract = edge_index[:, edge_score > self.threshhold]

adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
_, cluster_np = connected_components(adj, directed=True,
connection="weak")

cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device)
C = one_hot(cluster)
A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)
S = to_dense_adj(edge_index, edge_attr=edge_score,
max_num_nodes=x.size(0)).squeeze(0)

A_contract = to_dense_adj(edge_contract,
max_num_nodes=x.size(0)).squeeze(0)
nodes_single = ((A_contract.sum(dim=-1) +
A_contract.sum(dim=-2)) == 0).nonzero()
S[nodes_single, nodes_single] = 1.0

x_out = (S @ C).t() @ x
edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))
batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch)
unpool_info = UnpoolInfo(edge_index, cluster, batch)

return x_out, edge_index_out, batch_out, unpool_info

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.in_channels})'
2 changes: 1 addition & 1 deletion torch_geometric/nn/pool/edge_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self,
in_channels: int,
edge_score_method: Optional[Callable] = None,
dropout: Optional[float] = 0.0,
dropout: float = 0.0,
add_to_edge_score: float = 0.5,
):
super().__init__()
Expand Down
Loading