Skip to content

Commit

Permalink
Add reduce argument to nn.pool.pool.pool_edge (#9116)
Browse files Browse the repository at this point in the history
To specify edge_attr reduction scheme (issue #9107)

---------

Co-authored-by: Matthieu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
4 people authored Mar 28, 2024
1 parent dcd2844 commit 557492b
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torch_geometric/nn/pool/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from torch_geometric.utils import coalesce, remove_self_loops, scatter


def pool_edge(cluster, edge_index, edge_attr: Optional[torch.Tensor] = None):
def pool_edge(
cluster,
edge_index,
edge_attr: Optional[torch.Tensor] = None,
reduce: Optional[str] = 'sum',
):
num_nodes = cluster.size(0)
edge_index = cluster[edge_index.view(-1)].view(2, -1)
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
if edge_index.numel() > 0:
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes)
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,
reduce=reduce)
return edge_index, edge_attr


Expand Down

0 comments on commit 557492b

Please sign in to comment.