diff --git a/inference-engine/src/inference_engine/ie_core.cpp b/inference-engine/src/inference_engine/ie_core.cpp index cf769aa6978be8..6e23d7683a1e61 100644 --- a/inference-engine/src/inference_engine/ie_core.cpp +++ b/inference-engine/src/inference_engine/ie_core.cpp @@ -180,6 +180,7 @@ class Core::Impl : public ICore { class CoreConfig final { public: struct CacheConfig { + std::string _cacheDir; std::shared_ptr _cacheManager; }; @@ -187,6 +188,7 @@ class Core::Impl : public ICore { auto it = config.find(CONFIG_KEY(CACHE_DIR)); if (it != config.end()) { std::lock_guard lock(_cacheConfigMutex); + _cacheConfig._cacheDir = it->second; if (!it->second.empty()) { FileUtils::createDirectoryRecursive(it->second); _cacheConfig._cacheManager = std::make_shared(std::move(it->second)); @@ -241,6 +243,27 @@ class Core::Impl : public ICore { return supported; } + bool DeviceSupportsCacheDir(const InferencePlugin& plugin) const { + return DeviceSupportsConfigKey(plugin, CONFIG_KEY(CACHE_DIR)); + } + + bool DeviceSupportsConfigKey(const InferencePlugin& plugin, const std::string& key) const { + bool supported = false; + std::vector supportedMetricKeys; + try { + // If plugin doesn't support 'SUPPORTED_METRICS' - treat it as config is not supported as well + supportedMetricKeys = + plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), {}).as>(); + } catch(...) {} + auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(), + METRIC_KEY(SUPPORTED_CONFIG_KEYS)); + if (it != supportedMetricKeys.end()) { + std::vector configKeys = plugin.GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), {}); + supported = std::find(configKeys.begin(), configKeys.end(), key) != configKeys.end(); + } + return supported; + } + SoExecutableNetworkInternal LoadNetworkImpl(const CNNNetwork& network, InferencePlugin& plugin, const std::map& parsedConfig, @@ -700,6 +723,12 @@ class Core::Impl : public ICore { // configuring { + if (DeviceSupportsCacheDir(plugin)) { + auto cacheConfig = coreConfig.getCacheConfig(); + if (cacheConfig._cacheManager) { + desc.defaultConfig[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir; + } + } allowNotImplemented([&]() { plugin.SetConfig(desc.defaultConfig); }); @@ -816,7 +845,14 @@ class Core::Impl : public ICore { for (auto& plugin : plugins) { if (deviceName.empty() || deviceName == plugin.first) { allowNotImplemented([&]() { - plugin.second.SetConfig(config); + auto configCopy = config; + if (DeviceSupportsCacheDir(plugin.second)) { + auto cacheConfig = coreConfig.getCacheConfig(); + if (cacheConfig._cacheManager) { + configCopy[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir; + } + } + plugin.second.SetConfig(configCopy); }); } } diff --git a/inference-engine/tests/functional/inference_engine/caching_test.cpp b/inference-engine/tests/functional/inference_engine/caching_test.cpp index 8aefaec7d224f6..dd19dd3815d70c 100644 --- a/inference-engine/tests/functional/inference_engine/caching_test.cpp +++ b/inference-engine/tests/functional/inference_engine/caching_test.cpp @@ -111,6 +111,7 @@ class MockCachingInferencePlugin : public MockCachingInferencePluginBase { const std::map& config)); MOCK_CONST_METHOD2(GetMetric, Parameter(const std::string& name, const std::map& options)); + MOCK_METHOD1(SetConfig, void(const std::map& options)); MOCK_METHOD1(GetDefaultContext, RemoteContext::Ptr(const ParamMap& params)); }; @@ -362,6 +363,11 @@ class CachingTest : public ::testing::TestWithParam) { + throw InferenceEngine::NotImplemented("Not implemented"); + })); + EXPECT_CALL(*net, GetInputsInfo()).Times(AnyNumber()) .WillRepeatedly(Return(ConstInputsDataMap{})); EXPECT_CALL(*net, GetOutputsInfo()).Times(AnyNumber()) @@ -567,6 +573,93 @@ TEST_P(CachingTest, TestNoCacheMetricSupported) { } } +TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)) + .Times(AnyNumber()).WillRepeatedly( + Return(std::vector{METRIC_KEY(SUPPORTED_CONFIG_KEYS)})); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)) + .Times(AtLeast(1)).WillRepeatedly(Return(std::vector{CONFIG_KEY(CACHE_DIR)})); + EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(1)).WillRepeatedly( + Invoke([](const std::map& config) { + ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0); + })); + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0); + ASSERT_NO_THROW( + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + })); + } +} + +TEST_P(CachingTest, TestCacheEnabled_noConfig) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)) + .Times(AnyNumber()).WillRepeatedly( + Return(std::vector{METRIC_KEY(SUPPORTED_CONFIG_KEYS)})); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)) + .Times(AtLeast(1)).WillRepeatedly(Return(std::vector{})); + EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly( + Invoke([](const std::map& config) { + ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0); + })); + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } +} + + +TEST_P(CachingTest, TestNoCacheMetric_configThrow) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)) + .Times(AnyNumber()).WillRepeatedly( + Return(std::vector{METRIC_KEY(SUPPORTED_CONFIG_KEYS)})); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)) + .Times(AtLeast(1)).WillRepeatedly(Return(std::vector{CONFIG_KEY(CACHE_DIR)})); + EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(1)).WillRepeatedly( + Invoke([](const std::map& config) { + ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0); + throw InferenceEngine::GeneralError("Error occurred"); + })); + + ASSERT_ANY_THROW( + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + })); +} + +TEST_P(CachingTest, TestNoCacheEnabled_cacheDirConfig) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)) + .Times(AnyNumber()).WillRepeatedly( + Return(std::vector{METRIC_KEY(SUPPORTED_CONFIG_KEYS)})); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)) + .Times(AtLeast(1)).WillRepeatedly(Return(std::vector{CONFIG_KEY(CACHE_DIR)})); + EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly( + Invoke([](const std::map& config) { + ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0); + })); + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + testLoad([&](Core &ie) { + m_testFunction(ie); + }); + } +} + TEST_P(CachingTest, TestLoadChangeCacheDir) { EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.cpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.cpp index a2b845b0c4acc3..cd2e7b95f469d8 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.cpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.cpp @@ -17,8 +17,11 @@ MockPlugin::MockPlugin(InferenceEngine::IInferencePlugin *target) { _target = target; } -void MockPlugin::SetConfig(const std::map& config) { - this->config = config; +void MockPlugin::SetConfig(const std::map& _config) { + this->config = _config; + if (_target) { + _target->SetConfig(config); + } } Parameter MockPlugin::GetMetric(const std::string& name, const std::map& options) const {