diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 0efebe3cfec9..304c5e11f1a5 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -5,6 +5,7 @@ import logging import warnings +from collections import defaultdict # Numpy support import numpy as np @@ -1270,6 +1271,220 @@ def _get_abs_layer_name(node): params, num_layers) return sym +# An internal list to contain all the control flow primitives used in Tensorflow +# 1.x. +_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] + +def _in_while_loop(control_flow_node_map, op_name): + """ + Check if a given control flow operator is part of a while loop execution + frame. This is based on the fact that there is only one occurrence of + `LoopCond` for a loop execution frame and it is only presented in the loop + construct. + + Parameters + ---------- + control_flow_node_map : Dict[str, Set[str]] + A dictionay contains the unqiue control flow execution frame name to + a set of primitive operators mapping. + + op_name : str + The name of a control flow primitive. + + Returns + ------- + ret : bool + Return true if the operator is in a while loop execution frame, + otherwise, return false. + """ + return op_name in control_flow_node_map and \ + "LoopCond" in control_flow_node_map[op_name] + + +class Branch: + """A class contains the components that are used to build up a Relay if + node. + + Parameters + ---------- + cond : tvm.relay.Expr + The condition of a if node. + + true_branch : tvm.relay.Expr + The body of the true branch of a if expression. + + false_branch: tvm.relay.Expr + The body of the false branch of a if expression. + + _if : tvm.relay.Expr + An internal variable indicates where an if expression is already created + for a matched TF condition construct. + + Examples + -------- + The following is a cond statement written in TensorFlow: + + .. code-block:: python + + def vanilla_cond(): + i = tf.constant(1) + j = tf.constant(4) + + def f1(): + return tf.multiply(1, 17) + + def f2(): + return tf.add(4, 23) + r = tf.cond(tf.less(i, j), f1, f2) + + This condition statement should be coverted into Relay in the following + form: + + .. code-block:: python + + fn (%Const: Tensor[(1,), int32], + %Const_1: Tensor[(1,), int32], + %cond/Mul/x: Tensor[(1,), int32], + %cond/Mul/y: Tensor[(1,), int32], + %cond/Add/x: Tensor[(1,), int32], + %cond/Add/y: Tensor[(1,), int32]) { + %0 = less(%Const, %Const_1) # ty=Tensor[(1,), bool] + %1 = min(%0) + if (%1) { + %2 = multiply(%cond/Mul/x, %cond/Mul/y) + %2 + } else { + %3 = add(%cond/Add/x, %cond/Add/y) + %3 + } + } + """ + def __init__(self): + self._if = None + self.cond = None + self.true_branch = None + self.false_branch = None + + def _if_node(self): + """An internal API to create a relay if node from the matched TF + condition construct. + """ + # `cond` returns a tensor that contains boolean values. We add a `min` + # operator to checks if there is any false value. If so, this condition + # doesn't not hold. + cond = tvm.relay.op.min(self.cond) + return tvm.relay.If(cond, self.true_branch, self.false_branch) + + def if_node(self): + """Create an tvm.relay.If node if it hasn't been created yet.""" + if self._if is None: + self._if = self._if_node() + return self._if + + +class Loop: + """ + A class contains the components that are used to build up a Relay + recursive call. + + Parameters + ---------- + loop_vars : List[tvm.relay.Expr] + The loop variables that used in a while loop. + + cond : tvm.relay.Expr + The condition of a while loop. + + body : tvm.relay.Expr + The body of a matched while loop. + + _loop : tvm.relay.Expr + An internal variable indicates where a recursive call is already created + for a matched TF while loop construct. + + Examples + -------- + The following is a vanilla loop from TensorFlow: + + .. code-block:: python + + i = tf.constant(0) + c = lambda i: tf.less(i, 10) + b = lambda i: tf.add(i, 1) + r = tf.while_loop(c, b, [i]) + + It will be converted to the following recursive call in Relay: + + .. code-block:: python + + fn (%while/Less/y: Tensor[(1,), int32], + %while/Add/y: Tensor[(1,), int32], + %Const: Tensor[(1,), int32]) { + %0 = fn(%loop_var0: Tensor[(1,), int32]) { + %1 = less(%loop_var0, %while/Less/y) + %2 = min(%1) + if (%2) { + %3 = add(%loop_var0, %while/Add/y) + free_var %while_loop + %4 = %while_loop(%3) + %4 + } else { + %5 = (%loop_var0,) + %5 + } + } + let %while_loop1 = %0 + %6 = %while_loop1(%Const) + %6 + } + """ + def __init__(self): + self.loop_vars = [] + self.cond = None + self.body = [] + self._loop = None + + def _while_loop(self): + """An internal API to create a Relay recurisve call for a matched TF + `while_loop` construct. + """ + wl = tvm.relay.var('while_loop') + + sb = tvm.relay.scope_builder.ScopeBuilder() + + loop_vars = [] + bind_map = {} + for i, var in enumerate(self.loop_vars): + assert isinstance(var, _expr.Var), repr(var) + v = tvm.relay.var("loop_var" + str(i), + type_annotation=var.type_annotation) + loop_vars.append(v) + bind_map[var] = v + + self.cond = tvm.relay.bind(self.cond, bind_map) + self.body = [tvm.relay.bind(b, bind_map) for b in self.body] + + cond = tvm.relay.op.min(self.cond) + + with sb.if_scope(cond): + sb.ret(wl(*self.body)) + with sb.else_scope(): + sb.ret(tvm.relay.Tuple(loop_vars)) + + loop_fn = tvm.relay.Function(loop_vars, sb.get()) + sb = tvm.relay.scope_builder.ScopeBuilder() + sb.let(wl, loop_fn) + sb.ret(wl(*self.loop_vars)) + return sb.get() + + def while_loop(self): + """Instantiate a while loop if it has not been created yet.""" + if self._loop is None: + self._loop = self._while_loop() + return self._loop + return self._loop + + class GraphProto(object): """ A helper class for handling relay graph copying from Tensorflow GraphDef. Definition: @@ -1284,6 +1499,8 @@ def __init__(self): self._num_rnn_layer = False self._outputs_are_0d = {} self._input_shapes = {} + self._loops = {} + self._branches = {} def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. @@ -1332,7 +1549,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): raise NotImplementedError( \ "The following operators are not implemented: {}".format(missing_operators)) + control_flow_node_map = defaultdict(set) for node in graph.node: + node_name_prefix = node.name.rsplit('/', 1)[0] + control_flow_node_map[node_name_prefix].add(node.op) if node.op == 'Placeholder': if shape and node.name in shape: self._input_shapes[node.name] = list(shape[node.name]) @@ -1447,12 +1667,17 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): # This means the node is 1d in Relay and 0d in TF. # See `_expand_dims_0d_aware`. if self._outputs_are_0d[node_name][tensor_slot] and input_shape: - input_0d_mismatch.add(in_sym) + input_0d_mismatch.add(in_sym[0]) attr['_input_shapes'] = input_shapes attr['_input_0d_mismatch'] = input_0d_mismatch - op = self._convert_operator(node.op, inputs, attr, graph) + if node.op in _control_flow_nodes: + op = self._convert_control_flow_operator(node, inputs, + attr, + control_flow_node_map) + else: + op = self._convert_operator(node.op, inputs, attr, graph) # Check if op is converted to param if isinstance(op, np.ndarray): @@ -1493,7 +1718,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out = [] if outputs is None: - out = op + if node.op == "Exit": + out = [op[0].tuple_value] + else: + out = op else: for out_name in outputs: if ":" in out_name: @@ -1529,7 +1757,9 @@ def _parse_import_prerequisites(self, graph): elif node.op == "Const": pass else: - if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): + if any([node.op in t for t in [_identity_list, _convert_map, + _convert_map_rnn, + _control_flow_nodes]]): pass else: missing_operators.add(node.op) @@ -1656,6 +1886,89 @@ def _convert_rnn_operator(self, op_name, inputs, sym = self.rnn.process_op(op_name, inputs, attrs, params) return sym + def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map): + """ + Convert the Relay control flow primitive into corresponding component + of a Relay control flow construct, i.e. `tf.cond` and `tf.while_loop` + are converted in Relay `If` and recusrive call, respectively. + + Parameters + ---------- + node: TensorFlow graph node object. + A TensorFlow graph node object. + + inputs : List[tvm.relay.Expr] + List of input symbols. + + attrs : Dict[tvm.Attrs] + Dict of operator attributes. + + control_flow_node_map : Dict[str, Set[str]] + A dictionary contains the execution frame name to primitives + mapping. + + Returns + ------- + op : tvm.relay.Expr + Converted relay expression. + """ + node_name_prefix = node.name.rsplit('/', 1)[0] + if node.op == "Merge": + if _in_while_loop(control_flow_node_map, node_name_prefix): + op = self._nodes[node.input[0]] + self._loops[node_name_prefix] = Loop() + else: + if len(self._branches) == 0: + raise RuntimeError("Cannot find a created " + "conditional for merge node") + branch = self._branches[node_name_prefix] + false_br = self._nodes[node.input[0]] + true_br = self._nodes[node.input[1]] + assert len(true_br) == 1 + assert len(false_br) == 1 + branch.true_branch = true_br[0] + branch.false_branch = false_br[0] + op = [branch.if_node()] + elif node.op == "Exit": + loop = self._loops[node_name_prefix] + exit_name = node.name.split('/')[-1] + assert str.startswith(exit_name, 'Exit') + + # TensorFlow has differen naming convention on different + # versions. + if '_' in exit_name: + exit_number = int("0" + exit_name[5:]) + else: + exit_number = int("0" + exit_name[4:]) + + expr = loop.while_loop() + op = _expr.TupleGetItem(expr, exit_number) + elif node.op == "Enter": + op = self._nodes[node.input[0]] + elif node.op == "LoopCond": + op = self._nodes[node.input[0]] + assert len(op) == 1 + self._loops[node_name_prefix].cond = op[0] + elif node.op == "Switch": + op = self._nodes[node.input[0]] + assert len(op) == 1 + if _in_while_loop(control_flow_node_map, node_name_prefix): + self._loops[node_name_prefix].loop_vars.append(op[0]) + else: + if node_name_prefix not in self._branches: + self._branches[node_name_prefix] = Branch() + self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0]) + elif node.op == "NextIteration": + op = self._nodes[node.input[0]] + assert len(op) == 1 + self._loops[node_name_prefix].body.append(op[0]) + else: + raise Exception("Cannot identify control flow operator: " + + "{}".format(node.op)) + + return op + + def _convert_operator(self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None): """Convert from Tensorflow operator to relay operator. diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 77f5c10e7ed7..77840c9be824 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -270,16 +270,30 @@ class Interpreter : return TupleValueNode::make(values); } - Value VisitExpr_(const FunctionNode* func_node) final { - auto func = GetRef(func_node); + // TODO(@jroesch): this doesn't support mutual letrec. + Value MakeClosure(const Function& func, const Var& letrec_name = Var()) { tvm::Map captured_mod; Array free_vars = FreeVars(func); for (const auto& var : free_vars) { - captured_mod.Set(var, Eval(var)); + // Evaluate the free var (which could be a function call) if it hasn't + // shown up in a letting binding that has invoked the function. + if (!letrec_name.defined() || letrec_name != var) { + captured_mod.Set(var, Eval(var)); + } } - return ClosureNode::make(captured_mod, func); + // We must use mutation here to build a self referential closure. + auto closure = ClosureNode::make(captured_mod, func); + auto mut_closure = + static_cast(const_cast(closure.get())); + mut_closure->env.Set(letrec_name, closure); + return closure; + } + + Value VisitExpr_(const FunctionNode* func_node) final { + auto func = GetRef(func_node); + return MakeClosure(func); } Value InvokePrimitiveOp(Function func, @@ -438,10 +452,16 @@ class Interpreter : } } - Value VisitExpr_(const LetNode* op) final { - auto value = Eval(op->value); - this->extend(op->var, value); - return Eval(op->body); + Value VisitExpr_(const LetNode* let) final { + if (auto func = let->value.as()) { + auto clo = MakeClosure(GetRef(func), let->var); + this->extend(let->var, clo); + } else { + auto value = Eval(let->value); + this->extend(let->var, value); + } + + return Eval(let->body); } Value VisitExpr_(const TupleGetItemNode* op) final { diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py new file mode 100644 index 000000000000..c5b38c319467 --- /dev/null +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -0,0 +1,285 @@ +"""Unit tests for converting TensorFlow control flow op to Relay.""" +import tensorflow as tf +import numpy as np +from tvm import relay +from tvm.relay.frontend.tensorflow import from_tensorflow + + +def check_equal(graph, tf_out): + expr, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + ex = relay.create_executor('debug') + relay_out = ex.evaluate(expr)(**params) + if isinstance(relay_out, relay.backend.interpreter.TensorValue): + np.testing.assert_allclose(tf_out, relay_out.asnumpy()) + else: + if not isinstance(tf_out, list): + tf_out = [tf_out] + for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]): + np.testing.assert_allclose(x, y) + + +def test_vanilla_loop(): + graph = tf.Graph() + with graph.as_default(): + i = tf.constant(0) + + def c(i): return tf.less(i, 10) + + def b(i): return tf.add(i, 1) + + r = tf.while_loop(c, b, [i]) + + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def test_loop_2_vars(): + graph = tf.Graph() + with graph.as_default(): + i0 = tf.constant(0) + j0 = tf.ones([2, 2]) + + def c(i, j): return i < 10 + + def b(i, j): return [tf.add(i, 1), j] + + i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0]) + i1 += tf.constant(1337) + + with tf.Session() as sess: + tf_out = sess.run(i1) + + check_equal(graph, tf_out) + + +def test_loop_3_vars(): + graph = tf.Graph() + with graph.as_default(): + i0 = tf.constant(1) + j0 = tf.constant(2) + k0 = tf.constant(4) + + def c(i, j, k): return i < 10 + + def b(i, j, k): return [i+1, j * k, k + i] + r = tf.while_loop(c, b, loop_vars=[i0, j0, k0]) + + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def test_loop_conditions(): + graph = tf.Graph() + with graph.as_default(): + i = tf.constant(1) + j = tf.constant(1) + k = tf.constant(5) + + def c(i, j, k): return \ + tf.equal(tf.not_equal(tf.less(i + j, 10), + tf.less(j * k, 100)), + tf.greater_equal(k, i + j)) + + def b(i, j, k): return [i+j, j+k, k+1] + r = tf.while_loop(c, b, loop_vars=[i, j, k]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def test_loop_bodies(): + graph = tf.Graph() + with graph.as_default(): + def body(x): + a = tf.constant(np.array([[5, 6], [7, 8]]), dtype=tf.int32) + b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32) + c = a + b + return tf.nn.relu(x + c) + + def condition(x): + return tf.reduce_sum(x) < 100 + x = tf.constant(0, shape=[2, 2]) + r = tf.while_loop(condition, body, [x]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def test_nested_loop(): + graph = tf.Graph() + with graph.as_default(): + + def body(x): + def nest_body(c): + return tf.multiply(c, 2) + def cd(c): return tf.less(c, 10) + c = tf.constant(2) + res = tf.while_loop(cd, nest_body, loop_vars=[c]) + return tf.nn.relu(x + res) + + def condition(x): + return tf.greater(x, 100) + x = tf.constant(3) + r = tf.while_loop(condition, body, loop_vars=[x]) + + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def test_vanilla_cond(): + graph = tf.Graph() + with graph.as_default(): + i = tf.constant(1) + j = tf.constant(4) + + def f1(): + return tf.multiply(1, 17) + + def f2(): + return tf.add(4, 23) + r = tf.cond(tf.less(i, j), f1, f2) + + with tf.Session(graph=graph) as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def test_multiple_cond_vars(): + graph = tf.Graph() + with graph.as_default(): + x1 = tf.constant(7) + x2 = tf.constant(12) + z = tf.constant(20) + r = tf.cond(tf.less(tf.add(x1, x2), 10), + lambda: tf.add(10, 2), lambda: tf.square(5)) + + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +def test_cond_fn_parameters(): + graph = tf.Graph() + with graph.as_default(): + def fn1(x, y): + return tf.multiply(5, 6) + + def fn2(x, y): + return tf.add(3, 4) + + i = tf.constant(1) + j = tf.constant(2) + k = tf.constant(3) + r = tf.cond(tf.less(i, j), lambda: fn1(i, k), lambda: fn2(j, k)) + + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={i: 1, j: 2, k: 3}) + + check_equal(graph, tf_out) + + +def test_nested_cond(): + graph = tf.Graph() + with graph.as_default(): + def fn1(a, b): + def nest_fn1(): + return tf.add(1, 2) + + def nest_fn2(): + return tf.subtract(10, 5) + + res = tf.cond(tf.less(1, 2), nest_fn1, nest_fn2) + return tf.multiply(tf.add(87, res), 10) + + def fn2(a, b): + return tf.add(10, 10) + + x = tf.constant(5) + y = tf.constant(6) + z = tf.constant(7) + pred = tf.less(x, y) + r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) + + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True}) + + check_equal(graph, tf_out) + + +def test_loop_in_cond(): + graph = tf.Graph() + with graph.as_default(): + def fn1(a, b): + i = tf.constant(0) + + def cd(i): return tf.less(i, 10) + + def bd(i): return tf.add(i, 1) + res = tf.while_loop(cd, bd, [i]) + return tf.multiply(tf.add(20, res), 10) + + def fn2(a, b): + return tf.add(10, 20) + + x = tf.constant(7) + y = tf.constant(20) + z = tf.constant(10) + pred = tf.less(x, y) + r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) + + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True}) + + check_equal(graph, tf_out) + + +def test_cond_in_loop(): + graph = tf.Graph() + with graph.as_default(): + def body(x): + x = tf.constant(7) + z = tf.constant(20) + res = tf.cond(tf.less(x, 10), lambda: tf.add( + 10, 20), lambda: tf.square(10)) + return tf.multiply(res, x) + + x = tf.constant(21) + def condition(x): + return tf.less(x, 100) + + r = tf.while_loop(condition, body, loop_vars=[x]) + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + +if __name__ == "__main__": + + # tf.while_loop + test_vanilla_loop() + test_loop_2_vars() + test_loop_3_vars() + test_loop_conditions() + test_loop_bodies() + + # tf.cond + test_vanilla_cond() + test_multiple_cond_vars() + test_cond_fn_parameters() + + # nested cases + test_nested_loop() + test_nested_cond() + test_loop_in_cond() + test_cond_in_loop()