Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept stream pointer instead of shared_ptr in paddle frontend #6807

Merged
merged 8 commits into from
Jul 30, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,10 @@ namespace ngraph
} // namespace frontend

template <>
class FRONTEND_API VariantWrapper<std::shared_ptr<std::istream>>
: public VariantImpl<std::shared_ptr<std::istream>>
class FRONTEND_API VariantWrapper<std::istream*> : public VariantImpl<std::istream*>
{
public:
static constexpr VariantTypeInfo type_info{"Variant::std::shared_ptr<std::istream>", 0};
static constexpr VariantTypeInfo type_info{"Variant::std::istream*", 0};
mvafin marked this conversation as resolved.
Show resolved Hide resolved
const VariantTypeInfo& get_type_info() const override { return type_info; }
VariantWrapper(const value_type& value)
: VariantImpl<value_type>(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ std::vector<Place::Ptr> Place::get_consuming_operations(const std::string& outpu
return {};
}

constexpr VariantTypeInfo VariantWrapper<std::shared_ptr<std::istream>>::type_info;
constexpr VariantTypeInfo VariantWrapper<std::istream*>::type_info;

#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
constexpr VariantTypeInfo VariantWrapper<std::wstring>::type_info;
Expand Down
23 changes: 10 additions & 13 deletions ngraph/frontend/paddlepaddle/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ namespace ngraph
std::istream* variant_to_stream_ptr(const std::shared_ptr<Variant>& variant,
std::ifstream& ext_stream)
{
if (is_type<VariantWrapper<std::shared_ptr<std::istream>>>(variant))
if (is_type<VariantWrapper<std::istream*>>(variant))
{
auto m_stream =
as_type_ptr<VariantWrapper<std::shared_ptr<std::istream>>>(variant)->get();
return m_stream.get();
return as_type_ptr<VariantWrapper<std::istream*>>(variant)->get();
}
else if (is_type<VariantWrapper<std::string>>(variant))
{
Expand Down Expand Up @@ -225,13 +223,13 @@ namespace ngraph
return model_str && model_str.is_open();
}
#endif
else if (is_type<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0]))
else if (is_type<VariantWrapper<std::istream*>>(variants[0]))
{
// Validating first stream, it must contain a model
std::shared_ptr<std::istream> p_model_stream =
as_type_ptr<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0])->get();
auto p_model_stream =
as_type_ptr<VariantWrapper<std::istream*>>(variants[0])->get();
paddle::framework::proto::ProgramDesc fw;
return fw.ParseFromIstream(p_model_stream.get());
return fw.ParseFromIstream(p_model_stream);
}
return false;
}
Expand All @@ -258,13 +256,12 @@ namespace ngraph
#endif
// The case with only model stream provided and no weights. This means model has
// no learnable weights
else if (is_type<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0]))
else if (is_type<VariantWrapper<std::istream*>>(variants[0]))
{
std::shared_ptr<std::istream> p_model_stream =
as_type_ptr<VariantWrapper<std::shared_ptr<std::istream>>>(variants[0])
->get();
auto p_model_stream =
as_type_ptr<VariantWrapper<std::istream*>>(variants[0])->get();
return std::make_shared<InputModelPDPD>(
std::vector<std::istream*>{p_model_stream.get()});
std::vector<std::istream*>{p_model_stream});
}
}
else if (variants.size() == 2)
Expand Down
19 changes: 9 additions & 10 deletions ngraph/test/frontend/shared/src/load_from.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,9 @@ TEST_P(FrontEndLoadFromTest, testLoadFromTwoFiles)

TEST_P(FrontEndLoadFromTest, testLoadFromStream)
{
auto ifs = std::make_shared<std::ifstream>(
FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_stream),
std::ios::in | std::ifstream::binary);
auto is = std::dynamic_pointer_cast<std::istream>(ifs);
std::ifstream ifs(FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_stream),
std::ios::in | std::ios::binary);
std::istream* is = &ifs;
std::vector<std::string> frontends;
FrontEnd::Ptr fe;
ASSERT_NO_THROW(frontends = m_fem.get_available_front_ends());
Expand All @@ -85,14 +84,14 @@ TEST_P(FrontEndLoadFromTest, testLoadFromStream)

TEST_P(FrontEndLoadFromTest, testLoadFromTwoStreams)
{
auto model_ifs = std::make_shared<std::ifstream>(
std::ifstream model_ifs(
FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_streams[0]),
std::ios::in | std::ifstream::binary);
auto weights_ifs = std::make_shared<std::ifstream>(
std::ios::in | std::ios::binary);
std::ifstream weights_ifs(
FrontEndTestUtils::make_model_path(m_param.m_modelsPath + m_param.m_streams[1]),
std::ios::in | std::ifstream::binary);
auto model_is = std::dynamic_pointer_cast<std::istream>(model_ifs);
auto weights_is = std::dynamic_pointer_cast<std::istream>(weights_ifs);
std::ios::in | std::ios::binary);
std::istream* model_is(&model_ifs);
std::istream* weights_is(&weights_ifs);

std::vector<std::string> frontends;
FrontEnd::Ptr fe;
Expand Down
6 changes: 3 additions & 3 deletions tests/fuzz/src/import_pdpd-fuzzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) {
frontend_manager.load_by_framework(PDPD);
ngraph::frontend::InputModel::Ptr input_model;
std::stringstream model;
std::stringstream params;
model << std::string((const char *)model_buf, model_size);
std::shared_ptr<std::istream> in_model(&model);
std::istream* in_model(&model);
if (params_buf) {
std::stringstream params;
params << std::string((const char *)params_buf, params_size);
std::shared_ptr<std::istream> in_params(&params);
std::istream* in_params(&params);
input_model = frontend->load(in_model, in_params);
} else
input_model = frontend->load(in_model);
Expand Down