Skip to content

Commit

Permalink
[RELAY][Frontend][TF] decompile tf control flow (#2830)
Browse files Browse the repository at this point in the history
* decompile tf control flow

* Add docs

* remove import relay

* move tests under tensorflow frontend

* minor fix
  • Loading branch information
zhiics authored and yzhliu committed Mar 24, 2019
1 parent 8ef35dc commit 2df3364
Show file tree
Hide file tree
Showing 3 changed files with 630 additions and 12 deletions.
321 changes: 317 additions & 4 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging
import warnings
from collections import defaultdict
# Numpy support
import numpy as np

Expand Down Expand Up @@ -1270,6 +1271,220 @@ def _get_abs_layer_name(node):
params, num_layers)
return sym

# An internal list to contain all the control flow primitives used in Tensorflow
# 1.x.
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']

def _in_while_loop(control_flow_node_map, op_name):
"""
Check if a given control flow operator is part of a while loop execution
frame. This is based on the fact that there is only one occurrence of
`LoopCond` for a loop execution frame and it is only presented in the loop
construct.
Parameters
----------
control_flow_node_map : Dict[str, Set[str]]
A dictionay contains the unqiue control flow execution frame name to
a set of primitive operators mapping.
op_name : str
The name of a control flow primitive.
Returns
-------
ret : bool
Return true if the operator is in a while loop execution frame,
otherwise, return false.
"""
return op_name in control_flow_node_map and \
"LoopCond" in control_flow_node_map[op_name]


class Branch:
"""A class contains the components that are used to build up a Relay if
node.
Parameters
----------
cond : tvm.relay.Expr
The condition of a if node.
true_branch : tvm.relay.Expr
The body of the true branch of a if expression.
false_branch: tvm.relay.Expr
The body of the false branch of a if expression.
_if : tvm.relay.Expr
An internal variable indicates where an if expression is already created
for a matched TF condition construct.
Examples
--------
The following is a cond statement written in TensorFlow:
.. code-block:: python
def vanilla_cond():
i = tf.constant(1)
j = tf.constant(4)
def f1():
return tf.multiply(1, 17)
def f2():
return tf.add(4, 23)
r = tf.cond(tf.less(i, j), f1, f2)
This condition statement should be coverted into Relay in the following
form:
.. code-block:: python
fn (%Const: Tensor[(1,), int32],
%Const_1: Tensor[(1,), int32],
%cond/Mul/x: Tensor[(1,), int32],
%cond/Mul/y: Tensor[(1,), int32],
%cond/Add/x: Tensor[(1,), int32],
%cond/Add/y: Tensor[(1,), int32]) {
%0 = less(%Const, %Const_1) # ty=Tensor[(1,), bool]
%1 = min(%0)
if (%1) {
%2 = multiply(%cond/Mul/x, %cond/Mul/y)
%2
} else {
%3 = add(%cond/Add/x, %cond/Add/y)
%3
}
}
"""
def __init__(self):
self._if = None
self.cond = None
self.true_branch = None
self.false_branch = None

def _if_node(self):
"""An internal API to create a relay if node from the matched TF
condition construct.
"""
# `cond` returns a tensor that contains boolean values. We add a `min`
# operator to checks if there is any false value. If so, this condition
# doesn't not hold.
cond = tvm.relay.op.min(self.cond)
return tvm.relay.If(cond, self.true_branch, self.false_branch)

def if_node(self):
"""Create an tvm.relay.If node if it hasn't been created yet."""
if self._if is None:
self._if = self._if_node()
return self._if


class Loop:
"""
A class contains the components that are used to build up a Relay
recursive call.
Parameters
----------
loop_vars : List[tvm.relay.Expr]
The loop variables that used in a while loop.
cond : tvm.relay.Expr
The condition of a while loop.
body : tvm.relay.Expr
The body of a matched while loop.
_loop : tvm.relay.Expr
An internal variable indicates where a recursive call is already created
for a matched TF while loop construct.
Examples
--------
The following is a vanilla loop from TensorFlow:
.. code-block:: python
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
It will be converted to the following recursive call in Relay:
.. code-block:: python
fn (%while/Less/y: Tensor[(1,), int32],
%while/Add/y: Tensor[(1,), int32],
%Const: Tensor[(1,), int32]) {
%0 = fn(%loop_var0: Tensor[(1,), int32]) {
%1 = less(%loop_var0, %while/Less/y)
%2 = min(%1)
if (%2) {
%3 = add(%loop_var0, %while/Add/y)
free_var %while_loop
%4 = %while_loop(%3)
%4
} else {
%5 = (%loop_var0,)
%5
}
}
let %while_loop1 = %0
%6 = %while_loop1(%Const)
%6
}
"""
def __init__(self):
self.loop_vars = []
self.cond = None
self.body = []
self._loop = None

def _while_loop(self):
"""An internal API to create a Relay recurisve call for a matched TF
`while_loop` construct.
"""
wl = tvm.relay.var('while_loop')

sb = tvm.relay.scope_builder.ScopeBuilder()

loop_vars = []
bind_map = {}
for i, var in enumerate(self.loop_vars):
assert isinstance(var, _expr.Var), repr(var)
v = tvm.relay.var("loop_var" + str(i),
type_annotation=var.type_annotation)
loop_vars.append(v)
bind_map[var] = v

self.cond = tvm.relay.bind(self.cond, bind_map)
self.body = [tvm.relay.bind(b, bind_map) for b in self.body]

cond = tvm.relay.op.min(self.cond)

with sb.if_scope(cond):
sb.ret(wl(*self.body))
with sb.else_scope():
sb.ret(tvm.relay.Tuple(loop_vars))

loop_fn = tvm.relay.Function(loop_vars, sb.get())
sb = tvm.relay.scope_builder.ScopeBuilder()
sb.let(wl, loop_fn)
sb.ret(wl(*self.loop_vars))
return sb.get()

def while_loop(self):
"""Instantiate a while loop if it has not been created yet."""
if self._loop is None:
self._loop = self._while_loop()
return self._loop
return self._loop


class GraphProto(object):
""" A helper class for handling relay graph copying from Tensorflow GraphDef.
Definition:
Expand All @@ -1284,6 +1499,8 @@ def __init__(self):
self._num_rnn_layer = False
self._outputs_are_0d = {}
self._input_shapes = {}
self._loops = {}
self._branches = {}

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -1332,7 +1549,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

control_flow_node_map = defaultdict(set)
for node in graph.node:
node_name_prefix = node.name.rsplit('/', 1)[0]
control_flow_node_map[node_name_prefix].add(node.op)
if node.op == 'Placeholder':
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
Expand Down Expand Up @@ -1447,12 +1667,17 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym)
input_0d_mismatch.add(in_sym[0])

attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch

op = self._convert_operator(node.op, inputs, attr, graph)
if node.op in _control_flow_nodes:
op = self._convert_control_flow_operator(node, inputs,
attr,
control_flow_node_map)
else:
op = self._convert_operator(node.op, inputs, attr, graph)

# Check if op is converted to param
if isinstance(op, np.ndarray):
Expand Down Expand Up @@ -1493,7 +1718,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

out = []
if outputs is None:
out = op
if node.op == "Exit":
out = [op[0].tuple_value]
else:
out = op
else:
for out_name in outputs:
if ":" in out_name:
Expand Down Expand Up @@ -1529,7 +1757,9 @@ def _parse_import_prerequisites(self, graph):
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
if any([node.op in t for t in [_identity_list, _convert_map,
_convert_map_rnn,
_control_flow_nodes]]):
pass
else:
missing_operators.add(node.op)
Expand Down Expand Up @@ -1656,6 +1886,89 @@ def _convert_rnn_operator(self, op_name, inputs,
sym = self.rnn.process_op(op_name, inputs, attrs, params)
return sym

def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map):
"""
Convert the Relay control flow primitive into corresponding component
of a Relay control flow construct, i.e. `tf.cond` and `tf.while_loop`
are converted in Relay `If` and recusrive call, respectively.
Parameters
----------
node: TensorFlow graph node object.
A TensorFlow graph node object.
inputs : List[tvm.relay.Expr]
List of input symbols.
attrs : Dict[tvm.Attrs]
Dict of operator attributes.
control_flow_node_map : Dict[str, Set[str]]
A dictionary contains the execution frame name to primitives
mapping.
Returns
-------
op : tvm.relay.Expr
Converted relay expression.
"""
node_name_prefix = node.name.rsplit('/', 1)[0]
if node.op == "Merge":
if _in_while_loop(control_flow_node_map, node_name_prefix):
op = self._nodes[node.input[0]]
self._loops[node_name_prefix] = Loop()
else:
if len(self._branches) == 0:
raise RuntimeError("Cannot find a created "
"conditional for merge node")
branch = self._branches[node_name_prefix]
false_br = self._nodes[node.input[0]]
true_br = self._nodes[node.input[1]]
assert len(true_br) == 1
assert len(false_br) == 1
branch.true_branch = true_br[0]
branch.false_branch = false_br[0]
op = [branch.if_node()]
elif node.op == "Exit":
loop = self._loops[node_name_prefix]
exit_name = node.name.split('/')[-1]
assert str.startswith(exit_name, 'Exit')

# TensorFlow has differen naming convention on different
# versions.
if '_' in exit_name:
exit_number = int("0" + exit_name[5:])
else:
exit_number = int("0" + exit_name[4:])

expr = loop.while_loop()
op = _expr.TupleGetItem(expr, exit_number)
elif node.op == "Enter":
op = self._nodes[node.input[0]]
elif node.op == "LoopCond":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].cond = op[0]
elif node.op == "Switch":
op = self._nodes[node.input[0]]
assert len(op) == 1
if _in_while_loop(control_flow_node_map, node_name_prefix):
self._loops[node_name_prefix].loop_vars.append(op[0])
else:
if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch()
self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0])
elif node.op == "NextIteration":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].body.append(op[0])
else:
raise Exception("Cannot identify control flow operator: " +
"{}".format(node.op))

return op


def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to relay operator.
Expand Down
Loading

0 comments on commit 2df3364

Please sign in to comment.