Skip to content

Commit

Permalink
Address reviewer's comment: clean and simplify code
Browse files Browse the repository at this point in the history
Signed-off-by: Shoujiang Ma <[email protected]>
  • Loading branch information
mashoujiang committed May 19, 2021
1 parent d45a27a commit 04874aa
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 151 deletions.
11 changes: 2 additions & 9 deletions inference-engine/src/auto_plugin/auto_exec_network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,8 @@

namespace AutoPlugin {

using namespace InferenceEngine;

AutoExecutableNetwork::AutoExecutableNetwork(const SoExecutableNetworkInternal& network,
const DeviceInformation& deviceInfo,
const bool needPerfCounters) :
_deviceInfo(deviceInfo),
_network(network),
_config(deviceInfo.config.begin(), deviceInfo.config.end()),
_needPerfCounters(needPerfCounters) {
AutoExecutableNetwork::AutoExecutableNetwork(const SoExecutableNetworkInternal& network) :
_network(network) {
}

AutoExecutableNetwork::~AutoExecutableNetwork() = default;
Expand Down
10 changes: 3 additions & 7 deletions inference-engine/src/auto_plugin/auto_exec_network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ class AutoExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSaf
public:
using Ptr = std::shared_ptr<AutoExecutableNetwork>;

AutoExecutableNetwork(const InferenceEngine::SoExecutableNetworkInternal& network,
const DeviceInformation& deviceInfo,
const bool needPerfCounters = false);
explicit AutoExecutableNetwork(const InferenceEngine::SoExecutableNetworkInternal& network);

void Export(std::ostream& networkModel) override;
InferenceEngine::RemoteContext::Ptr GetContext() const override;
Expand All @@ -42,10 +40,8 @@ class AutoExecutableNetwork : public InferenceEngine::ExecutableNetworkThreadSaf
InferenceEngine::OutputsDataMap networkOutputs) override;
~AutoExecutableNetwork() override;

DeviceInformation _deviceInfo;
InferenceEngine::SoExecutableNetworkInternal _network;
std::unordered_map<std::string, InferenceEngine::Parameter> _config;
bool _needPerfCounters = false;
private:
InferenceEngine::SoExecutableNetworkInternal _network;
};

} // namespace AutoPlugin
139 changes: 11 additions & 128 deletions inference-engine/src/auto_plugin/auto_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@

namespace AutoPlugin {
namespace {
ConfigType mergeConfigs(ConfigType config, const ConfigType& local) {
for (auto && kvp : local) {
config[kvp.first] = kvp.second;
}
return config;
}

std::string GetNetworkPrecision(const InferenceEngine::CNNNetwork &network) {
auto nGraphFunc = network.getFunction();
for (auto & node : nGraphFunc->get_ordered_ops()) {
Expand Down Expand Up @@ -56,93 +49,17 @@ AutoInferencePlugin::AutoInferencePlugin() {

IE::IExecutableNetworkInternal::Ptr AutoInferencePlugin::LoadNetwork(const std::string& fileName,
const ConfigType& config) {
if (GetCore() == nullptr) {
IE_THROW() << "Please, work with AUTO device via InferencEngine::Core object";
}

auto fullConfig = mergeConfigs(_config, config);
auto metaDevices = GetDeviceChoice(fullConfig);

DeviceInformation selectedDevice;
IE::SoExecutableNetworkInternal executableNetwork;
while (!metaDevices.empty()) {
selectedDevice = SelectDevice(metaDevices);
try {
executableNetwork = GetCore()->LoadNetwork(fileName, selectedDevice.deviceName, config);
break;
} catch (...) {
auto eraseDevice = std::find_if(metaDevices.begin(), metaDevices.end(),
[=](const DeviceInformation &d) -> bool { return d.deviceName == selectedDevice.deviceName; });
if (eraseDevice == metaDevices.end()) {
IE_THROW() << "Didn't find the selected device name";
}
metaDevices.erase(eraseDevice);
executableNetwork = {};
}
}
if (!executableNetwork) {
IE_THROW() << "Failed to load network by AUTO plugin";
}

bool enablePerfCounters = false;
try {
enablePerfCounters =
executableNetwork->GetConfig(IE::PluginConfigParams::KEY_PERF_COUNT).as<std::string>() ==
IE::PluginConfigParams::YES;
}
catch (...) {
}

return std::make_shared<AutoExecutableNetwork>(executableNetwork,
selectedDevice,
enablePerfCounters);
return LoadNetworkImpl(fileName, config);
}

IE::ExecutableNetworkInternal::Ptr AutoInferencePlugin::LoadExeNetworkImpl(const IE::CNNNetwork& network,
const ConfigType& config) {
if (GetCore() == nullptr) {
IE_THROW() << "Please, work with AUTO device via InferencEngine::Core object";
}

if (network.getFunction() == nullptr) {
IE_THROW() << "AUTO device supports just ngraph network representation";
}

auto fullConfig = mergeConfigs(_config, config);
auto metaDevices = GetDeviceChoice(fullConfig);
DeviceInformation selectedDevice;
IE::SoExecutableNetworkInternal executableNetwork;
while (!metaDevices.empty()) {
selectedDevice = SelectDevice(network, metaDevices);
try {
executableNetwork = GetCore()->LoadNetwork(
network, selectedDevice.deviceName, selectedDevice.config);
break;
} catch (...) {
auto eraseDevice = std::find_if(metaDevices.begin(), metaDevices.end(),
[=](const DeviceInformation& d)->bool{return d.deviceName == selectedDevice.deviceName;});
if (eraseDevice == metaDevices.end()) {
IE_THROW() << "Didn't find the selected device name";
}
metaDevices.erase(eraseDevice);
executableNetwork = {};
}
}
if (!executableNetwork) {
IE_THROW() << "Failed to load network by AUTO plugin";
}

bool enablePerfCounters = false;
try {
enablePerfCounters =
executableNetwork->GetConfig(IE::PluginConfigParams::KEY_PERF_COUNT).as<std::string>() ==
IE::PluginConfigParams::YES;
} catch (...) {
}

return std::make_shared<AutoExecutableNetwork>(executableNetwork,
selectedDevice,
enablePerfCounters);
auto networkPrecision = GetNetworkPrecision(network);
return LoadNetworkImpl(network, config, networkPrecision);
}

IE::QueryNetworkResult AutoInferencePlugin::QueryNetwork(const IE::CNNNetwork& network, const ConfigType& config) const {
Expand Down Expand Up @@ -293,55 +210,14 @@ ConfigType AutoInferencePlugin::GetSupportedConfig(const ConfigType& config,
return supportedConfig;
}

DeviceInformation AutoInferencePlugin::SelectDevice(const std::vector<DeviceInformation>& metaDevices) {
if (metaDevices.empty()) {
IE_THROW(NotFound) << "No available device to select in AUTO plugin";
}
if (metaDevices.size() == 1) {
return metaDevices.at(0);
}

std::vector<DeviceInformation> CPU;
std::vector<DeviceInformation> GPU;

for (auto& item : metaDevices) {
if (item.deviceName.find("CPU") == 0) {
CPU.push_back(item);
continue;
}
if (item.deviceName.find("GPU") == 0) {
GPU.push_back(item);
continue;
}
}

if (CPU.empty() && GPU.empty()) {
IE_THROW(NotFound) << "No available device found";
}

// Sort GPU by name: GPU.2 > GPU.1 > GPU.0 > GPU, so we always choose the GPU[0] as best device
std::sort(GPU.begin(), GPU.end(), [](const DeviceInformation& a, const DeviceInformation& b)->bool{return b.deviceName < a.deviceName;});

if (!GPU.empty()) {
return GPU[0];
}
if (CPU.empty()) {
IE_THROW() << "Cannot select any device";
}
return CPU[0];
}

DeviceInformation AutoInferencePlugin::SelectDevice(const InferenceEngine::CNNNetwork& network,
const std::vector<DeviceInformation>& metaDevices) {
DeviceInformation AutoInferencePlugin::SelectDevice(const std::vector<DeviceInformation>& metaDevices, const std::string& networkPrecision) {
if (metaDevices.empty()) {
IE_THROW(NotFound) << "No available device to select in AUTO plugin";
}
if (metaDevices.size() == 1) {
return metaDevices.at(0);
}

auto networkPrecision = GetNetworkPrecision(network);

std::vector<DeviceInformation> CPU;
std::vector<DeviceInformation> GPU;

Expand Down Expand Up @@ -377,6 +253,13 @@ DeviceInformation AutoInferencePlugin::SelectDevice(const InferenceEngine::CNNNe
return CPU[0];
}

ConfigType AutoInferencePlugin::mergeConfigs(ConfigType config, const ConfigType& local) {
for (auto && kvp : local) {
config[kvp.first] = kvp.second;
}
return config;
}

// define CreatePluginEngine to create plugin instance
static const IE::Version version = {{2, 1}, CI_BUILD_NUMBER, "AutoPlugin"};
IE_DEFINE_PLUGIN_CREATE_FUNCTION(AutoInferencePlugin, version)
Expand Down
40 changes: 33 additions & 7 deletions inference-engine/src/auto_plugin/auto_plugin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ class AutoInferencePlugin : public IE::InferencePluginInternal {
public:
AutoInferencePlugin();
~AutoInferencePlugin() = default;
IE::ExecutableNetwork LoadNetwork(const std::string &modelPath,
const std::map<std::string, std::string> &config) override;
IE::ExecutableNetworkInternal::Ptr LoadExeNetworkImpl(const IE::CNNNetwork& network, const ConfigType& config) override;
IE::IExecutableNetworkInternal::Ptr LoadNetwork(const std::string& fileName, const ConfigType& config) override;
IE::QueryNetworkResult QueryNetwork(const IE::CNNNetwork& network, const ConfigType& config) const override;
Expand All @@ -33,12 +31,40 @@ class AutoInferencePlugin : public IE::InferencePluginInternal {
private:
std::vector<AutoPlugin::DeviceInformation> GetDeviceChoice(const ConfigType& config) const;
std::vector<std::string> GetOptimizationCapabilities() const;
DeviceInformation SelectDevice(const std::vector<DeviceInformation>& metaDevices);
DeviceInformation SelectDevice(const InferenceEngine::CNNNetwork& network,
const std::vector<DeviceInformation>& metaDevices);

protected:
DeviceInformation SelectDevice(const std::vector<DeviceInformation>& metaDevices, const std::string& networkPrecision = METRIC_VALUE(FP32));
ConfigType GetSupportedConfig(const ConfigType& config, const AutoPlugin::DeviceName & deviceName) const;
static ConfigType mergeConfigs(ConfigType config, const ConfigType& local);

template <typename T>
IE::ExecutableNetworkInternal::Ptr LoadNetworkImpl(const T &param, const ConfigType &config, const std::string &networkPrecision = METRIC_VALUE(FP32)) {
if (GetCore() == nullptr) {
IE_THROW() << "Please, work with AUTO device via InferencEngine::Core object";
}

auto fullConfig = mergeConfigs(_config, config);
auto metaDevices = GetDeviceChoice(fullConfig);
DeviceInformation selectedDevice;
IE::ExecutableNetwork executableNetwork;
while (!metaDevices.empty()) {
selectedDevice = SelectDevice(metaDevices, networkPrecision);
try {
executableNetwork = GetCore()->LoadNetwork(param, selectedDevice.deviceName, selectedDevice.config);
break;
} catch (...) {
auto eraseDevice = std::find_if(metaDevices.begin(), metaDevices.end(),
[=](const DeviceInformation& d)->bool{return d.deviceName == selectedDevice.deviceName;});
if (eraseDevice == metaDevices.end()) {
IE_THROW() << "Didn't find the selected device name";
}
metaDevices.erase(eraseDevice);
executableNetwork = {};
}
}
if (!executableNetwork) {
IE_THROW() << "Failed to load network by AUTO plugin";
}
return std::make_shared<AutoExecutableNetwork>(executableNetwork);
}
};

} // namespace AutoPlugin

0 comments on commit 04874aa

Please sign in to comment.