Skip to content

Commit

Permalink
Caching: pass global CACHE_DIR setting to plugin (openvinotoolkit#5893)
Browse files Browse the repository at this point in the history
* Caching: pass global CACHE_DIR setting to plugin

This can be helpful for GPU - it doesn't support Import/Export but can
significantly speed up load time when CACHE_DIR is set for device only

* Ignore exception in 'DeviceSupportsConfigKey' if plugin doesn't support GetMetric at all
  • Loading branch information
nosovmik authored and yekruglov committed Jun 7, 2021
1 parent 2b702e4 commit 65afdb4
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 3 deletions.
38 changes: 37 additions & 1 deletion inference-engine/src/inference_engine/ie_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,15 @@ class Core::Impl : public ICore {
class CoreConfig final {
public:
struct CacheConfig {
std::string _cacheDir;
std::shared_ptr<ICacheManager> _cacheManager;
};

void setAndUpdate(std::map<std::string, std::string>& config) {
auto it = config.find(CONFIG_KEY(CACHE_DIR));
if (it != config.end()) {
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
_cacheConfig._cacheDir = it->second;
if (!it->second.empty()) {
FileUtils::createDirectoryRecursive(it->second);
_cacheConfig._cacheManager = std::make_shared<FileStorageCacheManager>(std::move(it->second));
Expand Down Expand Up @@ -241,6 +243,27 @@ class Core::Impl : public ICore {
return supported;
}

bool DeviceSupportsCacheDir(const InferencePlugin& plugin) const {
return DeviceSupportsConfigKey(plugin, CONFIG_KEY(CACHE_DIR));
}

bool DeviceSupportsConfigKey(const InferencePlugin& plugin, const std::string& key) const {
bool supported = false;
std::vector<std::string> supportedMetricKeys;
try {
// If plugin doesn't support 'SUPPORTED_METRICS' - treat it as config is not supported as well
supportedMetricKeys =
plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), {}).as<std::vector<std::string>>();
} catch(...) {}
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
METRIC_KEY(SUPPORTED_CONFIG_KEYS));
if (it != supportedMetricKeys.end()) {
std::vector<std::string> configKeys = plugin.GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), {});
supported = std::find(configKeys.begin(), configKeys.end(), key) != configKeys.end();
}
return supported;
}

SoExecutableNetworkInternal LoadNetworkImpl(const CNNNetwork& network,
InferencePlugin& plugin,
const std::map<std::string, std::string>& parsedConfig,
Expand Down Expand Up @@ -700,6 +723,12 @@ class Core::Impl : public ICore {

// configuring
{
if (DeviceSupportsCacheDir(plugin)) {
auto cacheConfig = coreConfig.getCacheConfig();
if (cacheConfig._cacheManager) {
desc.defaultConfig[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir;
}
}
allowNotImplemented([&]() {
plugin.SetConfig(desc.defaultConfig);
});
Expand Down Expand Up @@ -816,7 +845,14 @@ class Core::Impl : public ICore {
for (auto& plugin : plugins) {
if (deviceName.empty() || deviceName == plugin.first) {
allowNotImplemented([&]() {
plugin.second.SetConfig(config);
auto configCopy = config;
if (DeviceSupportsCacheDir(plugin.second)) {
auto cacheConfig = coreConfig.getCacheConfig();
if (cacheConfig._cacheManager) {
configCopy[CONFIG_KEY(CACHE_DIR)] = cacheConfig._cacheDir;
}
}
plugin.second.SetConfig(configCopy);
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class MockCachingInferencePlugin : public MockCachingInferencePluginBase {
const std::map<std::string, std::string>& config));

MOCK_CONST_METHOD2(GetMetric, Parameter(const std::string& name, const std::map<std::string, Parameter>& options));
MOCK_METHOD1(SetConfig, void(const std::map<std::string, std::string>& options));
MOCK_METHOD1(GetDefaultContext, RemoteContext::Ptr(const ParamMap& params));
};

Expand Down Expand Up @@ -362,6 +363,11 @@ class CachingTest : public ::testing::TestWithParam<std::tuple<TestParam, std::s
return res;
}));

EXPECT_CALL(plugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>) {
throw InferenceEngine::NotImplemented("Not implemented");
}));

EXPECT_CALL(*net, GetInputsInfo()).Times(AnyNumber())
.WillRepeatedly(Return(ConstInputsDataMap{}));
EXPECT_CALL(*net, GetOutputsInfo()).Times(AnyNumber())
Expand Down Expand Up @@ -567,6 +573,93 @@ TEST_P(CachingTest, TestNoCacheMetricSupported) {
}
}

TEST_P(CachingTest, TestNoCacheMetric_hasCacheDirConfig) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(1)).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));

{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
ASSERT_NO_THROW(
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
}));
}
}

TEST_P(CachingTest, TestCacheEnabled_noConfig) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{}));
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));

{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, OnLoadNetworkFromFile()).Times(m_type == TestLoadType::EModelName ? 1 : 0);
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
});
}
}


TEST_P(CachingTest, TestNoCacheMetric_configThrow) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AtLeast(1)).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_GT(config.count(CONFIG_KEY(CACHE_DIR)), 0);
throw InferenceEngine::GeneralError("Error occurred");
}));

ASSERT_ANY_THROW(
testLoad([&](Core &ie) {
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
m_testFunction(ie);
}));
}

TEST_P(CachingTest, TestNoCacheEnabled_cacheDirConfig) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
.Times(AnyNumber()).WillRepeatedly(
Return(std::vector<std::string>{METRIC_KEY(SUPPORTED_CONFIG_KEYS)}));
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _))
.Times(AtLeast(1)).WillRepeatedly(Return(std::vector<std::string>{CONFIG_KEY(CACHE_DIR)}));
EXPECT_CALL(*mockPlugin, SetConfig(_)).Times(AnyNumber()).WillRepeatedly(
Invoke([](const std::map<std::string, std::string>& config) {
ASSERT_EQ(config.count(CONFIG_KEY(CACHE_DIR)), 0);
}));

{
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
testLoad([&](Core &ie) {
m_testFunction(ie);
});
}
}

TEST_P(CachingTest, TestLoadChangeCacheDir) {
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ MockPlugin::MockPlugin(InferenceEngine::IInferencePlugin *target) {
_target = target;
}

void MockPlugin::SetConfig(const std::map<std::string, std::string>& config) {
this->config = config;
void MockPlugin::SetConfig(const std::map<std::string, std::string>& _config) {
this->config = _config;
if (_target) {
_target->SetConfig(config);
}
}

Parameter MockPlugin::GetMetric(const std::string& name, const std::map<std::string, InferenceEngine::Parameter>& options) const {
Expand Down

0 comments on commit 65afdb4

Please sign in to comment.