diff --git a/inference-engine/tests/functional/inference_engine/pdpd_reader/read_pdpd_model_test.cpp b/inference-engine/tests/functional/inference_engine/pdpd_reader/read_pdpd_model_test.cpp index 5ec2077da1ef0b..1f2a333e3c5419 100644 --- a/inference-engine/tests/functional/inference_engine/pdpd_reader/read_pdpd_model_test.cpp +++ b/inference-engine/tests/functional/inference_engine/pdpd_reader/read_pdpd_model_test.cpp @@ -43,7 +43,7 @@ TEST(PDPD_Reader_Tests, ImportBasicModelToCore) { "RefPDPDFunction"); const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES); const FunctionsComparator::Result res = func_comparator(function, reference); - ASSERT_TRUE(res.valid); + ASSERT_TRUE(res.valid) << res.message; } #if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) @@ -79,6 +79,6 @@ TEST(PDPD_Reader_Tests, ImportBasicModelToCoreWstring) { "RefPDPDFunction"); const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::NAMES); const FunctionsComparator::Result res = func_comparator(function, reference); - ASSERT_TRUE(res.valid); + ASSERT_TRUE(res.valid) << res.message; } #endif diff --git a/ngraph/frontend/paddlepaddle/CMakeLists.txt b/ngraph/frontend/paddlepaddle/CMakeLists.txt index 198d0f2124351c..ab9c5bcef84da7 100644 --- a/ngraph/frontend/paddlepaddle/CMakeLists.txt +++ b/ngraph/frontend/paddlepaddle/CMakeLists.txt @@ -71,7 +71,7 @@ endif() link_system_libraries(${TARGET_NAME} PRIVATE ${Protobuf_LITE_LIBRARIES}) target_link_libraries(${TARGET_NAME} PRIVATE ngraph::frontend_manager::static - PRIVATE ngraph::builder) + PRIVATE ngraph::builder inference_engine_transformations) add_clang_format_target(${TARGET_NAME}_clang FOR_TARGETS ${TARGET_NAME} EXCLUDE_PATTERNS ${PROTO_SRCS} ${PROTO_HDRS}) diff --git a/ngraph/frontend/paddlepaddle/include/paddlepaddle_frontend/frontend.hpp b/ngraph/frontend/paddlepaddle/include/paddlepaddle_frontend/frontend.hpp index 410068b2e26fcc..66e9c053ad2c20 100644 --- a/ngraph/frontend/paddlepaddle/include/paddlepaddle_frontend/frontend.hpp +++ b/ngraph/frontend/paddlepaddle/include/paddlepaddle_frontend/frontend.hpp @@ -7,6 +7,7 @@ #include #include "exceptions.hpp" #include "model.hpp" +#include "place.hpp" namespace ngraph { @@ -22,6 +23,32 @@ namespace ngraph /// \return fully converted nGraph function std::shared_ptr convert(InputModel::Ptr model) const override; + /// \brief Completely convert the remaining, not converted part of a function. + /// \param partiallyConverted partially converted nGraph function + /// \return fully converted nGraph function + std::shared_ptr + convert(std::shared_ptr partiallyConverted) const override; + + /// \brief Convert only those parts of the model that can be converted leaving others + /// as-is. Converted parts are not normalized by additional transformations; normalize + /// function or another form of convert function should be called to finalize the + /// conversion process. + /// \param model Input model + /// \return partially converted nGraph function + std::shared_ptr + convert_partially(InputModel::Ptr model) const override; + + /// \brief Convert operations with one-to-one mapping with decoding nodes. + /// Each decoding node is an nGraph node representing a single FW operation node with + /// all attributes represented in FW-independent way. + /// \param model Input model + /// \return nGraph function after decoding + std::shared_ptr decode(InputModel::Ptr model) const override; + + /// \brief Runs normalization passes on function that was loaded with partial conversion + /// \param function partially converted nGraph function + void normalize(std::shared_ptr function) const override; + protected: /// \brief Check if FrontEndPDPD can recognize model from given parts /// \param params Can be path to folder which contains __model__ file or path to @@ -40,7 +67,10 @@ namespace ngraph private: static std::shared_ptr - convert_model(const std::shared_ptr& model); + convert_each_node(const std::shared_ptr& model, + std::function( + const std::map>&, + const std::shared_ptr&)> func); }; } // namespace frontend diff --git a/ngraph/frontend/paddlepaddle/src/decoder.cpp b/ngraph/frontend/paddlepaddle/src/decoder.cpp index 1758893b254c2a..4669ab1eb2efc6 100644 --- a/ngraph/frontend/paddlepaddle/src/decoder.cpp +++ b/ngraph/frontend/paddlepaddle/src/decoder.cpp @@ -99,6 +99,16 @@ namespace ngraph return output_names; } + size_t DecoderPDPDProto::get_output_size() const + { + size_t res = 0; + for (const auto& output : op_place->get_desc().outputs()) + { + res += output.arguments().size(); + } + return res; + } + ngraph::element::Type DecoderPDPDProto::get_out_port_type(const std::string& port_name) const { @@ -135,5 +145,40 @@ namespace ngraph " Expected number: 0 or 1"); return attrs; } + + std::map DecoderPDPDProto::map_for_each_input( + std::function(const std::string&)> func) const + { + std::map res; + for (const auto& port : op_place->get_desc().inputs()) + { + std::vector> v; + v.reserve(port.arguments_size()); + for (const auto& inp : port.arguments()) + { + v.push_back(func(inp)); + } + res.emplace(std::make_pair(port.parameter(), v)); + } + return res; + } + + std::map DecoderPDPDProto::map_for_each_output( + std::function(const std::string&)> func) const + { + std::map res; + for (const auto& port : op_place->get_desc().outputs()) + { + std::vector> v; + v.reserve(port.arguments_size()); + for (const auto& out : port.arguments()) + { + v.push_back(func(out)); + } + res.emplace(std::make_pair(port.parameter(), v)); + } + return res; + } + } // namespace frontend } // namespace ngraph \ No newline at end of file diff --git a/ngraph/frontend/paddlepaddle/src/decoder.hpp b/ngraph/frontend/paddlepaddle/src/decoder.hpp index 67be6694f860a4..87c7cd978846fe 100644 --- a/ngraph/frontend/paddlepaddle/src/decoder.hpp +++ b/ngraph/frontend/paddlepaddle/src/decoder.hpp @@ -40,10 +40,18 @@ namespace ngraph std::vector get_output_names() const override; + size_t get_output_size() const override; + ngraph::element::Type get_out_port_type(const std::string& port_name) const override; std::string get_op_type() const override; + std::map + map_for_each_input(std::function(const std::string&)> func) const; + + std::map + map_for_each_output(std::function(const std::string&)> func) const; + private: std::vector decode_attribute_helper(const std::string& name) const; diff --git a/ngraph/frontend/paddlepaddle/src/frontend.cpp b/ngraph/frontend/paddlepaddle/src/frontend.cpp index 84c5e9301d6e3a..3668ad15451f6d 100644 --- a/ngraph/frontend/paddlepaddle/src/frontend.cpp +++ b/ngraph/frontend/paddlepaddle/src/frontend.cpp @@ -21,10 +21,9 @@ #include "decoder.hpp" #include "node_context.hpp" #include "op_table.hpp" +#include "pdpd_fw_node.hpp" #include "pdpd_utils.hpp" -#include "frontend_manager/frontend_manager.hpp" - using namespace ngraph::opset7; using namespace ngraph; using namespace ngraph::frontend; @@ -35,25 +34,25 @@ namespace ngraph { namespace pdpd { - NamedOutputs make_ng_node(std::map>& nodes, + NamedOutputs make_ng_node(const std::map>& nodes, const std::shared_ptr& op_place, const std::map& CREATORS_MAP) { - const auto& op = op_place->get_desc(); + const auto& op_desc = op_place->get_desc(); - FRONT_END_OP_CONVERSION_CHECK(CREATORS_MAP.find(op.type()) != CREATORS_MAP.end(), + auto creator_it = CREATORS_MAP.find(op_desc.type()); + FRONT_END_OP_CONVERSION_CHECK(creator_it != CREATORS_MAP.end(), "No creator found for ", - op.type(), + op_desc.type(), " node."); - pdpd::NamedInputs named_inputs; - const auto& input_ports = op_place->get_input_ports(); - for (const auto& name_to_ports : input_ports) + NamedInputs named_inputs; + for (const auto& input_port : op_desc.inputs()) { - for (const auto& port : name_to_ports.second) + for (const auto& in_tensor_name : input_port.arguments()) { - const auto& var_desc = port->get_source_tensor_pdpd()->get_desc(); - if (nodes.count(var_desc.name())) - named_inputs[name_to_ports.first].push_back(nodes.at(var_desc.name())); + auto node_it = nodes.find(in_tensor_name); + if (node_it != nodes.end()) + named_inputs[input_port.parameter()].push_back(node_it->second); else // return empty map when not all inputs exist. It usually means that // these nodes are not used because model inputs were overwritten @@ -61,17 +60,75 @@ namespace ngraph } } - try + return creator_it->second(NodeContext(DecoderPDPDProto(op_place), named_inputs)); + } + + NamedOutputs make_framework_node(const std::map>& nodes, + const std::shared_ptr& op_place) + { + const auto& op_desc = op_place->get_desc(); + + OutputVector inputs_vector; + std::vector inputs_names; + NamedOutputs named_outputs; + for (const auto& input_port : op_desc.inputs()) { - return CREATORS_MAP.at(op.type())( - NodeContext(DecoderPDPDProto(op_place), named_inputs)); + for (const auto& in_tensor_name : input_port.arguments()) + { + auto it = nodes.find(in_tensor_name); + if (it != nodes.end()) + { + inputs_vector.push_back(it->second); + inputs_names.push_back(in_tensor_name); + } + else + { + // return empty map when not all inputs exist. It usually means that + // these nodes are not used because model inputs were overwritten + return named_outputs; + } + } } - catch (...) + + auto node = std::make_shared( + DecoderPDPDProto(op_place), inputs_vector, inputs_names); + + return node->get_named_outputs(); + } + + bool + normalize_framework_node(const std::shared_ptr& node, + const std::map& CREATORS_MAP) + { + auto type = node->get_op_type(); + auto creator_it = CREATORS_MAP.find(type); + FRONT_END_OP_CONVERSION_CHECK( + creator_it != CREATORS_MAP.end(), "No creator found for ", type, " node."); + + auto new_node_outputs = + creator_it->second(NodeContext(node->get_decoder(), node->get_named_inputs())); + auto new_node = new_node_outputs.begin()->second[0].get_node_shared_ptr(); + new_node->set_friendly_name(node->get_friendly_name()); + auto node_outputs = node->get_named_outputs(); + + auto new_ports = new_node_outputs.begin(); + auto old_ports = node_outputs.begin(); + for (; new_ports != new_node_outputs.end() && old_ports != node_outputs.end(); + ++new_ports, ++old_ports) { - // TODO: define exception types - // In case of partial conversion we need to create generic ngraph op here - return NamedOutputs(); + FRONT_END_OP_CONVERSION_CHECK(new_ports->first == old_ports->first, + "Node outputs inconsistent after normalization: ", + node->get_friendly_name()); + auto new_output = new_ports->second.begin(); + auto old_output = old_ports->second.begin(); + for (; new_output != new_ports->second.end() && + old_output != old_ports->second.end(); + ++old_output, ++new_output) + { + old_output->replace(*new_output); + } } + return true; } std::istream* variant_to_stream_ptr(const std::shared_ptr& variant, @@ -104,16 +161,16 @@ namespace ngraph } // namespace pdpd - std::shared_ptr - FrontEndPDPD::convert_model(const std::shared_ptr& model) + std::shared_ptr FrontEndPDPD::convert_each_node( + const std::shared_ptr& model, + std::function( + const std::map>&, const std::shared_ptr&)> + func) { - // std::cout << "Convert Model Start" << std::endl; - - std::map> nodes_dict(model->getTensorValues()); + auto nodes_dict(model->getTensorValues()); ParameterVector parameter_nodes; ResultVector result_nodes; - std::map CREATORS_MAP = pdpd::get_supported_ops(); for (const auto& _inp_place : model->get_inputs()) { const auto& inp_place = std::dynamic_pointer_cast(_inp_place); @@ -130,45 +187,54 @@ namespace ngraph const auto& op_places = model->getOpPlaces(); for (const auto& op_place : op_places) { - const auto& op_type = op_place->get_desc().type(); - if (op_type == "feed" || op_type == "fetch") + const auto& op_desc = op_place->get_desc(); + if (op_desc.type() == "feed" || op_desc.type() == "fetch") { // inputs and outputs are stored in the model already continue; } else { - const auto& named_outputs = - pdpd::make_ng_node(nodes_dict, op_place, CREATORS_MAP); + pdpd::NamedOutputs named_outputs; + try + { + named_outputs = func(nodes_dict, op_place); + } + catch (OpConversionFailure) + { + // TODO: define exception types + // In case of partial conversion we need to create generic ngraph op here + continue; + } - // set layer name by the name of first output var if (!named_outputs.empty()) { - const auto& first_output_var = op_place->get_output_ports() - .begin() - ->second.at(0) - ->get_target_tensor_pdpd() - ->get_desc(); + // set layer name by the name of first output var + const auto& tensor_name = op_desc.outputs().begin()->arguments()[0]; auto node = named_outputs.begin()->second[0].get_node_shared_ptr(); - node->set_friendly_name(first_output_var.name()); - } + node->set_friendly_name(tensor_name); - const auto& out_ports = op_place->get_output_ports(); - for (const auto& name_to_outputs : named_outputs) - { - const auto& ports = out_ports.at(name_to_outputs.first); - FRONT_END_OP_CONVERSION_CHECK( - ports.size() == name_to_outputs.second.size(), - "The number of output tensors must be equal to " - "the number of outputs of the ngraph node."); - for (size_t idx = 0; idx < ports.size(); ++idx) + const auto& out_ports = op_desc.outputs(); + for (const auto& port : out_ports) { - const auto& var = ports[idx]->get_target_tensor_pdpd()->get_desc(); - name_to_outputs.second[idx].get_tensor().set_names({var.name()}); - // if nodes_dict already has node mapped to this tensor name it usually - // means that it was overwritten using setTensorValue - if (!nodes_dict.count(var.name())) - nodes_dict[var.name()] = name_to_outputs.second[idx]; + // TODO: figure a way to safely handle unused outputs + if (named_outputs.count(port.parameter())) + { + const auto& ng_outputs = named_outputs.at(port.parameter()); + FRONT_END_OP_CONVERSION_CHECK( + ng_outputs.size() == port.arguments_size(), + "The number of output tensors must be equal to " + "the number of outputs of the ngraph node."); + for (size_t idx = 0; idx < ng_outputs.size(); ++idx) + { + const auto& var_name = port.arguments()[idx]; + ng_outputs[idx].get_tensor().set_names({var_name}); + // if nodes_dict already has node mapped to this tensor name it + // usually means that it was overwritten using setTensorValue + if (!nodes_dict.count(var_name)) + nodes_dict[var_name] = ng_outputs[idx]; + } + } } } } @@ -288,10 +354,92 @@ namespace ngraph std::shared_ptr FrontEndPDPD::convert(InputModel::Ptr model) const { auto pdpd_model = std::dynamic_pointer_cast(model); - auto f = convert_model(pdpd_model); + std::map CREATORS_MAP = pdpd::get_supported_ops(); + auto f = + convert_each_node(pdpd_model, + [&](const std::map>& nodes_dict, + const std::shared_ptr& op_place) { + return pdpd::make_ng_node(nodes_dict, op_place, CREATORS_MAP); + }); + return f; + } + + std::shared_ptr + FrontEndPDPD::convert(std::shared_ptr partiallyConverted) const + { + auto function = clone_function(*partiallyConverted.get()); + for (const auto& node : function->get_ordered_ops()) + { + if (is_type(node)) + { + pdpd::normalize_framework_node( + std::dynamic_pointer_cast(node), + pdpd::get_supported_ops()); + } + } + for (auto result : function->get_results()) + { + result->validate_and_infer_types(); + } + return function; + } + + std::shared_ptr + FrontEndPDPD::convert_partially(InputModel::Ptr model) const + { + auto pdpd_model = std::dynamic_pointer_cast(model); + std::map CREATORS_MAP = pdpd::get_supported_ops(); + auto f = convert_each_node( + pdpd_model, + [&](const std::map>& nodes_dict, + const std::shared_ptr& op_place) { + pdpd::NamedOutputs named_outputs; + try + { + named_outputs = pdpd::make_ng_node(nodes_dict, op_place, CREATORS_MAP); + } + catch (const OpConversionFailure&) + { + named_outputs = pdpd::make_framework_node(nodes_dict, op_place); + } + return named_outputs; + }); return f; } + std::shared_ptr FrontEndPDPD::decode(InputModel::Ptr model) const + { + auto pdpd_model = std::dynamic_pointer_cast(model); + std::map CREATORS_MAP = pdpd::get_supported_ops(); + auto f = convert_each_node(pdpd_model, pdpd::make_framework_node); + return f; + } + + void FrontEndPDPD::normalize(std::shared_ptr function) const + { + for (const auto& node : function->get_ordered_ops()) + { + if (is_type(node)) + { + try + { + pdpd::normalize_framework_node( + std::dynamic_pointer_cast(node), + pdpd::get_supported_ops()); + } + catch (const OpConversionFailure&) + { + // do nothing if conversion failed + continue; + } + } + } + for (auto result : function->get_results()) + { + result->validate_and_infer_types(); + } + } + } // namespace frontend } // namespace ngraph diff --git a/ngraph/frontend/paddlepaddle/src/node_context.hpp b/ngraph/frontend/paddlepaddle/src/node_context.hpp index 3cee4812d712e5..21201003c345f2 100644 --- a/ngraph/frontend/paddlepaddle/src/node_context.hpp +++ b/ngraph/frontend/paddlepaddle/src/node_context.hpp @@ -54,6 +54,8 @@ namespace ngraph virtual std::vector get_output_names() const = 0; + virtual size_t get_output_size() const = 0; + /// \brief Get output port type /// /// Current API assumes that output port has only one output type. @@ -141,6 +143,18 @@ namespace ngraph return name_map.at(name); } + /// Returns all inputs in order they appear in map. This is used for FrameworkNode + /// creation + OutputVector get_all_ng_inputs() const + { + OutputVector res; + for (const auto& entry : name_map) + { + res.insert(res.end(), entry.second.begin(), entry.second.end()); + } + return res; + } + std::vector get_output_names() const { return decoder.get_output_names(); diff --git a/ngraph/frontend/paddlepaddle/src/pdpd_fw_node.cpp b/ngraph/frontend/paddlepaddle/src/pdpd_fw_node.cpp new file mode 100644 index 00000000000000..a8ea630732101c --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/pdpd_fw_node.cpp @@ -0,0 +1,35 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +namespace ngraph +{ + namespace frontend + { + NGRAPH_RTTI_DEFINITION(PDPDFrameworkNode, "PDPDFrameworkNode", 1); + + std::map PDPDFrameworkNode::get_named_inputs() const + { + return m_decoder.map_for_each_input([&](std::string name) { + auto it = std::find(m_inputs_names.begin(), m_inputs_names.end(), name); + if (it != m_inputs_names.end()) + { + return input(it - m_inputs_names.begin()).get_source_output(); + } + else + { + return Output(); + } + }); + } + + std::map PDPDFrameworkNode::get_named_outputs() + { + size_t idx = 0; + return m_decoder.map_for_each_output([&](std::string name) { return output(idx++); }); + } + + } // namespace frontend +} // namespace ngraph diff --git a/ngraph/frontend/paddlepaddle/src/pdpd_fw_node.hpp b/ngraph/frontend/paddlepaddle/src/pdpd_fw_node.hpp new file mode 100644 index 00000000000000..7f3482235087e4 --- /dev/null +++ b/ngraph/frontend/paddlepaddle/src/pdpd_fw_node.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "decoder.hpp" + +namespace ngraph +{ + namespace frontend + { + class PDPDFrameworkNode : public op::FrameworkNode + { + public: + NGRAPH_RTTI_DECLARATION; + + PDPDFrameworkNode(const DecoderPDPDProto& decoder, + const OutputVector& inputs, + const std::vector& inputs_names) + : FrameworkNode(inputs, decoder.get_output_size()) + , m_decoder{decoder} + , m_inputs_names{inputs_names} + { + op::FrameworkNodeAttrs attrs; + attrs.set_type_name(m_decoder.get_op_type()); + set_attrs(attrs); + } + + virtual std::shared_ptr + clone_with_new_inputs(const OutputVector& inputs) const override + { + return std::make_shared(m_decoder, inputs, m_inputs_names); + } + + std::string get_op_type() const { return m_decoder.get_op_type(); } + + const DecoderPDPDProto& get_decoder() const { return m_decoder; } + + std::map get_named_inputs() const; + + std::map get_named_outputs(); + + private: + const DecoderPDPDProto m_decoder; + std::vector m_inputs_names; + }; + } // namespace frontend +} // namespace ngraph \ No newline at end of file diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 5cc1197307bd5a..f2fcdd42cd495d 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -632,7 +632,7 @@ install(TARGETS unit-test ############ FRONTEND ############ target_include_directories(unit-test PRIVATE ${FRONTEND_INCLUDE_PATH} frontend/shared/include) -target_link_libraries(unit-test PRIVATE frontend_manager cnpy) +target_link_libraries(unit-test PRIVATE frontend_manager cnpy commonTestUtils) add_subdirectory(frontend) ### END FRONTEND ### diff --git a/ngraph/test/frontend/paddlepaddle/convert_model.cpp b/ngraph/test/frontend/paddlepaddle/convert_model.cpp new file mode 100644 index 00000000000000..3c71b1d32fa877 --- /dev/null +++ b/ngraph/test/frontend/paddlepaddle/convert_model.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "convert_model.hpp" + +using namespace ngraph; +using namespace ngraph::frontend; + +static const std::string PDPD = "pdpd"; + +using PDPDConvertModelTest = FrontEndConvertModelTest; + +static const std::vector models{ + std::string("conv2d"), + std::string("conv2d_s/conv2d.pdmodel"), + std::string("conv2d_relu/conv2d_relu.pdmodel"), + std::string("2in_2out/2in_2out.pdmodel"), + std::string("multi_tensor_split/multi_tensor_split.pdmodel"), + std::string("2in_2out_dynbatch/2in_2out_dynbatch.pdmodel"), +}; + +INSTANTIATE_TEST_SUITE_P(PDPDConvertModelTest, + FrontEndConvertModelTest, + ::testing::Combine(::testing::Values(PDPD), + ::testing::Values(std::string(TEST_PDPD_MODELS)), + ::testing::ValuesIn(models)), + FrontEndConvertModelTest::getTestCaseName); diff --git a/ngraph/test/frontend/shared/include/convert_model.hpp b/ngraph/test/frontend/shared/include/convert_model.hpp new file mode 100644 index 00000000000000..5fef0354671338 --- /dev/null +++ b/ngraph/test/frontend/shared/include/convert_model.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include + +using ConvertParam = std::tuple; // Model name + +class FrontEndConvertModelTest : public ::testing::TestWithParam +{ +public: + std::string m_feName; + std::string m_pathToModels; + std::string m_modelFile; + ngraph::frontend::FrontEndManager m_fem; + ngraph::frontend::FrontEnd::Ptr m_frontEnd; + ngraph::frontend::InputModel::Ptr m_inputModel; + + static std::string getTestCaseName(const testing::TestParamInfo& obj); + + void SetUp() override; + +protected: + void initParamTest(); + + void doLoadFromFile(); +}; diff --git a/ngraph/test/frontend/shared/src/convert_model.cpp b/ngraph/test/frontend/shared/src/convert_model.cpp new file mode 100644 index 00000000000000..e78520bc2942e3 --- /dev/null +++ b/ngraph/test/frontend/shared/src/convert_model.cpp @@ -0,0 +1,92 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "convert_model.hpp" +#include "common_test_utils/ngraph_test_utils.hpp" +#include "utils.hpp" + +using namespace ngraph; +using namespace ngraph::frontend; + +std::string + FrontEndConvertModelTest::getTestCaseName(const testing::TestParamInfo& obj) +{ + std::string fe, path, fileName; + std::tie(fe, path, fileName) = obj.param; + return fe + "_" + FrontEndTestUtils::fileToTestName(fileName); +} + +void FrontEndConvertModelTest::SetUp() +{ + FrontEndTestUtils::setupTestEnv(); + m_fem = FrontEndManager(); // re-initialize after setting up environment + initParamTest(); +} + +void FrontEndConvertModelTest::initParamTest() +{ + std::tie(m_feName, m_pathToModels, m_modelFile) = GetParam(); + m_modelFile = m_pathToModels + m_modelFile; +} + +void FrontEndConvertModelTest::doLoadFromFile() +{ + std::vector frontends; + ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends()); + ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework(m_feName)); + ASSERT_NE(m_frontEnd, nullptr); + ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(m_modelFile)); + ASSERT_NE(m_inputModel, nullptr); +} + +TEST_P(FrontEndConvertModelTest, test_convert_partially_equal_convert) +{ + ASSERT_NO_THROW(doLoadFromFile()); + std::shared_ptr function_ref; + ASSERT_NO_THROW(function_ref = m_frontEnd->convert(m_inputModel)); + ASSERT_NE(function_ref, nullptr); + std::shared_ptr function; + ASSERT_NO_THROW(function = m_frontEnd->convert_partially(m_inputModel)); + ASSERT_NE(function, nullptr); + + const FunctionsComparator func_comparator = + FunctionsComparator::with_default().enable(FunctionsComparator::NAMES); + const FunctionsComparator::Result res = func_comparator(function, function_ref); + ASSERT_TRUE(res.valid) << res.message; +} + +TEST_P(FrontEndConvertModelTest, test_decode_convert_equal_convert) +{ + ASSERT_NO_THROW(doLoadFromFile()); + std::shared_ptr function_ref; + ASSERT_NO_THROW(function_ref = m_frontEnd->convert(m_inputModel)); + ASSERT_NE(function_ref, nullptr); + std::shared_ptr function_tmp; + ASSERT_NO_THROW(function_tmp = m_frontEnd->decode(m_inputModel)); + std::shared_ptr function; + ASSERT_NO_THROW(function = m_frontEnd->convert(function_tmp)); + ASSERT_NE(function, nullptr); + + const FunctionsComparator func_comparator = + FunctionsComparator::with_default().enable(FunctionsComparator::NAMES); + const FunctionsComparator::Result res = func_comparator(function, function_ref); + ASSERT_TRUE(res.valid) << res.message; +} + +TEST_P(FrontEndConvertModelTest, test_decode_normalize_equal_convert) +{ + ASSERT_NO_THROW(doLoadFromFile()); + std::shared_ptr function_ref; + ASSERT_NO_THROW(function_ref = m_frontEnd->convert(m_inputModel)); + ASSERT_NE(function_ref, nullptr); + std::shared_ptr function; + ASSERT_NO_THROW(function = m_frontEnd->decode(m_inputModel)); + ASSERT_NE(function, nullptr); + ASSERT_NO_THROW(m_frontEnd->normalize(function)); + + const FunctionsComparator func_comparator = + FunctionsComparator::with_default().enable(FunctionsComparator::NAMES); + const FunctionsComparator::Result res = func_comparator(function, function_ref); + ASSERT_TRUE(res.valid) << res.message; +}