Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][Frontend][TF] decompile tf control flow #2830

Merged
merged 5 commits into from
Mar 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 317 additions & 4 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging
import warnings
from collections import defaultdict
# Numpy support
import numpy as np

Expand Down Expand Up @@ -1270,6 +1271,220 @@ def _get_abs_layer_name(node):
params, num_layers)
return sym

# An internal list to contain all the control flow primitives used in Tensorflow
# 1.x.
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']

def _in_while_loop(control_flow_node_map, op_name):
"""
Check if a given control flow operator is part of a while loop execution
frame. This is based on the fact that there is only one occurrence of
`LoopCond` for a loop execution frame and it is only presented in the loop
construct.
Parameters
----------
control_flow_node_map : Dict[str, Set[str]]
A dictionay contains the unqiue control flow execution frame name to
a set of primitive operators mapping.
op_name : str
The name of a control flow primitive.
Returns
-------
ret : bool
Return true if the operator is in a while loop execution frame,
otherwise, return false.
"""
return op_name in control_flow_node_map and \
"LoopCond" in control_flow_node_map[op_name]


class Branch:
zhiics marked this conversation as resolved.
Show resolved Hide resolved
"""A class contains the components that are used to build up a Relay if
node.
Parameters
----------
cond : tvm.relay.Expr
The condition of a if node.
true_branch : tvm.relay.Expr
The body of the true branch of a if expression.
false_branch: tvm.relay.Expr
The body of the false branch of a if expression.
_if : tvm.relay.Expr
An internal variable indicates where an if expression is already created
for a matched TF condition construct.
Examples
--------
The following is a cond statement written in TensorFlow:
.. code-block:: python
def vanilla_cond():
i = tf.constant(1)
j = tf.constant(4)
def f1():
return tf.multiply(1, 17)
def f2():
return tf.add(4, 23)
r = tf.cond(tf.less(i, j), f1, f2)
This condition statement should be coverted into Relay in the following
form:
.. code-block:: python
fn (%Const: Tensor[(1,), int32],
%Const_1: Tensor[(1,), int32],
%cond/Mul/x: Tensor[(1,), int32],
%cond/Mul/y: Tensor[(1,), int32],
%cond/Add/x: Tensor[(1,), int32],
%cond/Add/y: Tensor[(1,), int32]) {
%0 = less(%Const, %Const_1) # ty=Tensor[(1,), bool]
%1 = min(%0)
if (%1) {
%2 = multiply(%cond/Mul/x, %cond/Mul/y)
%2
} else {
%3 = add(%cond/Add/x, %cond/Add/y)
%3
}
}
"""
def __init__(self):
self._if = None
self.cond = None
self.true_branch = None
self.false_branch = None

def _if_node(self):
"""An internal API to create a relay if node from the matched TF
condition construct.
"""
# `cond` returns a tensor that contains boolean values. We add a `min`
# operator to checks if there is any false value. If so, this condition
# doesn't not hold.
cond = tvm.relay.op.min(self.cond)
return tvm.relay.If(cond, self.true_branch, self.false_branch)

def if_node(self):
"""Create an tvm.relay.If node if it hasn't been created yet."""
if self._if is None:
self._if = self._if_node()
return self._if


class Loop:
"""
A class contains the components that are used to build up a Relay
recursive call.
Parameters
----------
loop_vars : List[tvm.relay.Expr]
The loop variables that used in a while loop.
cond : tvm.relay.Expr
The condition of a while loop.
body : tvm.relay.Expr
The body of a matched while loop.
_loop : tvm.relay.Expr
An internal variable indicates where a recursive call is already created
for a matched TF while loop construct.
Examples
--------
The following is a vanilla loop from TensorFlow:
.. code-block:: python
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
It will be converted to the following recursive call in Relay:
.. code-block:: python
fn (%while/Less/y: Tensor[(1,), int32],
%while/Add/y: Tensor[(1,), int32],
%Const: Tensor[(1,), int32]) {
%0 = fn(%loop_var0: Tensor[(1,), int32]) {
%1 = less(%loop_var0, %while/Less/y)
%2 = min(%1)
if (%2) {
%3 = add(%loop_var0, %while/Add/y)
free_var %while_loop
%4 = %while_loop(%3)
%4
} else {
%5 = (%loop_var0,)
%5
}
}
let %while_loop1 = %0
%6 = %while_loop1(%Const)
%6
}
"""
def __init__(self):
self.loop_vars = []
self.cond = None
self.body = []
self._loop = None

def _while_loop(self):
"""An internal API to create a Relay recurisve call for a matched TF
`while_loop` construct.
"""
wl = tvm.relay.var('while_loop')

sb = tvm.relay.scope_builder.ScopeBuilder()

loop_vars = []
bind_map = {}
for i, var in enumerate(self.loop_vars):
assert isinstance(var, _expr.Var), repr(var)
v = tvm.relay.var("loop_var" + str(i),
type_annotation=var.type_annotation)
loop_vars.append(v)
bind_map[var] = v

self.cond = tvm.relay.bind(self.cond, bind_map)
self.body = [tvm.relay.bind(b, bind_map) for b in self.body]

cond = tvm.relay.op.min(self.cond)

with sb.if_scope(cond):
sb.ret(wl(*self.body))
with sb.else_scope():
sb.ret(tvm.relay.Tuple(loop_vars))

loop_fn = tvm.relay.Function(loop_vars, sb.get())
sb = tvm.relay.scope_builder.ScopeBuilder()
sb.let(wl, loop_fn)
sb.ret(wl(*self.loop_vars))
return sb.get()

def while_loop(self):
"""Instantiate a while loop if it has not been created yet."""
if self._loop is None:
self._loop = self._while_loop()
return self._loop
zhiics marked this conversation as resolved.
Show resolved Hide resolved
return self._loop


class GraphProto(object):
""" A helper class for handling relay graph copying from Tensorflow GraphDef.
Definition:
Expand All @@ -1284,6 +1499,8 @@ def __init__(self):
self._num_rnn_layer = False
self._outputs_are_0d = {}
self._input_shapes = {}
self._loops = {}
self._branches = {}

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -1332,7 +1549,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
raise NotImplementedError( \
"The following operators are not implemented: {}".format(missing_operators))

control_flow_node_map = defaultdict(set)
for node in graph.node:
node_name_prefix = node.name.rsplit('/', 1)[0]
control_flow_node_map[node_name_prefix].add(node.op)
if node.op == 'Placeholder':
if shape and node.name in shape:
self._input_shapes[node.name] = list(shape[node.name])
Expand Down Expand Up @@ -1447,12 +1667,17 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym)
input_0d_mismatch.add(in_sym[0])

attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch

op = self._convert_operator(node.op, inputs, attr, graph)
if node.op in _control_flow_nodes:
op = self._convert_control_flow_operator(node, inputs,
attr,
control_flow_node_map)
else:
op = self._convert_operator(node.op, inputs, attr, graph)

# Check if op is converted to param
if isinstance(op, np.ndarray):
Expand Down Expand Up @@ -1493,7 +1718,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

out = []
if outputs is None:
out = op
if node.op == "Exit":
out = [op[0].tuple_value]
else:
out = op
else:
for out_name in outputs:
if ":" in out_name:
Expand Down Expand Up @@ -1529,7 +1757,9 @@ def _parse_import_prerequisites(self, graph):
elif node.op == "Const":
pass
else:
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
if any([node.op in t for t in [_identity_list, _convert_map,
_convert_map_rnn,
_control_flow_nodes]]):
pass
else:
missing_operators.add(node.op)
Expand Down Expand Up @@ -1656,6 +1886,89 @@ def _convert_rnn_operator(self, op_name, inputs,
sym = self.rnn.process_op(op_name, inputs, attrs, params)
return sym

def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map):
"""
Convert the Relay control flow primitive into corresponding component
of a Relay control flow construct, i.e. `tf.cond` and `tf.while_loop`
are converted in Relay `If` and recusrive call, respectively.
Parameters
----------
node: TensorFlow graph node object.
A TensorFlow graph node object.
inputs : List[tvm.relay.Expr]
List of input symbols.
attrs : Dict[tvm.Attrs]
Dict of operator attributes.
control_flow_node_map : Dict[str, Set[str]]
A dictionary contains the execution frame name to primitives
mapping.
Returns
-------
op : tvm.relay.Expr
Converted relay expression.
"""
node_name_prefix = node.name.rsplit('/', 1)[0]
if node.op == "Merge":
if _in_while_loop(control_flow_node_map, node_name_prefix):
op = self._nodes[node.input[0]]
self._loops[node_name_prefix] = Loop()
else:
if len(self._branches) == 0:
raise RuntimeError("Cannot find a created "
"conditional for merge node")
branch = self._branches[node_name_prefix]
false_br = self._nodes[node.input[0]]
true_br = self._nodes[node.input[1]]
assert len(true_br) == 1
assert len(false_br) == 1
branch.true_branch = true_br[0]
branch.false_branch = false_br[0]
op = [branch.if_node()]
elif node.op == "Exit":
loop = self._loops[node_name_prefix]
exit_name = node.name.split('/')[-1]
assert str.startswith(exit_name, 'Exit')

# TensorFlow has differen naming convention on different
# versions.
if '_' in exit_name:
exit_number = int("0" + exit_name[5:])
else:
exit_number = int("0" + exit_name[4:])

expr = loop.while_loop()
op = _expr.TupleGetItem(expr, exit_number)
elif node.op == "Enter":
op = self._nodes[node.input[0]]
elif node.op == "LoopCond":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].cond = op[0]
elif node.op == "Switch":
op = self._nodes[node.input[0]]
assert len(op) == 1
if _in_while_loop(control_flow_node_map, node_name_prefix):
self._loops[node_name_prefix].loop_vars.append(op[0])
else:
if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch()
self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0])
elif node.op == "NextIteration":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].body.append(op[0])
else:
raise Exception("Cannot identify control flow operator: " +
"{}".format(node.op))

return op


def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to relay operator.
Expand Down
Loading