From 5e58143173ce2905afa56271d140949bc3586e79 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Fri, 13 Mar 2020 17:35:20 +0000 Subject: [PATCH] Add more comments --- python/tvm/relay/frontend/tensorflow.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 2aa6a6ff478fb..4e04a620d1be2 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2117,6 +2117,7 @@ def b(i): return [tf.add(i, 1), tf.add(i, 1) + slice] Instead, we recognize whether slice is inside while_loop block and pass it as an extra loop variable to avoid duplicate computation. + TODO(kevinthesun): Add a LICM pass for Relay to handle generic loop/function. """ def __init__(self, loop_name, hash2tfnode, while_loop_name_set): ExprVisitor.__init__(self) @@ -2131,8 +2132,14 @@ def _find_parent_loop_name(self, node_name): name_prefix = node_name.rsplit('/', 1)[0] if name_prefix.startswith("^"): name_prefix = name_prefix[1:] + # To get the name of the direct parent while loop for a given node, + # we iterate all the while loop names inside TensorFlow graph def. + # If we find a loop name with which current node name starts, + # it means current node is under this loop. However, due to nested + # loop, this loop may not be the direct parent while loop of current + # node. We need to keep the longest loop name, which represents the + # innermost while loop corresponding to current node. for lname in self._while_loop_name_set: - # For nested loop, we should pick the inner most one. if name_prefix.startswith(lname) and len(ploop_name) < len(lname): ploop_name = lname @@ -2144,8 +2151,8 @@ def _find_parent_loop_name(self, node_name): def visit(self, expr): """ For each expression in the body, look up the corresponding - tensorflow node with its structural hash. If the current loop is the - direct parent of this node, we check whether its every input node belong + TensorFlow node with its structural hash. If the current loop is the + direct parent of this node, we check whether its every input node belongs to the current loop. If not, we mark this input node as an extra loop variable to the current loop. """ @@ -2154,6 +2161,8 @@ def visit(self, expr): if expr_hash in self._hash2tfnode: node = self._hash2tfnode[expr_hash] ploop_name = self._find_parent_loop_name(node.name) + # It is possibel that a node is under nested loop of current loop. + # We only check the direct children of current loop. if ploop_name == self._loop_name: for iname in node.input: iploop_name = self._find_parent_loop_name(iname)