Skip to content

Commit

Permalink
refactor import network
Browse files Browse the repository at this point in the history
  • Loading branch information
aalbersk committed Aug 25, 2020
1 parent ead37fb commit 8931902
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 21 deletions.
19 changes: 12 additions & 7 deletions inference-engine/src/gna_plugin/gna_executable_network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe
std::shared_ptr<GNAPlugin> plg;

public:
GNAExecutableNetwork(const std::string &aotFileName, std::shared_ptr<GNAPlugin> plg)
: plg(plg) {
plg->ImportNetwork(aotFileName);
_networkInputs = plg->GetInputs();
_networkOutputs = plg->GetOutputs();
}
GNAExecutableNetwork(const std::string& aotFileName, std::shared_ptr<GNAPlugin> plg)
: plg(plg) {
std::fstream inputStream(aotFileName, std::ios_base::in | std::ios_base::binary);
if (inputStream.fail()) {
THROW_GNA_EXCEPTION << "Cannot open file to import model: " << aotFileName;
}

plg->ImportNetwork(inputStream);
_networkInputs = plg->GetInputs();
_networkOutputs = plg->GetOutputs();
}

GNAExecutableNetwork(std::istream& networkModel, std::shared_ptr<GNAPlugin> plg)
: plg(plg) {
Expand All @@ -40,7 +45,7 @@ class GNAExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSafe
plg->LoadNetwork(network);
}

GNAExecutableNetwork(const std::string &aotFileName, const std::map<std::string, std::string> &config)
GNAExecutableNetwork(const std::string& aotFileName, const std::map<std::string, std::string>& config)
: GNAExecutableNetwork(aotFileName, std::make_shared<GNAPlugin>(config)) {
}

Expand Down
10 changes: 0 additions & 10 deletions inference-engine/src/gna_plugin/gna_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,16 +1134,6 @@ void GNAPlugin::SetName(const std::string & pluginName) noexcept {
_pluginName = pluginName;
}

InferenceEngine::IExecutableNetwork::Ptr GNAPlugin::ImportNetwork(const std::string &modelFileName) {
// no need to return anything dueto weird design of internal base classes
std::fstream inputStream(modelFileName, ios_base::in | ios_base::binary);
if (inputStream.fail()) {
THROW_GNA_EXCEPTION << "Cannot open file to import model: " << modelFileName;
}

return ImportNetwork(inputStream);
}

InferenceEngine::IExecutableNetwork::Ptr GNAPlugin::ImportNetwork(std::istream& networkModel) {
auto header = GNAModelSerial::ReadHeader(networkModel);

Expand Down
1 change: 0 additions & 1 deletion inference-engine/src/gna_plugin/gna_plugin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ class GNAPlugin : public InferenceEngine::IInferencePlugin {
THROW_GNA_EXCEPTION << "Not implemented";
}

InferenceEngine::IExecutableNetwork::Ptr ImportNetwork(const std::string &modelFileName);
InferenceEngine::IExecutableNetwork::Ptr ImportNetwork(std::istream& networkModel);

/**
Expand Down
3 changes: 2 additions & 1 deletion inference-engine/src/gna_plugin/gna_plugin_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ class GNAPluginInternal : public InferenceEngine::InferencePluginInternal {
defaultConfig.UpdateFromMap(config);
}

InferenceEngine::IExecutableNetwork::Ptr ImportNetwork(
InferenceEngine::IExecutableNetwork::Ptr ImportNetwork(
const std::string &modelFileName,
const std::map<std::string, std::string> &config) override {
Config updated_config(defaultConfig);
updated_config.UpdateFromMap(config);
auto plg = std::make_shared<GNAPlugin>(updated_config.key_config_map);
plgPtr = plg;

return make_executable_network(std::make_shared<GNAExecutableNetwork>(modelFileName, plg));
}

Expand Down
14 changes: 12 additions & 2 deletions inference-engine/tests_deprecated/unit/engines/gna/gna_matcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@ void GNAPropagateMatcher :: match() {
};

auto loadNetworkFromAOT = [&] () {
auto sp = plugin.ImportNetwork(_env.importedModelFileName);
std::fstream inputStream(_env.importedModelFileName, std::ios_base::in | std::ios_base::binary);
if (inputStream.fail()) {
THROW_GNA_EXCEPTION << "Cannot open file to import model: " << _env.importedModelFileName;
}

auto sp = plugin.ImportNetwork(inputStream);
inputsInfo = plugin.GetInputs();
outputsInfo = plugin.GetOutputs();
};
Expand Down Expand Up @@ -505,7 +510,12 @@ void GNADumpXNNMatcher::load(std::shared_ptr<GNAPlugin> & plugin) {
};

auto loadNetworkFromAOT = [&]() {
plugin->ImportNetwork(_env.importedModelFileName);
std::fstream inputStream(_env.importedModelFileName, std::ios_base::in | std::ios_base::binary);
if (inputStream.fail()) {
THROW_GNA_EXCEPTION << "Cannot open file to import model: " << _env.importedModelFileName;
}

plugin->ImportNetwork(inputStream);
};

auto loadNetwork = [&]() {
Expand Down

0 comments on commit 8931902

Please sign in to comment.