Skip to content

Commit

Permalink
Add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 21, 2019
1 parent 819f73c commit a65cfbb
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 123 deletions.
305 changes: 233 additions & 72 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np

import tvm
from tvm import relay
from topi.util import get_const_tuple
from .. import ir_pass
from .. import expr as _expr
Expand Down Expand Up @@ -1271,51 +1272,170 @@ def _get_abs_layer_name(node):
params, num_layers)
return sym

# An internal list to contain all the control flow primitives used in Tensorflow
# 1.x.
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']

def _in_while_loop(control_flow_node_map, op_name):
"""
Check if a given control flow operator is part of a while loop execution
frame. This is based on the fact that there is only one occurrence of
`LoopCond` for a loop execution frame and it is only presented in the loop
construct.
Parameters
----------
control_flow_node_map : Dict[str, Set[str]]
A dictionay contains the unqiue control flow execution frame name to
a set of primitive operators mapping.
op_name : str
The name of a control flow primitive.
Returns
-------
ret : bool
Return true if the operator is in a while loop execution frame,
otherwise, return false.
"""
return op_name in control_flow_node_map and \
"LoopCond" in control_flow_node_map[op_name]


class Branch:
"""A class contains the components that are used to build up a Relay if
node.
Parameters
----------
cond : tvm.relay.Expr
The condition of a if node.
true_branch : tvm.relay.Expr
The body of the true branch of a if expression.
false_branch: tvm.relay.Expr
The body of the false branch of a if expression.
_if : tvm.relay.Expr
An internal variable indicates where an if expression is already created
for a matched TF condition construct.
Examples
--------
The following is a cond statement written in TensorFlow:
.. code-block:: python
def vanilla_cond():
i = tf.constant(1)
j = tf.constant(4)
def f1():
return tf.multiply(1, 17)
def f2():
return tf.add(4, 23)
r = tf.cond(tf.less(i, j), f1, f2)
This condition statement should be coverted into Relay in the following
form:
.. code-block:: python
fn (%Const: Tensor[(1,), int32],
%Const_1: Tensor[(1,), int32],
%cond/Mul/x: Tensor[(1,), int32],
%cond/Mul/y: Tensor[(1,), int32],
%cond/Add/x: Tensor[(1,), int32],
%cond/Add/y: Tensor[(1,), int32]) {
%0 = less(%Const, %Const_1) # ty=Tensor[(1,), bool]
%1 = min(%0)
if (%1) {
%2 = multiply(%cond/Mul/x, %cond/Mul/y)
%2
} else {
%3 = add(%cond/Add/x, %cond/Add/y)
%3
}
}
"""
def __init__(self):
self._if = None
self.cond_vars = set()
self.cond = None
self.true_branch = None
self.false_branch = None

def _if_node(self):
from tvm import relay

cond_vars = []
bind_map = {}
for i, var in enumerate(list(self.cond_vars)):
if not isinstance(var, _expr.Var):
raise TypeError("var is expected to be _expr.Var type, but "
"received {}".format(repr(var)))
v = relay.var("cond_var" + str(i),
type_annotation=var.type_annotation)
cond_vars.append(v)
bind_map[var] = v

self.cond = relay.bind(self.cond, bind_map)
"""An internal API to create a relay if node from the matched TF
condition construct.
"""
cond = relay.op.min(self.cond)
self.true_branch = relay.bind(self.true_branch, bind_map)
self.false_branch = relay.bind(self.false_branch, bind_map)

return relay.If(cond, self.true_branch, self.false_branch)

def if_node(self):
"""Create a if node if it hasn't been created yet."""
"""Create an tvm.relay.If node if it hasn't been created yet."""
if self._if is None:
self._if = self._if_node()
return self._if
return self._if


class Loop:
"""A class contains the components that are used to build up a Relay
"""
A class contains the components that are used to build up a Relay
recursive call.
Parameters
----------
loop_vars : List[tvm.relay.Expr]
The loop variables that used in a while loop.
cond : tvm.relay.Expr
The condition of a while loop.
body : tvm.relay.Expr
The body of a matched while loop.
_loop : tvm.relay.Expr
An internal variable indicates where a recursive call is already created
for a matched TF while loop construct.
Examples
--------
The following is a vanilla loop from TensorFlow:
.. code-block:: python
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
It will be converted to the following recursive call in Relay:
.. code-block:: python
fn (%while/Less/y: Tensor[(1,), int32],
%while/Add/y: Tensor[(1,), int32],
%Const: Tensor[(1,), int32]) {
%0 = fn(%loop_var0: Tensor[(1,), int32]) {
%1 = less(%loop_var0, %while/Less/y)
%2 = min(%1)
if (%2) {
%3 = add(%loop_var0, %while/Add/y)
free_var %while_loop
%4 = %while_loop(%3)
%4
} else {
%5 = (%loop_var0,)
%5
}
}
let %while_loop1 = %0
%6 = %while_loop1(%Const)
%6
}
"""
def __init__(self):
self.loop_vars = []
Expand All @@ -1324,8 +1444,11 @@ def __init__(self):
self._loop = None

def _while_loop(self):
from tvm import relay
"""An internal API to create a Relay recurisve call for a matched TF
`while_loop` construct.
"""
wl = relay.var('while_loop')

sb = relay.scope_builder.ScopeBuilder()

loop_vars = []
Expand Down Expand Up @@ -1354,17 +1477,13 @@ def _while_loop(self):
return sb.get()

def while_loop(self):
"""Instantiate a while loop if it has not been created yet."""
if self._loop is None:
self._loop = self._while_loop()
return self._loop
return self._loop


def _in_while_loop(control_flow_node_map, op_name):
return op_name in control_flow_node_map and \
"LoopCond" in control_flow_node_map[op_name]


class GraphProto(object):
""" A helper class for handling relay graph copying from Tensorflow GraphDef.
Definition:
Expand All @@ -1381,7 +1500,6 @@ def __init__(self):
self._input_shapes = {}
self._loops = {}
self._branches = {}
# self.module = relay.Module({})

def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
"""Construct relay nodes from tensorflow graph definition - GraphDef.
Expand Down Expand Up @@ -1548,55 +1666,15 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`.
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym)
input_0d_mismatch.add(in_sym[0])

attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch
node_name_prefix = node.name.rsplit('/', 1)[0]

if node.op == "Merge":
if _in_while_loop(control_flow_node_map, node_name_prefix):
op = self._nodes[node.input[0]]
self._loops[node_name_prefix] = Loop()
else:
if len(self._branches) == 0:
raise RuntimeError("Cannot find a created "
"conditional for merge node")
branch = self._branches[node_name_prefix]
false_br = self._nodes[node.input[0]]
true_br = self._nodes[node.input[1]]
assert len(true_br) == 1
assert len(false_br) == 1
branch.true_branch = true_br[0]
branch.false_branch = false_br[0]
op = [branch.if_node()]
# del self._branches[node_name_prefix]
elif node.op == "Exit":
loop = self._loops[node_name_prefix]
exit_name = node.name.split('/')[-1]
assert str.startswith(exit_name, 'Exit')
exit_number = int("0" + exit_name[4:])
expr = loop.while_loop()
op = _expr.TupleGetItem(expr, exit_number)
elif node.op == "Enter":
op = self._nodes[node.input[0]]
elif node.op == "LoopCond":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].cond = op[0]
elif node.op == "Switch":
op = self._nodes[node.input[0]]
assert len(op) == 1
if _in_while_loop(control_flow_node_map, node_name_prefix):
self._loops[node_name_prefix].loop_vars.append(op[0])
else:
if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch()
self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0])
elif node.op == "NextIteration":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].body.append(op[0])
if node.op in _control_flow_nodes:
op = self._convert_control_flow_operator(node, inputs,
attr,
control_flow_node_map)
else:
op = self._convert_operator(node.op, inputs, attr, graph)

Expand Down Expand Up @@ -1807,6 +1885,89 @@ def _convert_rnn_operator(self, op_name, inputs,
sym = self.rnn.process_op(op_name, inputs, attrs, params)
return sym

def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map):
"""
Convert the Relay control flow primitive into corresponding component
of a Relay control flow construct, i.e. `tf.cond` and `tf.while_loop`
are converted in Relay `If` and recusrive call, respectively.
Parameters
----------
node: TensorFlow graph node object.
A TensorFlow graph node object.
inputs : List[tvm.relay.Expr]
List of input symbols.
attrs : Dict[tvm.Attrs]
Dict of operator attributes.
control_flow_node_map : Dict[str, Set[str]]
A dictionary contains the execution frame name to primitives
mapping.
Returns
-------
op : tvm.relay.Expr
Converted relay expression.
"""
node_name_prefix = node.name.rsplit('/', 1)[0]
if node.op == "Merge":
if _in_while_loop(control_flow_node_map, node_name_prefix):
op = self._nodes[node.input[0]]
self._loops[node_name_prefix] = Loop()
else:
if len(self._branches) == 0:
raise RuntimeError("Cannot find a created "
"conditional for merge node")
branch = self._branches[node_name_prefix]
false_br = self._nodes[node.input[0]]
true_br = self._nodes[node.input[1]]
assert len(true_br) == 1
assert len(false_br) == 1
branch.true_branch = true_br[0]
branch.false_branch = false_br[0]
op = [branch.if_node()]
elif node.op == "Exit":
loop = self._loops[node_name_prefix]
exit_name = node.name.split('/')[-1]
assert str.startswith(exit_name, 'Exit')

# TensorFlow has differen naming convention on different
# versions.
if '_' in exit_name:
exit_number = int("0" + exit_name[5:])
else:
exit_number = int("0" + exit_name[4:])

expr = loop.while_loop()
op = _expr.TupleGetItem(expr, exit_number)
elif node.op == "Enter":
op = self._nodes[node.input[0]]
elif node.op == "LoopCond":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].cond = op[0]
elif node.op == "Switch":
op = self._nodes[node.input[0]]
assert len(op) == 1
if _in_while_loop(control_flow_node_map, node_name_prefix):
self._loops[node_name_prefix].loop_vars.append(op[0])
else:
if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch()
self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0])
elif node.op == "NextIteration":
op = self._nodes[node.input[0]]
assert len(op) == 1
self._loops[node_name_prefix].body.append(op[0])
else:
raise Exception("Cannot identify control flow operator: " +
"{}".format(node.op))

return op


def _convert_operator(self, op_name, inputs, attrs,
graph, identity_list=None, convert_map=None):
"""Convert from Tensorflow operator to relay operator.
Expand Down
Loading

0 comments on commit a65cfbb

Please sign in to comment.