diff --git a/ci/onnxruntime_test.yml b/ci/onnxruntime_test.yml index b378937..e93e93a 100644 --- a/ci/onnxruntime_test.yml +++ b/ci/onnxruntime_test.yml @@ -27,7 +27,7 @@ pool: steps: - checkout: self submodules: true - - script: git clone --recursive --branch fix_android_build https://github.com/daquexian/onnxruntime $(Agent.HomeDirectory)/onnxruntime + - script: git clone --recursive --branch android https://github.com/daquexian/onnxruntime $(Agent.HomeDirectory)/onnxruntime displayName: Clone ONNX Runtime - script: rm -rf $(Agent.HomeDirectory)/onnxruntime/cmake/external/DNNLibrary && cp -r $(Build.SourcesDirectory) $(Agent.HomeDirectory)/onnxruntime/cmake/external/DNNLibrary displayName: Copy latest DNNLibrary diff --git a/include/tools/onnx2daq/OnnxConverter.h b/include/tools/onnx2daq/OnnxConverter.h index 0fba27c..8010061 100644 --- a/include/tools/onnx2daq/OnnxConverter.h +++ b/include/tools/onnx2daq/OnnxConverter.h @@ -215,7 +215,7 @@ class OnnxConverter { void Clear(); public: - std::vector> GetSupportedNodes( + expected>, std::string> GetSupportedNodes( ONNX_NAMESPACE::ModelProto model_proto); void Convert(const std::string &model_str, const std::string &filepath, const std::string &table_file = ""); diff --git a/tools/getsupportednodes/getsupportednodes.cpp b/tools/getsupportednodes/getsupportednodes.cpp index 92dc9a5..c2c819c 100644 --- a/tools/getsupportednodes/getsupportednodes.cpp +++ b/tools/getsupportednodes/getsupportednodes.cpp @@ -12,6 +12,14 @@ int main(int argc, char *argv[]) // FIXME: Handle the return value model_proto.ParseFromString(ss.str()); dnn::OnnxConverter converter; - PNT(converter.GetSupportedNodes(model_proto)); - return 0; + const auto nodes = converter.GetSupportedNodes(model_proto); + if (nodes) { + const auto &supported_ops = nodes.value(); + PNT(supported_ops); + return 0; + } else { + const auto &error = nodes.error(); + PNT(error); + return 1; + } } diff --git a/tools/onnx2daq/OnnxConverter.cpp b/tools/onnx2daq/OnnxConverter.cpp index 3be1337..b975d67 100644 --- a/tools/onnx2daq/OnnxConverter.cpp +++ b/tools/onnx2daq/OnnxConverter.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -271,8 +272,9 @@ void OnnxConverter::HandleInitializer() { ONNX_NAMESPACE::TensorProto_DataType_INT64) { // TODO: shape of reshape layer } else { - PNT(tensor.name(), tensor.data_type()); - DNN_ASSERT(false, ""); + DNN_ASSERT(false, "The data type \"" + std::to_string(tensor.data_type()) + + "\" of tensor \"" + + tensor.name() + "\" is not supported"); } operands_.push_back(name); } @@ -630,34 +632,38 @@ bool IsValidSupportedNodesVec(const std::vector &supported_node_vec, return false; } -std::vector> OnnxConverter::GetSupportedNodes( +expected>, std::string> OnnxConverter::GetSupportedNodes( ONNX_NAMESPACE::ModelProto model_proto) { GOOGLE_PROTOBUF_VERIFY_VERSION; ONNX_NAMESPACE::shape_inference::InferShapes(model_proto); model_proto_ = model_proto; - HandleInitializer(); + try { + HandleInitializer(); - std::vector> supported_node_vecs; - std::vector supported_node_vec; - for (int i = 0; i < model_proto.graph().node_size(); i++) { - bool supported; - std::string error_msg; - std::tie(supported, error_msg) = - IsNodeSupported(model_proto, model_proto.graph().node(i)); - if (supported) { - supported_node_vec.push_back(i); - } else { - if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) { - supported_node_vecs.push_back(supported_node_vec); - supported_node_vec.clear(); + std::vector> supported_node_vecs; + std::vector supported_node_vec; + for (int i = 0; i < model_proto.graph().node_size(); i++) { + bool supported; + std::string error_msg; + std::tie(supported, error_msg) = + IsNodeSupported(model_proto, model_proto.graph().node(i)); + if (supported) { + supported_node_vec.push_back(i); + } else { + if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) { + supported_node_vecs.push_back(supported_node_vec); + supported_node_vec.clear(); + } } } + if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) { + supported_node_vecs.push_back(supported_node_vec); + } + Clear(); + return supported_node_vecs; + } catch (std::exception &e) { + return make_unexpected(e.what()); } - if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) { - supported_node_vecs.push_back(supported_node_vec); - } - Clear(); - return supported_node_vecs; } void OnnxConverter::Convert(const ONNX_NAMESPACE::ModelProto &model_proto,