Skip to content

Commit

Permalink
[Core] pass loaded_from_cache to import_model
Browse files Browse the repository at this point in the history
  • Loading branch information
riverlijunjie committed Jan 15, 2024
1 parent f784931 commit 477a1d3
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 10 deletions.
6 changes: 4 additions & 2 deletions src/inference/src/dev/core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1498,8 +1498,10 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::load_model_from_cache(
throw HeaderException();
}

compiled_model = context ? plugin.import_model(networkStream, context, config)
: plugin.import_model(networkStream, config);
ov::AnyMap update_config = config;
update_config[ov::loaded_from_cache.name()] = true;
compiled_model = context ? plugin.import_model(networkStream, context, update_config)
: plugin.import_model(networkStream, update_config);
if (auto wrapper = std::dynamic_pointer_cast<InferenceEngine::ICompiledModelWrapper>(compiled_model._ptr)) {
wrapper->get_executable_network()->loadedFromCache();
}
Expand Down
14 changes: 12 additions & 2 deletions src/plugins/intel_cpu/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,12 +940,22 @@ std::shared_ptr<ov::ICompiledModel> Engine::import_model(std::istream& networkMo

Config conf = engConfig;
Config::ModelType modelType = getModelType(model);
conf.readProperties(config, modelType);

// check ov::loaded_from_cache property and erase it to avoid exception in readProperties.
auto _config = config;
const auto& it = _config.find(ov::loaded_from_cache.name());
bool loaded_from_cache = false;
if (it != _config.end()) {
loaded_from_cache = it->second.as<bool>();
_config.erase(it);
}
conf.readProperties(_config, modelType);

// import config props from caching model
calculate_streams(conf, model, true);

auto compiled_model = std::make_shared<CompiledModel>(model, shared_from_this(), conf, extensionManager, true);
auto compiled_model =
std::make_shared<CompiledModel>(model, shared_from_this(), conf, extensionManager, loaded_from_cache);
return compiled_model;
}
} // namespace intel_cpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,8 @@ namespace {
CompileModelCacheRuntimePropertiesTestBase,
::testing::Combine(::testing::ValuesIn(TestCpuTargets), ::testing::ValuesIn(CpuConfigs)),
CompileModelCacheRuntimePropertiesTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_CPU,
CompileModelLoadFromCacheTest,
::testing::Combine(::testing::ValuesIn(TestCpuTargets), ::testing::ValuesIn(CpuConfigs)),
CompileModelLoadFromCacheTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ class CompiledModel : public ov::ICompiledModel {
CompiledModel(std::shared_ptr<ov::Model> model,
const std::shared_ptr<const ov::IPlugin>& plugin,
RemoteContextImpl::Ptr context,
const ExecutionConfig& config);
const ExecutionConfig& config,
const bool loaded_from_cache = false);
CompiledModel(cldnn::BinaryInputBuffer& ib,
const std::shared_ptr<const ov::IPlugin>& plugin,
RemoteContextImpl::Ptr context,
const ExecutionConfig& config);
const ExecutionConfig& config,
const bool loaded_from_cache = false);

std::shared_ptr<ov::IAsyncInferRequest> create_infer_request() const override;
std::shared_ptr<ov::ISyncInferRequest> create_sync_infer_request() const override;
Expand Down
10 changes: 6 additions & 4 deletions src/plugins/intel_gpu/src/plugin/compiled_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ std::shared_ptr<ov::threading::ITaskExecutor> create_task_executor(const std::sh
CompiledModel::CompiledModel(std::shared_ptr<ov::Model> model,
const std::shared_ptr<const ov::IPlugin>& plugin,
RemoteContextImpl::Ptr context,
const ExecutionConfig& config)
const ExecutionConfig& config,
const bool loaded_from_cache)
: ov::ICompiledModel(model,
plugin,
context,
Expand All @@ -69,7 +70,7 @@ CompiledModel::CompiledModel(std::shared_ptr<ov::Model> model,
, m_model_name(model->get_friendly_name())
, m_inputs(ov::ICompiledModel::inputs())
, m_outputs(ov::ICompiledModel::outputs())
, m_loaded_from_cache(false) {
, m_loaded_from_cache(loaded_from_cache) {
auto graph_base = std::make_shared<Graph>(model, m_context, m_config, 0);
for (uint16_t n = 0; n < m_config.get_property(ov::num_streams); n++) {
auto graph = n == 0 ? graph_base : std::make_shared<Graph>(graph_base, n);
Expand All @@ -80,7 +81,8 @@ CompiledModel::CompiledModel(std::shared_ptr<ov::Model> model,
CompiledModel::CompiledModel(cldnn::BinaryInputBuffer& ib,
const std::shared_ptr<const ov::IPlugin>& plugin,
RemoteContextImpl::Ptr context,
const ExecutionConfig& config)
const ExecutionConfig& config,
const bool loaded_from_cache)
: ov::ICompiledModel(nullptr,
plugin,
context,
Expand All @@ -90,7 +92,7 @@ CompiledModel::CompiledModel(cldnn::BinaryInputBuffer& ib,
, m_config(config)
, m_wait_executor(std::make_shared<ov::threading::CPUStreamsExecutor>(ov::threading::IStreamsExecutor::Config{"Intel GPU plugin wait executor"}))
, m_model_name("")
, m_loaded_from_cache(true) {
, m_loaded_from_cache(loaded_from_cache) {
{
size_t num_params;
ib >> num_params;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,9 @@ namespace {
::testing::Combine(::testing::Values(ov::test::utils::DEVICE_GPU),
::testing::ValuesIn(GPULoadFromFileConfigs)),
CompileModelLoadFromMemoryTestBase::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_CachingSupportCase_GPU,
CompileModelLoadFromCacheTest,
::testing::Combine(::testing::Values(ov::test::utils::DEVICE_GPU),
::testing::ValuesIn(GPULoadFromFileConfigs)),
CompileModelLoadFromCacheTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ class CompileModelCacheRuntimePropertiesTestBase
void run() override;
};

using CompileModelLoadFromCacheParams = std::tuple<std::string, // device name
ov::AnyMap // device configuration
>;
class CompileModelLoadFromCacheTest : public testing::WithParamInterface<CompileModelLoadFromCacheParams>,
virtual public SubgraphBaseTest,
virtual public OVPluginTestBase {
std::string m_cacheFolderName;
std::string m_modelName;
std::string m_weightsName;

public:
static std::string getTestCaseName(testing::TestParamInfo<CompileModelLoadFromCacheParams> obj);

void SetUp() override;
void TearDown() override;
void run() override;
};

using compileModelLoadFromMemoryParams = std::tuple<std::string, // device name
ov::AnyMap // device configuration
>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,68 @@ TEST_P(CompileModelCacheRuntimePropertiesTestBase, CanLoadFromFileWithoutExcepti
run();
}

std::string CompileModelLoadFromCacheTest::getTestCaseName(
testing::TestParamInfo<CompileModelLoadFromCacheParams> obj) {
auto param = obj.param;
auto deviceName = std::get<0>(param);
auto configuration = std::get<1>(param);
std::ostringstream result;
std::replace(deviceName.begin(), deviceName.end(), ':', '.');
result << "device_name=" << deviceName << "_";
for (auto& iter : configuration) {
result << "_" << iter.first << "_" << iter.second.as<std::string>() << "_";
}
return result.str();
}

void CompileModelLoadFromCacheTest::SetUp() {
ovModelWithName funcPair;
std::tie(targetDevice, configuration) = GetParam();
target_device = targetDevice;
APIBaseTest::SetUp();
std::stringstream ss;
std::string filePrefix = ov::test::utils::generateTestFilePrefix();
ss << "testCache_" << filePrefix;
m_modelName = ss.str() + ".xml";
m_weightsName = ss.str() + ".bin";
for (auto& iter : configuration) {
ss << "_" << iter.first << "_" << iter.second.as<std::string>() << "_";
}
m_cacheFolderName = ss.str();
core->set_property(ov::cache_dir());
ov::pass::Manager manager;
manager.register_pass<ov::pass::Serialize>(m_modelName, m_weightsName);
manager.run_passes(ov::test::utils::make_conv_pool_relu({1, 3, 227, 227}, ov::element::f32));
}

void CompileModelLoadFromCacheTest::TearDown() {
ov::test::utils::removeFilesWithExt(m_cacheFolderName, "blob");
ov::test::utils::removeFilesWithExt(m_cacheFolderName, "cl_cache");
ov::test::utils::removeIRFiles(m_modelName, m_weightsName);
std::remove(m_cacheFolderName.c_str());
core->set_property(ov::cache_dir());
APIBaseTest::TearDown();
}

void CompileModelLoadFromCacheTest::run() {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
core->set_property(ov::cache_dir(m_cacheFolderName));
compiledModel = core->compile_model(m_modelName, targetDevice, configuration);
EXPECT_EQ(false, compiledModel.get_property(ov::loaded_from_cache.name()).as<bool>());

std::stringstream strm;
compiledModel.export_model(strm);
ov::CompiledModel importedCompiledModel = core->import_model(strm, target_device, configuration);
EXPECT_EQ(false, importedCompiledModel.get_property(ov::loaded_from_cache.name()).as<bool>());

compiledModel = core->compile_model(m_modelName, targetDevice, configuration);
EXPECT_EQ(true, compiledModel.get_property(ov::loaded_from_cache.name()).as<bool>());
}

TEST_P(CompileModelLoadFromCacheTest, CanGetCorrectLoadedFromCacheProperty) {
run();
}

std::string CompileModelLoadFromMemoryTestBase::getTestCaseName(
testing::TestParamInfo<compileModelLoadFromMemoryParams> obj) {
auto param = obj.param;
Expand Down

0 comments on commit 477a1d3

Please sign in to comment.