From 33d76ac00518ed867fbc168ff7fe0f43d3362c10 Mon Sep 17 00:00:00 2001 From: "Szarmach, Michal" Date: Mon, 5 Jun 2023 11:45:16 +0200 Subject: [PATCH 1/5] Enable trim to layer with hetero CSR flow --- torch_geometric/nn/conv/hetero_conv.py | 8 +++--- torch_geometric/utils/trim_to_layer.py | 36 ++++++++++++++++++++------ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/torch_geometric/nn/conv/hetero_conv.py b/torch_geometric/nn/conv/hetero_conv.py index e1d4d2a9d26e..527e9b356c1b 100644 --- a/torch_geometric/nn/conv/hetero_conv.py +++ b/torch_geometric/nn/conv/hetero_conv.py @@ -7,7 +7,7 @@ from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.module_dict import ModuleDict -from torch_geometric.typing import Adj, EdgeType, NodeType +from torch_geometric.typing import Adj, EdgeType, NodeType, SparseTensor from torch_geometric.utils.hetero import check_add_self_loops @@ -146,8 +146,10 @@ def forward( value_dict.get(dst, None)) conv = self.convs[str_edge_type] - - if src == dst: + if isinstance(edge_index, + SparseTensor) and edge_index.numel() == 0: + out = x_dict[dst] + elif src == dst: out = conv(x_dict[src], edge_index, *args, **kwargs) else: out = conv((x_dict[src], x_dict[dst]), edge_index, *args, diff --git a/torch_geometric/utils/trim_to_layer.py b/torch_geometric/utils/trim_to_layer.py index 025041d2da38..131ef7630214 100644 --- a/torch_geometric/utils/trim_to_layer.py +++ b/torch_geometric/utils/trim_to_layer.py @@ -54,8 +54,14 @@ def trim_to_layer( for k, v in x.items() } edge_index = { - k: trim_adj(v, layer, num_sampled_nodes_per_hop[k[-1]], - num_sampled_edges_per_hop[k]) + k: trim_adj( + v, + layer, + num_sampled_nodes_per_hop[k[-1]] + if k[0] == k[-1] else # src != dst + (num_sampled_nodes_per_hop[k[0]], + num_sampled_nodes_per_hop[k[-1]]), + num_sampled_edges_per_hop[k]) for k, v in edge_index.items() } if edge_attr is not None: @@ -140,14 +146,28 @@ def trim_adj( ) elif isinstance(edge_index, SparseTensor): - num_nodes = edge_index.size(0) - num_sampled_nodes_per_hop[-layer] - num_seed_nodes = num_nodes - num_sampled_nodes_per_hop[-(layer + 1)] + if isinstance(num_sampled_nodes_per_hop, tuple): + num_nodes = (edge_index.size(0) - + num_sampled_nodes_per_hop[1][-layer], + edge_index.size(1) - + num_sampled_nodes_per_hop[0][-layer]) + + num_seed_nodes = num_nodes[0] - num_sampled_nodes_per_hop[1][-( + layer + 1)] + else: + num_nodes = (edge_index.size(0) - + num_sampled_nodes_per_hop[-layer], + edge_index.size(0) - + num_sampled_nodes_per_hop[-layer]) + num_seed_nodes = num_nodes[0] - num_sampled_nodes_per_hop[-(layer + + 1)] + return trim_sparse_tensor(edge_index, num_nodes, num_seed_nodes) raise ValueError(f"Unsupported 'edge_index' type '{type(edge_index)}'") -def trim_sparse_tensor(src: SparseTensor, num_nodes: int, +def trim_sparse_tensor(src: SparseTensor, num_nodes: tuple, num_seed_nodes: None) -> SparseTensor: r"""Trims a :class:`SparseTensor` along both dimensions to only contain the upper :obj:`num_nodes` in both dimensions. @@ -157,13 +177,13 @@ def trim_sparse_tensor(src: SparseTensor, num_nodes: int, Args: src (SparseTensor): The sparse tensor. - num_nodes (int): The number of first nodes to keep. + num_nodes (tuple): The number of first nodes to keep. num_seed_nodes (int): The number of seed nodes to compute representations. """ rowptr, col, value = src.csr() - rowptr = torch.narrow(rowptr, 0, 0, num_nodes + 1).clone() + rowptr = torch.narrow(rowptr, 0, 0, num_nodes[0] + 1).clone() rowptr[num_seed_nodes + 1:] = rowptr[num_seed_nodes] col = torch.narrow(col, 0, 0, rowptr[-1]) @@ -180,7 +200,7 @@ def trim_sparse_tensor(src: SparseTensor, num_nodes: int, rowptr=rowptr, col=col, value=value, - sparse_sizes=(num_nodes, num_nodes), + sparse_sizes=num_nodes, rowcount=None, colptr=None, colcount=None, From abdc4dd0ff9721e825f0fa1991fb7c7fea72ee6e Mon Sep 17 00:00:00 2001 From: "Szarmach, Michal" Date: Mon, 5 Jun 2023 12:42:07 +0200 Subject: [PATCH 2/5] Refactor --- torch_geometric/utils/trim_to_layer.py | 39 +++++++++----------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/torch_geometric/utils/trim_to_layer.py b/torch_geometric/utils/trim_to_layer.py index 131ef7630214..0376b5fa5af3 100644 --- a/torch_geometric/utils/trim_to_layer.py +++ b/torch_geometric/utils/trim_to_layer.py @@ -54,14 +54,9 @@ def trim_to_layer( for k, v in x.items() } edge_index = { - k: trim_adj( - v, - layer, - num_sampled_nodes_per_hop[k[-1]] - if k[0] == k[-1] else # src != dst - (num_sampled_nodes_per_hop[k[0]], - num_sampled_nodes_per_hop[k[-1]]), - num_sampled_edges_per_hop[k]) + k: trim_adj(v, layer, (num_sampled_nodes_per_hop[k[0]], + num_sampled_nodes_per_hop[k[-1]]), + num_sampled_edges_per_hop[k]) for k, v in edge_index.items() } if edge_attr is not None: @@ -72,8 +67,10 @@ def trim_to_layer( return x, edge_index, edge_attr x = trim_feat(x, layer, num_sampled_nodes_per_hop) - edge_index = trim_adj(edge_index, layer, num_sampled_nodes_per_hop, - num_sampled_edges_per_hop) + edge_index = trim_adj( + edge_index, layer, + (num_sampled_nodes_per_hop, num_sampled_nodes_per_hop), + num_sampled_edges_per_hop) if edge_attr is not None: edge_attr = trim_feat(edge_attr, layer, num_sampled_edges_per_hop) @@ -131,7 +128,7 @@ def trim_feat(x: Tensor, layer: int, num_samples_per_hop: List[int]) -> Tensor: def trim_adj( edge_index: Adj, layer: int, - num_sampled_nodes_per_hop: List[int], + num_sampled_nodes_per_hop: Tuple[List[int]], num_sampled_edges_per_hop: List[int], ) -> Adj: @@ -146,21 +143,11 @@ def trim_adj( ) elif isinstance(edge_index, SparseTensor): - if isinstance(num_sampled_nodes_per_hop, tuple): - num_nodes = (edge_index.size(0) - - num_sampled_nodes_per_hop[1][-layer], - edge_index.size(1) - - num_sampled_nodes_per_hop[0][-layer]) - - num_seed_nodes = num_nodes[0] - num_sampled_nodes_per_hop[1][-( - layer + 1)] - else: - num_nodes = (edge_index.size(0) - - num_sampled_nodes_per_hop[-layer], - edge_index.size(0) - - num_sampled_nodes_per_hop[-layer]) - num_seed_nodes = num_nodes[0] - num_sampled_nodes_per_hop[-(layer + - 1)] + num_nodes = (edge_index.size(0) - num_sampled_nodes_per_hop[1][-layer], + edge_index.size(1) - num_sampled_nodes_per_hop[0][-layer]) + + num_seed_nodes = num_nodes[0] - num_sampled_nodes_per_hop[1][-(layer + + 1)] return trim_sparse_tensor(edge_index, num_nodes, num_seed_nodes) From 1f5080c59ead35d384ead4ae8f22ab67a75e9f00 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 5 Jun 2023 13:17:59 +0000 Subject: [PATCH 3/5] update --- torch_geometric/nn/conv/hetero_conv.py | 7 ++--- torch_geometric/utils/trim_to_layer.py | 42 ++++++++++++++++---------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/torch_geometric/nn/conv/hetero_conv.py b/torch_geometric/nn/conv/hetero_conv.py index 527e9b356c1b..b86e0cba7f11 100644 --- a/torch_geometric/nn/conv/hetero_conv.py +++ b/torch_geometric/nn/conv/hetero_conv.py @@ -7,7 +7,7 @@ from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.module_dict import ModuleDict -from torch_geometric.typing import Adj, EdgeType, NodeType, SparseTensor +from torch_geometric.typing import Adj, EdgeType, NodeType from torch_geometric.utils.hetero import check_add_self_loops @@ -146,10 +146,7 @@ def forward( value_dict.get(dst, None)) conv = self.convs[str_edge_type] - if isinstance(edge_index, - SparseTensor) and edge_index.numel() == 0: - out = x_dict[dst] - elif src == dst: + if src == dst: out = conv(x_dict[src], edge_index, *args, **kwargs) else: out = conv((x_dict[src], x_dict[dst]), edge_index, *args, diff --git a/torch_geometric/utils/trim_to_layer.py b/torch_geometric/utils/trim_to_layer.py index 0376b5fa5af3..524b92fce3ec 100644 --- a/torch_geometric/utils/trim_to_layer.py +++ b/torch_geometric/utils/trim_to_layer.py @@ -54,9 +54,13 @@ def trim_to_layer( for k, v in x.items() } edge_index = { - k: trim_adj(v, layer, (num_sampled_nodes_per_hop[k[0]], - num_sampled_nodes_per_hop[k[-1]]), - num_sampled_edges_per_hop[k]) + k: trim_adj( + v, + layer, + num_sampled_nodes_per_hop[k[0]], + num_sampled_nodes_per_hop[k[-1]], + num_sampled_edges_per_hop[k], + ) for k, v in edge_index.items() } if edge_attr is not None: @@ -68,9 +72,12 @@ def trim_to_layer( x = trim_feat(x, layer, num_sampled_nodes_per_hop) edge_index = trim_adj( - edge_index, layer, - (num_sampled_nodes_per_hop, num_sampled_nodes_per_hop), - num_sampled_edges_per_hop) + edge_index, + layer, + num_sampled_nodes_per_hop, + num_sampled_nodes_per_hop, + num_sampled_edges_per_hop, + ) if edge_attr is not None: edge_attr = trim_feat(edge_attr, layer, num_sampled_edges_per_hop) @@ -128,7 +135,8 @@ def trim_feat(x: Tensor, layer: int, num_samples_per_hop: List[int]) -> Tensor: def trim_adj( edge_index: Adj, layer: int, - num_sampled_nodes_per_hop: Tuple[List[int]], + num_sampled_src_nodes_per_hop: List[int], + num_sampled_dst_nodes_per_hop: List[int], num_sampled_edges_per_hop: List[int], ) -> Adj: @@ -143,18 +151,19 @@ def trim_adj( ) elif isinstance(edge_index, SparseTensor): - num_nodes = (edge_index.size(0) - num_sampled_nodes_per_hop[1][-layer], - edge_index.size(1) - num_sampled_nodes_per_hop[0][-layer]) + size = ( + edge_index.size(0) - num_sampled_dst_nodes_per_hop[-layer], + edge_index.size(1) - num_sampled_src_nodes_per_hop[-layer], + ) - num_seed_nodes = num_nodes[0] - num_sampled_nodes_per_hop[1][-(layer + - 1)] + num_seed_nodes = size[0] - num_sampled_dst_nodes_per_hop[-(layer + 1)] - return trim_sparse_tensor(edge_index, num_nodes, num_seed_nodes) + return trim_sparse_tensor(edge_index, size, num_seed_nodes) raise ValueError(f"Unsupported 'edge_index' type '{type(edge_index)}'") -def trim_sparse_tensor(src: SparseTensor, num_nodes: tuple, +def trim_sparse_tensor(src: SparseTensor, size: Tuple[int, int], num_seed_nodes: None) -> SparseTensor: r"""Trims a :class:`SparseTensor` along both dimensions to only contain the upper :obj:`num_nodes` in both dimensions. @@ -164,13 +173,14 @@ def trim_sparse_tensor(src: SparseTensor, num_nodes: tuple, Args: src (SparseTensor): The sparse tensor. - num_nodes (tuple): The number of first nodes to keep. + size (Tuple[int, int]): The number of source and destination nodes to + keep. num_seed_nodes (int): The number of seed nodes to compute representations. """ rowptr, col, value = src.csr() - rowptr = torch.narrow(rowptr, 0, 0, num_nodes[0] + 1).clone() + rowptr = torch.narrow(rowptr, 0, 0, size[0] + 1).clone() rowptr[num_seed_nodes + 1:] = rowptr[num_seed_nodes] col = torch.narrow(col, 0, 0, rowptr[-1]) @@ -187,7 +197,7 @@ def trim_sparse_tensor(src: SparseTensor, num_nodes: tuple, rowptr=rowptr, col=col, value=value, - sparse_sizes=num_nodes, + sparse_sizes=size, rowcount=None, colptr=None, colcount=None, From 323f3854b8793355e304b95568d28cebe3e418ee Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 5 Jun 2023 13:18:19 +0000 Subject: [PATCH 4/5] update --- torch_geometric/nn/conv/hetero_conv.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_geometric/nn/conv/hetero_conv.py b/torch_geometric/nn/conv/hetero_conv.py index b86e0cba7f11..e1d4d2a9d26e 100644 --- a/torch_geometric/nn/conv/hetero_conv.py +++ b/torch_geometric/nn/conv/hetero_conv.py @@ -146,6 +146,7 @@ def forward( value_dict.get(dst, None)) conv = self.convs[str_edge_type] + if src == dst: out = conv(x_dict[src], edge_index, *args, **kwargs) else: From ce837bcdb297aba91361fcb0e2a3b06f58f03678 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 5 Jun 2023 13:24:50 +0000 Subject: [PATCH 5/5] update --- test/utils/test_trim_to_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils/test_trim_to_layer.py b/test/utils/test_trim_to_layer.py index 3260a133c6e1..5ee789feb7c8 100644 --- a/test/utils/test_trim_to_layer.py +++ b/test/utils/test_trim_to_layer.py @@ -18,7 +18,7 @@ def test_trim_sparse_tensor(): edge_index = torch.tensor([[0, 0, 1, 2], [1, 2, 3, 4]]) adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[5, 5]) - adj = trim_sparse_tensor(adj, num_nodes=3, num_seed_nodes=1) + adj = trim_sparse_tensor(adj, size=(3, 3), num_seed_nodes=1) row, col, _ = adj.coo() assert row.tolist() == [0, 0]