Skip to content

Commit

Permalink
Merge pull request #59 from JDAI-CV/check_conv_weight
Browse files Browse the repository at this point in the history
Update onnx->daq checker
  • Loading branch information
daquexian authored Jul 16, 2019
2 parents 4a4ded8 + a724168 commit ab22710
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 7 deletions.
4 changes: 3 additions & 1 deletion dnnlibrary/NeuralNetworksWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ OperandType::OperandType(Type type, std::vector<uint32_t> d, float scale,
int32_t zeroPoint)
: type(type), dimensions(std::move(d)), channelQuant(std::nullopt) {
if (dimensions.empty()) {
DNN_ASSERT(isScalarType(type), typeToStr(type));
if (!isScalarType(type)) {
dimensions = {1};
}
} else {
DNN_ASSERT(!isScalarType(type), typeToStr(type));
}
Expand Down
6 changes: 4 additions & 2 deletions include/tools/onnx2daq/OnnxConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ class OnnxConverter {

void HandleInitializer();
std::vector<flatbuffers::Offset<DNN::Input>> GetInputOfOnnxModel();
std::vector<flatbuffers::Offset<flatbuffers::String>> GetOutputOfOnnxModel();
std::vector<flatbuffers::Offset<flatbuffers::String>>
GetOutputOfOnnxModel();
void ReadTableFile(const std::string &table_file);
std::vector<flatbuffers::Offset<DNN::QuantInfo>> ConvertQuantInfosToFbs();

std::pair<bool, std::string> IsNodeSupported(
const ONNX_NAMESPACE::ModelProto &model_proto,
const ONNX_NAMESPACE::NodeProto &node_proto) const;

void AddConv(const std::string &input_name, const std::vector<int> &strides,
Expand Down Expand Up @@ -184,7 +186,7 @@ class OnnxConverter {

public:
std::vector<std::vector<int>> GetSupportedNodes(
const ONNX_NAMESPACE::ModelProto &model);
ONNX_NAMESPACE::ModelProto model_proto);
void Convert(const std::string &model_str, const std::string &filepath,
const std::string &table_file = "");
void Convert(const ONNX_NAMESPACE::ModelProto &model,
Expand Down
2 changes: 1 addition & 1 deletion tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_subdirectory(onnx2daq)
if (NOT ${ONNX2DAQ_ONLY_LIB})
if (NOT DEFINED ONNX2DAQ_ONLY_LIB OR NOT ${ONNX2DAQ_ONLY_LIB})
add_subdirectory(getsupportednodes)
endif()
88 changes: 85 additions & 3 deletions tools/onnx2daq/OnnxConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <common/helper.h>
#include <glog/logging.h>
#include <onnx/optimizer/optimize.h>
#include <onnx/shape_inference/implementation.h>
#include "NodeAttrHelper.h"

using std::string;
Expand Down Expand Up @@ -148,6 +149,9 @@ OnnxConverter::FbStrVector(const std::vector<std::string> &std_str_vector) {
* nnapi: [1, height, width, depth_out]
*/
OnnxConverter::Tensor OnnxConverter::OnnxToNnapiAxes1230(const Tensor &src) {
if (src.shape.size() != 4) {
return src;
}
Tensor dest = src;
size_t elemsize = 0;
if (src.data_type == Tensor::DataType::UINT8) {
Expand Down Expand Up @@ -180,6 +184,9 @@ OnnxConverter::Tensor OnnxConverter::OnnxToNnapiAxes1230(const Tensor &src) {
}

OnnxConverter::Tensor OnnxConverter::OnnxToNnapiAxes0231(const Tensor &src) {
if (src.shape.size() != 4) {
return src;
}
Tensor dest = src;
size_t elemsize = 0;
if (src.data_type == Tensor::DataType::UINT8) {
Expand Down Expand Up @@ -1026,7 +1033,40 @@ void OnnxConverter::Convert(const std::string &model_str,
Save(filepath);
}

dnn::optional<Shaper::Shape> GetShape(
const ONNX_NAMESPACE::ModelProto &model_proto, const std::string &name) {
for (const auto &value_info : model_proto.graph().value_info()) {
if (value_info.name() == name) {
if (!value_info.has_type()) {
return dnn::nullopt;
} else if (!value_info.type().has_tensor_type()) {
return dnn::nullopt;
} else if (!value_info.type().tensor_type().has_shape()) {
return dnn::nullopt;
} else if (value_info.type().tensor_type().shape().dim_size() ==
0) {
return dnn::nullopt;
}

Shape shape;
for (const auto &dim :
value_info.type().tensor_type().shape().dim()) {
if (dim.has_dim_value()) {
shape.push_back(dim.dim_value());
} else {
return dnn::nullopt;
}
}

return shape;
}
}

return dnn::nullopt;
}

std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
const ONNX_NAMESPACE::ModelProto &model_proto,
const ONNX_NAMESPACE::NodeProto &node) const {
NodeAttrHelper helper(node);
const auto &op = node.op_type();
Expand Down Expand Up @@ -1080,6 +1120,21 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
if (helper.get("kernel_shape", std::vector<int>{1, 1}).size() != 2) {
return {false, "Only pooling 2d is supported"};
}
if (helper.get("ceil_mode", 0) == 1) {
return {false, "ceil_mode == 1 is not supported for pooling"};
}
if (helper.get("dilations", std::vector<int>{1, 1}) !=
std::vector<int>{1, 1}) {
return {false, "Dilations of pooling is not supported"};
}
if (node.output_size() != 1) {
return {false, "Argmax in maxpooling is not supported"};
}
} else if (op == "GlobalAveragePool" || op == "GlobalMaxPool") {
const auto &input_shape = GetShape(model_proto, node.input(0));
if (!input_shape.has_value() || input_shape.value().size() != 4) {
return {false, "Only rank-4 tensor is supported in " + op};
}
} else if (op == "PRelu") {
const auto slope_name = m(node.input(1));
if (onnx_tensors_.has(slope_name)) {
Expand All @@ -1106,6 +1161,22 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
"Your onnx model may be in training mode, please export "
"it in test mode."};
}
const auto scale_name = m(node.input(1));
const auto b_name = m(node.input(2));
const auto mean_name = m(node.input(3));
const auto var_name = m(node.input(4));
if (!onnx_tensors_.has(scale_name)) {
return {false, "Scale of BN must be known"};
}
if (!onnx_tensors_.has(b_name)) {
return {false, "B of BN must be known"};
}
if (!onnx_tensors_.has(mean_name)) {
return {false, "Mean of BN must be known"};
}
if (!onnx_tensors_.has(var_name)) {
return {false, "Var of BN must be known"};
}
} else if (op == "LRN") {
const auto size = helper.get("size", 1);
if (size % 2 == 0) {
Expand All @@ -1128,6 +1199,15 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
if (axis != 1) {
return {false, "Only axis == 1 is supported in Softmax"};
}
const auto &input_shape = GetShape(model_proto, node.input(0));
if (!input_shape.has_value() || input_shape.value().size() != 4) {
return {false, "Only rank-4 tensor is supported in Softmax"};
}
} else if (op == "Concat") {
const auto &input_shape = GetShape(model_proto, node.input(0));
if (!input_shape.has_value() || input_shape.value().size() != 4) {
return {false, "Only rank-4 tensor is supported in Softmax"};
}
}
return {true, ""};
}
Expand All @@ -1139,7 +1219,8 @@ bool IsValidSupportedNodesVec(const std::vector<int> &supported_node_vec,
const auto &node = model_proto.graph().node(supported_node_vec[0]);
// Reshape and Dropout are simply ignored in DNNLibrary, causing the
// input == output, which is not allowed in NNAPI
if (node.op_type() == "Reshape" || node.op_type() == "Dropout") {
if (node.op_type() == "Reshape" || node.op_type() == "Dropout" ||
node.op_type() == "Identity") {
return false;
}
}
Expand All @@ -1149,8 +1230,9 @@ bool IsValidSupportedNodesVec(const std::vector<int> &supported_node_vec,
}

std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
const ONNX_NAMESPACE::ModelProto &model_proto) {
ONNX_NAMESPACE::ModelProto model_proto) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
ONNX_NAMESPACE::shape_inference::InferShapes(model_proto);
model_proto_ = model_proto;
HandleInitializer();

Expand All @@ -1160,7 +1242,7 @@ std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
bool supported;
std::string error_msg;
std::tie(supported, error_msg) =
IsNodeSupported(model_proto.graph().node(i));
IsNodeSupported(model_proto, model_proto.graph().node(i));
if (supported) {
supported_node_vec.push_back(i);
} else {
Expand Down

0 comments on commit ab22710

Please sign in to comment.