diff --git a/nnvm/python/nnvm/frontend/mxnet.py b/nnvm/python/nnvm/frontend/mxnet.py index 13ee4c43fad47..431a067abe96c 100644 --- a/nnvm/python/nnvm/frontend/mxnet.py +++ b/nnvm/python/nnvm/frontend/mxnet.py @@ -396,6 +396,55 @@ def _as_list(arr): return arr return [arr] +def _topo_sort(symbol): + """Sort all symbols in the mxnet graph in topological order. + + Parameters + ---------- + symbol : mxnet.sym.Symbol + + Returns: + ------- + list + List of mxnet symbol + """ + queue = [] + symbol_map = {} + deps = {} + dep_cnts = {} + for s in symbol: + symbol_map[s.attr('name')] = s + queue.append(s) + while queue: + sym = queue.pop(0) + name = sym.attr('name') + childs = sym.get_children() + if childs is None: + dep_cnts[name] = 0 + else: + dep_cnts[name] = len(set([c.attr('name') for c in childs])) + for child in childs: + child_name = child.attr('name') + if child_name not in deps: + deps[child_name] = set() + deps[child_name].add(name) + if child_name not in symbol_map: + symbol_map[child_name] = child + queue.append(child) + order = [] + while dep_cnts: + remove = [] + for name in dep_cnts: + if dep_cnts[name] == 0: + order.append(symbol_map[name]) + remove.append(name) + if name in deps: + for other in deps[name]: + dep_cnts[other] -= 1 + for name in remove: + del dep_cnts[name] + return order + def _from_mxnet_impl(symbol, graph): """Convert mxnet symbol to nnvm implementation. Reconstruct a nnvm symbol by traversing the mxnet symbol. @@ -413,28 +462,37 @@ def _from_mxnet_impl(symbol, graph): nnvm.sym.Symbol Converted symbol """ - if len(symbol.list_outputs()) > 1: - return [_from_mxnet_impl(s, graph) for s in symbol] - - name = symbol.attr('name') - output_index = json.loads(symbol.tojson())['heads'][0][1] - node = graph.get(name, None) - if node: - return node[output_index] - attr = symbol.list_attr() - op_name = symbol.attr('op_name') - childs = symbol.get_children() - if childs is not None: - childs = [_from_mxnet_impl(childs[i], graph) for i in range(len(childs.list_outputs()))] - childs = [x for y in childs for x in _as_list(y)] # expand group symbol - node = _convert_symbol(op_name, childs, attr) - elif op_name != 'null': - node = _convert_symbol(op_name, [], attr) # no input symbol - else: - op_name = json.loads(symbol.tojson())['nodes'][0]['op'] - node = _sym.Variable(name=name, **attr) - graph[name] = node - return node[output_index] + def get_node(sym): + name = sym.attr('name') + if name not in graph: + return None + output_index = json.loads(sym.tojson())['heads'][0][1] + return graph[name][output_index] + + assert symbol is not None + # Traverse all symbols in topological order + for sym in _topo_sort(symbol): + name = sym.attr('name') + attr = sym.list_attr() + op_name = sym.attr('op_name') + childs = sym.get_children() + if childs is not None: + childs = [get_node(child) for child in childs] + childs = [x for y in childs for x in _as_list(y)] + node = _convert_symbol(op_name, childs, attr) + elif op_name != 'null': + node = _convert_symbol(op_name, [], attr) + else: + node = _sym.Variable(name=name, **attr) + graph[name] = node + nodes = [] + for sym in symbol: + node = get_node(sym) + assert node is not None + nodes.append(node) + if len(nodes) > 1: + return _sym.Group(nodes) + return nodes[0] def from_mxnet(symbol, arg_params=None, aux_params=None): """Convert from MXNet's model into compatible NNVM format.