From 9a36e77f50dffee76422890ce2501d01356b0739 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 30 Jul 2021 12:58:17 +0300 Subject: [PATCH] Accept stream pointer instead of shared_ptr in paddle frontend (#6807) * Accept stream pointer instead of shared_ptr * Fix build * Fix build tests on centos --- .../frontend_manager/frontend_manager.hpp | 5 ++-- .../frontend_manager/src/frontend_manager.cpp | 2 +- ngraph/frontend/paddlepaddle/src/frontend.cpp | 23 ++++++++----------- ngraph/test/frontend/shared/src/load_from.cpp | 19 ++++++++------- tests/fuzz/src/import_pdpd-fuzzer.cc | 6 ++--- 5 files changed, 25 insertions(+), 30 deletions(-) diff --git a/ngraph/frontend/frontend_manager/include/frontend_manager/frontend_manager.hpp b/ngraph/frontend/frontend_manager/include/frontend_manager/frontend_manager.hpp index 2b92a6386b5552..e917c89c83ae0b 100644 --- a/ngraph/frontend/frontend_manager/include/frontend_manager/frontend_manager.hpp +++ b/ngraph/frontend/frontend_manager/include/frontend_manager/frontend_manager.hpp @@ -96,11 +96,10 @@ namespace ngraph } // namespace frontend template <> - class FRONTEND_API VariantWrapper> - : public VariantImpl> + class FRONTEND_API VariantWrapper : public VariantImpl { public: - static constexpr VariantTypeInfo type_info{"Variant::std::shared_ptr", 0}; + static constexpr VariantTypeInfo type_info{"Variant::std::istream*", 0}; const VariantTypeInfo& get_type_info() const override { return type_info; } VariantWrapper(const value_type& value) : VariantImpl(value) diff --git a/ngraph/frontend/frontend_manager/src/frontend_manager.cpp b/ngraph/frontend/frontend_manager/src/frontend_manager.cpp index 2f994ffc59285f..051519341922c9 100644 --- a/ngraph/frontend/frontend_manager/src/frontend_manager.cpp +++ b/ngraph/frontend/frontend_manager/src/frontend_manager.cpp @@ -454,7 +454,7 @@ std::vector Place::get_consuming_operations(const std::string& outpu return {}; } -constexpr VariantTypeInfo VariantWrapper>::type_info; +constexpr VariantTypeInfo VariantWrapper::type_info; #if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32) constexpr VariantTypeInfo VariantWrapper::type_info; diff --git a/ngraph/frontend/paddlepaddle/src/frontend.cpp b/ngraph/frontend/paddlepaddle/src/frontend.cpp index 231fbb6cb5388f..ebd4b0e329ba66 100644 --- a/ngraph/frontend/paddlepaddle/src/frontend.cpp +++ b/ngraph/frontend/paddlepaddle/src/frontend.cpp @@ -134,11 +134,9 @@ namespace ngraph std::istream* variant_to_stream_ptr(const std::shared_ptr& variant, std::ifstream& ext_stream) { - if (is_type>>(variant)) + if (is_type>(variant)) { - auto m_stream = - as_type_ptr>>(variant)->get(); - return m_stream.get(); + return as_type_ptr>(variant)->get(); } else if (is_type>(variant)) { @@ -281,13 +279,13 @@ namespace ngraph return model_str && model_str.is_open(); } #endif - else if (is_type>>(variants[0])) + else if (is_type>(variants[0])) { // Validating first stream, it must contain a model - std::shared_ptr p_model_stream = - as_type_ptr>>(variants[0])->get(); + auto p_model_stream = + as_type_ptr>(variants[0])->get(); paddle::framework::proto::ProgramDesc fw; - return fw.ParseFromIstream(p_model_stream.get()); + return fw.ParseFromIstream(p_model_stream); } return false; } @@ -314,13 +312,12 @@ namespace ngraph #endif // The case with only model stream provided and no weights. This means model has // no learnable weights - else if (is_type>>(variants[0])) + else if (is_type>(variants[0])) { - std::shared_ptr p_model_stream = - as_type_ptr>>(variants[0]) - ->get(); + auto p_model_stream = + as_type_ptr>(variants[0])->get(); return std::make_shared( - std::vector{p_model_stream.get()}); + std::vector{p_model_stream}); } } else if (variants.size() == 2) diff --git a/ngraph/test/frontend/shared/src/load_from.cpp b/ngraph/test/frontend/shared/src/load_from.cpp index 67c2b0888f3a44..9578baaaa29bc9 100644 --- a/ngraph/test/frontend/shared/src/load_from.cpp +++ b/ngraph/test/frontend/shared/src/load_from.cpp @@ -65,10 +65,9 @@ TEST_P(FrontEndLoadFromTest, testLoadFromTwoFiles) TEST_P(FrontEndLoadFromTest, testLoadFromStream) { - auto ifs = std::make_shared( - FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_stream), - std::ios::in | std::ifstream::binary); - auto is = std::dynamic_pointer_cast(ifs); + std::ifstream ifs(FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_stream), + std::ios::in | std::ios::binary); + std::istream* is = &ifs; std::vector frontends; FrontEnd::Ptr fe; ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends()); @@ -85,14 +84,14 @@ TEST_P(FrontEndLoadFromTest, testLoadFromStream) TEST_P(FrontEndLoadFromTest, testLoadFromTwoStreams) { - auto model_ifs = std::make_shared( + std::ifstream model_ifs( FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_streams[0]), - std::ios::in | std::ifstream::binary); - auto weights_ifs = std::make_shared( + std::ios::in | std::ios::binary); + std::ifstream weights_ifs( FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_streams[1]), - std::ios::in | std::ifstream::binary); - auto model_is = std::dynamic_pointer_cast(model_ifs); - auto weights_is = std::dynamic_pointer_cast(weights_ifs); + std::ios::in | std::ios::binary); + std::istream* model_is(&model_ifs); + std::istream* weights_is(&weights_ifs); std::vector frontends; FrontEnd::Ptr fe; diff --git a/tests/fuzz/src/import_pdpd-fuzzer.cc b/tests/fuzz/src/import_pdpd-fuzzer.cc index 313faf46f3ee05..b25338c4356f83 100644 --- a/tests/fuzz/src/import_pdpd-fuzzer.cc +++ b/tests/fuzz/src/import_pdpd-fuzzer.cc @@ -28,12 +28,12 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { frontend_manager.load_by_framework(PDPD); ngraph::frontend::InputModel::Ptr input_model; std::stringstream model; + std::stringstream params; model << std::string((const char *)model_buf, model_size); - std::shared_ptr in_model(&model); + std::istream* in_model(&model); if (params_buf) { - std::stringstream params; params << std::string((const char *)params_buf, params_size); - std::shared_ptr in_params(¶ms); + std::istream* in_params(¶ms); input_model = frontend->load(in_model, in_params); } else input_model = frontend->load(in_model);