From 869d5609c684ca3026cd9574350a54c36da0edbd Mon Sep 17 00:00:00 2001 From: daquexian Date: Tue, 18 Jun 2019 09:55:11 +0800 Subject: [PATCH] Add GetSupportedNodes() --- tools/onnx2daq/OnnxConverter.cpp | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/tools/onnx2daq/OnnxConverter.cpp b/tools/onnx2daq/OnnxConverter.cpp index 327563c..7e248c8 100644 --- a/tools/onnx2daq/OnnxConverter.cpp +++ b/tools/onnx2daq/OnnxConverter.cpp @@ -829,7 +829,8 @@ std::pair OnnxConverter::IsNodeSupported( const auto output_name = node.output(0); for (const auto another_node : model_proto_.graph().node()) { for (const auto input_name : another_node.input()) { - if (input_name == output_name && another_node.op_type() != "Gemm") { + if (input_name == output_name && + another_node.op_type() != "Gemm") { return {false, "Reshape can only be the last layer or precede a " "gemm layer for now"}; @@ -843,8 +844,30 @@ std::pair OnnxConverter::IsNodeSupported( std::vector> OnnxConverter::GetSupportedNodes( const ONNX_NAMESPACE::ModelProto &model_proto) { GOOGLE_PROTOBUF_VERIFY_VERSION; + model_proto_ = model_proto; HandleInitializer(); - + + std::vector> supported_node_vecs; + std::vector supported_node_vec; + for (size_t i = 0; i < model_proto.graph().node_size(); i++) { + bool supported; + std::string error_msg; + std::tie(supported, error_msg) = + IsNodeSupported(model_proto_.graph().node(i)); + if (supported) { + supported_node_vec.push_back(i); + } else { + if (!supported_node_vec.empty()) { + supported_node_vecs.push_back(supported_node_vec); + supported_node_vec.clear(); + } + } + } + if (!supported_node_vec.empty()) { + supported_node_vecs.push_back(supported_node_vec); + } + Clear(); + return supported_node_vecs; } void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,