Skip to content

Commit

Permalink
Removed QueryNetworkResult from new API (#7507)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov authored Sep 15, 2021
1 parent bdaa44d commit 7654789
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,13 @@ namespace runtime {
* @brief This type of map is commonly used to pass set of parameters
*/
using ConfigMap = std::map<std::string, std::string>;

/**
* @brief This type of map is used for result of Core::query_model
* - `key` means operation name
* - `value` means device name supporting this operation
*/
using SupportedOpsMap = std::map<std::string, std::string>;

} // namespace runtime
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ class INFERENCE_ENGINE_API_CLASS(Core) {
* @param deviceName A name of a device to query
* @param network Network object to query
* @param config Optional map of pairs: (config parameter name, config parameter value)
* @return An object containing a map of pairs a layer name -> a device name supporting this layer.
* @return An object containing a map of pairs a operation name -> a device name supporting this operation.
*/
ie::QueryNetworkResult query_model(const std::shared_ptr<const ov::Function>& network,
const std::string& deviceName,
const ConfigMap& config = {}) const;
SupportedOpsMap query_model(const std::shared_ptr<const ov::Function>& network,
const std::string& deviceName,
const ConfigMap& config = {}) const;

/**
* @brief Sets configuration for device, acceptable keys can be found in ie_plugin_config.hpp
Expand Down
11 changes: 7 additions & 4 deletions inference-engine/src/inference_engine/src/ie_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1306,11 +1306,14 @@ ExecutableNetwork Core::import_model(std::istream& networkModel,
return {exec._so, exec._ptr};
}

ie::QueryNetworkResult Core::query_model(const std::shared_ptr<const ngraph::Function>& network,
const std::string& deviceName,
const ConfigMap& config) const {
return _impl->QueryNetwork(ie::CNNNetwork(std::const_pointer_cast<ngraph::Function>(network)), deviceName, config);
SupportedOpsMap Core::query_model(const std::shared_ptr<const ngraph::Function>& network,
const std::string& deviceName,
const ConfigMap& config) const {
auto cnnNet = ie::CNNNetwork(std::const_pointer_cast<ngraph::Function>(network));
auto qnResult = _impl->QueryNetwork(cnnNet, deviceName, config);
return qnResult.supportedLayersMap;
}

void Core::set_config(const ConfigMap& config, const std::string& deviceName) {
// HETERO case
if (deviceName.find("HETERO:") == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,7 @@ TEST_P(OVClassNetworkTestP, QueryNetworkWithKSO) {
ov::runtime::Core ie = createCoreWithTemplate();

try {
auto rres = ie.query_model(ksoNetwork, deviceName);
auto rl_map = rres.supportedLayersMap;
auto rl_map = ie.query_model(ksoNetwork, deviceName);
auto func = ksoNetwork;
for (const auto& op : func->get_ops()) {
if (!rl_map.count(op->get_friendly_name())) {
Expand Down Expand Up @@ -556,8 +555,7 @@ TEST_P(OVClassNetworkTestP, SetAffinityWithConstantBranches) {
func = std::make_shared<ngraph::Function>(results, params);
}

auto rres = ie.query_model(func, deviceName);
auto rl_map = rres.supportedLayersMap;
auto rl_map = ie.query_model(func, deviceName);
for (const auto& op : func->get_ops()) {
if (!rl_map.count(op->get_friendly_name())) {
FAIL() << "Op " << op->get_friendly_name() << " is not supported by " << deviceName;
Expand All @@ -579,8 +577,7 @@ TEST_P(OVClassNetworkTestP, SetAffinityWithKSO) {
ov::runtime::Core ie = createCoreWithTemplate();

try {
auto rres = ie.query_model(ksoNetwork, deviceName);
auto rl_map = rres.supportedLayersMap;
auto rl_map = ie.query_model(ksoNetwork, deviceName);
auto func = ksoNetwork;
for (const auto& op : func->get_ops()) {
if (!rl_map.count(op->get_friendly_name())) {
Expand All @@ -601,10 +598,10 @@ TEST_P(OVClassNetworkTestP, SetAffinityWithKSO) {
TEST_P(OVClassNetworkTestP, QueryNetworkHeteroActualNoThrow) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
ov::runtime::Core ie = createCoreWithTemplate();
QueryNetworkResult res;
ov::runtime::SupportedOpsMap res;
ASSERT_NO_THROW(
res = ie.query_model(actualNetwork, CommonTestUtils::DEVICE_HETERO, {{"TARGET_FALLBACK", deviceName}}));
ASSERT_LT(0, res.supportedLayersMap.size());
ASSERT_LT(0, res.size());
}

TEST_P(OVClassNetworkTestP, QueryNetworkMultiThrows) {
Expand Down Expand Up @@ -1408,15 +1405,15 @@ TEST_P(OVClassLoadNetworkTest, QueryNetworkHETEROWithMULTINoThrow_V10) {
for (auto&& node : function->get_ops()) {
expectedLayers.emplace(node->get_friendly_name());
}
QueryNetworkResult result;
ov::runtime::SupportedOpsMap result;
std::string targetFallback(CommonTestUtils::DEVICE_MULTI + std::string(",") + deviceName);
ASSERT_NO_THROW(result = ie.query_model(
multinputNetwork,
CommonTestUtils::DEVICE_HETERO,
{{MULTI_CONFIG_KEY(DEVICE_PRIORITIES), devices}, {"TARGET_FALLBACK", targetFallback}}));

std::unordered_set<std::string> actualLayers;
for (auto&& layer : result.supportedLayersMap) {
for (auto&& layer : result) {
actualLayers.emplace(layer.first);
}
ASSERT_EQ(expectedLayers, actualLayers);
Expand Down Expand Up @@ -1444,14 +1441,14 @@ TEST_P(OVClassLoadNetworkTest, QueryNetworkMULTIWithHETERONoThrow_V10) {
for (auto&& node : function->get_ops()) {
expectedLayers.emplace(node->get_friendly_name());
}
QueryNetworkResult result;
ov::runtime::SupportedOpsMap result;
ASSERT_NO_THROW(result = ie.query_model(multinputNetwork,
CommonTestUtils::DEVICE_MULTI,
{{MULTI_CONFIG_KEY(DEVICE_PRIORITIES), devices},
{"TARGET_FALLBACK", deviceName + "," + deviceName}}));

std::unordered_set<std::string> actualLayers;
for (auto&& layer : result.supportedLayersMap) {
for (auto&& layer : result) {
actualLayers.emplace(layer.first);
}
ASSERT_EQ(expectedLayers, actualLayers);
Expand Down

0 comments on commit 7654789

Please sign in to comment.