diff --git a/cmake/developer_package/plugins/plugins.cmake b/cmake/developer_package/plugins/plugins.cmake index 6210ede333ad8b..cec023f3062513 100644 --- a/cmake/developer_package/plugins/plugins.cmake +++ b/cmake/developer_package/plugins/plugins.cmake @@ -112,8 +112,8 @@ function(ie_add_plugin) if(TARGET inference_engine_ir_v7_reader) add_dependencies(${IE_PLUGIN_NAME} inference_engine_ir_v7_reader) endif() - if(TARGET inference_engine_onnx_reader) - add_dependencies(${IE_PLUGIN_NAME} inference_engine_onnx_reader) + if(TARGET onnx_ngraph_frontend) + add_dependencies(${IE_PLUGIN_NAME} onnx_ngraph_frontend) endif() # install rules diff --git a/docs/IE_DG/Deep_Learning_Inference_Engine_DevGuide.md b/docs/IE_DG/Deep_Learning_Inference_Engine_DevGuide.md index 5fc2b3f910255f..0f07f5503811f5 100644 --- a/docs/IE_DG/Deep_Learning_Inference_Engine_DevGuide.md +++ b/docs/IE_DG/Deep_Learning_Inference_Engine_DevGuide.md @@ -43,10 +43,10 @@ This library contains the classes to: Starting from 2020.4 release, Inference Engine introduced a concept of `CNNNetwork` reader plugins. Such plugins can be automatically dynamically loaded by Inference Engine in runtime depending on file format: * Linux* OS: - `libinference_engine_ir_reader.so` to read a network from IR - - `libinference_engine_onnx_reader.so` to read a network from ONNX model format + - `onnx_ngraph_frontend.so` to read a network from ONNX model format * Windows* OS: - `inference_engine_ir_reader.dll` to read a network from IR - - `inference_engine_onnx_reader.dll` to read a network from ONNX model format + - `onnx_ngraph_frontend.dll` to read a network from ONNX model format ### Device-Specific Plugin Libraries diff --git a/docs/IE_DG/inference_engine_intro.md b/docs/IE_DG/inference_engine_intro.md index a4c33fd6f856ec..89d80654fe4480 100644 --- a/docs/IE_DG/inference_engine_intro.md +++ b/docs/IE_DG/inference_engine_intro.md @@ -46,10 +46,10 @@ This library contains the classes to: Starting from 2020.4 release, Inference Engine introduced a concept of `CNNNetwork` reader plugins. Such plugins can be automatically dynamically loaded by Inference Engine in runtime depending on file format: * Unix* OS: - `libinference_engine_ir_reader.so` to read a network from IR - - `libinference_engine_onnx_reader.so` to read a network from ONNX model format + - `onnx_ngraph_frontend.so` to read a network from ONNX model format * Windows* OS: - `inference_engine_ir_reader.dll` to read a network from IR - - `inference_engine_onnx_reader.dll` to read a network from ONNX model format + - `onnx_ngraph_frontend.dll` to read a network from ONNX model format ### Device-specific Plugin Libraries ### diff --git a/inference-engine/src/CMakeLists.txt b/inference-engine/src/CMakeLists.txt index 5f3959223fbe3d..889c44e2780490 100644 --- a/inference-engine/src/CMakeLists.txt +++ b/inference-engine/src/CMakeLists.txt @@ -53,5 +53,5 @@ add_custom_target(ie_libraries ALL inference_engine_lp_transformations inference_engine_snippets) if(NGRAPH_ONNX_FRONTEND_ENABLE) - add_dependencies(ie_libraries inference_engine_onnx_reader) + add_dependencies(ie_libraries onnx_ngraph_frontend) endif() diff --git a/inference-engine/src/inference_engine/src/ie_network_reader.cpp b/inference-engine/src/inference_engine/src/ie_network_reader.cpp index f3c95ac4ed2526..f8c15a4f00d4e1 100644 --- a/inference-engine/src/inference_engine/src/ie_network_reader.cpp +++ b/inference-engine/src/inference_engine/src/ie_network_reader.cpp @@ -96,6 +96,11 @@ namespace { // Extension to plugins creator std::multimap readers; +static ngraph::frontend::FrontEndManager* get_frontend_manager() { + static ngraph::frontend::FrontEndManager manager; + return &manager; +} + void registerReaders() { OV_ITT_SCOPED_TASK(ov::itt::domains::IE, "registerReaders"); static bool initialized = false; @@ -115,14 +120,6 @@ void registerReaders() { return std::make_shared(name, library_name); }; - // try to load ONNX reader if library exists - auto onnxReader = - create_if_exists("ONNX", std::string("inference_engine_onnx_reader") + std::string(IE_BUILD_POSTFIX)); - if (onnxReader) { - readers.emplace("onnx", onnxReader); - readers.emplace("prototxt", onnxReader); - } - // try to load IR reader v10 if library exists auto irReaderv10 = create_if_exists("IRv10", std::string("inference_engine_ir_reader") + std::string(IE_BUILD_POSTFIX)); @@ -174,10 +171,6 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath, #endif // Try to open model file std::ifstream modelStream(model_path, std::ios::binary); - // save path in extensible array of stream - // notice: lifetime of path pointed by pword(0) is limited by current scope - const std::string path_to_save_in_stream = modelPath; - modelStream.pword(0) = const_cast(path_to_save_in_stream.c_str()); if (!modelStream.is_open()) IE_THROW() << "Model file " << modelPath << " cannot be opened!"; @@ -240,7 +233,7 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath, } } // Try to load with FrontEndManager - static ngraph::frontend::FrontEndManager manager; + const auto manager = get_frontend_manager(); ngraph::frontend::FrontEnd::Ptr FE; ngraph::frontend::InputModel::Ptr inputModel; if (!binPath.empty()) { @@ -249,17 +242,17 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath, #else std::string weights_path = binPath; #endif - FE = manager.load_by_model(model_path, weights_path); + FE = manager->load_by_model(model_path, weights_path); if (FE) inputModel = FE->load(model_path, weights_path); } else { - FE = manager.load_by_model(model_path); + FE = manager->load_by_model(model_path); if (FE) inputModel = FE->load(model_path); } if (inputModel) { auto ngFunc = FE->convert(inputModel); - return CNNNetwork(ngFunc); + return CNNNetwork(ngFunc, exts); } IE_THROW() << "Unknown model format! Cannot find reader for model format: " << fileExt << " and read the model: " << modelPath << ". Please check that reader library exists in your PATH."; @@ -282,6 +275,19 @@ CNNNetwork details::ReadNetwork(const std::string& model, return reader->read(modelStream, exts); } } + // Try to load with FrontEndManager + // NOTE: weights argument is ignored + const auto manager = get_frontend_manager(); + ngraph::frontend::FrontEnd::Ptr FE; + ngraph::frontend::InputModel::Ptr inputModel; + FE = manager->load_by_model(&modelStream); + if (FE) + inputModel = FE->load(&modelStream); + if (inputModel) { + auto ngFunc = FE->convert(inputModel); + return CNNNetwork(ngFunc, exts); + } + IE_THROW() << "Unknown model format! Cannot find reader for the model and read it. Please check that reader " "library exists in your PATH."; } diff --git a/inference-engine/src/readers/CMakeLists.txt b/inference-engine/src/readers/CMakeLists.txt index 139a515f3fa560..b1864152adeb09 100644 --- a/inference-engine/src/readers/CMakeLists.txt +++ b/inference-engine/src/readers/CMakeLists.txt @@ -17,7 +17,3 @@ add_cpplint_target(${TARGET_NAME}_cpplint FOR_SOURCES ${reader_api_hpp}) add_subdirectory(ir_reader) add_subdirectory(ir_reader_v7) - -if(NGRAPH_ONNX_FRONTEND_ENABLE) - add_subdirectory(onnx_reader) -endif() diff --git a/inference-engine/src/readers/onnx_reader/CMakeLists.txt b/inference-engine/src/readers/onnx_reader/CMakeLists.txt deleted file mode 100644 index b5b409f99c7786..00000000000000 --- a/inference-engine/src/readers/onnx_reader/CMakeLists.txt +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (C) 2018-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 -# - -set(TARGET_NAME "inference_engine_onnx_reader") - -file(GLOB_RECURSE LIBRARY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/*.hpp) - -# Create named folders for the sources within the .vcproj -# Empty name lists them directly under the .vcproj - -source_group("src" FILES ${LIBRARY_SRC}) - -# Create module library - -add_library(${TARGET_NAME} MODULE ${LIBRARY_SRC}) - -ie_add_vs_version_file(NAME ${TARGET_NAME} - FILEDESCRIPTION "Inference Engine ONNX reader plugin") - -target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) - -target_compile_definitions(${TARGET_NAME} PRIVATE IMPLEMENT_INFERENCE_ENGINE_PLUGIN) - -target_link_libraries(${TARGET_NAME} PRIVATE inference_engine_reader_api onnx_ngraph_frontend inference_engine) - -ie_add_api_validator_post_build_step(TARGET ${TARGET_NAME}) - -set_target_properties(${TARGET_NAME} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE ${ENABLE_LTO}) - -# code style - -add_cpplint_target(${TARGET_NAME}_cpplint FOR_TARGETS ${TARGET_NAME}) - -# install - -install(TARGETS ${TARGET_NAME} - LIBRARY DESTINATION ${IE_CPACK_RUNTIME_PATH} COMPONENT core) diff --git a/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp b/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp deleted file mode 100644 index 0af45c259aea2d..00000000000000 --- a/inference-engine/src/readers/onnx_reader/ie_onnx_reader.cpp +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (C) 2018-2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "ie_onnx_reader.hpp" -#include "onnx_model_validator.hpp" -#include -#include - -using namespace InferenceEngine; - -namespace { -std::string readPathFromStream(std::istream& stream) { - if (stream.pword(0) == nullptr) { - return {}; - } - // read saved path from extensible array - return std::string{static_cast(stream.pword(0))}; -} - -/** - * This helper struct uses RAII to rewind/reset the stream so that it points to the beginning - * of the underlying resource (string, file, ...). It works similarily to std::lock_guard - * which releases a mutex upon destruction. - * - * This makes sure that the stream is always reset (exception, successful and unsuccessful - * model validation). - */ -struct StreamRewinder { - StreamRewinder(std::istream& stream) : m_stream(stream) { - m_stream.seekg(0, m_stream.beg); - } - ~StreamRewinder() { - m_stream.seekg(0, m_stream.beg); - } -private: - std::istream& m_stream; -}; -} // namespace - -bool ONNXReader::supportModel(std::istream& model) const { - StreamRewinder rwd{model}; - - const auto model_path = readPathFromStream(model); - - // this might mean that the model is loaded from a string in memory - // let's try to figure out if it's any of the supported formats - if (model_path.empty()) { - if (!is_valid_model(model, onnx_format{})) { - model.seekg(0, model.beg); - return is_valid_model(model, prototxt_format{}); - } else { - return true; - } - } - - if (model_path.find(".prototxt", 0) != std::string::npos) { - return is_valid_model(model, prototxt_format{}); - } else { - return is_valid_model(model, onnx_format{}); - } -} - -CNNNetwork ONNXReader::read(std::istream& model, const std::vector& exts) const { - return CNNNetwork(ngraph::onnx_import::import_onnx_model(model, readPathFromStream(model)), exts); -} - -INFERENCE_PLUGIN_API(void) InferenceEngine::CreateReader(std::shared_ptr& reader) { - reader = std::make_shared(); -} diff --git a/inference-engine/src/readers/onnx_reader/ie_onnx_reader.hpp b/inference-engine/src/readers/onnx_reader/ie_onnx_reader.hpp deleted file mode 100644 index 7d797a4230a96e..00000000000000 --- a/inference-engine/src/readers/onnx_reader/ie_onnx_reader.hpp +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (C) 2018-2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -namespace InferenceEngine { - -class ONNXReader: public IReader { -public: - /** - * @brief Checks that reader supports format of the model - * @param model stream with model - * @return true if format is supported - */ - bool supportModel(std::istream& model) const override; - /** - * @brief Reads the model to CNNNetwork - * @param model stream with model - * @param exts vector with extensions - * - * @return CNNNetwork - */ - CNNNetwork read(std::istream& model, const std::vector& exts) const override; - /** - * @brief Reads the model to CNNNetwork - * @param model stream with model - * @param weights blob with binary data - * @param exts vector with extensions - * - * @return CNNNetwork - */ - CNNNetwork read(std::istream& model, const Blob::CPtr& weights, const std::vector& exts) const override { - IE_THROW() << "ONNX reader cannot read model with weights!"; - } - - std::vector getDataFileExtensions() const override { - return {}; - } -}; - -} // namespace InferenceEngine - diff --git a/inference-engine/src/readers/onnx_reader/onnx_model_validator.cpp b/inference-engine/src/readers/onnx_reader/onnx_model_validator.cpp deleted file mode 100644 index ae472990c6a312..00000000000000 --- a/inference-engine/src/readers/onnx_reader/onnx_model_validator.cpp +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright (C) 2018-2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "onnx_model_validator.hpp" - -#include -#include -#include -#include -#include - -namespace detail { -namespace onnx { - enum Field { - IR_VERSION = 1, - PRODUCER_NAME = 2, - PRODUCER_VERSION = 3, - DOMAIN_ = 4, // DOMAIN collides with some existing symbol in MSVC thus - underscore - MODEL_VERSION = 5, - DOC_STRING = 6, - GRAPH = 7, - OPSET_IMPORT = 8, - METADATA_PROPS = 14, - TRAINING_INFO = 20 - }; - - enum WireType { - VARINT = 0, - BITS_64 = 1, - LENGTH_DELIMITED = 2, - START_GROUP = 3, - END_GROUP = 4, - BITS_32 = 5 - }; - - // A PB key consists of a field number (defined in onnx.proto) and a type of data that follows this key - using PbKey = std::pair; - - // This pair represents a key found in the encoded model and optional size of the payload - // that follows the key (in bytes). They payload should be skipped for the fast check purposes. - using ONNXField = std::pair; - - bool is_correct_onnx_field(const PbKey& decoded_key) { - static const std::map onnx_fields = { - {IR_VERSION, VARINT}, - {PRODUCER_NAME, LENGTH_DELIMITED}, - {PRODUCER_VERSION, LENGTH_DELIMITED}, - {DOMAIN_, LENGTH_DELIMITED}, - {MODEL_VERSION, VARINT}, - {DOC_STRING, LENGTH_DELIMITED}, - {GRAPH, LENGTH_DELIMITED}, - {OPSET_IMPORT, LENGTH_DELIMITED}, - {METADATA_PROPS, LENGTH_DELIMITED}, - {TRAINING_INFO, LENGTH_DELIMITED}, - }; - - if (!onnx_fields.count(static_cast(decoded_key.first))) { - return false; - } - - return onnx_fields.at(static_cast(decoded_key.first)) == static_cast(decoded_key.second); - } - - /** - * Only 7 bits in each component of a varint count in this algorithm. The components form - * a decoded number when they are concatenated bitwise in a reverse order. For example: - * bytes = [b1, b2, b3, b4] - * varint = b4 ++ b3 ++ b2 ++ b1 <== only 7 bits of each byte should be extracted before concat - * - * b1 b2 - * bytes = [00101100, 00000010] - * b2 b1 - * varint = 0000010 ++ 0101100 = 100101100 => decimal: 300 - * Each consecutive varint byte needs to be left shifted "7 x its position in the vector" - * and bitwise added to the accumulator afterwards. - */ - uint32_t varint_bytes_to_number(const std::vector& bytes) { - uint32_t accumulator = 0u; - - for (size_t i = 0; i < bytes.size(); ++i) { - uint32_t b = bytes[i]; - b <<= 7 * i; - accumulator |= b; - } - - return accumulator; - } - - uint32_t decode_varint(std::istream& model) { - std::vector bytes; - // max 4 bytes for a single value because this function returns a 32-bit long decoded varint - const size_t MAX_VARINT_BYTES = 4u; - // optimization to avoid allocations during push_back calls - bytes.reserve(MAX_VARINT_BYTES); - - char key_component = 0; - model.get(key_component); - - // keep reading all bytes from the stream which have the MSB on - while (key_component & 0x80 && bytes.size() < MAX_VARINT_BYTES) { - // drop the most significant bit - const char component = key_component & ~0x80; - bytes.push_back(component); - model.get(key_component); - } - // add the last byte - the one with MSB off - bytes.push_back(key_component); - - return varint_bytes_to_number(bytes); - } - - PbKey decode_key(const char key) { - // 3 least significant bits - const char wire_type = key & 0b111; - // remaining bits - const char field_number = key >> 3; - return {field_number, wire_type}; - } - - ONNXField decode_next_field(std::istream& model) { - char key = 0; - model.get(key); - - const auto decoded_key = decode_key(key); - - if (!is_correct_onnx_field(decoded_key)) { - throw std::runtime_error{"Incorrect field detected in the processed model"}; - } - - const auto onnx_field = static_cast(decoded_key.first); - - switch (decoded_key.second) { - case VARINT: { - // the decoded varint is the payload in this case but its value doesnt matter - // in the fast check process so we just discard it - decode_varint(model); - return {onnx_field, 0}; - } - case LENGTH_DELIMITED: - // the varint following the key determines the payload length - return {onnx_field, decode_varint(model)}; - case BITS_64: - return {onnx_field, 8}; - case BITS_32: - return {onnx_field, 4}; - case START_GROUP: - case END_GROUP: - throw std::runtime_error{"StartGroup and EndGroup are not used in ONNX models"}; - default: - throw std::runtime_error{"Unknown WireType encountered in the model"}; - } - } - - inline void skip_payload(std::istream& model, uint32_t payload_size) { - model.seekg(payload_size, std::ios::cur); - } -} // namespace onnx - -namespace prototxt { - bool contains_onnx_model_keys(const std::string& model, const size_t expected_keys_num) { - size_t keys_found = 0; - - const std::vector onnx_keys = { - "ir_version", "producer_name", "producer_version", "domain", "model_version", - "doc_string", "graph", "opset_import", "metadata_props", "training_info" - }; - - size_t search_start_pos = 0; - - while (keys_found < expected_keys_num) { - const auto key_finder = [&search_start_pos, &model](const std::string& key) { - const auto key_pos = model.find(key, search_start_pos); - if (key_pos != model.npos) { - // don't search from the beginning each time - search_start_pos = key_pos + key.size(); - return true; - } else { - return false; - } - }; - - const auto found = std::any_of(std::begin(onnx_keys), std::end(onnx_keys), key_finder); - if (!found) { - break; - } else { - ++keys_found; - } - } - - return keys_found == expected_keys_num; - } -} // namespace prototxt -} // namespace detail - -namespace InferenceEngine { - bool is_valid_model(std::istream& model, onnx_format) { - // the model usually starts with a 0x08 byte indicating the ir_version value - // so this checker expects at least 2 valid ONNX keys to be found in the validated model - const unsigned int EXPECTED_FIELDS_FOUND = 2u; - unsigned int valid_fields_found = 0u; - try { - while (!model.eof() && valid_fields_found < EXPECTED_FIELDS_FOUND) { - const auto field = detail::onnx::decode_next_field(model); - - ++valid_fields_found; - - if (field.second > 0) { - detail::onnx::skip_payload(model, field.second); - } - } - - return valid_fields_found == EXPECTED_FIELDS_FOUND; - } catch (...) { - return false; - } - } - - bool is_valid_model(std::istream& model, prototxt_format) { - std::array head_of_file; - - model.seekg(0, model.beg); - model.read(head_of_file.data(), head_of_file.size()); - model.clear(); - model.seekg(0, model.beg); - - return detail::prototxt::contains_onnx_model_keys( - std::string{std::begin(head_of_file), std::end(head_of_file)}, 2); - } -} // namespace InferenceEngine diff --git a/inference-engine/src/readers/onnx_reader/onnx_model_validator.hpp b/inference-engine/src/readers/onnx_reader/onnx_model_validator.hpp deleted file mode 100644 index ad7af077cd72c5..00000000000000 --- a/inference-engine/src/readers/onnx_reader/onnx_model_validator.hpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (C) 2018-2021 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include - -namespace InferenceEngine { - // 2 empty structs used for tag dispatch below - struct onnx_format {}; - struct prototxt_format {}; - - bool is_valid_model(std::istream& model, onnx_format); - - bool is_valid_model(std::istream& model, prototxt_format); -} // namespace InferenceEngine diff --git a/inference-engine/tests/functional/inference_engine/CMakeLists.txt b/inference-engine/tests/functional/inference_engine/CMakeLists.txt index 56258e5143743e..d5a9d2fd2d7529 100644 --- a/inference-engine/tests/functional/inference_engine/CMakeLists.txt +++ b/inference-engine/tests/functional/inference_engine/CMakeLists.txt @@ -60,7 +60,7 @@ if(NGRAPH_ONNX_FRONTEND_ENABLE) target_compile_definitions(${TARGET_NAME} PRIVATE NGRAPH_ONNX_FRONTEND_ENABLE ONNX_TEST_MODELS="${TEST_MODEL_ZOO}/onnx_reader/models/") - add_dependencies(${TARGET_NAME} inference_engine_onnx_reader) + add_dependencies(${TARGET_NAME} onnx_ngraph_frontend) endif() if(NGRAPH_PDPD_FRONTEND_ENABLE) diff --git a/model-optimizer/mo/main.py b/model-optimizer/mo/main.py index 0747107c8c5e5f..3a9553a2bc0855 100644 --- a/model-optimizer/mo/main.py +++ b/model-optimizer/mo/main.py @@ -109,6 +109,9 @@ def prepare_ir(argv: argparse.Namespace): if argv.input_model: if not argv.framework: moc_front_end = fem.load_by_model(argv.input_model) + # skip onnx frontend as not fully supported yet (63050) + if moc_front_end and moc_front_end.get_name() == "onnx": + moc_front_end = None if moc_front_end: argv.framework = moc_front_end.get_name() elif argv.framework in available_moc_front_ends: diff --git a/ngraph/frontend/frontend_manager/include/frontend_manager/frontend_manager.hpp b/ngraph/frontend/frontend_manager/include/frontend_manager/frontend_manager.hpp index 5cc181935efb74..6a72fc18fc8bd1 100644 --- a/ngraph/frontend/frontend_manager/include/frontend_manager/frontend_manager.hpp +++ b/ngraph/frontend/frontend_manager/include/frontend_manager/frontend_manager.hpp @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include "frontend.hpp" @@ -107,6 +108,16 @@ class FRONTEND_API VariantWrapper : public VariantImpl(value) {} }; +template <> +class FRONTEND_API VariantWrapper : public VariantImpl { +public: + static constexpr VariantTypeInfo type_info{"Variant::std::istringstream*", 0}; + const VariantTypeInfo& get_type_info() const override { + return type_info; + } + VariantWrapper(const value_type& value) : VariantImpl(value) {} +}; + #if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) template <> class FRONTEND_API VariantWrapper : public VariantImpl { diff --git a/ngraph/frontend/frontend_manager/src/frontend_manager.cpp b/ngraph/frontend/frontend_manager/src/frontend_manager.cpp index e48403678c645f..70f6edd6cd0524 100644 --- a/ngraph/frontend/frontend_manager/src/frontend_manager.cpp +++ b/ngraph/frontend/frontend_manager/src/frontend_manager.cpp @@ -375,6 +375,8 @@ std::vector Place::get_consuming_operations(const std::string& outpu constexpr VariantTypeInfo VariantWrapper::type_info; +constexpr VariantTypeInfo VariantWrapper::type_info; + #if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) constexpr VariantTypeInfo VariantWrapper::type_info; #endif diff --git a/ngraph/frontend/onnx/frontend/include/onnx_frontend/frontend.hpp b/ngraph/frontend/onnx/frontend/include/onnx_frontend/frontend.hpp index 38d0b22975a51d..412d8837fe9cf9 100644 --- a/ngraph/frontend/onnx/frontend/include/onnx_frontend/frontend.hpp +++ b/ngraph/frontend/onnx/frontend/include/onnx_frontend/frontend.hpp @@ -19,6 +19,8 @@ class ONNX_FRONTEND_API FrontEndONNX : public FrontEnd { std::shared_ptr convert(InputModel::Ptr model) const override; void convert(std::shared_ptr partially_converted) const override; std::shared_ptr decode(InputModel::Ptr model) const override; + std::string get_name() const override; + bool supported_impl(const std::vector>& variants) const override; protected: InputModel::Ptr load_impl(const std::vector>& params) const override; diff --git a/ngraph/frontend/onnx/frontend/src/editor.cpp b/ngraph/frontend/onnx/frontend/src/editor.cpp index ccbd18b0e1dcfb..0b1ca6d8b2ee28 100644 --- a/ngraph/frontend/onnx/frontend/src/editor.cpp +++ b/ngraph/frontend/onnx/frontend/src/editor.cpp @@ -11,6 +11,7 @@ #include "detail/subgraph_extraction.hpp" #include "edge_mapper.hpp" +#include "ngraph/file_util.hpp" #include "ngraph/log.hpp" #include "onnx_common/parser.hpp" #include "onnx_common/utils.hpp" @@ -19,6 +20,8 @@ using namespace ngraph; using namespace ngraph::onnx_editor; +NGRAPH_SUPPRESS_DEPRECATED_START + namespace { using namespace ONNX_NAMESPACE; @@ -191,6 +194,14 @@ struct onnx_editor::ONNXModelEditor::Impl { Impl(const std::string& model_path) : m_model_proto{std::make_shared(onnx_common::parse_from_file(model_path))} {} + + Impl(std::istream& model_stream) + : m_model_proto{std::make_shared(onnx_common::parse_from_istream(model_stream))} {} + +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) + Impl(const std::wstring& model_path) + : m_model_proto{std::make_shared(onnx_common::parse_from_file(model_path))} {} +#endif }; onnx_editor::ONNXModelEditor::ONNXModelEditor(const std::string& model_path) @@ -199,6 +210,20 @@ onnx_editor::ONNXModelEditor::ONNXModelEditor(const std::string& model_path) delete impl; }} {} +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) +onnx_editor::ONNXModelEditor::ONNXModelEditor(const std::wstring& model_path) + : m_model_path{file_util::wstring_to_string(model_path)}, + m_pimpl{new ONNXModelEditor::Impl{model_path}, [](Impl* impl) { + delete impl; + }} {} +#endif + +onnx_editor::ONNXModelEditor::ONNXModelEditor(std::istream& model_stream, const std::string& model_path) + : m_model_path{model_path}, + m_pimpl{new ONNXModelEditor::Impl{model_stream}, [](Impl* impl) { + delete impl; + }} {} + const std::string& onnx_editor::ONNXModelEditor::model_path() const { return m_model_path; } diff --git a/ngraph/frontend/onnx/frontend/src/editor.hpp b/ngraph/frontend/onnx/frontend/src/editor.hpp index bec2f1b47e1aa7..0605eb03bb254d 100644 --- a/ngraph/frontend/onnx/frontend/src/editor.hpp +++ b/ngraph/frontend/onnx/frontend/src/editor.hpp @@ -31,6 +31,17 @@ class ONNX_IMPORTER_API ONNXModelEditor final { /// /// \param model_path Path to the file containing the model. ONNXModelEditor(const std::string& model_path); +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) + ONNXModelEditor(const std::wstring& model_path); +#endif + + /// \brief Creates an editor from a model stream. The stream is parsed and loaded + /// into the m_model_proto member variable. + /// + /// \param model_stream The stream containing the model. + /// \param model_path Path to the file containing the model. This information can be used + /// for ONNX external weights feature support. + ONNXModelEditor(std::istream& model_stream, const std::string& path = ""); /// \brief Modifies the in-memory representation of the model by setting /// custom input types for all inputs specified in the provided map. diff --git a/ngraph/frontend/onnx/frontend/src/frontend.cpp b/ngraph/frontend/onnx/frontend/src/frontend.cpp index 8ec3fdf952555c..db7eeb7d094d14 100644 --- a/ngraph/frontend/onnx/frontend/src/frontend.cpp +++ b/ngraph/frontend/onnx/frontend/src/frontend.cpp @@ -4,14 +4,23 @@ #include #include +#include #include #include #include +#include #include +#include "onnx_common/onnx_model_validator.hpp" + using namespace ngraph; using namespace ngraph::frontend; +using VariantString = VariantWrapper; +using VariantWString = VariantWrapper; +using VariantIstreamPtr = VariantWrapper; +using VariantIstringstreamPtr = VariantWrapper; + extern "C" ONNX_FRONTEND_API FrontEndVersion GetAPIVersion() { return OV_FRONTEND_API_VERSION; } @@ -26,12 +35,39 @@ extern "C" ONNX_FRONTEND_API void* GetFrontEndData() { } InputModel::Ptr FrontEndONNX::load_impl(const std::vector>& variants) const { - NGRAPH_CHECK(variants.size() == 1, - "Only one parameter to load function is expected. Got " + std::to_string(variants.size())); - NGRAPH_CHECK(ov::is_type>(variants[0]), - "Parameter to load function need to be a std::string"); - auto path = ov::as_type_ptr>(variants[0])->get(); - return std::make_shared(path); + if (variants.size() == 0) { + return nullptr; + } + if (ov::is_type(variants[0])) { + const auto path = ov::as_type_ptr(variants[0])->get(); + return std::make_shared(path); + } +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) + if (ov::is_type(variants[0])) { + const auto path = ov::as_type_ptr(variants[0])->get(); + return std::make_shared(path); + } +#endif + std::istream* stream = nullptr; + if (ov::is_type(variants[0])) { + stream = ov::as_type_ptr(variants[0])->get(); + } else if (ov::is_type(variants[0])) { + stream = ov::as_type_ptr(variants[0])->get(); + } + if (stream != nullptr) { + if (variants.size() > 1 && ov::is_type(variants[1])) { + const auto path = ov::as_type_ptr(variants[1])->get(); + return std::make_shared(*stream, path); + } +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) + if (variants.size() > 1 && ov::is_type(variants[1])) { + const auto path = ov::as_type_ptr(variants[1])->get(); + return std::make_shared(*stream, path); + } +#endif + return std::make_shared(*stream); + } + return nullptr; } std::shared_ptr FrontEndONNX::convert(InputModel::Ptr model) const { @@ -49,3 +85,63 @@ std::shared_ptr FrontEndONNX::decode(InputModel::Ptr model) co NGRAPH_CHECK(model_onnx != nullptr, "Invalid input model"); return model_onnx->decode(); } + +std::string FrontEndONNX::get_name() const { + return "onnx"; +} + +namespace { +/** + * This helper struct uses RAII to rewind/reset the stream so that it points to the beginning + * of the underlying resource (string, file, and so on). It works similarly to std::lock_guard, + * which releases a mutex upon destruction. + * + * This ensures that the stream is always reset (exception, successful and unsuccessful + * model validation). + */ +struct StreamRewinder { + StreamRewinder(std::istream& stream) : m_stream(stream) { + m_stream.seekg(0, m_stream.beg); + } + ~StreamRewinder() { + m_stream.seekg(0, m_stream.beg); + } + +private: + std::istream& m_stream; +}; +} // namespace + +bool FrontEndONNX::supported_impl(const std::vector>& variants) const { + if (variants.size() == 0) { + return false; + } + std::ifstream model_stream; + if (ov::is_type(variants[0])) { + const auto path = ov::as_type_ptr(variants[0])->get(); + model_stream.open(path, std::ios::in | std::ifstream::binary); + } +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) + else if (ov::is_type(variants[0])) { + const auto path = ov::as_type_ptr(variants[0])->get(); + model_stream.open(path, std::ios::in | std::ifstream::binary); + } +#endif + if (model_stream.is_open()) { + model_stream.seekg(0, model_stream.beg); + const bool is_valid_model = onnx_common::is_valid_model(model_stream); + model_stream.close(); + return is_valid_model; + } + std::istream* stream = nullptr; + if (ov::is_type(variants[0])) { + stream = ov::as_type_ptr(variants[0])->get(); + } else if (ov::is_type(variants[0])) { + stream = ov::as_type_ptr(variants[0])->get(); + } + if (stream != nullptr) { + StreamRewinder rwd{*stream}; + return onnx_common::is_valid_model(*stream); + } + return false; +} diff --git a/ngraph/frontend/onnx/frontend/src/input_model.cpp b/ngraph/frontend/onnx/frontend/src/input_model.cpp index ae0a82f840edf6..9db743a74b93b6 100644 --- a/ngraph/frontend/onnx/frontend/src/input_model.cpp +++ b/ngraph/frontend/onnx/frontend/src/input_model.cpp @@ -5,15 +5,32 @@ #include "input_model.hpp" #include +#include #include "place.hpp" using namespace ngraph; using namespace ngraph::frontend; +NGRAPH_SUPPRESS_DEPRECATED_START + InputModelONNX::InputModelONNX(const std::string& path) : m_editor{std::make_shared(path)} {} +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) +InputModelONNX::InputModelONNX(const std::wstring& path) + : m_editor{std::make_shared(path)} {} +#endif + +InputModelONNX::InputModelONNX(std::istream& model_stream) + : m_editor{std::make_shared(model_stream)} {} + +InputModelONNX::InputModelONNX(std::istream& model_stream, const std::string& path) + : m_editor{std::make_shared(model_stream, path)} {} + +InputModelONNX::InputModelONNX(std::istream& model_stream, const std::wstring& path) + : InputModelONNX(model_stream, file_util::wstring_to_string(path)) {} + std::vector InputModelONNX::get_inputs() const { const auto& inputs = m_editor->model_inputs(); std::vector in_places; diff --git a/ngraph/frontend/onnx/frontend/src/input_model.hpp b/ngraph/frontend/onnx/frontend/src/input_model.hpp index 300b5a94d51e4b..2ed2ff0e8e4f0a 100644 --- a/ngraph/frontend/onnx/frontend/src/input_model.hpp +++ b/ngraph/frontend/onnx/frontend/src/input_model.hpp @@ -6,12 +6,21 @@ #include #include +#include namespace ngraph { namespace frontend { class InputModelONNX : public InputModel { public: InputModelONNX(const std::string& path); +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) + InputModelONNX(const std::wstring& path); +#endif + InputModelONNX(std::istream& model_stream); + // The path can be required even if the model is passed as a stream because it is necessary + // for ONNX external data feature + InputModelONNX(std::istream& model_stream, const std::string& path); + InputModelONNX(std::istream& model_stream, const std::wstring& path); std::vector get_inputs() const override; std::vector get_outputs() const override; diff --git a/ngraph/frontend/onnx/onnx_common/include/onnx_common/onnx_model_validator.hpp b/ngraph/frontend/onnx/onnx_common/include/onnx_common/onnx_model_validator.hpp new file mode 100644 index 00000000000000..e41bb8134595a1 --- /dev/null +++ b/ngraph/frontend/onnx/onnx_common/include/onnx_common/onnx_model_validator.hpp @@ -0,0 +1,14 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ngraph { +namespace onnx_common { + +bool is_valid_model(std::istream& model); +} // namespace onnx_common +} // namespace ngraph diff --git a/ngraph/frontend/onnx/onnx_common/include/onnx_common/parser.hpp b/ngraph/frontend/onnx/onnx_common/include/onnx_common/parser.hpp index 26675840afbebf..7310f150c5b33c 100644 --- a/ngraph/frontend/onnx/onnx_common/include/onnx_common/parser.hpp +++ b/ngraph/frontend/onnx/onnx_common/include/onnx_common/parser.hpp @@ -18,6 +18,9 @@ namespace onnx_common { /// /// \return The parsed in-memory representation of the ONNX model ONNX_NAMESPACE::ModelProto parse_from_file(const std::string& file_path); +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) +ONNX_NAMESPACE::ModelProto parse_from_file(const std::wstring& file_path); +#endif /// \brief Parses an ONNX model from a stream (representing for example a file) /// diff --git a/ngraph/frontend/onnx/onnx_common/src/onnx_model_validator.cpp b/ngraph/frontend/onnx/onnx_common/src/onnx_model_validator.cpp new file mode 100644 index 00000000000000..82f2b3540fe625 --- /dev/null +++ b/ngraph/frontend/onnx/onnx_common/src/onnx_model_validator.cpp @@ -0,0 +1,219 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "onnx_common/onnx_model_validator.hpp" + +#include +#include +#include +#include +#include + +namespace detail { +namespace onnx { +enum Field { + IR_VERSION = 1, + PRODUCER_NAME = 2, + PRODUCER_VERSION = 3, + DOMAIN_ = 4, // DOMAIN collides with some existing symbol in MSVC thus - underscore + MODEL_VERSION = 5, + DOC_STRING = 6, + GRAPH = 7, + OPSET_IMPORT = 8, + METADATA_PROPS = 14, + TRAINING_INFO = 20 +}; + +enum WireType { VARINT = 0, BITS_64 = 1, LENGTH_DELIMITED = 2, START_GROUP = 3, END_GROUP = 4, BITS_32 = 5 }; + +// A PB key consists of a field number (defined in onnx.proto) and a type of data that follows this key +using PbKey = std::pair; + +// This pair represents a key found in the encoded model and optional size of the payload +// that follows the key (in bytes). The payload should be skipped for fast check purposes. +using ONNXField = std::pair; + +bool is_correct_onnx_field(const PbKey& decoded_key) { + static const std::map onnx_fields = { + {IR_VERSION, VARINT}, + {PRODUCER_NAME, LENGTH_DELIMITED}, + {PRODUCER_VERSION, LENGTH_DELIMITED}, + {DOMAIN_, LENGTH_DELIMITED}, + {MODEL_VERSION, VARINT}, + {DOC_STRING, LENGTH_DELIMITED}, + {GRAPH, LENGTH_DELIMITED}, + {OPSET_IMPORT, LENGTH_DELIMITED}, + {METADATA_PROPS, LENGTH_DELIMITED}, + {TRAINING_INFO, LENGTH_DELIMITED}, + }; + + if (!onnx_fields.count(static_cast(decoded_key.first))) { + return false; + } + + return onnx_fields.at(static_cast(decoded_key.first)) == static_cast(decoded_key.second); +} + +/** + * Only 7 bits in each component of a varint count in this algorithm. The components form + * a decoded number when they are concatenated bitwise in reverse order. For example: + * bytes = [b1, b2, b3, b4] + * varint = b4 ++ b3 ++ b2 ++ b1 <== only 7 bits of each byte should be extracted before concat + * + * b1 b2 + * bytes = [00101100, 00000010] + * b2 b1 + * varint = 0000010 ++ 0101100 = 100101100 => decimal: 300 + * Each consecutive varint byte needs to be left-shifted "7 x its position in the vector" + * and bitwise added to the accumulator afterward. + */ +uint32_t varint_bytes_to_number(const std::vector& bytes) { + uint32_t accumulator = 0u; + + for (size_t i = 0; i < bytes.size(); ++i) { + uint32_t b = bytes[i]; + b <<= 7 * i; + accumulator |= b; + } + + return accumulator; +} + +uint32_t decode_varint(std::istream& model) { + std::vector bytes; + // max 4 bytes for a single value because this function returns a 32-bit long decoded varint + const size_t MAX_VARINT_BYTES = 4u; + // optimization to avoid allocations during push_back calls + bytes.reserve(MAX_VARINT_BYTES); + + char key_component = 0; + model.get(key_component); + + // keep reading all bytes which have the MSB on from the stream + while (key_component & 0x80 && bytes.size() < MAX_VARINT_BYTES) { + // drop the most significant bit + const char component = key_component & ~0x80; + bytes.push_back(component); + model.get(key_component); + } + // add the last byte - the one with MSB off + bytes.push_back(key_component); + + return varint_bytes_to_number(bytes); +} + +PbKey decode_key(const char key) { + // 3 least significant bits + const char wire_type = key & 0b111; + // remaining bits + const char field_number = key >> 3; + return {field_number, wire_type}; +} + +ONNXField decode_next_field(std::istream& model) { + char key = 0; + model.get(key); + + const auto decoded_key = decode_key(key); + + if (!is_correct_onnx_field(decoded_key)) { + throw std::runtime_error{"Incorrect field detected in the processed model"}; + } + + const auto onnx_field = static_cast(decoded_key.first); + + switch (decoded_key.second) { + case VARINT: { + // the decoded varint is the payload in this case but its value does not matter + // in the fast check process so you can discard it + decode_varint(model); + return {onnx_field, 0}; + } + case LENGTH_DELIMITED: + // the varint following the key determines the payload length + return {onnx_field, decode_varint(model)}; + case BITS_64: + return {onnx_field, 8}; + case BITS_32: + return {onnx_field, 4}; + case START_GROUP: + case END_GROUP: + throw std::runtime_error{"StartGroup and EndGroup are not used in ONNX models"}; + default: + throw std::runtime_error{"Unknown WireType encountered in the model"}; + } +} + +inline void skip_payload(std::istream& model, uint32_t payload_size) { + model.seekg(payload_size, std::ios::cur); +} +} // namespace onnx + +namespace prototxt { +bool contains_onnx_model_keys(const std::string& model, const size_t expected_keys_num) { + size_t keys_found = 0; + + const std::vector onnx_keys = {"ir_version", + "producer_name", + "producer_version", + "domain", + "model_version", + "doc_string", + "graph", + "opset_import", + "metadata_props", + "training_info"}; + + size_t search_start_pos = 0; + + while (keys_found < expected_keys_num) { + const auto key_finder = [&search_start_pos, &model](const std::string& key) { + const auto key_pos = model.find(key, search_start_pos); + if (key_pos != model.npos) { + // don't search from the beginning each time + search_start_pos = key_pos + key.size(); + return true; + } else { + return false; + } + }; + + const auto found = std::any_of(std::begin(onnx_keys), std::end(onnx_keys), key_finder); + if (!found) { + break; + } else { + ++keys_found; + } + } + + return keys_found == expected_keys_num; +} +} // namespace prototxt +} // namespace detail + +namespace ngraph { +namespace onnx_common { +bool is_valid_model(std::istream& model) { + // the model usually starts with a 0x08 byte indicating the ir_version value + // so this checker expects at least 2 valid ONNX keys to be found in the validated model + const unsigned int EXPECTED_FIELDS_FOUND = 2u; + unsigned int valid_fields_found = 0u; + try { + while (!model.eof() && valid_fields_found < EXPECTED_FIELDS_FOUND) { + const auto field = detail::onnx::decode_next_field(model); + + ++valid_fields_found; + + if (field.second > 0) { + detail::onnx::skip_payload(model, field.second); + } + } + + return valid_fields_found == EXPECTED_FIELDS_FOUND; + } catch (...) { + return false; + } +} +} // namespace onnx_common +} // namespace ngraph diff --git a/ngraph/frontend/onnx/onnx_common/src/parser.cpp b/ngraph/frontend/onnx/onnx_common/src/parser.cpp index 03fbc358f229ae..3cb7f17ebd1936 100644 --- a/ngraph/frontend/onnx/onnx_common/src/parser.cpp +++ b/ngraph/frontend/onnx/onnx_common/src/parser.cpp @@ -8,6 +8,8 @@ #include #include +#include + #include "ngraph/except.hpp" namespace ngraph { @@ -19,9 +21,27 @@ ONNX_NAMESPACE::ModelProto parse_from_file(const std::string& file_path) { throw ngraph_error("Could not open the file: " + file_path); }; - return parse_from_istream(file_stream); + auto model_proto = parse_from_istream(file_stream); + file_stream.close(); + return model_proto; } +#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) +ONNX_NAMESPACE::ModelProto parse_from_file(const std::wstring& file_path) { + std::ifstream file_stream{file_path, std::ios::in | std::ios::binary}; + + if (!file_stream.is_open()) { + NGRAPH_SUPPRESS_DEPRECATED_START + throw ngraph_error("Could not open the file: " + file_util::wstring_to_string(file_path)); + NGRAPH_SUPPRESS_DEPRECATED_END + }; + + auto model_proto = parse_from_istream(file_stream); + file_stream.close(); + return model_proto; +} +#endif + ONNX_NAMESPACE::ModelProto parse_from_istream(std::istream& model_stream) { if (!model_stream.good()) { model_stream.clear(); diff --git a/ngraph/python/tests/test_frontend/test_frontend_onnx.py b/ngraph/python/tests/test_frontend/test_frontend_onnx.py index 4290c7de2ce4e8..81bb89217fece4 100644 --- a/ngraph/python/tests/test_frontend/test_frontend_onnx.py +++ b/ngraph/python/tests/test_frontend/test_frontend_onnx.py @@ -96,3 +96,18 @@ def test_decode_and_convert(): b = np.array([[2, 3], [4, 5]], dtype=np.float32) expected = np.array([[1.5, 5], [10.5, 18]], dtype=np.float32) run_function(decoded_function, a, b, expected=[expected]) + + +def test_load_by_model(): + skip_if_onnx_frontend_is_disabled() + + fe = fem.load_by_model(onnx_model_filename) + assert fe + assert fe.get_name() == "onnx" + model = fe.load(onnx_model_filename) + assert model + decoded_function = fe.decode(model) + assert decoded_function + + assert not fem.load_by_model("test.xx") + assert not fem.load_by_model("onnx.yy") diff --git a/ngraph/test/frontend/CMakeLists.txt b/ngraph/test/frontend/CMakeLists.txt index 67ffb5a992ac1b..901d917edc7295 100644 --- a/ngraph/test/frontend/CMakeLists.txt +++ b/ngraph/test/frontend/CMakeLists.txt @@ -7,6 +7,10 @@ if (NGRAPH_PDPD_FRONTEND_ENABLE) add_subdirectory(paddlepaddle) endif() +if (NGRAPH_ONNX_FRONTEND_ENABLE) + add_subdirectory(onnx) +endif() + set(SRC ${CMAKE_CURRENT_SOURCE_DIR}/mock_frontend.cpp) add_library(mock1_ngraph_frontend SHARED ${SRC}) diff --git a/ngraph/test/frontend/onnx/CMakeLists.txt b/ngraph/test/frontend/onnx/CMakeLists.txt new file mode 100644 index 00000000000000..d78a8ed4274f7e --- /dev/null +++ b/ngraph/test/frontend/onnx/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +set(TARGET_NAME "onnx_frontend_tests") + +file(GLOB_RECURSE SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) + +add_executable(${TARGET_NAME} ${SRC}) + +target_link_libraries(${TARGET_NAME} PRIVATE frontend_shared_test_classes) + +add_clang_format_target(${TARGET_NAME}_clang FOR_TARGETS ${TARGET_NAME}) + +install(TARGETS ${TARGET_NAME} + RUNTIME DESTINATION tests + COMPONENT tests + EXCLUDE_FROM_ALL) + +set(TEST_ONNX_MODELS_DIRNAME ${TEST_MODEL_ZOO}/ngraph/models/onnx) +target_compile_definitions(${TARGET_NAME} PRIVATE -D TEST_ONNX_MODELS_DIRNAME=\"${TEST_ONNX_MODELS_DIRNAME}/\") +set(MANIFEST ${CMAKE_CURRENT_SOURCE_DIR}/unit_test.manifest) +target_compile_definitions(${TARGET_NAME} PRIVATE -D MANIFEST=\"${MANIFEST}\") + +add_dependencies(${TARGET_NAME} onnx_ngraph_frontend) +add_dependencies(${TARGET_NAME} test_model_zoo) diff --git a/ngraph/test/frontend/onnx/load_from.cpp b/ngraph/test/frontend/onnx/load_from.cpp new file mode 100644 index 00000000000000..4d98644bc722de --- /dev/null +++ b/ngraph/test/frontend/onnx/load_from.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "load_from.hpp" + +#include + +#include +#include + +#include "onnx_utils.hpp" +#include "utils.hpp" + +using namespace ngraph; +using namespace ngraph::frontend; + +using ONNXLoadTest = FrontEndLoadFromTest; + +static LoadFromFEParam getTestData() { + LoadFromFEParam res; + res.m_frontEndName = ONNX_FE; + res.m_modelsPath = std::string(TEST_ONNX_MODELS_DIRNAME); + res.m_file = "external_data/external_data.onnx"; + res.m_stream = "add_abc.onnx"; + return res; +} + +TEST_P(FrontEndLoadFromTest, testLoadFromStreamAndPassPath) { + NGRAPH_SUPPRESS_DEPRECATED_START + const auto path = file_util::path_join(TEST_ONNX_MODELS_DIRNAME, "external_data/external_data.onnx"); + NGRAPH_SUPPRESS_DEPRECATED_END + std::ifstream ifs(path, std::ios::in | std::ios::binary); + ASSERT_TRUE(ifs.is_open()); + std::istream* is = &ifs; + std::vector frontends; + FrontEnd::Ptr fe; + ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends()); + ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_model(is)); + ASSERT_NE(m_frontEnd, nullptr); + + ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(is, path)); + ASSERT_NE(m_inputModel, nullptr); + + std::shared_ptr function; + ASSERT_NO_THROW(function = m_frontEnd->convert(m_inputModel)); + ASSERT_NE(function, nullptr); +} + +INSTANTIATE_TEST_SUITE_P(ONNXLoadTest, + FrontEndLoadFromTest, + ::testing::Values(getTestData()), + FrontEndLoadFromTest::getTestCaseName); diff --git a/ngraph/test/frontend/onnx/main.cpp b/ngraph/test/frontend/onnx/main.cpp new file mode 100644 index 00000000000000..3f72a06cf5db35 --- /dev/null +++ b/ngraph/test/frontend/onnx/main.cpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "gtest/gtest.h" +#include "utils.hpp" + +std::string get_disabled_tests() { + std::string result = "-"; + const std::string manifest_path = MANIFEST; + std::ifstream manifest_stream(manifest_path); + std::string line; + while (std::getline(manifest_stream, line)) { + if (line.empty()) { + continue; + } + if (line.size() > 0 && line[0] == '#') { + continue; + } + result += ":" + line; + } + manifest_stream.close(); + return result; +} + +int main(int argc, char** argv) { + ::testing::GTEST_FLAG(filter) += get_disabled_tests(); + return FrontEndTestUtils::run_tests(argc, argv); +} diff --git a/ngraph/test/frontend/onnx/onnx_utils.hpp b/ngraph/test/frontend/onnx/onnx_utils.hpp new file mode 100644 index 00000000000000..86ce5c94ee51df --- /dev/null +++ b/ngraph/test/frontend/onnx/onnx_utils.hpp @@ -0,0 +1,9 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +static const std::string ONNX_FE = "onnx_experimental"; diff --git a/ngraph/test/frontend/onnx/unit_test.manifest b/ngraph/test/frontend/onnx/unit_test.manifest new file mode 100644 index 00000000000000..790acbe661ff67 --- /dev/null +++ b/ngraph/test/frontend/onnx/unit_test.manifest @@ -0,0 +1,3 @@ +# external weights are not supported for ONNX Frontend +ONNXLoadTest/FrontEndLoadFromTest.testLoadFromTwoFiles/onnx_experimental +ONNXLoadTest/FrontEndLoadFromTest.testLoadFromTwoStreams/onnx_experimental diff --git a/ngraph/test/onnx/onnx_editor.cpp b/ngraph/test/onnx/onnx_editor.cpp index 1e7e36f5dc5e45..590dc7e9350e1a 100644 --- a/ngraph/test/onnx/onnx_editor.cpp +++ b/ngraph/test/onnx/onnx_editor.cpp @@ -1183,6 +1183,21 @@ NGRAPH_TEST(onnx_editor, values__append_two_initializers_mixed_types) { test_case.run(); } +NGRAPH_TEST(onnx_editor, read_model_from_stream) { + std::string path = file_util::path_join(SERIALIZED_ZOO, "onnx/external_data/external_data.onnx"); + std::ifstream stream{path, std::ios::in | std::ios::binary}; + ASSERT_TRUE(stream.is_open()); + ONNXModelEditor editor{stream, path}; + + auto test_case = test::TestCase(editor.get_function()); + test_case.add_input({1.f, 2.f, 3.f, 4.f}); + test_case.add_expected_output(Shape{2, 2}, {3.f, 6.f, 9.f, 12.f}); + + test_case.run(); + + stream.close(); +} + NGRAPH_TEST(onnx_editor, combined__cut_and_replace_shape) { ONNXModelEditor editor{file_util::path_join(SERIALIZED_ZOO, "onnx/model_editor/subgraph__inception_head.onnx")}; diff --git a/ngraph/test/onnx/onnx_import_library.cpp b/ngraph/test/onnx/onnx_import_library.cpp index 2180feeac170e5..61df4cf84e3018 100644 --- a/ngraph/test/onnx/onnx_import_library.cpp +++ b/ngraph/test/onnx/onnx_import_library.cpp @@ -15,7 +15,7 @@ NGRAPH_TEST(onnx, check_ir_version_support) { // It appears you've changed the ONNX library version used by nGraph. Please update the value // tested below to make sure it equals the current IR_VERSION enum value defined in ONNX headers // - // You should also check the onnx_reader/onnx_model_validator.cpp file and make sure that + // You should also check the onnx_common/src/onnx_model_validator.cpp file and make sure that // the details::onnx::is_correct_onnx_field() handles any new fields added in the new release // of the ONNX library. Make sure to update the "Field" enum and the function mentioned above. // diff --git a/scripts/deployment_manager/configs/darwin.json b/scripts/deployment_manager/configs/darwin.json index f3581e31cbae93..09be8f75ed9913 100644 --- a/scripts/deployment_manager/configs/darwin.json +++ b/scripts/deployment_manager/configs/darwin.json @@ -21,7 +21,6 @@ "deployment_tools/inference_engine/lib/intel64/libinference_engine_transformations.dylib", "deployment_tools/inference_engine/lib/intel64/libinference_engine_preproc.so", "deployment_tools/inference_engine/lib/intel64/libinference_engine_ir_reader.so", - "deployment_tools/inference_engine/lib/intel64/libinference_engine_onnx_reader.so", "deployment_tools/inference_engine/lib/intel64/libinference_engine_c_api.dylib", "deployment_tools/inference_engine/lib/intel64/libAutoPlugin.so", "deployment_tools/inference_engine/lib/intel64/libHeteroPlugin.so", diff --git a/scripts/deployment_manager/configs/linux.json b/scripts/deployment_manager/configs/linux.json index fd4600abb8e401..0c39eeb82a4963 100644 --- a/scripts/deployment_manager/configs/linux.json +++ b/scripts/deployment_manager/configs/linux.json @@ -27,7 +27,6 @@ "deployment_tools/inference_engine/lib/intel64/libinference_engine_transformations.so", "deployment_tools/inference_engine/lib/intel64/libinference_engine_preproc.so", "deployment_tools/inference_engine/lib/intel64/libinference_engine_ir_reader.so", - "deployment_tools/inference_engine/lib/intel64/libinference_engine_onnx_reader.so", "deployment_tools/inference_engine/lib/intel64/libinference_engine_c_api.so", "deployment_tools/inference_engine/lib/intel64/libAutoPlugin.so", "deployment_tools/inference_engine/lib/intel64/libHeteroPlugin.so", diff --git a/scripts/deployment_manager/configs/windows.json b/scripts/deployment_manager/configs/windows.json index ba95c29e2440b5..14ceedbff8a3fb 100644 --- a/scripts/deployment_manager/configs/windows.json +++ b/scripts/deployment_manager/configs/windows.json @@ -21,7 +21,6 @@ "deployment_tools/inference_engine/bin/intel64/Release/inference_engine_transformations.dll", "deployment_tools/inference_engine/bin/intel64/Release/inference_engine_preproc.dll", "deployment_tools/inference_engine/bin/intel64/Release/inference_engine_ir_reader.dll", - "deployment_tools/inference_engine/bin/intel64/Release/inference_engine_onnx_reader.dll", "deployment_tools/inference_engine/bin/intel64/Release/inference_engine_c_api.dll", "deployment_tools/inference_engine/bin/intel64/Release/AutoPlugin.dll", "deployment_tools/inference_engine/lib/intel64/Release/HeteroPlugin.dll",