Skip to content

Commit

Permalink
Accept stream pointer instead of shared_ptr in paddle frontend (#6807)
Browse files Browse the repository at this point in the history
* Accept stream pointer instead of shared_ptr

* Fix build

* Fix build tests on centos
  • Loading branch information
mvafin authored Jul 30, 2021
1 parent c1d8c23 commit 9a36e77
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,10 @@ namespace ngraph
} // namespace frontend

template <>
class FRONTEND_API VariantWrapper<std::shared_ptr<std::istream>>
: public VariantImpl<std::shared_ptr<std::istream>>
class FRONTEND_API VariantWrapper<std::istream*> : public VariantImpl<std::istream*>
{
public:
static constexpr VariantTypeInfo type_info{"Variant::std::shared_ptr<std::istream>", 0};
static constexpr VariantTypeInfo type_info{"Variant::std::istream*", 0};
const VariantTypeInfo& get_type_info() const override { return type_info; }
VariantWrapper(const value_type& value)
: VariantImpl<value_type>(value)
Expand Down
2 changes: 1 addition & 1 deletion ngraph/frontend/frontend_manager/src/frontend_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ std::vector<Place::Ptr> Place::get_consuming_operations(const std::string& outpu
return {};
}

constexpr VariantTypeInfo VariantWrapper<std::shared_ptr<std::istream>>::type_info;
constexpr VariantTypeInfo VariantWrapper<std::istream*>::type_info;

#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
constexpr VariantTypeInfo VariantWrapper<std::wstring>::type_info;
Expand Down
23 changes: 10 additions & 13 deletions ngraph/frontend/paddlepaddle/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,9 @@ namespace ngraph
std::istream* variant_to_stream_ptr(const std::shared_ptr<Variant>& variant,
std::ifstream& ext_stream)
{
if (is_type<VariantWrapper<std::shared_ptr<std::istream>>>(variant))
if (is_type<VariantWrapper<std::istream*>>(variant))
{
auto m_stream =
as_type_ptr<VariantWrapper<std::shared_ptr<std::istream>>>(variant)->get();
return m_stream.get();
return as_type_ptr<VariantWrapper<std::istream*>>(variant)->get();
}
else if (is_type<VariantWrapper<std::string>>(variant))
{
Expand Down Expand Up @@ -281,13 +279,13 @@ namespace ngraph
return model_str && model_str.is_open();
}
#endif
else if (is_type<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0]))
else if (is_type<VariantWrapper<std::istream*>>(variants[0]))
{
// Validating first stream, it must contain a model
std::shared_ptr<std::istream> p_model_stream =
as_type_ptr<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0])->get();
auto p_model_stream =
as_type_ptr<VariantWrapper<std::istream*>>(variants[0])->get();
paddle::framework::proto::ProgramDesc fw;
return fw.ParseFromIstream(p_model_stream.get());
return fw.ParseFromIstream(p_model_stream);
}
return false;
}
Expand All @@ -314,13 +312,12 @@ namespace ngraph
#endif
// The case with only model stream provided and no weights. This means model has
// no learnable weights
else if (is_type<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0]))
else if (is_type<VariantWrapper<std::istream*>>(variants[0]))
{
std::shared_ptr<std::istream> p_model_stream =
as_type_ptr<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0])
->get();
auto p_model_stream =
as_type_ptr<VariantWrapper<std::istream*>>(variants[0])->get();
return std::make_shared<InputModelPDPD>(
std::vector<std::istream*>{p_model_stream.get()});
std::vector<std::istream*>{p_model_stream});
}
}
else if (variants.size() == 2)
Expand Down
19 changes: 9 additions & 10 deletions ngraph/test/frontend/shared/src/load_from.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,9 @@ TEST_P(FrontEndLoadFromTest, testLoadFromTwoFiles)

TEST_P(FrontEndLoadFromTest, testLoadFromStream)
{
auto ifs = std::make_shared<std::ifstream>(
FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_stream),
std::ios::in | std::ifstream::binary);
auto is = std::dynamic_pointer_cast<std::istream>(ifs);
std::ifstream ifs(FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_stream),
std::ios::in | std::ios::binary);
std::istream* is = &ifs;
std::vector<std::string> frontends;
FrontEnd::Ptr fe;
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
Expand All @@ -85,14 +84,14 @@ TEST_P(FrontEndLoadFromTest, testLoadFromStream)

TEST_P(FrontEndLoadFromTest, testLoadFromTwoStreams)
{
auto model_ifs = std::make_shared<std::ifstream>(
std::ifstream model_ifs(
FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_streams[0]),
std::ios::in | std::ifstream::binary);
auto weights_ifs = std::make_shared<std::ifstream>(
std::ios::in | std::ios::binary);
std::ifstream weights_ifs(
FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_streams[1]),
std::ios::in | std::ifstream::binary);
auto model_is = std::dynamic_pointer_cast<std::istream>(model_ifs);
auto weights_is = std::dynamic_pointer_cast<std::istream>(weights_ifs);
std::ios::in | std::ios::binary);
std::istream* model_is(&model_ifs);
std::istream* weights_is(&weights_ifs);

std::vector<std::string> frontends;
FrontEnd::Ptr fe;
Expand Down
6 changes: 3 additions & 3 deletions tests/fuzz/src/import_pdpd-fuzzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
frontend_manager.load_by_framework(PDPD);
ngraph::frontend::InputModel::Ptr input_model;
std::stringstream model;
std::stringstream params;
model << std::string((const char *)model_buf, model_size);
std::shared_ptr<std::istream> in_model(&model);
std::istream* in_model(&model);
if (params_buf) {
std::stringstream params;
params << std::string((const char *)params_buf, params_size);
std::shared_ptr<std::istream> in_params(&params);
std::istream* in_params(&params);
input_model = frontend->load(in_model, in_params);
} else
input_model = frontend->load(in_model);
Expand Down

0 comments on commit 9a36e77

Please sign in to comment.