Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable trim_to_layer with hetero CSR flow #7514

Merged
merged 6 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/utils/test_trim_to_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_trim_sparse_tensor():
edge_index = torch.tensor([[0, 0, 1, 2], [1, 2, 3, 4]])
adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[5, 5])

adj = trim_sparse_tensor(adj, num_nodes=3, num_seed_nodes=1)
adj = trim_sparse_tensor(adj, size=(3, 3), num_seed_nodes=1)

row, col, _ = adj.coo()
assert row.tolist() == [0, 0]
Expand Down
41 changes: 29 additions & 12 deletions torch_geometric/utils/trim_to_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,13 @@ def trim_to_layer(
for k, v in x.items()
}
edge_index = {
k: trim_adj(v, layer, num_sampled_nodes_per_hop[k[-1]],
num_sampled_edges_per_hop[k])
k: trim_adj(
v,
layer,
num_sampled_nodes_per_hop[k[0]],
num_sampled_nodes_per_hop[k[-1]],
num_sampled_edges_per_hop[k],
)
for k, v in edge_index.items()
}
if edge_attr is not None:
Expand All @@ -66,8 +71,13 @@ def trim_to_layer(
return x, edge_index, edge_attr

x = trim_feat(x, layer, num_sampled_nodes_per_hop)
edge_index = trim_adj(edge_index, layer, num_sampled_nodes_per_hop,
num_sampled_edges_per_hop)
edge_index = trim_adj(
edge_index,
layer,
num_sampled_nodes_per_hop,
num_sampled_nodes_per_hop,
num_sampled_edges_per_hop,
)

if edge_attr is not None:
edge_attr = trim_feat(edge_attr, layer, num_sampled_edges_per_hop)
Expand Down Expand Up @@ -125,7 +135,8 @@ def trim_feat(x: Tensor, layer: int, num_samples_per_hop: List[int]) -> Tensor:
def trim_adj(
edge_index: Adj,
layer: int,
num_sampled_nodes_per_hop: List[int],
num_sampled_src_nodes_per_hop: List[int],
num_sampled_dst_nodes_per_hop: List[int],
num_sampled_edges_per_hop: List[int],
) -> Adj:

Expand All @@ -140,14 +151,19 @@ def trim_adj(
)

elif isinstance(edge_index, SparseTensor):
num_nodes = edge_index.size(0) - num_sampled_nodes_per_hop[-layer]
num_seed_nodes = num_nodes - num_sampled_nodes_per_hop[-(layer + 1)]
return trim_sparse_tensor(edge_index, num_nodes, num_seed_nodes)
size = (
edge_index.size(0) - num_sampled_dst_nodes_per_hop[-layer],
edge_index.size(1) - num_sampled_src_nodes_per_hop[-layer],
)

num_seed_nodes = size[0] - num_sampled_dst_nodes_per_hop[-(layer + 1)]

return trim_sparse_tensor(edge_index, size, num_seed_nodes)

raise ValueError(f"Unsupported 'edge_index' type '{type(edge_index)}'")


def trim_sparse_tensor(src: SparseTensor, num_nodes: int,
def trim_sparse_tensor(src: SparseTensor, size: Tuple[int, int],
num_seed_nodes: None) -> SparseTensor:
r"""Trims a :class:`SparseTensor` along both dimensions to only contain
the upper :obj:`num_nodes` in both dimensions.
Expand All @@ -157,13 +173,14 @@ def trim_sparse_tensor(src: SparseTensor, num_nodes: int,

Args:
src (SparseTensor): The sparse tensor.
num_nodes (int): The number of first nodes to keep.
size (Tuple[int, int]): The number of source and destination nodes to
keep.
num_seed_nodes (int): The number of seed nodes to compute
representations.
"""
rowptr, col, value = src.csr()

rowptr = torch.narrow(rowptr, 0, 0, num_nodes + 1).clone()
rowptr = torch.narrow(rowptr, 0, 0, size[0] + 1).clone()
rowptr[num_seed_nodes + 1:] = rowptr[num_seed_nodes]

col = torch.narrow(col, 0, 0, rowptr[-1])
Expand All @@ -180,7 +197,7 @@ def trim_sparse_tensor(src: SparseTensor, num_nodes: int,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=(num_nodes, num_nodes),
sparse_sizes=size,
rowcount=None,
colptr=None,
colcount=None,
Expand Down