Skip to content

Commit

Permalink
Rework model loading in FE manager, implement PDPD probing
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed May 19, 2021
1 parent d04a866 commit 95aad6e
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 127 deletions.
7 changes: 0 additions & 7 deletions inference-engine/samples/benchmark_app/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,6 @@ int main(int argc, char* argv[]) {

auto startTime = Time::now();
CNNNetwork cnnNetwork = ie.ReadNetwork(FLAGS_m);
// ngraph::frontend::FrontEndManager manager;
// auto FE = manager.loadByFramework("pdpd");
// auto inputModel = FE->loadFromFile(FLAGS_m);
// //inputModel->setPartialShape(inputModel->getInputs()[0], ngraph::PartialShape({1, 224, 224, 3}));
// auto ngFunc = FE->convert(inputModel);
// CNNNetwork cnnNetwork(ngFunc);
// cnnNetwork.serialize("benchmark_app_loaded_network.xml");

auto duration_ms = double_to_string(get_total_ms_time(startTime));
slog::info << "Read network took " << duration_ms << " ms" << slog::endl;
Expand Down
16 changes: 15 additions & 1 deletion inference-engine/src/inference_engine/ie_network_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <file_utils.h>
#include <ie_reader.hpp>
#include <ie_ir_version.hpp>
#include <frontend_manager/frontend_manager.hpp>

#include <fstream>
#include <istream>
Expand Down Expand Up @@ -151,6 +152,19 @@ void assertIfIRv7LikeModel(std::istream & modelStream) {
} // namespace

CNNNetwork details::ReadNetwork(const std::string& modelPath, const std::string& binPath, const std::vector<IExtensionPtr>& exts) {
// Try to load with FrontEndManager
ngraph::frontend::FrontEndManager manager;
std::vector<std::shared_ptr<ngraph::Variant>> variants{std::make_shared<ngraph::VariantWrapper<std::string>>(modelPath)};
if (!binPath.empty()) {
variants.push_back(std::make_shared<ngraph::VariantWrapper<std::string>>(binPath));
}
auto FE = manager.load_by_variants(variants);
if (FE) {
ngraph::frontend::InputModel::Ptr inputModel = FE->load(variants);
auto ngFunc = FE->convert(inputModel);
return CNNNetwork(ngFunc);
}

// Register readers if it is needed
registerReaders();

Expand Down Expand Up @@ -248,4 +262,4 @@ CNNNetwork details::ReadNetwork(const std::string& model, const Blob::CPtr& weig
IE_THROW() << "Unknown model format! Cannot find reader for the model and read it. Please check that reader library exists in your PATH.";
}

} // namespace InferenceEngine
} // namespace InferenceEngine
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <string>
#include "frontend_manager_defs.hpp"
#include "ngraph/function.hpp"
#include "ngraph/variant.hpp"

namespace ngraph
{
Expand Down Expand Up @@ -359,6 +360,8 @@ namespace ngraph

virtual ~FrontEnd();

virtual bool supported(const std::vector<std::shared_ptr<Variant>>& variants) const;

/// \brief Loads an input model by specified model file path
/// If model is stored in several files (e.g. model topology and model weights) -
/// frontend implementation is responsible to handle this case, generally frontend may
Expand All @@ -367,35 +370,8 @@ namespace ngraph
/// \return Loaded input model
virtual InputModel::Ptr load_from_file(const std::string& path) const;

/// \brief Loads an input model by specified number of model files
/// This shall be used for cases when client knows all model files (model, weights, etc)
/// \param paths Array of model files
/// \return Loaded input model
virtual InputModel::Ptr load_from_files(const std::vector<std::string>& paths) const;

/// \brief Loads an input model by already loaded memory buffer
/// Memory structure is frontend-defined and is not specified in generic API
/// \param model Model memory buffer
/// \return Loaded input model
virtual InputModel::Ptr load_from_memory(const void* model) const;

/// \brief Loads an input model from set of memory buffers
/// Memory structure is frontend-defined and is not specified in generic API
/// \param modelParts Array of model memory buffers
/// \return Loaded input model
virtual InputModel::Ptr
load_from_memory_fragments(const std::vector<const void*>& modelParts) const;

/// \brief Loads an input model by input stream representing main model file
/// \param stream Input stream of main model
/// \return Loaded input model
virtual InputModel::Ptr load_from_stream(std::istream& stream) const;

/// \brief Loads an input model by input streams representing all model files
/// \param streams Array of input streams for model
/// \return Loaded input model
virtual InputModel::Ptr
load_from_streams(const std::vector<std::istream*>& streams) const;
load(const std::vector<std::shared_ptr<Variant>>& variants) const;

/// \brief Completely convert and normalize entire function, throws if it is not
/// possible
Expand Down Expand Up @@ -484,8 +460,9 @@ namespace ngraph
/// \param fec Frontend capabilities. It is recommended to use only those capabilities
/// which are needed to minimize load time
/// \return Frontend interface for further loading of model
FrontEnd::Ptr load_by_model(const std::string& path,
FrontEndCapFlags fec = FrontEndCapabilities::FEC_DEFAULT);
FrontEnd::Ptr
load_by_variants(const std::vector<std::shared_ptr<Variant>>& variants,
FrontEndCapFlags fec = FrontEndCapabilities::FEC_DEFAULT);

/// \brief Gets list of registered frontends
std::vector<std::string> get_available_front_ends() const;
Expand Down Expand Up @@ -518,4 +495,16 @@ namespace ngraph

} // namespace frontend

template <>
class NGRAPH_API VariantWrapper<std::istream*> : public VariantImpl<std::istream*>
{
public:
static constexpr VariantTypeInfo type_info{"Variant::std::istream*", 0};
const VariantTypeInfo& get_type_info() const override { return type_info; }
VariantWrapper(const value_type& value)
: VariantImpl<value_type>(value)
{
}
};

} // namespace ngraph
50 changes: 23 additions & 27 deletions ngraph/frontend/frontend_manager/src/frontend_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,18 @@ class FrontEndManager::Impl
return keys;
}

FrontEnd::Ptr loadByModel(const std::string& path, FrontEndCapFlags fec)
FrontEnd::Ptr loadByVariants(const std::vector<std::shared_ptr<Variant>>& variants,
FrontEndCapFlags fec)
{
FRONT_END_NOT_IMPLEMENTED(loadByModel);
for (const auto& factory : m_factories)
{
auto FE = factory.second(fec);
if (FE->supported(variants))
{
return FE;
}
}
return FrontEnd::Ptr();
}

void registerFrontEnd(const std::string& name, FrontEndFactory creator)
Expand Down Expand Up @@ -107,9 +116,11 @@ FrontEnd::Ptr FrontEndManager::load_by_framework(const std::string& framework, F
return m_impl->loadByFramework(framework, fec);
}

FrontEnd::Ptr FrontEndManager::load_by_model(const std::string& path, FrontEndCapFlags fec)
FrontEnd::Ptr
FrontEndManager::load_by_variants(const std::vector<std::shared_ptr<Variant>>& variants,
FrontEndCapFlags fec)
{
return m_impl->loadByModel(path, fec);
return m_impl->loadByVariants(variants, fec);
}

std::vector<std::string> FrontEndManager::get_available_front_ends() const
Expand All @@ -128,37 +139,20 @@ FrontEnd::FrontEnd() = default;

FrontEnd::~FrontEnd() = default;

InputModel::Ptr FrontEnd::load_from_file(const std::string& path) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_file);
}

InputModel::Ptr FrontEnd::load_from_files(const std::vector<std::string>& paths) const
bool FrontEnd::supported(const std::vector<std::shared_ptr<Variant>>& variants) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_files);
return false;
}

InputModel::Ptr FrontEnd::load_from_memory(const void* model) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_memory);
}

InputModel::Ptr
FrontEnd::load_from_memory_fragments(const std::vector<const void*>& modelParts) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_memory_fragments);
}

InputModel::Ptr FrontEnd::load_from_stream(std::istream& path) const
InputModel::Ptr FrontEnd::load_from_file(const std::string& path) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_stream);
FRONT_END_NOT_IMPLEMENTED(load_from_file);
}

InputModel::Ptr FrontEnd::load_from_streams(const std::vector<std::istream*>& paths) const
InputModel::Ptr FrontEnd::load(const std::vector<std::shared_ptr<Variant>>& params) const
{
FRONT_END_NOT_IMPLEMENTED(load_from_streams);
FRONT_END_NOT_IMPLEMENTED(load);
}

std::shared_ptr<ngraph::Function> FrontEnd::convert(InputModel::Ptr model) const
{
FRONT_END_NOT_IMPLEMENTED(convert);
Expand Down Expand Up @@ -388,3 +382,5 @@ Place::Ptr Place::get_source_tensor(int inputPortIndex) const
{
FRONT_END_NOT_IMPLEMENTED(get_source_tensor);
}

constexpr VariantTypeInfo VariantWrapper<std::istream*>::type_info;
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ namespace ngraph
public:
FrontEndPDPD() {}

virtual bool supported(const std::vector<std::shared_ptr<Variant>>& variants) const;

/**
* @brief Reads model from file and deducts file names of weights
* @param path path to folder which contains __model__ file or path to .pdmodel file
Expand All @@ -28,29 +30,14 @@ namespace ngraph
virtual InputModel::Ptr load_from_file(const std::string& path) const override;

/**
* @brief Reads model and weights from files
* @param paths vector containing path to .pdmodel and .pdiparams files
* @return InputModel::Ptr
*/
virtual InputModel::Ptr
load_from_files(const std::vector<std::string>& paths) const override;

/**
* @brief Reads model from stream
* @param model_stream stream containing .pdmodel or __model__ files. Can only be used
* if model have no weights
* @return InputModel::Ptr
*/
virtual InputModel::Ptr load_from_stream(std::istream& model_stream) const override;

/**
* @brief Reads model from stream
* @param paths vector of streams containing .pdmodel and .pdiparams files. Can't be
* used in case of multiple weight files
* @brief Reads model from 1 or 2 given file names or 1 or 2 std::istream containing
* model in protobuf format and weights
* @param params Can be path to folder which contains __model__ file or path to .pdmodel
* file
* @return InputModel::Ptr
*/
virtual InputModel::Ptr
load_from_streams(const std::vector<std::istream*>& paths) const override;
load(const std::vector<std::shared_ptr<Variant>>& params) const override;

virtual std::shared_ptr<Function> convert(InputModel::Ptr model) const override;
};
Expand Down
Loading

0 comments on commit 95aad6e

Please sign in to comment.