diff --git a/src/plugins/auto/auto_schedule.cpp b/src/plugins/auto/auto_schedule.cpp index 6fc8c8fbb78626..a0fda78864a9d2 100644 --- a/src/plugins/auto/auto_schedule.cpp +++ b/src/plugins/auto/auto_schedule.cpp @@ -234,9 +234,7 @@ void AutoSchedule::init(const ScheduleContext::Ptr& sContext) { bool isCumulative = (_autoSContext->_performanceHint == IE::PluginConfigParams::CUMULATIVE_THROUGHPUT) ? true : false; if (isCumulative) { - std::list validDevices = - _autoSContext->_plugin->GetValidDevice(_autoSContext->_devicePriorities, - _loadContext[ACTUALDEVICE].networkPrecision); + const auto& validDevices = _autoSContext->_devicePriorities; // When the hint is ctput and there is only one device, the single-device logic is used if (validDevices.size() == 1) { _loadContext[ACTUALDEVICE].deviceInfo = validDevices.front(); @@ -244,10 +242,6 @@ void AutoSchedule::init(const ScheduleContext::Ptr& sContext) { IE::PluginConfigParams::THROUGHPUT; } else if (validDevices.size() > 1) { _loadContext[ACTUALDEVICE].isEnabled = false; - _autoSContext->_devicePriorities.clear(); - std::copy(std::begin(validDevices), - std::end(validDevices), - std::back_inserter(_autoSContext->_devicePriorities)); // Total number of devices in CTPUT _nCTputDeviceNums = validDevices.size(); // Generate contexts for loading each device @@ -527,7 +521,7 @@ void AutoSchedule::init(const ScheduleContext::Ptr& sContext) { _passthroughExeNet = _loadContext[ACTUALDEVICE].executableNetwork; } } - WaitFirstNetworkReady(); + _autoSContext->_hwExecutableNetwork = WaitFirstNetworkReady(); } void AutoSchedule::TryToLoadNetWork(AutoLoadContext& context, const std::string& modelPath, const IE::CNNNetwork& network, bool isCumulative) { @@ -627,7 +621,7 @@ void AutoSchedule::TryToLoadNetWork(AutoLoadContext& context, const std::string& TryToLoadNetWork(context, modelPath, network, isCumulative); } -void AutoSchedule::WaitFirstNetworkReady() { +SoExecNetwork AutoSchedule::WaitFirstNetworkReady() { if (_firstLoadFuture.valid()) { // wait for the first loading finished _firstLoadFuture.wait(); @@ -635,7 +629,7 @@ void AutoSchedule::WaitFirstNetworkReady() { // check if there is any device that have loaded network successfully for (int i = CONTEXTNUM - 2; i >= 0; i--) { if (_loadContext[i].isEnabled && _loadContext[i].isAlready) { - return; + return _loadContext[i].executableNetwork; } } // the first loading is failed, wait for another loading @@ -644,7 +638,7 @@ void AutoSchedule::WaitFirstNetworkReady() { _loadContext[i].future.wait(); // check if loading is successful if (_loadContext[i].isAlready) { - return; + return _loadContext[i].executableNetwork; } } } @@ -655,17 +649,21 @@ void AutoSchedule::WaitFirstNetworkReady() { } } // devices loaded successfully in CTPUT + SoExecNetwork execNetwork; if (_pCTPUTLoadContext) { int nLoadSucNums = 0; for (size_t i = 0; i < _nCTputDeviceNums; i++) { // check if device loaded successfully if (_pCTPUTLoadContext[i].isAlready) { + if (!execNetwork) { + execNetwork = _pCTPUTLoadContext[i].executableNetwork; + } nLoadSucNums++; } } // one or more devices loaded successfully if (nLoadSucNums > 0) { - return; + return execNetwork; } } IE_THROW() << GetLogTag() << "load all devices failed"; diff --git a/src/plugins/auto/auto_schedule.hpp b/src/plugins/auto/auto_schedule.hpp index bd174d80746ae3..f836391752e38d 100644 --- a/src/plugins/auto/auto_schedule.hpp +++ b/src/plugins/auto/auto_schedule.hpp @@ -61,7 +61,12 @@ class AutoSchedule : public MultiSchedule { AutoScheduleContext::Ptr _autoSContext; private: - void WaitFirstNetworkReady(); + /** + * @brief wait for one of the executable network to finish loading. + * @return An SoPtr object hold an available executable network loaded to HW device. + * @note An exception will be thrown if all loading of network to hw device fails. + */ + SoExecNetwork WaitFirstNetworkReady(); void TryToLoadNetWork(AutoLoadContext& context, const std::string& modelPath, const IE::CNNNetwork& network, bool isCumulative); bool selectOtherDevice(const std::string& currentDeviceName); IE::Task releaseActualdeviceTask; diff --git a/src/plugins/auto/common.hpp b/src/plugins/auto/common.hpp index 2c5dc588d6a6f8..891e0bebb015cd 100644 --- a/src/plugins/auto/common.hpp +++ b/src/plugins/auto/common.hpp @@ -160,6 +160,7 @@ class AutoScheduleContext : public MultiScheduleContext { std::mutex _confMutex; std::mutex _fallbackMutex; MultiDeviceInferencePlugin* _plugin; + SoExecNetwork _hwExecutableNetwork; virtual ~AutoScheduleContext() = default; }; diff --git a/src/plugins/auto/plugin.cpp b/src/plugins/auto/plugin.cpp index e464db48421e38..c688f1ef27c586 100644 --- a/src/plugins/auto/plugin.cpp +++ b/src/plugins/auto/plugin.cpp @@ -425,27 +425,32 @@ IExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadNetworkImpl(cons auto supportDevices = supportDevicesByConfig; CNNNetwork clonedNetwork; std::string clonedModelPath = modelPath; - if (modelPath.empty()) { - // if network is valid - LOG_INFO_TAG("load with CNN network"); - supportDevices = FilterDeviceByNetwork(supportDevicesByConfig, network); - // clone the network, in case of reshape conflict - clonedNetwork = InferenceEngine::details::cloneNetwork(network); - } else { - // model path, enable model load with single device situation - if (supportDevices.size() > 1 && GetName() != "MULTI") { - clonedNetwork = GetCore()->ReadNetwork(modelPath, std::string()); - // do we really need to disable model path? - clonedModelPath = ""; + // reset the strDevices to support devices + strDevices = ""; + // calling GetValidDevices() to get a prioritized list of devices + bool isCumulative = + (autoSContext->_performanceHint == IE::PluginConfigParams::CUMULATIVE_THROUGHPUT) ? true : false; + std::list devicesWithPriority(supportDevices.begin(), supportDevices.end()); + if (!isCumulative) { + if (modelPath.empty()) { + // if network is valid LOG_INFO_TAG("load with CNN network"); + supportDevices = FilterDeviceByNetwork(supportDevicesByConfig, network); + // clone the network, in case of reshape conflict + clonedNetwork = InferenceEngine::details::cloneNetwork(network); } else { - LOG_INFO_TAG("load with model path"); + // model path, enable model load with single device situation + if (supportDevices.size() > 1) { + clonedNetwork = GetCore()->ReadNetwork(modelPath, std::string()); + // do we really need to disable model path? + clonedModelPath = ""; + LOG_INFO_TAG("load with CNN network"); + } else { + LOG_INFO_TAG("load with model path"); + } } + devicesWithPriority = GetValidDevice(supportDevices, networkPrecision); } - // reset the strDevices to support devices - strDevices = ""; - // calling GetValidDevices() to get a prioritized list of devices - auto devicesWithPriority = GetValidDevice(supportDevices, networkPrecision); for (auto iter = devicesWithPriority.begin(); iter != devicesWithPriority.end(); iter++) { strDevices += iter->deviceName; strDevices += ","; @@ -488,6 +493,13 @@ IExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadNetworkImpl(cons } else { impl = std::make_shared(autoSContext, std::make_shared()); } + if (!modelPath.empty()) { + SetExeNetworkInfo(impl, + autoSContext->_hwExecutableNetwork->GetInputsInfo(), + autoSContext->_hwExecutableNetwork->GetOutputsInfo()); + impl->setInputs(autoSContext->_hwExecutableNetwork->getInputs()); + impl->setOutputs(autoSContext->_hwExecutableNetwork->getOutputs()); + } return impl; }