Skip to content

Commit

Permalink
Use ONNX FE instead of ONNX Reader leftovers (#7252)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mateusz Bencer authored Aug 28, 2021
1 parent d0f49fe commit 4a07a0b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 33 deletions.
17 changes: 9 additions & 8 deletions inference-engine/src/inference_engine/src/ie_network_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ namespace {
// Extension to plugins creator
std::multimap<std::string, Reader::Ptr> readers;

static ngraph::frontend::FrontEndManager* get_frontend_manager() {
static ngraph::frontend::FrontEndManager& get_frontend_manager() {
static ngraph::frontend::FrontEndManager manager;
return &manager;
return manager;
}

void registerReaders() {
Expand Down Expand Up @@ -233,7 +233,7 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath,
}
}
// Try to load with FrontEndManager
const auto manager = get_frontend_manager();
auto& manager = get_frontend_manager();
ngraph::frontend::FrontEnd::Ptr FE;
ngraph::frontend::InputModel::Ptr inputModel;
if (!binPath.empty()) {
Expand All @@ -242,11 +242,11 @@ CNNNetwork details::ReadNetwork(const std::string& modelPath,
#else
std::string weights_path = binPath;
#endif
FE = manager->load_by_model(model_path, weights_path);
FE = manager.load_by_model(model_path, weights_path);
if (FE)
inputModel = FE->load(model_path, weights_path);
} else {
FE = manager->load_by_model(model_path);
FE = manager.load_by_model(model_path);
if (FE)
inputModel = FE->load(model_path);
}
Expand All @@ -264,7 +264,8 @@ CNNNetwork details::ReadNetwork(const std::string& model,
const std::vector<IExtensionPtr>& exts) {
// Register readers if it is needed
registerReaders();
std::istringstream modelStream(model);
std::istringstream modelStringStream(model);
std::istream& modelStream = modelStringStream;

assertIfIRv7LikeModel(modelStream);

Expand All @@ -278,10 +279,10 @@ CNNNetwork details::ReadNetwork(const std::string& model,
}
// Try to load with FrontEndManager
// NOTE: weights argument is ignored
const auto manager = get_frontend_manager();
auto& manager = get_frontend_manager();
ngraph::frontend::FrontEnd::Ptr FE;
ngraph::frontend::InputModel::Ptr inputModel;
FE = manager->load_by_model(&modelStream);
FE = manager.load_by_model(&modelStream);
if (FE)
inputModel = FE->load(&modelStream);
if (inputModel) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,6 @@ class FRONTEND_API VariantWrapper<std::istream*> : public VariantImpl<std::istre
VariantWrapper(const value_type& value) : VariantImpl<value_type>(value) {}
};

template <>
class FRONTEND_API VariantWrapper<std::istringstream*> : public VariantImpl<std::istringstream*> {
public:
static constexpr VariantTypeInfo type_info{"Variant::std::istringstream*", 0};
const VariantTypeInfo& get_type_info() const override {
return type_info;
}
VariantWrapper(const value_type& value) : VariantImpl<value_type>(value) {}
};

#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
template <>
class FRONTEND_API VariantWrapper<std::wstring> : public VariantImpl<std::wstring> {
Expand Down
2 changes: 0 additions & 2 deletions ngraph/frontend/frontend_manager/src/frontend_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,6 @@ std::vector<Place::Ptr> Place::get_consuming_operations(const std::string& outpu

constexpr VariantTypeInfo VariantWrapper<std::istream*>::type_info;

constexpr VariantTypeInfo VariantWrapper<std::istringstream*>::type_info;

#if defined(ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
constexpr VariantTypeInfo VariantWrapper<std::wstring>::type_info;
#endif
15 changes: 2 additions & 13 deletions ngraph/frontend/onnx/frontend/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ using namespace ngraph::frontend;
using VariantString = VariantWrapper<std::string>;
using VariantWString = VariantWrapper<std::wstring>;
using VariantIstreamPtr = VariantWrapper<std::istream*>;
using VariantIstringstreamPtr = VariantWrapper<std::istringstream*>;

extern "C" ONNX_FRONTEND_API FrontEndVersion GetAPIVersion() {
return OV_FRONTEND_API_VERSION;
Expand Down Expand Up @@ -48,13 +47,8 @@ InputModel::Ptr FrontEndONNX::load_impl(const std::vector<std::shared_ptr<Varian
return std::make_shared<InputModelONNX>(path);
}
#endif
std::istream* stream = nullptr;
if (ov::is_type<VariantIstreamPtr>(variants[0])) {
stream = ov::as_type_ptr<VariantIstreamPtr>(variants[0])->get();
} else if (ov::is_type<VariantIstringstreamPtr>(variants[0])) {
stream = ov::as_type_ptr<VariantIstringstreamPtr>(variants[0])->get();
}
if (stream != nullptr) {
const auto stream = ov::as_type_ptr<VariantIstreamPtr>(variants[0])->get();
if (variants.size() > 1 && ov::is_type<VariantString>(variants[1])) {
const auto path = ov::as_type_ptr<VariantString>(variants[1])->get();
return std::make_shared<InputModelONNX>(*stream, path);
Expand Down Expand Up @@ -133,13 +127,8 @@ bool FrontEndONNX::supported_impl(const std::vector<std::shared_ptr<Variant>>& v
model_stream.close();
return is_valid_model;
}
std::istream* stream = nullptr;
if (ov::is_type<VariantIstreamPtr>(variants[0])) {
stream = ov::as_type_ptr<VariantIstreamPtr>(variants[0])->get();
} else if (ov::is_type<VariantIstringstreamPtr>(variants[0])) {
stream = ov::as_type_ptr<VariantIstringstreamPtr>(variants[0])->get();
}
if (stream != nullptr) {
const auto stream = ov::as_type_ptr<VariantIstreamPtr>(variants[0])->get();
StreamRewinder rwd{*stream};
return onnx_common::is_valid_model(*stream);
}
Expand Down

0 comments on commit 4a07a0b

Please sign in to comment.