-
Notifications
You must be signed in to change notification settings - Fork 5
/
connectivity_utils.py
120 lines (96 loc) · 4.12 KB
/
connectivity_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Lint as: python3
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tools to compute the connectivity of the graph."""
import functools
import numpy as np
from sklearn import neighbors
import tensorflow.compat.v1 as tf
def _compute_connectivity(positions, radius, add_self_edges):
"""Get the indices of connected edges with radius connectivity.
Args:
positions: Positions of nodes in the graph. Shape:
[num_nodes_in_graph, num_dims].
radius: Radius of connectivity.
add_self_edges: Whether to include self edges or not.
Returns:
senders indices [num_edges_in_graph]
receiver indices [num_edges_in_graph]
"""
tree = neighbors.KDTree(positions)
receivers_list = tree.query_radius(positions, r=radius)
num_nodes = len(positions)
senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
receivers = np.concatenate(receivers_list, axis=0)
if not add_self_edges:
# Remove self edges.
mask = senders != receivers
senders = senders[mask]
receivers = receivers[mask]
return senders, receivers
def _compute_connectivity_for_batch(
positions, n_node, radius, add_self_edges):
"""`compute_connectivity` for a batch of graphs.
Args:
positions: Positions of nodes in the batch of graphs. Shape:
[num_nodes_in_batch, num_dims].
n_node: Number of nodes for each graph in the batch. Shape:
[num_graphs in batch].
radius: Radius of connectivity.
add_self_edges: Whether to include self edges or not.
Returns:
senders indices [num_edges_in_batch]
receiver indices [num_edges_in_batch]
number of edges per graph [num_graphs_in_batch]
"""
# TODO(alvarosg): Consider if we want to support batches here or not.
# Separate the positions corresponding to particles in different graphs.
positions_per_graph_list = np.split(positions, np.cumsum(n_node[:-1]), axis=0)
receivers_list = []
senders_list = []
n_edge_list = []
num_nodes_in_previous_graphs = 0
# Compute connectivity for each graph in the batch.
for positions_graph_i in positions_per_graph_list:
senders_graph_i, receivers_graph_i = _compute_connectivity(
positions_graph_i, radius, add_self_edges)
num_edges_graph_i = len(senders_graph_i)
n_edge_list.append(num_edges_graph_i)
# Because the inputs will be concatenated, we need to add offsets to the
# sender and receiver indices according to the number of nodes in previous
# graphs in the same batch.
receivers_list.append(receivers_graph_i + num_nodes_in_previous_graphs)
senders_list.append(senders_graph_i + num_nodes_in_previous_graphs)
num_nodes_graph_i = len(positions_graph_i)
num_nodes_in_previous_graphs += num_nodes_graph_i
# Concatenate all of the results.
senders = np.concatenate(senders_list, axis=0).astype(np.int32)
receivers = np.concatenate(receivers_list, axis=0).astype(np.int32)
n_edge = np.stack(n_edge_list).astype(np.int32)
return senders, receivers, n_edge
def compute_connectivity_for_batch_pyfunc(
positions, n_node, radius, add_self_edges=True):
"""`_compute_connectivity_for_batch` wrapped in a pyfunc."""
partial_fn = functools.partial(
_compute_connectivity_for_batch, add_self_edges=add_self_edges)
senders, receivers, n_edge = tf.py_function(
partial_fn,
[positions, n_node, radius],
[tf.int32, tf.int32, tf.int32])
senders.set_shape([None])
receivers.set_shape([None])
n_edge.set_shape(n_node.get_shape())
return senders, receivers, n_edge