Skip to content

Commit

Permalink
[Relay][Tensorflow] Allow an op as loop var. (apache#3056)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiaoquan authored and Wei Chen committed May 13, 2019
1 parent 7cbcf7e commit 057e796
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
35 changes: 30 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .. import ir_pass
from .. import expr as _expr
from .. import op as _op
from ..expr_functor import ExprMutator

__all__ = ['from_tensorflow']

Expand Down Expand Up @@ -1414,6 +1415,27 @@ def _get_abs_layer_name(node):
# 1.x.
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']

class RewriteSubgraph(ExprMutator):
"""
A helper class to rewrite expr in while loop function to variable
Parameters
----------
rewrite_map : Dict[expr, expr]
A dictionay contains a set of expr to var mapping.
"""
def __init__(self, rewrite_map):
ExprMutator.__init__(self)
self.rewrite_map = rewrite_map

def visit(self, expr):
if expr in self.rewrite_map:
return self.rewrite_map[expr]
return super().visit(expr)

def rewrite_subgraph(expr, rewrites):
return RewriteSubgraph(rewrites).visit(expr)

def _in_while_loop(control_flow_node_map, op_name):
"""
Check if a given control flow operator is part of a while loop execution
Expand Down Expand Up @@ -1594,14 +1616,17 @@ def _while_loop(self):
loop_vars = []
bind_map = {}
for i, var in enumerate(self.loop_vars):
assert isinstance(var, _expr.Var), repr(var)
v = tvm.relay.var("loop_var" + str(i),
type_annotation=var.type_annotation)
if not isinstance(var, _expr.Var):
var_type = ir_pass.infer_type(var).checked_type
else:
var_type = var.type_annotation

v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type)
loop_vars.append(v)
bind_map[var] = v

self.cond = tvm.relay.bind(self.cond, bind_map)
self.body = [tvm.relay.bind(b, bind_map) for b in self.body]
self.cond = rewrite_subgraph(self.cond, bind_map)
self.body = [rewrite_subgraph(b, bind_map) for b in self.body]

cond = tvm.relay.op.min(self.cond)

Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/tensorflow/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ def b(i): return tf.add(i, 1)
check_equal(graph, tf_out)


def test_callnode_loop_vars():
graph = tf.Graph()
with graph.as_default():
i = tf.add(tf.constant(0), 1)

def c(i): return tf.less(i, 10)

def b(i): return tf.add(i, 1)

r = tf.while_loop(c, b, [i])

with tf.Session() as sess:
tf_out = sess.run(r)

check_equal(graph, tf_out)


def test_loop_2_vars():
graph = tf.Graph()
with graph.as_default():
Expand Down Expand Up @@ -288,6 +305,7 @@ def condition(x):
test_loop_3_vars()
test_loop_conditions()
test_loop_bodies()
test_callnode_loop_vars()

# tf.cond
test_vanilla_cond()
Expand Down

0 comments on commit 057e796

Please sign in to comment.