diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 2cb592323167..178805ca0415 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -1298,122 +1298,127 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); *s = sym->Copy(); - nnvm::Graph g = Symbol2Graph(*s); - const auto& indexed_graph = g.indexed_graph(); - const auto& mutable_nodes = indexed_graph.mutable_input_nodes(); - std::vector input_names = sym->ListInputNames(nnvm::Symbol::kAll); - size_t num_forward_inputs = input_names.size(); + + // create a data structure from pointer array + std::unordered_map options_map; + for (mx_uint i = 0; i < num_options; ++i) + options_map.emplace(keys[i], vals[i]); NDArray ***new_args_ptr = reinterpret_cast(new_args_handle); NDArray ***new_aux_ptr = reinterpret_cast(new_aux_handle); + NDArray **in_args_ptr = reinterpret_cast(in_args_handle); + NDArray **in_aux_ptr = reinterpret_cast(in_aux_handle); - if (args_len || aux_len) { - NDArray **in_args_ptr = reinterpret_cast(in_args_handle); - NDArray **in_aux_ptr = reinterpret_cast(in_aux_handle); - if (!skip_infer) { - Context default_ctx = Context::Create(static_cast(dev_type), 0); - mxnet::ShapeVector arg_shapes(args_len + aux_len); - nnvm::DTypeVector arg_dtypes(args_len + aux_len); - StorageTypeVector arg_stypes(args_len + aux_len); - - // create the input shape, dtype and stype maps - std::unordered_map input_shape_map(num_input_shapes); - for (uint32_t i = 0; i < num_input_shapes; ++i) { - input_shape_map.emplace(input_shape_names[i], + auto init_graph = [&](auto s) { + nnvm::Graph g = Symbol2Graph(*s); + const auto& indexed_graph = g.indexed_graph(); + const auto& mutable_nodes = indexed_graph.mutable_input_nodes(); + std::vector input_names = s->ListInputNames(nnvm::Symbol::kAll); + size_t num_forward_inputs = input_names.size(); + + if (args_len || aux_len) { + if (!skip_infer) { + Context default_ctx = Context::Create(static_cast(dev_type), 0); + mxnet::ShapeVector arg_shapes(args_len + aux_len); + nnvm::DTypeVector arg_dtypes(args_len + aux_len); + StorageTypeVector arg_stypes(args_len + aux_len); + + // create the input shape, dtype and stype maps + std::unordered_map input_shape_map(num_input_shapes); + for (uint32_t i = 0; i < num_input_shapes; ++i) { + input_shape_map.emplace(input_shape_names[i], mxnet::TShape(input_shape_data + input_shape_idx[i], input_shape_data + input_shape_idx[i+1])); - } - std::unordered_map input_dtype_map(num_input_dtypes); - for (uint32_t i = 0; i < num_input_dtypes; ++i) { - input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]); - } - std::unordered_map input_stype_map(num_input_stypes); - for (uint32_t i = 0; i < num_input_stypes; ++i) { - input_stype_map.emplace(input_stype_names[i], input_stypes[i]); - } + } + std::unordered_map input_dtype_map(num_input_dtypes); + for (uint32_t i = 0; i < num_input_dtypes; ++i) { + input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]); + } + std::unordered_map input_stype_map(num_input_stypes); + for (uint32_t i = 0; i < num_input_stypes; ++i) { + input_stype_map.emplace(input_stype_names[i], input_stypes[i]); + } - size_t args_top = 0, aux_top = 0; - // loop over inputs to symbol in order and add to args/aux if mutable - for (size_t i = 0; i < num_forward_inputs; ++i) { - const uint32_t nid = indexed_graph.input_nodes().at(i); - if (mutable_nodes.count(nid)) { - CHECK_LT(aux_top, aux_len) - << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for"; - if (in_aux_ptr[aux_top] != nullptr) { - const auto &in_arg = *(in_aux_ptr[aux_top]); - arg_shapes[i] = in_arg.shape(); - arg_dtypes[i] = in_arg.dtype(); - arg_stypes[i] = in_arg.storage_type(); - } - aux_top++; - } else { - auto name = input_names[i]; - CHECK_LT(args_top, args_len) - << "Cannot find arg '" << name << "' in provided args to optimize_for"; - if (in_args_ptr[args_top] != nullptr) { - const auto &in_arg = *(in_args_ptr[args_top]); - arg_shapes[i] = in_arg.shape(); - arg_dtypes[i] = in_arg.dtype(); - arg_stypes[i] = in_arg.storage_type(); - } else { - // input_names[i] is not in args but can be in the optional - // shape/type/stype attribute dicts. - auto it_shape = input_shape_map.find(name); - if (it_shape != input_shape_map.end()) { - arg_shapes[i] = it_shape->second; - } - auto it_type = input_dtype_map.find(name); - if (it_type != input_dtype_map.end()) { - arg_dtypes[i] = it_type->second; + size_t args_top = 0, aux_top = 0; + // loop over inputs to symbol in order and add to args/aux if mutable + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = indexed_graph.input_nodes().at(i); + if (mutable_nodes.count(nid)) { + CHECK_LT(aux_top, aux_len) + << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for"; + if (in_aux_ptr[aux_top] != nullptr) { + const auto &in_arg = *(in_aux_ptr[aux_top]); + arg_shapes[i] = in_arg.shape(); + arg_dtypes[i] = in_arg.dtype(); + arg_stypes[i] = in_arg.storage_type(); } - it_type = input_stype_map.find(name); - if (it_type != input_stype_map.end()) { - arg_stypes[i] = it_type->second; + aux_top++; + } else { + auto name = input_names[i]; + CHECK_LT(args_top, args_len) + << "Cannot find arg '" << name << "' in provided args to optimize_for"; + if (in_args_ptr[args_top] != nullptr) { + const auto &in_arg = *(in_args_ptr[args_top]); + arg_shapes[i] = in_arg.shape(); + arg_dtypes[i] = in_arg.dtype(); + arg_stypes[i] = in_arg.storage_type(); + } else { + // input_names[i] is not in args but can be in the optional + // shape/type/stype attribute dicts. + auto it_shape = input_shape_map.find(name); + if (it_shape != input_shape_map.end()) { + arg_shapes[i] = it_shape->second; + } + auto it_type = input_dtype_map.find(name); + if (it_type != input_dtype_map.end()) { + arg_dtypes[i] = it_type->second; + } + it_type = input_stype_map.find(name); + if (it_type != input_stype_map.end()) { + arg_stypes[i] = it_type->second; + } } + args_top++; } - args_top++; } - } - g.attrs["context"] = std::make_shared( + g.attrs["context"] = std::make_shared( exec::ContextVector(indexed_graph.num_nodes(), default_ctx)); - // infer shapes - g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); - // infer dtypes - g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); - // infer stypes - g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); + // infer shapes + g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); + // infer dtypes + g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__"); + // infer stypes + g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__"); + } + // set args/aux as attributes on graph so that subgraph property can use them + std::vector arg_names = s->ListInputNames(nnvm::Symbol::kReadOnlyArgs); + g.attrs["in_args"] = std::make_shared(in_args_ptr); + g.attrs["in_arg_names"] = std::make_shared(arg_names); + + std::vector aux_names = s->ListInputNames(nnvm::Symbol::kAuxiliaryStates); + g.attrs["in_aux"] = std::make_shared(in_aux_ptr); + g.attrs["in_aux_names"] = std::make_shared(aux_names); + } else { + // args/aux were not specified, so set nullptr/empty-lists + NDArray **in_args_ptr = static_cast(nullptr); + std::vector arg_names; + g.attrs["in_args"] = std::make_shared(in_args_ptr); + g.attrs["in_arg_names"] = std::make_shared(arg_names); + + NDArray **in_aux_ptr = static_cast(nullptr); + std::vector aux_names; + g.attrs["in_aux"] = std::make_shared(in_aux_ptr); + g.attrs["in_aux_names"] = std::make_shared(aux_names); } - // set args/aux as attributes on graph so that subgraph property can use them - std::vector arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs); - g.attrs["in_args"] = std::make_shared(in_args_ptr); - g.attrs["in_arg_names"] = std::make_shared(arg_names); - - std::vector aux_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates); - g.attrs["in_aux"] = std::make_shared(in_aux_ptr); - g.attrs["in_aux_names"] = std::make_shared(aux_names); - } else { - // args/aux were not specified, so set nullptr/empty-lists - NDArray **in_args_ptr = static_cast(nullptr); - std::vector arg_names; - g.attrs["in_args"] = std::make_shared(in_args_ptr); - g.attrs["in_arg_names"] = std::make_shared(arg_names); - - NDArray **in_aux_ptr = static_cast(nullptr); - std::vector aux_names; - g.attrs["in_aux"] = std::make_shared(in_aux_ptr); - g.attrs["in_aux_names"] = std::make_shared(aux_names); - } - // create a data structure from pointer array - std::unordered_map options_map; - for (mx_uint i = 0; i < num_options; ++i) - options_map.emplace(keys[i], vals[i]); - // set dedup option as attribute on graph to enable dedup during partitioning - if (options_map.count("dedup_subgraph") > 0 && - options_map.at("dedup_subgraph").compare("True") == 0) - g.attrs["dedup_subgraph"] = std::make_shared(std::string("True")); + // set dedup option as attribute on graph to enable dedup during partitioning + if (options_map.count("dedup_subgraph") > 0 && + options_map.at("dedup_subgraph").compare("True") == 0) + g.attrs["dedup_subgraph"] = std::make_shared(std::string("True")); + return g; + }; if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) { // use subgraph backend @@ -1421,14 +1426,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, ::Get()->GetSubgraphBackend(backend_name); const auto& subgraph_prop_list = backend->GetSubgraphProperties(); for (auto property : subgraph_prop_list) { + nnvm::Graph g = init_graph(s); property->PrePartition(g, options_map); g.attrs["subgraph_property"] = std::make_shared(property); g = ApplyPass(std::move(g), "BuildSubgraph"); g.attrs.erase("subgraph_property"); property->PostPartition(g); + s->outputs = g.outputs; } } else if (dmlc::Registry::Find(backend_name) != nullptr) { // use graph pass + nnvm::Graph g = init_graph(s); g.attrs["options_map"] = std::make_shared(options_map); g.attrs["pass_name"] = std::make_shared(backend_name); g = ApplyPass(std::move(g), backend_name); @@ -1441,6 +1449,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, g.attrs.erase("new_aux"); g.attrs.erase("new_arg_names"); g.attrs.erase("new_aux_names"); + s->outputs = g.outputs; NDArray** new_arg_arr = new NDArray*[new_arg_names.size()]; NDArray** new_aux_arr = new NDArray*[new_aux_names.size()]; @@ -1472,7 +1481,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, // cannot find graph pass or subgraph backend registered in this name LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found"; } - s->outputs = g.outputs; + *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); }