From 557492b60e65f64573197de1bddc56b29f499329 Mon Sep 17 00:00:00 2001 From: MatthieuMelennec <77962043+MatthieuMelennec@users.noreply.github.com> Date: Thu, 28 Mar 2024 15:25:21 +0100 Subject: [PATCH] Add `reduce` argument to `nn.pool.pool.pool_edge` (#9116) To specify edge_attr reduction scheme (issue #9107) --------- Co-authored-by: Matthieu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matthias Fey --- torch_geometric/nn/pool/pool.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/pool/pool.py b/torch_geometric/nn/pool/pool.py index da076bc1cbef..52bc5c6c4cd3 100644 --- a/torch_geometric/nn/pool/pool.py +++ b/torch_geometric/nn/pool/pool.py @@ -5,12 +5,18 @@ from torch_geometric.utils import coalesce, remove_self_loops, scatter -def pool_edge(cluster, edge_index, edge_attr: Optional[torch.Tensor] = None): +def pool_edge( + cluster, + edge_index, + edge_attr: Optional[torch.Tensor] = None, + reduce: Optional[str] = 'sum', +): num_nodes = cluster.size(0) edge_index = cluster[edge_index.view(-1)].view(2, -1) edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) if edge_index.numel() > 0: - edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes) + edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes, + reduce=reduce) return edge_index, edge_attr