diff --git a/inference-engine/src/auto_plugin/auto_plugin.cpp b/inference-engine/src/auto_plugin/auto_plugin.cpp index 692d8955a09424..2d533d1a26a970 100644 --- a/inference-engine/src/auto_plugin/auto_plugin.cpp +++ b/inference-engine/src/auto_plugin/auto_plugin.cpp @@ -24,13 +24,65 @@ namespace { return config; } - DeviceInformation SelectDevice(const std::vector& metaDevices) { + std::string GetNetworkPrecision(const InferenceEngine::CNNNetwork &network) { + for (auto&& layer : network.getInputsInfo()) { + auto precision = layer.second->getPrecision(); + auto name = std::string(precision.name()); + if (name == "I8") { + name = "INT8"; + } + return name; + } + return {}; + } + + DeviceInformation SelectDevice(const InferenceEngine::CNNNetwork &network, + const std::vector& metaDevices, + const std::vector& optCap) { + if (metaDevices.empty()) { + IE_THROW(NotFound) << "No available device to select in AUTO plugin"; + } + if (metaDevices.size() == 1) { + return metaDevices.at(0); + } + + std::vector CPU; + std::vector GPU; + for (auto& item : metaDevices) { if (item.deviceName.find("CPU") == 0) { - return item; + CPU.push_back(item); + continue; + } + if (item.deviceName.find("GPU") == 0) { + GPU.push_back(item); + continue; + } + } + + auto networkPrecision = GetNetworkPrecision(network); + auto getCap = [&](std::string&& substr){ + auto capability = std::find_if(optCap.begin(), optCap.end(), + [&](const std::string& c)->bool{ return (c.find(substr) != std::string::npos);}); + return capability; + }; + + if (CPU.empty() && GPU.empty()) { + IE_THROW(NotFound) << "No available device found"; + } + + std::sort(GPU.begin(), GPU.end(), [](const DeviceInformation& a, const DeviceInformation& b)->bool{return b.deviceName < a.deviceName;}); + + if (!GPU.empty()) { + auto capability = getCap("GPU"); + if (capability != optCap.end() && capability->find(networkPrecision) != std::string::npos) { + return GPU[0]; } } - IE_THROW(NotFound) << "No available device could be used"; + if (CPU.empty()) { + IE_THROW() << "Cannot select any device"; + } + return CPU[0]; } } // namespace @@ -50,9 +102,9 @@ IE::ExecutableNetworkInternal::Ptr AutoInferencePlugin::LoadExeNetworkImpl(const auto fullConfig = mergeConfigs(_config, config); auto metaDevices = GetDeviceChoice(fullConfig); + auto optCap = GetOptimizationCapabilities(); - // FIXME: always select CPU device now - DeviceInformation selectedDevice = SelectDevice(metaDevices); + DeviceInformation selectedDevice = SelectDevice(network, metaDevices, optCap); IE::ExecutableNetwork executableNetwork; try { executableNetwork = GetCore()->LoadNetwork(network, selectedDevice.deviceName, selectedDevice.config); @@ -179,6 +231,23 @@ std::vector AutoInferencePlugin::GetDeviceChoice( return metaDevices; } +std::vector AutoInferencePlugin::GetOptimizationCapabilities() const { + // FIXME: workaround to get devicelist. + std::unordered_set capabilities; + std::vector queryDeviceLists{"CPU", "GPU"}; + for (auto &item : queryDeviceLists) { + try { + std::vector device_cap = + GetCore()->GetMetric(item, METRIC_KEY(OPTIMIZATION_CAPABILITIES)); + for (auto &dc : device_cap) { + capabilities.insert(dc); + } + } catch (...) { + } + } + return {capabilities.begin(), capabilities.end()}; +} + //////////////////////////////////// private & protected functions /////////////////// ConfigType AutoInferencePlugin::GetSupportedConfig(const ConfigType& config, const std::string& deviceName) const { diff --git a/inference-engine/src/auto_plugin/auto_plugin.hpp b/inference-engine/src/auto_plugin/auto_plugin.hpp index a2a885d4f1fc51..fe9af5ff30af27 100644 --- a/inference-engine/src/auto_plugin/auto_plugin.hpp +++ b/inference-engine/src/auto_plugin/auto_plugin.hpp @@ -29,6 +29,7 @@ class AutoInferencePlugin : public IE::InferencePluginInternal { private: std::vector GetDeviceChoice(const ConfigType& config) const; + std::vector GetOptimizationCapabilities() const; protected: ConfigType GetSupportedConfig(const ConfigType& config, const AutoPlugin::DeviceName & deviceName) const;