Skip to content

Commit

Permalink
Add GetSupportedNodes()
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Jun 18, 2019
1 parent 9120998 commit 869d560
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions tools/onnx2daq/OnnxConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,8 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
const auto output_name = node.output(0);
for (const auto another_node : model_proto_.graph().node()) {
for (const auto input_name : another_node.input()) {
if (input_name == output_name && another_node.op_type() != "Gemm") {
if (input_name == output_name &&
another_node.op_type() != "Gemm") {
return {false,
"Reshape can only be the last layer or precede a "
"gemm layer for now"};
Expand All @@ -843,8 +844,30 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
const ONNX_NAMESPACE::ModelProto &model_proto) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
model_proto_ = model_proto;
HandleInitializer();


std::vector<std::vector<int>> supported_node_vecs;
std::vector<int> supported_node_vec;
for (size_t i = 0; i < model_proto.graph().node_size(); i++) {
bool supported;
std::string error_msg;
std::tie(supported, error_msg) =
IsNodeSupported(model_proto_.graph().node(i));
if (supported) {
supported_node_vec.push_back(i);
} else {
if (!supported_node_vec.empty()) {
supported_node_vecs.push_back(supported_node_vec);
supported_node_vec.clear();
}
}
}
if (!supported_node_vec.empty()) {
supported_node_vecs.push_back(supported_node_vec);
}
Clear();
return supported_node_vecs;
}

void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,
Expand Down

0 comments on commit 869d560

Please sign in to comment.