Skip to content

Commit

Permalink
Merge pull request #58 from JDAI-CV/check_conv_weight
Browse files Browse the repository at this point in the history
check whether onnx convolution weight exists in initializer
  • Loading branch information
daquexian authored Jul 16, 2019
2 parents 65b14c9 + 85cad29 commit 4a4ded8
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 17 deletions.
5 changes: 5 additions & 0 deletions .daq_pm/configs/all-28
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Configuration for [project_manager.vim](https://github.com/daquexian/project_manager.vim)
name DNNLibrary
type cpp
build_dir build-all-v28
cmake_options -DCMAKE_TOOLCHAIN_FILE=~/Android/Sdk/ndk-bundle/build/cmake/android.toolchain.cmake -DANDROID_PLATFORM=android-28 -DANDROID_ABI=arm64-v8a -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -GNinja -DDNN_READ_ONNX=ON -DDNN_CUSTOM_PROTOC_EXECUTABLE=/usr/bin/protoc
6 changes: 6 additions & 0 deletions .daq_pm/configs/x86-all
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# It is configuration file for [project_manager.vim](https://github.com/daquexian/project_manager.vim)
name DNNLibrary
type cpp
cmake_options -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DDNN_READ_ONNX=ON
build_dir build_x86all

2 changes: 1 addition & 1 deletion include/common/StrKeyMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class StrKeyMap {
void clear() {
map_.clear();
}
bool has(const std::string &key) {
bool has(const std::string &key) const {
return map_.find(key) != map_.end();
}

Expand Down
3 changes: 3 additions & 0 deletions tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
add_subdirectory(onnx2daq)
if (NOT ${ONNX2DAQ_ONLY_LIB})
add_subdirectory(getsupportednodes)
endif()
12 changes: 12 additions & 0 deletions tools/getsupportednodes/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_executable(get_supported_nodes
getsupportednodes.cpp)
target_link_libraries(get_supported_nodes
onnx2daq)
target_include_directories(get_supported_nodes
PRIVATE
${PROJECT_SOURCE_DIR}
)
if (DNN_SYSTEM_PROTOBUF)
treat_warnings_as_errors(get_supported_nodes)
endif()

17 changes: 17 additions & 0 deletions tools/getsupportednodes/getsupportednodes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <fstream>

#include <common/helper.h>
#include <tools/onnx2daq/OnnxConverter.h>

int main(int argc, char *argv[])
{
ONNX_NAMESPACE::ModelProto model_proto;
std::ifstream ifs(argv[1], std::ios::in | std::ios::binary);
std::stringstream ss;
ss << ifs.rdbuf();
// FIXME: Handle the return value
model_proto.ParseFromString(ss.str());
dnn::OnnxConverter converter;
PNT(converter.GetSupportedNodes(model_proto));
return 0;
}
68 changes: 52 additions & 16 deletions tools/onnx2daq/OnnxConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ void OnnxConverter::AddConv(const string &input_name,
dilations[1] * dilations[1] -
input_shape[2];
VLOG(5) << input_shape << ", " << pads << ", " << dilations << ", "
<< new_pads;
<< new_pads;
// Why "AllowShortBlocksOnASingleLine: false" doesn't work on it?
// clang-format off
{
Expand Down Expand Up @@ -278,6 +278,9 @@ void OnnxConverter::AddConv(const string &input_name,
return;
}

if (!onnx_tensors_.has(ori_weight_name)) {
throw std::invalid_argument("The weight of convolution must be known");
}
const auto &onnx_weight = onnx_tensors_.at(ori_weight_name);
if (group == 1) {
VLOG(5) << "Vanilla conv";
Expand All @@ -286,7 +289,8 @@ void OnnxConverter::AddConv(const string &input_name,
} else if (onnx_weight.shape[1] == 1) { // depthwise
VLOG(5) << "Depthwise conv";
AddLayerDepthwiseConvImpl(input_name, ori_weight_name, bias_name, pads,
strides, 1, output_name);
strides, onnx_weight.shape[0] / group,
output_name);
} else {
// TODO: Support it
throw std::invalid_argument("group != 1 is not supported");
Expand Down Expand Up @@ -921,8 +925,8 @@ OnnxConverter::GetInputOfOnnxModel() {
nnapi_shape = shape;
}
shaper_.AddShape(input.name(), nnapi_shape);
const auto flat_input =
DNN::CreateInputDirect(builder_, &nnapi_shape, input.name().c_str());
const auto flat_input = DNN::CreateInputDirect(builder_, &nnapi_shape,
input.name().c_str());
inputs.push_back(flat_input);
}

Expand Down Expand Up @@ -1044,18 +1048,22 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
const auto strides = helper.get("strides", vector<int>{1, 1});
const auto pads = helper.get("pads", vector<int>{0, 0, 0, 0});
const auto dilations = helper.get("dilations", vector<int>{1, 1});
CHECK_EQ(pads.size(), 4ul);
CHECK_EQ(strides.size(), 2ul);
CHECK_EQ(dilations.size(), 2ul);
const auto group = helper.get("group", 1);
if (dilations != vector<int>{1, 1} && strides != vector<int>{1, 1}) {
return {false,
"Both dilations and strides > 1 is not supported for now"};
}
const auto weight_name = m(node.input(1));
const auto &onnx_weight = onnx_tensors_.at(weight_name);
if (group != 1 && onnx_weight.shape[1] != 1) {
return {false, "group != 1 is not supported"};
if (onnx_tensors_.has(weight_name)) {
const auto &onnx_weight = onnx_tensors_.at(weight_name);
if (group != 1 && onnx_weight.shape[1] != 1) {
return {false, "group != 1 is not supported"};
}
if (onnx_weight.shape.size() != 4) {
return {false, "Only conv 2d is supported."};
}
} else {
return {false, "The weight of convolution must be known"};
}
} else if (op == "AveragePool" || op == "MaxPool") {
const auto count_include_pad = helper.get("count_include_pad", 0);
Expand All @@ -1069,11 +1077,18 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
if (helper.get("auto_pad", "NOTSET") != "NOTSET") {
return {false, "auto_pad is not supported"};
}
if (helper.get("kernel_shape", std::vector<int>{1, 1}).size() != 2) {
return {false, "Only pooling 2d is supported"};
}
} else if (op == "PRelu") {
const auto slope_name = m(node.input(1));
if (onnx_tensors_.at(slope_name).shape != Shape{1}) {
// TODO: support it
return {false, "Only support one element slope."};
if (onnx_tensors_.has(slope_name)) {
if (onnx_tensors_.at(slope_name).shape != Shape{1}) {
// TODO: support it
return {false, "PRelu only support one element slope."};
}
} else {
return {false, "PRelu slope must be known"};
}
} else if (op == "Gemm") {
const auto transA = helper.get("transA", 0);
Expand Down Expand Up @@ -1108,10 +1123,31 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
}
}
}
} else if (op == "Softmax") {
const auto axis = helper.get("axis", 1);
if (axis != 1) {
return {false, "Only axis == 1 is supported in Softmax"};
}
}
return {true, ""};
}

bool IsValidSupportedNodesVec(const std::vector<int> &supported_node_vec,
const ONNX_NAMESPACE::ModelProto &model_proto) {
if (!supported_node_vec.empty()) {
if (supported_node_vec.size() == 1) {
const auto &node = model_proto.graph().node(supported_node_vec[0]);
// Reshape and Dropout are simply ignored in DNNLibrary, causing the
// input == output, which is not allowed in NNAPI
if (node.op_type() == "Reshape" || node.op_type() == "Dropout") {
return false;
}
}
return true;
}
return false;
}

std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
const ONNX_NAMESPACE::ModelProto &model_proto) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
Expand All @@ -1124,17 +1160,17 @@ std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
bool supported;
std::string error_msg;
std::tie(supported, error_msg) =
IsNodeSupported(model_proto_.graph().node(i));
IsNodeSupported(model_proto.graph().node(i));
if (supported) {
supported_node_vec.push_back(i);
} else {
if (!supported_node_vec.empty()) {
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
supported_node_vecs.push_back(supported_node_vec);
supported_node_vec.clear();
}
}
}
if (!supported_node_vec.empty()) {
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
supported_node_vecs.push_back(supported_node_vec);
}
Clear();
Expand Down

0 comments on commit 4a4ded8

Please sign in to comment.