Skip to content

Commit

Permalink
decompile tf control flow
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 19, 2019
1 parent 340678d commit 819f73c
Show file tree
Hide file tree
Showing 2 changed files with 452 additions and 3 deletions.
157 changes: 154 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging
import warnings
from collections import defaultdict
# Numpy support
import numpy as np

Expand Down Expand Up @@ -1270,6 +1271,100 @@ def _get_abs_layer_name(node):
params, num_layers)
return sym

_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']

class Branch:
"""A class contains the components that are used to build up a Relay if
node.
"""
def __init__(self):
self._if = None
self.cond_vars = set()
self.cond = None
self.true_branch = None
self.false_branch = None

def _if_node(self):
from tvm import relay

cond_vars = []
bind_map = {}
for i, var in enumerate(list(self.cond_vars)):
if not isinstance(var, _expr.Var):
raise TypeError("var is expected to be _expr.Var type, but "
"received {}".format(repr(var)))
v = relay.var("cond_var" + str(i),
type_annotation=var.type_annotation)
cond_vars.append(v)
bind_map[var] = v

self.cond = relay.bind(self.cond, bind_map)
cond = relay.op.min(self.cond)
self.true_branch = relay.bind(self.true_branch, bind_map)
self.false_branch = relay.bind(self.false_branch, bind_map)

return relay.If(cond, self.true_branch, self.false_branch)

def if_node(self):
"""Create a if node if it hasn't been created yet."""
if self._if is None:
self._if = self._if_node()
return self._if
return self._if


class Loop:
"""A class contains the components that are used to build up a Relay
recursive call.
"""
def __init__(self):
self.loop_vars = []
self.cond = None
self.body = []
self._loop = None

def _while_loop(self):
from tvm import relay
wl = relay.var('while_loop')
sb = relay.scope_builder.ScopeBuilder()

loop_vars = []
bind_map = {}
for i, var in enumerate(self.loop_vars):
assert isinstance(var, _expr.Var), repr(var)
v = relay.var("loop_var" + str(i),
type_annotation=var.type_annotation)
loop_vars.append(v)
bind_map[var] = v

self.cond = relay.bind(self.cond, bind_map)
self.body = [relay.bind(b, bind_map) for b in self.body]

cond = relay.op.min(self.cond)

with sb.if_scope(cond):
sb.ret(wl(*self.body))
with sb.else_scope():
sb.ret(relay.Tuple(loop_vars))

loop_fn = relay.Function(loop_vars, sb.get())
sb = relay.scope_builder.ScopeBuilder()
sb.let(wl, loop_fn)
sb.ret(wl(*self.loop_vars))
return sb.get()

def while_loop(self):
if self._loop is None:
self._loop = self._while_loop()
return self._loop
return self._loop


def _in_while_loop(control_flow_node_map, op_name):
return op_name in control_flow_node_map and \
"LoopCond" in control_flow_node_map[op_name]


class GraphProto(object):
""" A helper class for handling relay graph copying from Tensorflow GraphDef.
Definition:
Expand All @@ -1284,6 +1379,9 @@ def __init__(self):
self._num_rnn_layer = False
self._outputs_are_0d = {}
self._input_shapes = {}
self._loops = {}
self._branches = {}
# self.module = relay.Module({})

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -1332,7 +1430,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

control_flow_node_map = defaultdict(set)
for node in graph.node:
node_name_prefix = node.name.rsplit('/', 1)[0]
control_flow_node_map[node_name_prefix].add(node.op)
if node.op == 'Placeholder':
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
Expand Down Expand Up @@ -1451,8 +1552,53 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch
node_name_prefix = node.name.rsplit('/', 1)[0]

op = self._convert_operator(node.op, inputs, attr, graph)
if node.op == "Merge":
if _in_while_loop(control_flow_node_map, node_name_prefix):
op = self._nodes[node.input[0]]
self._loops[node_name_prefix] = Loop()
else:
if len(self._branches) == 0:
raise RuntimeError("Cannot find a created "
"conditional for merge node")
branch = self._branches[node_name_prefix]
false_br = self._nodes[node.input[0]]
true_br = self._nodes[node.input[1]]
assert len(true_br) == 1
assert len(false_br) == 1
branch.true_branch = true_br[0]
branch.false_branch = false_br[0]
op = [branch.if_node()]
# del self._branches[node_name_prefix]
elif node.op == "Exit":
loop = self._loops[node_name_prefix]
exit_name = node.name.split('/')[-1]
assert str.startswith(exit_name, 'Exit')
exit_number = int("0" + exit_name[4:])
expr = loop.while_loop()
op = _expr.TupleGetItem(expr, exit_number)
elif node.op == "Enter":
op = self._nodes[node.input[0]]
elif node.op == "LoopCond":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].cond = op[0]
elif node.op == "Switch":
op = self._nodes[node.input[0]]
assert len(op) == 1
if _in_while_loop(control_flow_node_map, node_name_prefix):
self._loops[node_name_prefix].loop_vars.append(op[0])
else:
if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch()
self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0])
elif node.op == "NextIteration":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].body.append(op[0])
else:
op = self._convert_operator(node.op, inputs, attr, graph)

# Check if op is converted to param
if isinstance(op, np.ndarray):
Expand Down Expand Up @@ -1493,7 +1639,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

out = []
if outputs is None:
out = op
if node.op == "Exit":
out = [op[0].tuple_value]
else:
out = op
else:
for out_name in outputs:
if ":" in out_name:
Expand Down Expand Up @@ -1529,7 +1678,9 @@ def _parse_import_prerequisites(self, graph):
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
if any([node.op in t for t in [_identity_list, _convert_map,
_convert_map_rnn,
_control_flow_nodes]]):
pass
else:
missing_operators.add(node.op)
Expand Down
Loading

0 comments on commit 819f73c

Please sign in to comment.