Skip to content

Commit

Permalink
Add getsupportednodes tool, check for single reshape/dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Jul 16, 2019
1 parent 79edcc3 commit 18bce4e
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .daq_pm/configs/all-28
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Configuration for [project_manager.vim](https://github.com/daquexian/project_manager.vim)
name DNNLibrary
type cpp
build_dir build-all-v28
cmake_options -DCMAKE_TOOLCHAIN_FILE=~/Android/Sdk/ndk-bundle/build/cmake/android.toolchain.cmake -DANDROID_PLATFORM=android-28 -DANDROID_ABI=arm64-v8a -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -GNinja -DDNN_READ_ONNX=ON -DDNN_CUSTOM_PROTOC_EXECUTABLE=/usr/bin/protoc
6 changes: 6 additions & 0 deletions .daq_pm/configs/x86-all
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# It is configuration file for [project_manager.vim](https://github.com/daquexian/project_manager.vim)
name DNNLibrary
type cpp
cmake_options -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DDNN_READ_ONNX=ON
build_dir build_x86all

1 change: 1 addition & 0 deletions tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(onnx2daq)
add_subdirectory(getsupportednodes)
12 changes: 12 additions & 0 deletions tools/getsupportednodes/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_executable(get_supported_nodes
getsupportednodes.cpp)
target_link_libraries(get_supported_nodes
onnx2daq)
target_include_directories(get_supported_nodes
PRIVATE
${PROJECT_SOURCE_DIR}
)
if (DNN_SYSTEM_PROTOBUF)
treat_warnings_as_errors(get_supported_nodes)
endif()

17 changes: 17 additions & 0 deletions tools/getsupportednodes/getsupportednodes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <fstream>

#include <common/helper.h>
#include <tools/onnx2daq/OnnxConverter.h>

int main(int argc, char *argv[])
{
ONNX_NAMESPACE::ModelProto model_proto;
std::ifstream ifs(argv[1], std::ios::in | std::ios::binary);
std::stringstream ss;
ss << ifs.rdbuf();
// FIXME: Handle the return value
model_proto.ParseFromString(ss.str());
dnn::OnnxConverter converter;
PNT(converter.GetSupportedNodes(model_proto));
return 0;
}
22 changes: 19 additions & 3 deletions tools/onnx2daq/OnnxConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,22 @@ std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
return {true, ""};
}

bool IsValidSupportedNodesVec(const std::vector<int> &supported_node_vec,
const ONNX_NAMESPACE::ModelProto &model_proto) {
if (!supported_node_vec.empty()) {
if (supported_node_vec.size() == 1) {
const auto &node = model_proto.graph().node(supported_node_vec[0]);
// Reshape and Dropout are simply ignored in DNNLibrary, causing the
// input == output, which is not allowed in NNAPI
if (node.op_type() == "Reshape" || node.op_type() == "Dropout") {
return false;
}
}
return true;
}
return false;
}

std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
const ONNX_NAMESPACE::ModelProto &model_proto) {
GOOGLE_PROTOBUF_VERIFY_VERSION;
Expand All @@ -1144,17 +1160,17 @@ std::vector<std::vector<int>> OnnxConverter::GetSupportedNodes(
bool supported;
std::string error_msg;
std::tie(supported, error_msg) =
IsNodeSupported(model_proto_.graph().node(i));
IsNodeSupported(model_proto.graph().node(i));
if (supported) {
supported_node_vec.push_back(i);
} else {
if (!supported_node_vec.empty()) {
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
supported_node_vecs.push_back(supported_node_vec);
supported_node_vec.clear();
}
}
}
if (!supported_node_vec.empty()) {
if (IsValidSupportedNodesVec(supported_node_vec, model_proto)) {
supported_node_vecs.push_back(supported_node_vec);
}
Clear();
Expand Down

0 comments on commit 18bce4e

Please sign in to comment.