Skip to content

Commit

Permalink
Slicing with multi-input modules between start_module and end_module …
Browse files Browse the repository at this point in the history
…now possible
  • Loading branch information
alexander-g committed Dec 4, 2020
1 parent 2f18a9a commit a75c7e3
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 45 deletions.
161 changes: 134 additions & 27 deletions elegy/module_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def slice_module_from_to(
end_module: tp.Union[Module, str, None, tp.List[tp.Union[Module, str, None]]],
sample_input: np.ndarray,
) -> Module:
"""Creates a new submodule starting from the input of 'start_module' to the outputs of 'end_module'.
"""Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`.
Current limitations:
- only one input module is supported
- all operations between start_module and end_module must be performed by modules
i.e. jax.nn.relu() or x+1 is not allowed but can be converted by wrapping with elegy.to_module()
- all modules between start_module and end_module must have a single input and a single output
- only one `start_module` is supported
- all operations between `start_module` and `end_module` must be performed by modules
i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()`
- all modules between `start_module` and `end_module` must have a single output
"""
assert not isinstance(
start_module, (tp.Tuple, tp.List)
Expand All @@ -39,8 +39,8 @@ def slice_module_from_to(
end_ids = [get_output_id(edges, m) for m in end_module]

graph = construct_graph(edges)
paths = [find_path(graph, start_id, end_id) for end_id in end_ids]
tree = combine_paths(paths)
dag_paths = [find_dag_path(graph, start_id, end_id) for end_id in end_ids]
tree = combine_paths(dag_paths) #not really a tree
submodule = SlicedModule(tree)
return submodule

Expand Down Expand Up @@ -99,6 +99,17 @@ def merge_args_kwargs(*args, **kwargs) -> tp.List[tp.Tuple[tp.Any, tp.Any]]:
e.g. merge_args_kwargs(0, 77, a=-2) returns [(0,0), (1,77), ('a',-2)]"""
return list(enumerate(args)) + list(kwargs.items())

def split_merged_args_kwargs(args_kwargs: tp.List[tp.Tuple[tp.Any, tp.Any]]) -> tp.Tuple[tp.Tuple, tp.Dict]:
'''Reverse operation of merge_args_kwargs().
e.g. split_merged_args_kwargs([(0,0), (1,77), ('a':-2)]) -> (0,77), {'a':-2}'''
args,kwargs = list(), dict()
for key,value in args_kwargs:
if isinstance(key, int):
args.append(value)
else:
kwargs[key]=value
return tuple(args), kwargs


def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph:
"""Constructs a directed graph with IDs of input/output arrays representing the nodes
Expand Down Expand Up @@ -126,27 +137,97 @@ def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph:
return G


def find_path(graph: nx.DiGraph, start_node: int, end_node: int) -> nx.DiGraph:
"""Returns a new graph with only nodes and edges from start_node to end_node"""

def are_paths_computationally_equivalent(path0: nx.DiGraph, path1: nx.DiGraph) -> bool:
'''Checks two paths for computaional equivalence i.e. whether or not they differ only in depth of modules.
E.g if node B is computed by a module composed of several submodules with subnodes B0 and B1
then paths A->B->C and A->B0->B1->B->C are computationally equivalent.
On the other hand, this does not apply to branches A->B->C vs A->D->C.
Importantly, the edge["inkey"] attributes must be the same:
A->C != A->B->C if C if computed by a dual-input module (e.g. C = A+B)'''
#traverse both paths and check if nodes path0 are in path1 or vice versa
#get nodes from both paths, make sure they are ordered
#skip the first one assuming both have the same source node
nodes0 = list(nx.dfs_postorder_nodes(path0))[::-1][1:]
nodes1 = list(nx.dfs_postorder_nodes(path1))[::-1][1:]
while len(nodes0) and len(nodes1):
#currently traversed nodes from both paths
n0, n1 = nodes0[0], nodes1[0]

if n0 in nodes1:
#current node of path0 is in path1, still need to check 'inkey'
inkey0 = path0.get_edge_data(*list(path0.in_edges(n0))[0])['inkey']
inkey1 = path1.get_edge_data(*list(path1.in_edges(n0))[0])['inkey']
if inkey0 == inkey1:
#all ok, continue traversing paths
nodes1 = nodes1[nodes1.index(n0)+1:]
nodes0 = nodes0[1:]
continue
else:
#inkey is not the same, must be a multi-input module -> reject
return False
elif n1 in nodes0:
#current node of path1 is in path0, still need to check 'inkey'
inkey0 = path0.get_edge_data(*list(path0.in_edges(n1))[0])['inkey']
inkey1 = path1.get_edge_data(*list(path1.in_edges(n1))[0])['inkey']
if inkey0 == inkey1:
#all ok, continue traversing paths
nodes0 = nodes0[nodes0.index(n1)+1:]
nodes1 = nodes1[1:]
continue
else:
#inkey is not the same, must be a multi-input module -> reject
return False
else:
#neither path contains the current node of the other path -> reject
return False
if len(nodes0)>0 or len(nodes1)>0:
#should not happen because our paths have the same first and last nodes
return False
#traversed both paths until the end
return True


def filter_computationally_equivalent_paths(paths: tp.List[nx.DiGraph]) -> tp.List[nx.DiGraph]:
'''Removes paths with deep modules if there are paths with equivalent, shallow modules.
E.g: remove A->B0->B1->B->C in favor of A->B->C'''
filtered = set() #contains indices of paths to be removed
for i,j in itertools.combinations(range(len(paths)), 2):
if i in filtered or j in filtered:
continue
if are_paths_computationally_equivalent(paths[i], paths[j]):
#keep the shorter path
if len(paths[i]) > len(paths[j]):
filtered.add(i)
else:
filtered.add(j)
paths = [paths[i] for i in range(len(paths)) if i not in filtered]
return paths


def find_dag_path(graph: nx.DiGraph, start_node: int, end_node: int) -> nx.DiGraph:
"""Returns a new (possibly multi-path) graph with only nodes and edges from start_node to end_node"""
startname = list(graph[start_node].values())[0]["modulename"]
endname = list(graph.reverse()[end_node].values())[0]["modulename"]

try:
pathnodes = nx.shortest_path(graph, start_node, end_node)
edge_paths = list(nx.all_simple_edge_paths(graph, start_node, end_node)) #list of lists of tuples
if len(edge_paths)==0:
raise nx.NetworkXNoPath
except nx.NetworkXNoPath:
raise RuntimeError(
f"No path from {startname} to {endname}. Make sure all operations inbetween are performed by modules."
) from None
if len(pathnodes) < 2:
raise RuntimeError(
f"No operations between the input of {startname} and the output of {endname}."
) from None
pathgraph = graph.subgraph(pathnodes).copy()
# pathgraph is unordered, need to mark input and output edges
pathgraph[pathnodes[0]][pathnodes[1]]["is_input"] = True
pathgraph[pathnodes[-2]][pathnodes[-1]]["is_output"] = True
return pathgraph

graph_paths = [nx.edge_subgraph(graph, path) for path in edge_paths] #list of nx.DiGraphs
graph_paths = filter_computationally_equivalent_paths(graph_paths)
dag_graph = nx.algorithms.compose_all(graph_paths)
#dag_graph is unordered, need to mark input and output edges
for _,_, edgedata in dag_graph.out_edges(start_node, data=True):
edgedata['is_input'] = True
for _,_, edgedata in dag_graph.in_edges(end_node, data=True):
edgedata['is_output'] = True
return dag_graph


def combine_paths(paths: tp.List[nx.DiGraph]) -> nx.DiGraph:
Expand All @@ -172,37 +253,63 @@ def call(self, x: tp.Any) -> tp.Union[tp.Any, tp.Tuple[tp.Any]]:
for nodes, edge in self._tree.edges.items()
if edge.get("is_input", False)
]

# should not happen
assert len(set(input_nodes)) > 0, "could not find any input nodes"
assert len(set(input_nodes)) < 2, "multi-inputs not yet supported"
start_node = input_nodes[0]

outputs = self.visit_node(start_node, x)
outputs = self.visit_node(start_node, x, deferred_call_args=dict())

outputs = tuple(outputs)
if len(outputs) == 1:
outputs = outputs[0]
return outputs

def visit_edge(self, edge: tp.Dict, x: tp.Any) -> tp.Any:
def visit_edge(self, edge: tp.Dict, x: tp.Any, deferred_call_args: tp.Dict) -> tp.Any:
"""Performs the operation to get from node A to node B which the parameter "edge" connects"""
assert edge["inkey"] == 0, "inputs other than 0 not yet implemented"

x = edge["module"](x)
n_inputs = len(jax.tree_leaves(edge['input_ids']))
if n_inputs==1:
#a single-input module, simply call it with the input
x = edge["module"](x)
else:
#multi-input module
#check if all the inputs are ready
call_args = deferred_call_args.get(edge['modulename'], dict())
call_args[edge['inkey']] = x
if len(call_args) == n_inputs:
#all inputs are ready, call module
args, kwargs = split_merged_args_kwargs(call_args.items())
x = edge['module'](*args, **kwargs)
del deferred_call_args[edge['modulename']]
else:
#still missing some inputs, continue traversing the graph
deferred_call_args[edge['modulename']] = call_args
return DeferredCall

if isinstance(x, (tuple, list)):
# XXX: what if the whole tuple/list is needed as input later?
x = x[edge["outkey"]]

return x

def visit_node(self, node: int, x: tp.Any) -> tp.List[tp.Any]:
def visit_node(self, node: int, x: tp.Any, deferred_call_args: tp.Dict) -> tp.List[tp.Any]:
"""Recursively visits all nodes starting from the parameter "node" and collects outputs."""
outputs = []
for nextnode, edge in self._tree[node].items():
y = self.visit_edge(edge, x)
y = self.visit_edge(edge, x, deferred_call_args)
if y==DeferredCall:
#visited edge module is missing some inputs, will come back here later
continue
if edge.get("is_output", False):
outputs.append(y)
outputs.extend(self.visit_node(nextnode, y))
outputs.extend(self.visit_node(nextnode, y, deferred_call_args))

return outputs


class DeferredCall:
'''Dummy class that indicates that a call has to be deferred'''
...


115 changes: 97 additions & 18 deletions elegy/module_slicing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,88 @@ def test_retrain(self):
def test_no_path(self):
x = jnp.ones((32, 100))
basicmodule = BasicModule0()
try:
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, "linear2", "linear0", x
)
except RuntimeError as e:
assert e.args[0].startswith("No path from /linear2 to /linear0")
else:
assert False, "No error or wrong error raised"
for start_module in ['linear2', 'linear1']:
try:
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, start_module, "linear0", x
)
except RuntimeError as e:
assert e.args[0].startswith(f"No path from /{start_module} to /linear0")
else:
assert False, "No error or wrong error raised"

def test_multi_input_modules(self):
x = jnp.ones((32, 100))

module = ContainsMultiInputModule()
model = elegy.Model(module)
model.summary(x)

submodule = elegy.module_slicing.slice_module_from_to(module, None, '/multi_input_module', x)
submodel = elegy.Model(submodule)
submodel.summary(x)
print(submodule.get_parameters())

y = submodel.predict(x)
print(y.shape)
assert(y.shape==(32,25))
assert(jnp.allclose(y, module.test_call(x) ))


def test_computationally_equivalent_paths(self):
import networkx as nx
G = nx.DiGraph()
G.add_edge(0,1, inkey=0)
G.add_edge(1,2, inkey=0)
G.add_edge(0,2, inkey=0) #0->2 is equivalent to the path 0->1->2
G.add_edge(2,3, inkey=0)
G.add_edge(3,4, inkey=0)

g0 = G.edge_subgraph([(0,1), (1,2), (2,3)]).copy()
g1 = G.edge_subgraph([(0,2), (2,3)]).copy()

apce = elegy.module_slicing.are_paths_computationally_equivalent
fcep = elegy.module_slicing.filter_computationally_equivalent_paths

assert apce(g0,g1)
assert apce(g1,g0)
filtered_paths = fcep([g0,g1])
assert len(filtered_paths) == 1
assert filtered_paths[0] == g1

G = nx.DiGraph()
G.add_edge(0,1, inkey=0)
G.add_edge(1,2, inkey=0)
G.add_edge(0,2, inkey=1) #not equivalent, multi-input module
G.add_edge(2,3, inkey=0)
G.add_edge(3,4, inkey=0)

g0 = G.edge_subgraph([(0,1), (1,2), (2,3)]).copy()
g1 = G.edge_subgraph([(0,2), (2,3)]).copy()
g2 = G.edge_subgraph([(0,2), (2,3), (3,4)]).copy()

apce = elegy.module_slicing.are_paths_computationally_equivalent
assert not apce(g0,g1)
assert not apce(g1,g0)
assert not apce(g1,g2)
filtered_paths = fcep([g0,g1,g2])
assert len(filtered_paths) == 3
assert g0 in filtered_paths and g1 in filtered_paths and g2 in filtered_paths



def test_split_merge_args_kwargs(self):
args_kwargs = elegy.module_slicing.merge_args_kwargs(0,101,-2,a=65,b=77)
assert len(args_kwargs)==5
for x in [(0,0), (1,101), (2,-2), ('a',65), ('b',77)]:
assert x in args_kwargs

args,kwargs = elegy.module_slicing.split_merged_args_kwargs(args_kwargs)
assert args==(0,101,-2)
assert len(kwargs)==2
assert kwargs['a']==65 and kwargs['b']==77


try:
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, "linear1", "linear0", x
)
except RuntimeError as e:
assert e.args[0].startswith(
"No operations between the input of /linear1 and the output of /linear0"
)
else:
assert False, "No error or wrong error raised"


class BasicModule0(elegy.Module):
Expand All @@ -121,3 +184,19 @@ def test_call(self, x):
x = self.linear0(x)
x = self.linear1(x)
return x

class MultiInputModule(elegy.Module):
def call(self, x0, x1):
return x0[...,:25]+x1[...,:25]

class ContainsMultiInputModule(elegy.Module):
def call(self, x):
x0 = elegy.nn.Linear(25, name='linear0')(x)
x = MultiInputModule(name='multi_input_module')(x,x0)
x = elegy.nn.Linear(10)(x)
return x

def test_call(self, x):
x0 = self.linear0(x)
x = self.multi_input_module(x, x0)
return x

0 comments on commit a75c7e3

Please sign in to comment.