From 057e796822205a29f6a691f30b5a1c89294c5bf1 Mon Sep 17 00:00:00 2001 From: lixiaoquan Date: Wed, 1 May 2019 23:02:12 +0800 Subject: [PATCH] [Relay][Tensorflow] Allow an op as loop var. (#3056) --- python/tvm/relay/frontend/tensorflow.py | 35 ++++++++++++++++--- .../frontend/tensorflow/test_control_flow.py | 18 ++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 0f8b19bfb45f..9c312990c379 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -30,6 +30,7 @@ from .. import ir_pass from .. import expr as _expr from .. import op as _op +from ..expr_functor import ExprMutator __all__ = ['from_tensorflow'] @@ -1414,6 +1415,27 @@ def _get_abs_layer_name(node): # 1.x. _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] +class RewriteSubgraph(ExprMutator): + """ + A helper class to rewrite expr in while loop function to variable + + Parameters + ---------- + rewrite_map : Dict[expr, expr] + A dictionay contains a set of expr to var mapping. + """ + def __init__(self, rewrite_map): + ExprMutator.__init__(self) + self.rewrite_map = rewrite_map + + def visit(self, expr): + if expr in self.rewrite_map: + return self.rewrite_map[expr] + return super().visit(expr) + +def rewrite_subgraph(expr, rewrites): + return RewriteSubgraph(rewrites).visit(expr) + def _in_while_loop(control_flow_node_map, op_name): """ Check if a given control flow operator is part of a while loop execution @@ -1594,14 +1616,17 @@ def _while_loop(self): loop_vars = [] bind_map = {} for i, var in enumerate(self.loop_vars): - assert isinstance(var, _expr.Var), repr(var) - v = tvm.relay.var("loop_var" + str(i), - type_annotation=var.type_annotation) + if not isinstance(var, _expr.Var): + var_type = ir_pass.infer_type(var).checked_type + else: + var_type = var.type_annotation + + v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type) loop_vars.append(v) bind_map[var] = v - self.cond = tvm.relay.bind(self.cond, bind_map) - self.body = [tvm.relay.bind(b, bind_map) for b in self.body] + self.cond = rewrite_subgraph(self.cond, bind_map) + self.body = [rewrite_subgraph(b, bind_map) for b in self.body] cond = tvm.relay.op.min(self.cond) diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index e76a849ae8c3..b1860658a961 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -51,6 +51,23 @@ def b(i): return tf.add(i, 1) check_equal(graph, tf_out) +def test_callnode_loop_vars(): + graph = tf.Graph() + with graph.as_default(): + i = tf.add(tf.constant(0), 1) + + def c(i): return tf.less(i, 10) + + def b(i): return tf.add(i, 1) + + r = tf.while_loop(c, b, [i]) + + with tf.Session() as sess: + tf_out = sess.run(r) + + check_equal(graph, tf_out) + + def test_loop_2_vars(): graph = tf.Graph() with graph.as_default(): @@ -288,6 +305,7 @@ def condition(x): test_loop_3_vars() test_loop_conditions() test_loop_bodies() + test_callnode_loop_vars() # tf.cond test_vanilla_cond()