Skip to content

Commit

Permalink
[Frontend][MXNet] Change mxnet graph traversal from recursion to iter…
Browse files Browse the repository at this point in the history
…ation (apache#2007)
  • Loading branch information
icemelon authored and Wei Chen committed Feb 19, 2019
1 parent a07beee commit 0c3cfb3
Showing 1 changed file with 80 additions and 22 deletions.
102 changes: 80 additions & 22 deletions nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,55 @@ def _as_list(arr):
return arr
return [arr]

def _topo_sort(symbol):
"""Sort all symbols in the mxnet graph in topological order.
Parameters
----------
symbol : mxnet.sym.Symbol
Returns:
-------
list
List of mxnet symbol
"""
queue = []
symbol_map = {}
deps = {}
dep_cnts = {}
for s in symbol:
symbol_map[s.attr('name')] = s
queue.append(s)
while queue:
sym = queue.pop(0)
name = sym.attr('name')
childs = sym.get_children()
if childs is None:
dep_cnts[name] = 0
else:
dep_cnts[name] = len(set([c.attr('name') for c in childs]))
for child in childs:
child_name = child.attr('name')
if child_name not in deps:
deps[child_name] = set()
deps[child_name].add(name)
if child_name not in symbol_map:
symbol_map[child_name] = child
queue.append(child)
order = []
while dep_cnts:
remove = []
for name in dep_cnts:
if dep_cnts[name] == 0:
order.append(symbol_map[name])
remove.append(name)
if name in deps:
for other in deps[name]:
dep_cnts[other] -= 1
for name in remove:
del dep_cnts[name]
return order

def _from_mxnet_impl(symbol, graph):
"""Convert mxnet symbol to nnvm implementation.
Reconstruct a nnvm symbol by traversing the mxnet symbol.
Expand All @@ -413,28 +462,37 @@ def _from_mxnet_impl(symbol, graph):
nnvm.sym.Symbol
Converted symbol
"""
if len(symbol.list_outputs()) > 1:
return [_from_mxnet_impl(s, graph) for s in symbol]

name = symbol.attr('name')
output_index = json.loads(symbol.tojson())['heads'][0][1]
node = graph.get(name, None)
if node:
return node[output_index]
attr = symbol.list_attr()
op_name = symbol.attr('op_name')
childs = symbol.get_children()
if childs is not None:
childs = [_from_mxnet_impl(childs[i], graph) for i in range(len(childs.list_outputs()))]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol
node = _convert_symbol(op_name, childs, attr)
elif op_name != 'null':
node = _convert_symbol(op_name, [], attr) # no input symbol
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
node = _sym.Variable(name=name, **attr)
graph[name] = node
return node[output_index]
def get_node(sym):
name = sym.attr('name')
if name not in graph:
return None
output_index = json.loads(sym.tojson())['heads'][0][1]
return graph[name][output_index]

assert symbol is not None
# Traverse all symbols in topological order
for sym in _topo_sort(symbol):
name = sym.attr('name')
attr = sym.list_attr()
op_name = sym.attr('op_name')
childs = sym.get_children()
if childs is not None:
childs = [get_node(child) for child in childs]
childs = [x for y in childs for x in _as_list(y)]
node = _convert_symbol(op_name, childs, attr)
elif op_name != 'null':
node = _convert_symbol(op_name, [], attr)
else:
node = _sym.Variable(name=name, **attr)
graph[name] = node
nodes = []
for sym in symbol:
node = get_node(sym)
assert node is not None
nodes.append(node)
if len(nodes) > 1:
return _sym.Group(nodes)
return nodes[0]

def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from MXNet's model into compatible NNVM format.
Expand Down

0 comments on commit 0c3cfb3

Please sign in to comment.