Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][MXNet] Change mxnet graph traversal from recursion to iteration #2007

Merged
merged 3 commits into from
Oct 29, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 80 additions & 22 deletions nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,55 @@ def _as_list(arr):
return arr
return [arr]

def _topo_sort(symbol):
"""Sort all symbols in the mxnet graph in topological order.

Parameters
----------
symbol : mxnet.sym.Symbol

Returns:
-------
list
List of mxnet symbol
"""
queue = []
symbol_map = {}
deps = {}
dep_cnts = {}
for s in symbol:
symbol_map[s.attr('name')] = s
queue.append(s)
while queue:
sym = queue.pop(0)
name = sym.attr('name')
childs = sym.get_children()
if childs is None:
dep_cnts[name] = 0
else:
dep_cnts[name] = len(set([c.attr('name') for c in childs]))
for child in childs:
child_name = child.attr('name')
if child_name not in deps:
deps[child_name] = set()
deps[child_name].add(name)
if child_name not in symbol_map:
symbol_map[child_name] = child
queue.append(child)
order = []
while dep_cnts:
remove = []
for name in dep_cnts:
if dep_cnts[name] == 0:
order.append(symbol_map[name])
remove.append(name)
if name in deps:
for other in deps[name]:
dep_cnts[other] -= 1
for name in remove:
del dep_cnts[name]
return order

def _from_mxnet_impl(symbol, graph):
"""Convert mxnet symbol to nnvm implementation.
Reconstruct a nnvm symbol by traversing the mxnet symbol.
Expand All @@ -398,28 +447,37 @@ def _from_mxnet_impl(symbol, graph):
nnvm.sym.Symbol
Converted symbol
"""
if len(symbol.list_outputs()) > 1:
return [_from_mxnet_impl(s, graph) for s in symbol]

name = symbol.attr('name')
output_index = json.loads(symbol.tojson())['heads'][0][1]
node = graph.get(name, None)
if node:
return node[output_index]
attr = symbol.list_attr()
op_name = symbol.attr('op_name')
childs = symbol.get_children()
if childs is not None:
childs = [_from_mxnet_impl(childs[i], graph) for i in range(len(childs.list_outputs()))]
childs = [x for y in childs for x in _as_list(y)] # expand group symbol
node = _convert_symbol(op_name, childs, attr)
elif op_name != 'null':
node = _convert_symbol(op_name, [], attr) # no input symbol
else:
op_name = json.loads(symbol.tojson())['nodes'][0]['op']
node = _sym.Variable(name=name, **attr)
graph[name] = node
return node[output_index]
def get_node(sym):
name = sym.attr('name')
if name not in graph:
return None
output_index = json.loads(sym.tojson())['heads'][0][1]
return graph[name][output_index]

assert symbol is not None
# Traverse all symbols in topological order
for sym in _topo_sort(symbol):
name = sym.attr('name')
attr = sym.list_attr()
op_name = sym.attr('op_name')
childs = sym.get_children()
if childs is not None:
childs = [get_node(child) for child in childs]
childs = [x for y in childs for x in _as_list(y)]
node = _convert_symbol(op_name, childs, attr)
elif op_name != 'null':
node = _convert_symbol(op_name, [], attr)
else:
node = _sym.Variable(name=name, **attr)
graph[name] = node
nodes = []
for sym in symbol:
node = get_node(sym)
assert node is not None
nodes.append(node)
if len(nodes) > 1:
return _sym.Group(nodes)
return nodes[0]

def from_mxnet(symbol, arg_params=None, aux_params=None):
"""Convert from MXNet's model into compatible NNVM format.
Expand Down