diff --git a/dnnlibrary/NeuralNetworksWrapper.cpp b/dnnlibrary/NeuralNetworksWrapper.cpp index b009f12..5638539 100644 --- a/dnnlibrary/NeuralNetworksWrapper.cpp +++ b/dnnlibrary/NeuralNetworksWrapper.cpp @@ -10,7 +10,9 @@ OperandType::OperandType(Type type, std::vector d, float scale, int32_t zeroPoint) : type(type), dimensions(std::move(d)), channelQuant(std::nullopt) { if (dimensions.empty()) { - DNN_ASSERT(isScalarType(type), typeToStr(type)); + if (!isScalarType(type)) { + dimensions = {1}; + } } else { DNN_ASSERT(!isScalarType(type), typeToStr(type)); } diff --git a/include/tools/onnx2daq/OnnxConverter.h b/include/tools/onnx2daq/OnnxConverter.h index 6afc77b..d72ea1c 100644 --- a/include/tools/onnx2daq/OnnxConverter.h +++ b/include/tools/onnx2daq/OnnxConverter.h @@ -80,11 +80,13 @@ class OnnxConverter { void HandleInitializer(); std::vector> GetInputOfOnnxModel(); - std::vector> GetOutputOfOnnxModel(); + std::vector> + GetOutputOfOnnxModel(); void ReadTableFile(const std::string &table_file); std::vector> ConvertQuantInfosToFbs(); std::pair IsNodeSupported( + const ONNX_NAMESPACE::ModelProto &model_proto, const ONNX_NAMESPACE::NodeProto &node_proto) const; void AddConv(const std::string &input_name, const std::vector &strides, @@ -184,7 +186,7 @@ class OnnxConverter { public: std::vector> GetSupportedNodes( - const ONNX_NAMESPACE::ModelProto &model); + ONNX_NAMESPACE::ModelProto model_proto); void Convert(const std::string &model_str, const std::string &filepath, const std::string &table_file = ""); void Convert(const ONNX_NAMESPACE::ModelProto &model, diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 0fdd21b..2c13029 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,4 +1,4 @@ add_subdirectory(onnx2daq) -if (NOT ${ONNX2DAQ_ONLY_LIB}) +if (NOT DEFINED ONNX2DAQ_ONLY_LIB OR NOT ${ONNX2DAQ_ONLY_LIB}) add_subdirectory(getsupportednodes) endif() diff --git a/tools/onnx2daq/OnnxConverter.cpp b/tools/onnx2daq/OnnxConverter.cpp index 68b4231..d17a4df 100644 --- a/tools/onnx2daq/OnnxConverter.cpp +++ b/tools/onnx2daq/OnnxConverter.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "NodeAttrHelper.h" using std::string; @@ -148,6 +149,9 @@ OnnxConverter::FbStrVector(const std::vector &std_str_vector) { * nnapi: [1, height, width, depth_out] */ OnnxConverter::Tensor OnnxConverter::OnnxToNnapiAxes1230(const Tensor &src) { + if (src.shape.size() != 4) { + return src; + } Tensor dest = src; size_t elemsize = 0; if (src.data_type == Tensor::DataType::UINT8) { @@ -180,6 +184,9 @@ OnnxConverter::Tensor OnnxConverter::OnnxToNnapiAxes1230(const Tensor &src) { } OnnxConverter::Tensor OnnxConverter::OnnxToNnapiAxes0231(const Tensor &src) { + if (src.shape.size() != 4) { + return src; + } Tensor dest = src; size_t elemsize = 0; if (src.data_type == Tensor::DataType::UINT8) { @@ -1026,7 +1033,40 @@ void OnnxConverter::Convert(const std::string &model_str, Save(filepath); } +dnn::optional GetShape( + const ONNX_NAMESPACE::ModelProto &model_proto, const std::string &name) { + for (const auto &value_info : model_proto.graph().value_info()) { + if (value_info.name() == name) { + if (!value_info.has_type()) { + return dnn::nullopt; + } else if (!value_info.type().has_tensor_type()) { + return dnn::nullopt; + } else if (!value_info.type().tensor_type().has_shape()) { + return dnn::nullopt; + } else if (value_info.type().tensor_type().shape().dim_size() == + 0) { + return dnn::nullopt; + } + + Shape shape; + for (const auto &dim : + value_info.type().tensor_type().shape().dim()) { + if (dim.has_dim_value()) { + shape.push_back(dim.dim_value()); + } else { + return dnn::nullopt; + } + } + + return shape; + } + } + + return dnn::nullopt; +} + std::pair OnnxConverter::IsNodeSupported( + const ONNX_NAMESPACE::ModelProto &model_proto, const ONNX_NAMESPACE::NodeProto &node) const { NodeAttrHelper helper(node); const auto &op = node.op_type(); @@ -1080,6 +1120,21 @@ std::pair OnnxConverter::IsNodeSupported( if (helper.get("kernel_shape", std::vector{1, 1}).size() != 2) { return {false, "Only pooling 2d is supported"}; } + if (helper.get("ceil_mode", 0) == 1) { + return {false, "ceil_mode == 1 is not supported for pooling"}; + } + if (helper.get("dilations", std::vector{1, 1}) != + std::vector{1, 1}) { + return {false, "Dilations of pooling is not supported"}; + } + if (node.output_size() != 1) { + return {false, "Argmax in maxpooling is not supported"}; + } + } else if (op == "GlobalAveragePool" || op == "GlobalMaxPool") { + const auto &input_shape = GetShape(model_proto, node.input(0)); + if (!input_shape.has_value() || input_shape.value().size() != 4) { + return {false, "Only rank-4 tensor is supported in " + op}; + } } else if (op == "PRelu") { const auto slope_name = m(node.input(1)); if (onnx_tensors_.has(slope_name)) { @@ -1106,6 +1161,22 @@ std::pair OnnxConverter::IsNodeSupported( "Your onnx model may be in training mode, please export " "it in test mode."}; } + const auto scale_name = m(node.input(1)); + const auto b_name = m(node.input(2)); + const auto mean_name = m(node.input(3)); + const auto var_name = m(node.input(4)); + if (!onnx_tensors_.has(scale_name)) { + return {false, "Scale of BN must be known"}; + } + if (!onnx_tensors_.has(b_name)) { + return {false, "B of BN must be known"}; + } + if (!onnx_tensors_.has(mean_name)) { + return {false, "Mean of BN must be known"}; + } + if (!onnx_tensors_.has(var_name)) { + return {false, "Var of BN must be known"}; + } } else if (op == "LRN") { const auto size = helper.get("size", 1); if (size % 2 == 0) { @@ -1128,6 +1199,15 @@ std::pair OnnxConverter::IsNodeSupported( if (axis != 1) { return {false, "Only axis == 1 is supported in Softmax"}; } + const auto &input_shape = GetShape(model_proto, node.input(0)); + if (!input_shape.has_value() || input_shape.value().size() != 4) { + return {false, "Only rank-4 tensor is supported in Softmax"}; + } + } else if (op == "Concat") { + const auto &input_shape = GetShape(model_proto, node.input(0)); + if (!input_shape.has_value() || input_shape.value().size() != 4) { + return {false, "Only rank-4 tensor is supported in Softmax"}; + } } return {true, ""}; } @@ -1139,7 +1219,8 @@ bool IsValidSupportedNodesVec(const std::vector &supported_node_vec, const auto &node = model_proto.graph().node(supported_node_vec[0]); // Reshape and Dropout are simply ignored in DNNLibrary, causing the // input == output, which is not allowed in NNAPI - if (node.op_type() == "Reshape" || node.op_type() == "Dropout") { + if (node.op_type() == "Reshape" || node.op_type() == "Dropout" || + node.op_type() == "Identity") { return false; } } @@ -1149,8 +1230,9 @@ bool IsValidSupportedNodesVec(const std::vector &supported_node_vec, } std::vector> OnnxConverter::GetSupportedNodes( - const ONNX_NAMESPACE::ModelProto &model_proto) { + ONNX_NAMESPACE::ModelProto model_proto) { GOOGLE_PROTOBUF_VERIFY_VERSION; + ONNX_NAMESPACE::shape_inference::InferShapes(model_proto); model_proto_ = model_proto; HandleInitializer(); @@ -1160,7 +1242,7 @@ std::vector> OnnxConverter::GetSupportedNodes( bool supported; std::string error_msg; std::tie(supported, error_msg) = - IsNodeSupported(model_proto.graph().node(i)); + IsNodeSupported(model_proto, model_proto.graph().node(i)); if (supported) { supported_node_vec.push_back(i); } else {