From 9702d459e91ceea9a8de911299bfe2ecfa83b8d8 Mon Sep 17 00:00:00 2001 From: Shen Date: Fri, 26 Oct 2018 15:29:35 -0700 Subject: [PATCH 1/3] Change mxnet graph traversal from recursion to iteration --- nnvm/python/nnvm/frontend/mxnet.py | 94 +++++++++++++++++++++++------- 1 file changed, 72 insertions(+), 22 deletions(-) diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 87b169a1cfbc..7617c3ca2ace 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -381,6 +381,55 @@ def _as_list(arr): return arr return [arr] +def _topo_sort(symbol): + """Sort all symbols in the mxnet graph in topological order. + + Parameters + ---------- + symbol : mxnet.sym.Symbol + + Returns: + ------- + list + List of mxnet symbol + """ + queue = [] + symbol_map = {} + deps = {} + dep_cnts = {} + for s in symbol: + symbol_map[s.attr('name')] = s + queue.append(s) + while queue: + sym = queue.pop(0) + name = sym.attr('name') + childs = sym.get_children() + if childs is None: + dep_cnts[name] = 0 + else: + dep_cnts[name] = len(childs) + for child in childs: + child_name = child.attr('name') + if child_name not in deps: + deps[child_name] = set() + deps[child_name].add(name) + if child_name not in symbol_map: + symbol_map[child_name] = child + queue.append(child) + order = [] + while dep_cnts: + remove = [] + for name in dep_cnts: + if dep_cnts[name] == 0: + order.append(symbol_map[name]) + remove.append(name) + if name in deps: + for other in deps[name]: + dep_cnts[other] -= 1 + for name in remove: + del dep_cnts[name] + return order + def _from_mxnet_impl(symbol, graph): """Convert mxnet symbol to nnvm implementation. Reconstruct a nnvm symbol by traversing the mxnet symbol. @@ -398,28 +447,29 @@ def _from_mxnet_impl(symbol, graph): nnvm.sym.Symbol Converted symbol """ - if len(symbol.list_outputs()) > 1: - return [_from_mxnet_impl(s, graph) for s in symbol] - - name = symbol.attr('name') - output_index = json.loads(symbol.tojson())['heads'][0][1] - node = graph.get(name, None) - if node: - return node[output_index] - attr = symbol.list_attr() - op_name = symbol.attr('op_name') - childs = symbol.get_children() - if childs is not None: - childs = [_from_mxnet_impl(childs[i], graph) for i in range(len(childs.list_outputs()))] - childs = [x for y in childs for x in _as_list(y)] # expand group symbol - node = _convert_symbol(op_name, childs, attr) - elif op_name != 'null': - node = _convert_symbol(op_name, [], attr) # no input symbol - else: - op_name = json.loads(symbol.tojson())['nodes'][0]['op'] - node = _sym.Variable(name=name, **attr) - graph[name] = node - return node[output_index] + def get_node(sym): + name = sym.attr('name') + if name not in graph: + return None + output_index = json.loads(symbol.tojson())['heads'][0][1] + return graph[name][output_index] + + # Traverse all symbols in topological order + for sym in _topo_sort(symbol): + name = sym.attr('name') + attr = sym.list_attr() + op_name = sym.attr('op_name') + childs = sym.get_children() + if childs is not None: + childs = [get_node(child) for child in childs] + childs = [x for y in childs for x in _as_list(y)] + node = _convert_symbol(op_name, childs, attr) + elif op_name != 'null': + node = _convert_symbol(op_name, [], attr) + else: + node = _sym.Variable(name=name, **attr) + graph[name] = node + return [get_node(sym) for sym in symbol] def from_mxnet(symbol, arg_params=None, aux_params=None): """Convert from MXNet's model into compatible NNVM format. From 0c15eb28755849fcb3e9fa657f5ff851ab0f4a3e Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 29 Oct 2018 10:06:30 -0700 Subject: [PATCH 2/3] Fix --- nnvm/python/nnvm/frontend/mxnet.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 7617c3ca2ace..f8ee3ca73470 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -454,6 +454,7 @@ def get_node(sym): output_index = json.loads(symbol.tojson())['heads'][0][1] return graph[name][output_index] + assert symbol is not None # Traverse all symbols in topological order for sym in _topo_sort(symbol): name = sym.attr('name') @@ -469,7 +470,14 @@ def get_node(sym): else: node = _sym.Variable(name=name, **attr) graph[name] = node - return [get_node(sym) for sym in symbol] + nodes = [] + for sym in symbol: + node = get_node(sym) + assert node is not None + nodes.append(node) + if len(nodes) > 1: + return _sym.Group(nodes) + return nodes[0] def from_mxnet(symbol, arg_params=None, aux_params=None): """Convert from MXNet's model into compatible NNVM format. From 173f04e67dffe957b399a9a059aff0c6f7e73c9a Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 29 Oct 2018 12:19:23 -0700 Subject: [PATCH 3/3] Fix infinite loop in inputs with same name --- nnvm/python/nnvm/frontend/mxnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index f8ee3ca73470..d1c2f305c27d 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -407,7 +407,7 @@ def _topo_sort(symbol): if childs is None: dep_cnts[name] = 0 else: - dep_cnts[name] = len(childs) + dep_cnts[name] = len(set([c.attr('name') for c in childs])) for child in childs: child_name = child.attr('name') if child_name not in deps: @@ -451,7 +451,7 @@ def get_node(sym): name = sym.attr('name') if name not in graph: return None - output_index = json.loads(symbol.tojson())['heads'][0][1] + output_index = json.loads(sym.tojson())['heads'][0][1] return graph[name][output_index] assert symbol is not None