diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d684c12c87576..5b5e273b2466a 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2205,6 +2205,7 @@ def b(i): return [tf.add(i, 1), tf.add(i, 1) + slice] Instead, we recognize whether slice is inside while_loop block and pass it as an extra loop variable to avoid duplicate computation. + TODO(kevinthesun): Add a LICM pass for Relay to handle generic loop/function. """ def __init__(self, loop_name, hash2tfnode, while_loop_name_set): ExprVisitor.__init__(self) @@ -2219,8 +2220,14 @@ def _find_parent_loop_name(self, node_name): name_prefix = node_name.rsplit('/', 1)[0] if name_prefix.startswith("^"): name_prefix = name_prefix[1:] + # To get the name of the direct parent while loop for a given node, + # we iterate all the while loop names inside TensorFlow graph def. + # If we find a loop name with which current node name starts, + # it means current node is under this loop. However, due to nested + # loop, this loop may not be the direct parent while loop of current + # node. We need to keep the longest loop name, which represents the + # innermost while loop corresponding to current node. for lname in self._while_loop_name_set: - # For nested loop, we should pick the inner most one. if name_prefix.startswith(lname) and len(ploop_name) < len(lname): ploop_name = lname @@ -2232,8 +2239,8 @@ def _find_parent_loop_name(self, node_name): def visit(self, expr): """ For each expression in the body, look up the corresponding - tensorflow node with its structural hash. If the current loop is the - direct parent of this node, we check whether its every input node belong + TensorFlow node with its structural hash. If the current loop is the + direct parent of this node, we check whether its every input node belongs to the current loop. If not, we mark this input node as an extra loop variable to the current loop. """ @@ -2242,6 +2249,8 @@ def visit(self, expr): if expr_hash in self._hash2tfnode: node = self._hash2tfnode[expr_hash] ploop_name = self._find_parent_loop_name(node.name) + # It is possibel that a node is under nested loop of current loop. + # We only check the direct children of current loop. if ploop_name == self._loop_name: for iname in node.input: iploop_name = self._find_parent_loop_name(iname)