From a7d95bbfa330a2cbbd2936000cb05e95b6ae4412 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Thu, 20 May 2021 14:28:01 +0300 Subject: [PATCH] Added SetExeNetworkInfo to Plugin API level (#5715) --- inference-engine/include/ie_input_info.hpp | 9 ++++++- .../src/auto_plugin/auto_plugin.hpp | 10 +++++++- .../src/multi_device/multi_device_plugin.cpp | 25 +------------------ .../src/multi_device/multi_device_plugin.hpp | 4 --- .../impl/ie_executable_network_internal.hpp | 2 +- .../impl/ie_plugin_internal.hpp | 22 ++++++++++------ .../interface/ie_iplugin_internal.hpp | 6 +++-- .../inference_engine/caching_test.cpp | 2 +- 8 files changed, 39 insertions(+), 41 deletions(-) diff --git a/inference-engine/include/ie_input_info.hpp b/inference-engine/include/ie_input_info.hpp index 9ea794c3bbf765..e844bf7b8442ef 100644 --- a/inference-engine/include/ie_input_info.hpp +++ b/inference-engine/include/ie_input_info.hpp @@ -140,13 +140,20 @@ class InputInfo { /** * @brief Gets pre-process info for the input - * * @return A reference to the PreProcessInfo instance that contains pre-process info for this input */ PreProcessInfo& getPreProcess() { return _preProcessInfo; } + /** + * @brief Gets pre-process info for the input + * @return A reference to the PreProcessInfo instance that contains pre-process info for this input + */ + const PreProcessInfo& getPreProcess() const { + return _preProcessInfo; + } + protected: /** * @brief Pre-process info for the input diff --git a/inference-engine/src/auto_plugin/auto_plugin.hpp b/inference-engine/src/auto_plugin/auto_plugin.hpp index 8fae37870fffa6..3fbe2bbaa284fd 100644 --- a/inference-engine/src/auto_plugin/auto_plugin.hpp +++ b/inference-engine/src/auto_plugin/auto_plugin.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -62,7 +63,14 @@ class AutoInferencePlugin : public IE::InferencePluginInternal { if (!executableNetwork) { IE_THROW() << "Failed to load network by AUTO plugin"; } - return std::make_shared(executableNetwork); + auto impl = std::make_shared(executableNetwork); + + if (std::is_same::value) { + SetExeNetworkInfo(impl, executableNetwork->GetInputsInfo(), + executableNetwork->GetOutputsInfo()); + } + + return impl; } }; diff --git a/inference-engine/src/multi_device/multi_device_plugin.cpp b/inference-engine/src/multi_device/multi_device_plugin.cpp index c8ab5fa3bdd5be..57c4cd2df1ddbf 100644 --- a/inference-engine/src/multi_device/multi_device_plugin.cpp +++ b/inference-engine/src/multi_device/multi_device_plugin.cpp @@ -142,33 +142,10 @@ InferenceEngine::Parameter MultiDeviceInferencePlugin::GetMetric(const std::stri } } -void MultiDeviceInferencePlugin::SetExeNetworkInfo(InferenceEngine::ExecutableNetworkInternal::Ptr exeNetwork, - const InferenceEngine::ConstInputsDataMap& devInputs, - const InferenceEngine::ConstOutputsDataMap& devOutputs) { - // Set inputs/outputs and pointer to plugin manually here - InputsDataMap _inputs, clonedInputs; - OutputsDataMap _outputs, clonedOutputs; - for (auto& it : devInputs) { - InputInfo::CPtr devData = it.second; - InputInfo::Ptr data = std::make_shared(*devData); - _inputs[it.first] = data; - } - for (auto& it : devOutputs) { - CDataPtr devData = it.second; - DataPtr data = std::make_shared(*devData); - _outputs[it.first] = data; - } - copyInputOutputInfo(_inputs, _outputs, clonedInputs, clonedOutputs); - exeNetwork->setNetworkInputs(clonedInputs); - exeNetwork->setNetworkOutputs(clonedOutputs); - exeNetwork->SetPointerToPlugin(shared_from_this()); -} - // Is called only when caching is enabled IExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadNetwork(const std::string& modelPath, const std::map& config) { - CNNNetwork network; - return LoadExeNetworkImpl(modelPath, network, config); + return LoadExeNetworkImpl(modelPath, {}, config); } ExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadExeNetworkImpl(const CNNNetwork &network, diff --git a/inference-engine/src/multi_device/multi_device_plugin.hpp b/inference-engine/src/multi_device/multi_device_plugin.hpp index deb54bdd819879..bd07b5801f69ac 100644 --- a/inference-engine/src/multi_device/multi_device_plugin.hpp +++ b/inference-engine/src/multi_device/multi_device_plugin.hpp @@ -44,10 +44,6 @@ class MultiDeviceInferencePlugin : public InferenceEngine::InferencePluginIntern InferenceEngine::ExecutableNetworkInternal::Ptr LoadExeNetworkImpl(const std::string& modelPath, InferenceEngine::CNNNetwork network, const std::map& config); - - void SetExeNetworkInfo(InferenceEngine::ExecutableNetworkInternal::Ptr exeNetwork, - const InferenceEngine::ConstInputsDataMap& inputs, - const InferenceEngine::ConstOutputsDataMap& outputs); }; } // namespace MultiDevicePlugin diff --git a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_executable_network_internal.hpp b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_executable_network_internal.hpp index e5831b5e7005d4..f7dab9f2a250f7 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_executable_network_internal.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_executable_network_internal.hpp @@ -86,7 +86,7 @@ class ExecutableNetworkInternal : public IExecutableNetworkInternal { * @param[in] plugin The plugin * @note Needed to correctly handle ownership between objects. */ - void SetPointerToPlugin(const IInferencePlugin::Ptr& plugin) { + virtual void SetPointerToPlugin(const IInferencePlugin::Ptr& plugin) { _plugin = plugin; } diff --git a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_plugin_internal.hpp b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_plugin_internal.hpp index 63979814b1fad8..26ad80737065f8 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_plugin_internal.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/impl/ie_plugin_internal.hpp @@ -54,10 +54,6 @@ class InferencePluginInternal : public IInferencePlugin { IExecutableNetworkInternal::Ptr LoadNetwork(const CNNNetwork& network, const std::map& config, RemoteContext::Ptr context) override { - InputsDataMap networkInputs = network.getInputsInfo(), networkInputsCloned; - OutputsDataMap networkOutputs = network.getOutputsInfo(), networkOutputsCloned; - copyInputOutputInfo(networkInputs, networkOutputs, networkInputsCloned, networkOutputsCloned); - ExecutableNetworkInternal::Ptr impl; if (nullptr == context) { impl = LoadExeNetworkImpl(network, config); @@ -65,9 +61,7 @@ class InferencePluginInternal : public IInferencePlugin { impl = LoadExeNetworkImpl(network, context, config); } - impl->setNetworkInputs(networkInputsCloned); - impl->setNetworkOutputs(networkOutputsCloned); - impl->SetPointerToPlugin(shared_from_this()); + SetExeNetworkInfo(impl, network.getInputsInfo(), network.getOutputsInfo()); return impl; } @@ -216,6 +210,20 @@ class InferencePluginInternal : public IInferencePlugin { IE_THROW(NotImplemented); } + template + void SetExeNetworkInfo(const InferenceEngine::ExecutableNetworkInternal::Ptr& exeNetwork, + const std::map >& inputs, + const std::map >& outputs) { + // Set inputs/outputs and pointer to plugin manually here + InferenceEngine::InputsDataMap clonedInputs; + InferenceEngine::OutputsDataMap clonedOutputs; + copyInputOutputInfo(inputs, outputs, clonedInputs, clonedOutputs); + + exeNetwork->setNetworkInputs(clonedInputs); + exeNetwork->setNetworkOutputs(clonedOutputs); + exeNetwork->SetPointerToPlugin(shared_from_this()); + } + std::string _pluginName; //!< A device name that plugins enables std::map _config; //!< A map config keys -> values ICore* _core = nullptr; //!< A pointer to ICore interface diff --git a/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iplugin_internal.hpp b/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iplugin_internal.hpp index 50ccc72cfbc978..9bf6c64e349f4a 100644 --- a/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iplugin_internal.hpp +++ b/inference-engine/src/plugin_api/cpp_interfaces/interface/ie_iplugin_internal.hpp @@ -32,7 +32,7 @@ class IExecutableNetworkInternal; * @param[in] from PreProcessInfo to copy from * @param to PreProcessInfo to copy to */ -static void copyPreProcess(const PreProcessInfo& from, PreProcessInfo& to) { +inline void copyPreProcess(const PreProcessInfo& from, PreProcessInfo& to) { to = from; if (from.getMeanVariant() == MEAN_IMAGE) { for (size_t i = 0; i < from.getNumberOfChannels(); i++) { @@ -54,7 +54,9 @@ static void copyPreProcess(const PreProcessInfo& from, PreProcessInfo& to) { * @param _networkInputs The network inputs to copy to * @param _networkOutputs The network outputs to copy to */ -inline void copyInputOutputInfo(const InputsDataMap & networkInputs, const OutputsDataMap & networkOutputs, +template +inline void copyInputOutputInfo(const std::map > & networkInputs, + const std::map > & networkOutputs, InputsDataMap & _networkInputs, OutputsDataMap & _networkOutputs) { _networkInputs.clear(); _networkOutputs.clear(); diff --git a/inference-engine/tests/functional/inference_engine/caching_test.cpp b/inference-engine/tests/functional/inference_engine/caching_test.cpp index 4a20b2eb2a0680..cdd5d2efa533a1 100644 --- a/inference-engine/tests/functional/inference_engine/caching_test.cpp +++ b/inference-engine/tests/functional/inference_engine/caching_test.cpp @@ -133,7 +133,7 @@ class MockExecutableNetwork : public ExecutableNetworkInternal { ExecutableNetworkInternal::Export(networkModel); } - void SetPointerToPlugin(IInferencePlugin::Ptr plugin) override { + void SetPointerToPlugin(const IInferencePlugin::Ptr& plugin) override { std::lock_guard guard(m_pluginMutex); ExecutableNetworkInternal::SetPointerToPlugin(plugin); }