Skip to content

Commit

Permalink
Add more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun committed Mar 13, 2020
1 parent 785964b commit 5e58143
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2117,6 +2117,7 @@ def b(i): return [tf.add(i, 1), tf.add(i, 1) + slice]
Instead, we recognize whether slice is inside while_loop block and pass it as an
extra loop variable to avoid duplicate computation.
TODO(kevinthesun): Add a LICM pass for Relay to handle generic loop/function.
"""
def __init__(self, loop_name, hash2tfnode, while_loop_name_set):
ExprVisitor.__init__(self)
Expand All @@ -2131,8 +2132,14 @@ def _find_parent_loop_name(self, node_name):
name_prefix = node_name.rsplit('/', 1)[0]
if name_prefix.startswith("^"):
name_prefix = name_prefix[1:]
# To get the name of the direct parent while loop for a given node,
# we iterate all the while loop names inside TensorFlow graph def.
# If we find a loop name with which current node name starts,
# it means current node is under this loop. However, due to nested
# loop, this loop may not be the direct parent while loop of current
# node. We need to keep the longest loop name, which represents the
# innermost while loop corresponding to current node.
for lname in self._while_loop_name_set:
# For nested loop, we should pick the inner most one.
if name_prefix.startswith(lname) and len(ploop_name) < len(lname):
ploop_name = lname

Expand All @@ -2144,8 +2151,8 @@ def _find_parent_loop_name(self, node_name):
def visit(self, expr):
"""
For each expression in the body, look up the corresponding
tensorflow node with its structural hash. If the current loop is the
direct parent of this node, we check whether its every input node belong
TensorFlow node with its structural hash. If the current loop is the
direct parent of this node, we check whether its every input node belongs
to the current loop. If not, we mark this input node as an extra loop
variable to the current loop.
"""
Expand All @@ -2154,6 +2161,8 @@ def visit(self, expr):
if expr_hash in self._hash2tfnode:
node = self._hash2tfnode[expr_hash]
ploop_name = self._find_parent_loop_name(node.name)
# It is possibel that a node is under nested loop of current loop.
# We only check the direct children of current loop.
if ploop_name == self._loop_name:
for iname in node.input:
iploop_name = self._find_parent_loop_name(iname)
Expand Down

0 comments on commit 5e58143

Please sign in to comment.