Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix for optimize_for multiple subgraph properties issue #19263

Merged
merged 2 commits into from
Oct 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 110 additions & 101 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1298,137 +1298,145 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
*s = sym->Copy();
nnvm::Graph g = Symbol2Graph(*s);
const auto& indexed_graph = g.indexed_graph();
const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
std::vector<std::string> input_names = sym->ListInputNames(nnvm::Symbol::kAll);
size_t num_forward_inputs = input_names.size();

// create a data structure from pointer array
std::unordered_map<std::string, std::string> options_map;
for (mx_uint i = 0; i < num_options; ++i)
options_map.emplace(keys[i], vals[i]);

NDArray ***new_args_ptr = reinterpret_cast<NDArray***>(new_args_handle);
NDArray ***new_aux_ptr = reinterpret_cast<NDArray***>(new_aux_handle);
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);

if (args_len || aux_len) {
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
if (!skip_infer) {
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);

// create the input shape, dtype and stype maps
std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
for (uint32_t i = 0; i < num_input_shapes; ++i) {
input_shape_map.emplace(input_shape_names[i],
auto init_graph = [&](auto s) {
nnvm::Graph g = Symbol2Graph(*s);
const auto& indexed_graph = g.indexed_graph();
const auto& mutable_nodes = indexed_graph.mutable_input_nodes();
std::vector<std::string> input_names = s->ListInputNames(nnvm::Symbol::kAll);
size_t num_forward_inputs = input_names.size();

if (args_len || aux_len) {
if (!skip_infer) {
Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
mxnet::ShapeVector arg_shapes(args_len + aux_len);
nnvm::DTypeVector arg_dtypes(args_len + aux_len);
StorageTypeVector arg_stypes(args_len + aux_len);

// create the input shape, dtype and stype maps
std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
for (uint32_t i = 0; i < num_input_shapes; ++i) {
input_shape_map.emplace(input_shape_names[i],
mxnet::TShape(input_shape_data + input_shape_idx[i],
input_shape_data + input_shape_idx[i+1]));
}
std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
for (uint32_t i = 0; i < num_input_dtypes; ++i) {
input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
}
std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
for (uint32_t i = 0; i < num_input_stypes; ++i) {
input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
}
}
std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
for (uint32_t i = 0; i < num_input_dtypes; ++i) {
input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
}
std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
for (uint32_t i = 0; i < num_input_stypes; ++i) {
input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
}

size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
if (in_aux_ptr[aux_top] != nullptr) {
const auto &in_arg = *(in_aux_ptr[aux_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
aux_top++;
} else {
auto name = input_names[i];
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << name << "' in provided args to optimize_for";
if (in_args_ptr[args_top] != nullptr) {
const auto &in_arg = *(in_args_ptr[args_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
// input_names[i] is not in args but can be in the optional
// shape/type/stype attribute dicts.
auto it_shape = input_shape_map.find(name);
if (it_shape != input_shape_map.end()) {
arg_shapes[i] = it_shape->second;
}
auto it_type = input_dtype_map.find(name);
if (it_type != input_dtype_map.end()) {
arg_dtypes[i] = it_type->second;
size_t args_top = 0, aux_top = 0;
// loop over inputs to symbol in order and add to args/aux if mutable
for (size_t i = 0; i < num_forward_inputs; ++i) {
const uint32_t nid = indexed_graph.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_len)
<< "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
if (in_aux_ptr[aux_top] != nullptr) {
const auto &in_arg = *(in_aux_ptr[aux_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
}
it_type = input_stype_map.find(name);
if (it_type != input_stype_map.end()) {
arg_stypes[i] = it_type->second;
aux_top++;
} else {
auto name = input_names[i];
CHECK_LT(args_top, args_len)
<< "Cannot find arg '" << name << "' in provided args to optimize_for";
if (in_args_ptr[args_top] != nullptr) {
const auto &in_arg = *(in_args_ptr[args_top]);
arg_shapes[i] = in_arg.shape();
arg_dtypes[i] = in_arg.dtype();
arg_stypes[i] = in_arg.storage_type();
} else {
// input_names[i] is not in args but can be in the optional
// shape/type/stype attribute dicts.
auto it_shape = input_shape_map.find(name);
if (it_shape != input_shape_map.end()) {
arg_shapes[i] = it_shape->second;
}
auto it_type = input_dtype_map.find(name);
if (it_type != input_dtype_map.end()) {
arg_dtypes[i] = it_type->second;
}
it_type = input_stype_map.find(name);
if (it_type != input_stype_map.end()) {
arg_stypes[i] = it_type->second;
}
}
args_top++;
}
args_top++;
}
}

g.attrs["context"] = std::make_shared<nnvm::any>(
g.attrs["context"] = std::make_shared<nnvm::any>(
exec::ContextVector(indexed_graph.num_nodes(), default_ctx));

// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
// infer shapes
g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
// infer dtypes
g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
// infer stypes
g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
}
// set args/aux as attributes on graph so that subgraph property can use them
std::vector<std::string> arg_names = s->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);

std::vector<std::string> aux_names = s->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
} else {
// args/aux were not specified, so set nullptr/empty-lists
NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
std::vector<std::string> arg_names;
g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);

NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
std::vector<std::string> aux_names;
g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
}
// set args/aux as attributes on graph so that subgraph property can use them
std::vector<std::string> arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);

std::vector<std::string> aux_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates);
g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
} else {
// args/aux were not specified, so set nullptr/empty-lists
NDArray **in_args_ptr = static_cast<NDArray**>(nullptr);
std::vector<std::string> arg_names;
g.attrs["in_args"] = std::make_shared<nnvm::any>(in_args_ptr);
g.attrs["in_arg_names"] = std::make_shared<nnvm::any>(arg_names);

NDArray **in_aux_ptr = static_cast<NDArray**>(nullptr);
std::vector<std::string> aux_names;
g.attrs["in_aux"] = std::make_shared<nnvm::any>(in_aux_ptr);
g.attrs["in_aux_names"] = std::make_shared<nnvm::any>(aux_names);
}
// create a data structure from pointer array
std::unordered_map<std::string, std::string> options_map;
for (mx_uint i = 0; i < num_options; ++i)
options_map.emplace(keys[i], vals[i]);

// set dedup option as attribute on graph to enable dedup during partitioning
if (options_map.count("dedup_subgraph") > 0 &&
options_map.at("dedup_subgraph").compare("True") == 0)
g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
// set dedup option as attribute on graph to enable dedup during partitioning
if (options_map.count("dedup_subgraph") > 0 &&
options_map.at("dedup_subgraph").compare("True") == 0)
g.attrs["dedup_subgraph"] = std::make_shared<nnvm::any>(std::string("True"));
return g;
};

if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) {
// use subgraph backend
const auto backend = mxnet::op::SubgraphBackendRegistry
::Get()->GetSubgraphBackend(backend_name);
const auto& subgraph_prop_list = backend->GetSubgraphProperties();
for (auto property : subgraph_prop_list) {
nnvm::Graph g = init_graph(s);
property->PrePartition(g, options_map);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(property);
g = ApplyPass(std::move(g), "BuildSubgraph");
g.attrs.erase("subgraph_property");
property->PostPartition(g);
s->outputs = g.outputs;
}
} else if (dmlc::Registry<nnvm::PassFunctionReg>::Find(backend_name) != nullptr) {
// use graph pass
nnvm::Graph g = init_graph(s);
g.attrs["options_map"] = std::make_shared<nnvm::any>(options_map);
g.attrs["pass_name"] = std::make_shared<nnvm::any>(backend_name);
g = ApplyPass(std::move(g), backend_name);
Expand All @@ -1441,6 +1449,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
g.attrs.erase("new_aux");
g.attrs.erase("new_arg_names");
g.attrs.erase("new_aux_names");
s->outputs = g.outputs;

NDArray** new_arg_arr = new NDArray*[new_arg_names.size()];
NDArray** new_aux_arr = new NDArray*[new_aux_names.size()];
Expand Down Expand Up @@ -1472,7 +1481,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
// cannot find graph pass or subgraph backend registered in this name
LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found";
}
s->outputs = g.outputs;

*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}