Skip to content

Commit

Permalink
Add check on add_self_loops in HeteroConv and to_hetero (#4647)
Browse files Browse the repository at this point in the history
* add check on self-loops in hetero conv

* add check on self-loops in to_hetero

* fix old tets

* remove pytest import

* update

Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
Padarn and rusty1s authored May 15, 2022
1 parent db40aa6 commit 0ded02b
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added a check in `HeteroConv` and `to_hetero()` to ensure that `MessagePassing.add_self_loops` is disabled ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647))
- Added `HeteroData.subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635))
- Added the `AQSOL` dataset ([#4626](https://github.com/pyg-team/pytorch_geometric/pull/4626))
- Added `HeteroData.node_items()` and `HeteroData.edge_items()` functionality ([#4644](https://github.com/pyg-team/pytorch_geometric/pull/4644))
Expand Down
15 changes: 14 additions & 1 deletion test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def test_hetero_conv(aggr):
{
('paper', 'to', 'paper'): GCNConv(-1, 64),
('author', 'to', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'to', 'author'): GATConv((-1, -1), 64),
('paper', 'to', 'author'): GATConv(
(-1, -1), 64, add_self_loops=False),
}, aggr=aggr)

assert len(list(conv.parameters())) > 0
Expand Down Expand Up @@ -77,3 +78,15 @@ def test_hetero_conv_with_custom_conv():
assert len(out) == 2
assert out['paper'].size() == (50, 64)
assert out['author'].size() == (30, 64)


class MessagePassingLoops(MessagePassing):
def __init__(self):
super().__init__()
self.add_self_loops = True


def test_hetero_conv_self_loop_error():
HeteroConv({('a', 'to', 'a'): MessagePassingLoops()})
with pytest.raises(ValueError, match="incorrect message passing"):
HeteroConv({('a', 'to', 'b'): MessagePassingLoops()})
26 changes: 26 additions & 0 deletions test/nn/test_to_hetero_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

import pytest
import torch
from torch import Tensor
from torch.nn import Linear, ReLU, Sequential
Expand Down Expand Up @@ -363,3 +364,28 @@ def test_graph_level_to_hetero():
model = to_hetero(model, metadata, aggr='mean', debug=False)
out = model(x_dict, edge_index_dict, batch_dict)
assert out.size() == (1, 64)


class MessagePassingLoops(MessagePassing):
def __init__(self):
super().__init__()
self.add_self_loops = True

def forward(self, x):
return x


class ModelLoops(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = MessagePassingLoops()

def forward(self, x):
return self.conv(x)


def test_hetero_transformer_self_loop_error():
to_hetero(ModelLoops(), metadata=(['a'], [('a', 'to', 'a')]))
with pytest.raises(ValueError, match="incorrect message passing"):
to_hetero(ModelLoops(), metadata=(['a', 'b'], [('a', 'to', 'b'),
('b', 'to', 'a')]))
4 changes: 4 additions & 0 deletions torch_geometric/nn/conv/hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torch_geometric.nn.conv.hgt_conv import group
from torch_geometric.typing import Adj, EdgeType, NodeType
from torch_geometric.utils.hetero import check_add_self_loops


class HeteroConv(Module):
Expand Down Expand Up @@ -47,6 +48,9 @@ def __init__(self, convs: Dict[EdgeType, Module],
aggr: Optional[str] = "sum"):
super().__init__()

for edge_type, module in convs.items():
check_add_self_loops(module, [edge_type])

src_node_types = set([key[0] for key in convs.keys()])
dst_node_types = set([key[-1] for key in convs.keys()])
if len(src_node_types - dst_node_types) > 0:
Expand Down
10 changes: 8 additions & 2 deletions torch_geometric/nn/to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import torch
from torch.nn import Module

from torch_geometric.nn.fx import Transformer
from torch_geometric.nn.fx import Transformer, get_submodule
from torch_geometric.typing import EdgeType, Metadata, NodeType
from torch_geometric.utils.hetero import get_unused_node_types
from torch_geometric.utils.hetero import (
check_add_self_loops,
get_unused_node_types,
)

try:
from torch.fx import Graph, GraphModule, Node
Expand Down Expand Up @@ -168,6 +171,9 @@ def call_message_passing_module(self, node: Node, target: Any, name: str):
# Add calls to edge type-wise `MessagePassing` modules and aggregate
# the outputs to node type-wise embeddings afterwards.

module = get_submodule(self.module, target)
check_add_self_loops(module, self.metadata[1])

# Group edge-wise keys per destination:
key_name, keys_per_dst = {}, defaultdict(list)
for key in self.metadata[1]:
Expand Down
9 changes: 9 additions & 0 deletions torch_geometric/utils/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,12 @@ def get_unused_node_types(node_types: List[NodeType],
edge_types: List[EdgeType]) -> Set[NodeType]:
dst_node_types = set(edge_type[-1] for edge_type in edge_types)
return set(node_types) - set(dst_node_types)


def check_add_self_loops(module: torch.nn.Module, edge_types: List[EdgeType]):
is_bipartite = any([key[0] != key[-1] for key in edge_types])
if is_bipartite and getattr(module, 'add_self_loops', False):
raise ValueError(
f"'add_self_loops' attribute set to 'True' on module '{module}' "
f"for use with edge type(s) '{edge_types}'. This will lead to "
f"incorrect message passing results.")

0 comments on commit 0ded02b

Please sign in to comment.