Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter out empty tensors inside trim_to_layer #7942

Merged
merged 14 commits into from
Sep 11, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed the `trim_to_layer` function to filter out non-reachable node and edge types when operating on heterogeneous graphs ([#7942](https://github.com/pyg-team/pytorch_geometric/pull/7942))
- Accelerated and simplified `top_k` computation in `TopKPooling` ([#7737](https://github.com/pyg-team/pytorch_geometric/pull/7737))
- Updated `GIN` implementation in kernel benchmarks to have sequential batchnorms ([#7955](https://github.com/pyg-team/pytorch_geometric/pull/7955))
- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956)
Expand Down
36 changes: 36 additions & 0 deletions test/utils/test_trim_to_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,39 @@ def test_trim_to_layer_with_neighbor_loader():
assert out2.size() == (2, 16)

assert torch.allclose(out1, out2)


def test_trim_to_layer_filtering():
x_dict = {
'paper': torch.rand((13, 128)),
'author': torch.rand((5, 128)),
'field_of_study': torch.rand((6, 128))
}
edge_index_dict = {
('author', 'writes', 'paper'):
torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 2, 2]]),
('paper', 'has_topic', 'field_of_study'):
torch.tensor([[6, 7, 8, 9], [0, 0, 1, 1]])
}
num_sampled_nodes_dict = {
'paper': [1, 2, 10],
'author': [0, 2, 3],
'field_of_study': [0, 2, 4]
}
num_sampled_edges_dict = {
('author', 'writes', 'paper'): [2, 3],
('paper', 'has_topic', 'field_of_study'): [0, 4]
}
x_dict, edge_index_dict, _ = trim_to_layer(
layer=1,
num_sampled_nodes_per_hop=num_sampled_nodes_dict,
num_sampled_edges_per_hop=num_sampled_edges_dict,
x=x_dict,
edge_index=edge_index_dict,
)
assert list(edge_index_dict.keys()) == [('author', 'writes', 'paper')]
assert torch.equal(edge_index_dict[('author', 'writes', 'paper')],
torch.tensor([[0, 1], [0, 0]]))
assert x_dict['paper'].size() == (3, 128)
assert x_dict['author'].size() == (2, 128)
assert x_dict['field_of_study'].size() == (2, 128)
20 changes: 19 additions & 1 deletion torch_geometric/utils/trim_to_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, List, Optional, Tuple, Union
import copy
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -14,6 +15,17 @@
)


def filter_empty_entries(
input_dict: Dict[Union[Any], Tensor]) -> Dict[Any, Tensor]:
r"""Removes empty tensors from a dictionary. This avoids unnecessary
computation when some node/edge types are non-reachable after trimming."""
out_dict = copy.copy(input_dict)
for key, value in input_dict.items():
if value.numel() == 0:
del out_dict[key]
return out_dict


def trim_to_layer(
layer: int,
num_sampled_nodes_per_hop: Union[List[int], Dict[NodeType, List[int]]],
Expand Down Expand Up @@ -53,6 +65,8 @@ def trim_to_layer(
k: trim_feat(v, layer, num_sampled_nodes_per_hop[k])
for k, v in x.items()
}
x = filter_empty_entries(x)

edge_index = {
k:
trim_adj(
Expand All @@ -64,11 +78,15 @@ def trim_to_layer(
)
for k, v in edge_index.items()
}
edge_index = filter_empty_entries(edge_index)

if edge_attr is not None:
edge_attr = {
k: trim_feat(v, layer, num_sampled_edges_per_hop[k])
for k, v in edge_attr.items()
}
edge_attr = filter_empty_entries(edge_attr)

return x, edge_index, edge_attr

x = trim_feat(x, layer, num_sampled_nodes_per_hop)
Expand Down
Loading