diff --git a/src/common/snippets/include/snippets/generator.hpp b/src/common/snippets/include/snippets/generator.hpp index a3d7143340f44c..b05da86fc3515d 100644 --- a/src/common/snippets/include/snippets/generator.hpp +++ b/src/common/snippets/include/snippets/generator.hpp @@ -11,6 +11,7 @@ #include "snippets_isa.hpp" #include "snippets/lowered/linear_ir.hpp" +#include "snippets/kernel_executor_table.hpp" #include "snippets/shape_types.hpp" #include "target_machine.hpp" @@ -32,7 +33,8 @@ class LoweringResult { std::vector> m_saved_emitters{}; public: - std::shared_ptr compiled_snippet = nullptr; + CompiledSnippetPtr compiled_snippet = nullptr; + KernelExecutorTablePtr kernel_executor_table = nullptr; }; /** diff --git a/src/common/snippets/include/snippets/kernel_executor_table.hpp b/src/common/snippets/include/snippets/kernel_executor_table.hpp index af797e4c80422a..2d4b1185ffc5d7 100644 --- a/src/common/snippets/include/snippets/kernel_executor_table.hpp +++ b/src/common/snippets/include/snippets/kernel_executor_table.hpp @@ -43,7 +43,7 @@ class KernelExecutorBase { * @brief Update current kernel config in accordance with the passed expression. Corresponding kernel is recompiled if necessary. * This method should be called to update KernelExecutor based on runtime info (e.g. shapes) available through expression ptr */ - virtual void update_by_expression(const lowered::ExpressionPtr& expr, const lowered::LinearIRPtr& linear_ir) = 0; + virtual void update_by_expression(const lowered::ExpressionPtr& expr, const lowered::LinearIRCPtr& linear_ir) = 0; /** * @brief Replace current kernel config with the provided value. Corresponding kernel is recompiled if necessary. * This method should be called to restore a saved state of the executor, that was configured using update_by_expression(). @@ -70,7 +70,7 @@ class KernelExecutor : public KernelExecutorBase { explicit KernelExecutor(Conf c) : KernelExecutorBase(), m_config{std::move(c)} {} // Note: override when final is redundant, but needed to avoid warnings on some compilers - void update_by_expression(const lowered::ExpressionPtr& expr, const lowered::LinearIRPtr& linear_ir) override final { // NOLINT + void update_by_expression(const lowered::ExpressionPtr& expr, const lowered::LinearIRCPtr& linear_ir) override final { // NOLINT update_config(expr, linear_ir, m_config); OPENVINO_ASSERT(m_config.is_completed(), "Failed to update kernel config in update_by_expression"); update_kernel(m_config, m_kernel); @@ -103,7 +103,7 @@ class KernelExecutor : public KernelExecutorBase { protected: /*** Updates stored kernel config based on runtime info from expression (e.g. new input shapes). */ - virtual void update_config(const lowered::ExpressionPtr& expr, const lowered::LinearIRPtr& linear_ir, Conf& config) const = 0; + virtual void update_config(const lowered::ExpressionPtr& expr, const lowered::LinearIRCPtr& linear_ir, Conf& config) const = 0; /*** Updates stored kernel in accordance with the passed config. Recompilation of the kernel is * performed if necessary. */ virtual void update_kernel(const Conf& c, std::shared_ptr& kernel) const = 0; @@ -122,17 +122,26 @@ class KernelExecutorTable { typename std::enable_if::value, bool>::type = true> std::shared_ptr register_kernel(const lowered::ExpressionPtr& expr, C... args) { const auto& instance = std::make_shared(args...); - OPENVINO_ASSERT(m_table.insert({expr, instance}).second, "This expression already has an alterable kernel"); + OPENVINO_ASSERT(m_table.insert({expr->get_exec_num(), instance}).second, "This expression execution number already has an alterable kernel"); return instance; } - const std::shared_ptr& get_kernel_executor(const lowered::ExpressionPtr& expr) const { - OPENVINO_ASSERT(m_table.count(expr), "This expression doesn't have a registered kernel executor"); - return m_table.at(expr); + + const std::shared_ptr& get_kernel_executor(const lowered::ExpressionPtr& expr) const { + return get_kernel_executor(expr->get_exec_num()); + } + const std::shared_ptr& get_kernel_executor(double expr_exec_num) const { + OPENVINO_ASSERT(m_table.count(expr_exec_num), "This expression execution number doesn't have a registered kernel executor"); + return m_table.at(expr_exec_num); } + /*** Updates every registered KernelExecutor in accordance with the corresponding expression */ - void update_state(const lowered::LinearIRPtr& linear_ir) const { - for (const auto& record : m_table) - record.second->update_by_expression(record.first, linear_ir); + void update_state(const lowered::LinearIRCPtr& linear_ir) const { + for (const auto& expr : *linear_ir) { + const auto& found = m_table.find(expr->get_exec_num()); + if (found != m_table.end()) { + found->second->update_by_expression(expr, linear_ir); + } + } } /*** Returns lambda function that contains current state of the table, and restores this state when called */ @@ -141,19 +150,12 @@ class KernelExecutorTable { return [=]() { reset_state(current_state); }; } - /** - * @brief Replace originally registered ExpressionPtr with a new value. - * Note that code emission is performed on a copy of LIR, so all expression pointers visible from emitters won't - * be accessible from RuntimeConfigurator. In order to replace these cloned ExpressionPtrs with the original ones, - * we need to call this method. - */ - void replace_key_expression(const lowered::ExpressionPtr& from, const lowered::ExpressionPtr& to); - virtual ~KernelExecutorTable() = default; protected: - std::unordered_map> m_table{}; - typedef std::vector>> ExecTableState; + std::unordered_map> m_table {}; + + typedef std::vector>> ExecTableState; /*** Restore the table state previously obtained by get_state() */ void reset_state(const ExecTableState& state); diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index f2e45f8af68e17..55afd2c9ccd7ab 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -284,6 +284,7 @@ class LinearIR { size_t m_static_buffer_scratchpad_size = 0; }; using LinearIRPtr = std::shared_ptr; +using LinearIRCPtr = std::shared_ptr; template iterator LinearIR::find(iterator begin, iterator end, const ExpressionPtr& target) const { diff --git a/src/common/snippets/include/snippets/op/subgraph.hpp b/src/common/snippets/include/snippets/op/subgraph.hpp index 7837625f6e3e3c..84b66ce4d5306c 100644 --- a/src/common/snippets/include/snippets/op/subgraph.hpp +++ b/src/common/snippets/include/snippets/op/subgraph.hpp @@ -116,6 +116,7 @@ class Subgraph : public ov::op::util::SubGraphOp { std::shared_ptr clone() const; + const std::shared_ptr& get_runtime_configurator() const; const std::shared_ptr& update_runtime_config() const; static auto wrap_node_as_subgraph(const std::shared_ptr& node) -> std::shared_ptr; diff --git a/src/common/snippets/include/snippets/runtime_configurator.hpp b/src/common/snippets/include/snippets/runtime_configurator.hpp index 058eca59716d1b..a0c7d8336c5cd1 100644 --- a/src/common/snippets/include/snippets/runtime_configurator.hpp +++ b/src/common/snippets/include/snippets/runtime_configurator.hpp @@ -61,28 +61,36 @@ class RuntimeConfigurator { * @param linear_ir LinearIR * @return updated config */ - const std::shared_ptr& get_updated_config(const lowered::LinearIRPtr& linear_ir); - /*** Returns pointer to KernelExecutorTable owned by the config */ + const std::shared_ptr& get_updated_config(const lowered::LinearIRCPtr& linear_ir); + /** + * @brief Returns pointer to KernelExecutorTable owned by the config + * @return updated KernelExecutorTable + */ const std::shared_ptr& get_kernel_executor_table() const { return m_config->kernel_executor_table; } + /** + * @brief Set new KernelExecutorTable to the config + * @param table new KernelExecutorTable + */ + void set_kernel_executor_table(std::shared_ptr table) const; protected: /** * @brief Update RuntimeConfig based on LinearIR * @param linear_ir LinearIR */ - virtual void update(const lowered::LinearIRPtr& linear_ir); + virtual void update(const lowered::LinearIRCPtr& linear_ir); /** * @brief Allocate and intialize fields in RuntimeConfig and RuntimeConfigurator * @param linear_ir LinearIR */ - virtual void initialization(const lowered::LinearIRPtr& linear_ir); + virtual void initialization(const lowered::LinearIRCPtr& linear_ir); /** * @brief Initializes input and data information of LinearIR: * descriptors (that contains shapes and layouts) and data_sizes * @param linear_ir LinearIR */ - void init_data_info(const lowered::LinearIRPtr& linear_ir); + void init_data_info(const lowered::LinearIRCPtr& linear_ir); /** * @brief Initializes information of buffers: * - static buffer_scratchpad_size @@ -90,23 +98,23 @@ class RuntimeConfigurator { * - clusters with dynamic buffers (`m_dynamic_buffer_clusters`) for the quick access in `update()` * @param linear_ir LinearIR */ - void init_buffer_info(const lowered::LinearIRPtr& linear_ir); + void init_buffer_info(const lowered::LinearIRCPtr& linear_ir); /** * @brief Initializes tensor rank of config * @param linear_ir LinearIR */ - virtual void init_tensor_rank(const lowered::LinearIRPtr& linear_ir) const; + virtual void init_tensor_rank(const lowered::LinearIRCPtr& linear_ir) const; /** * @brief Update Loop informations in LinearIR: Unified and ExpandedLoopInfo * @param linear_ir LinearIR */ - void update_loop_info(const lowered::LinearIRPtr& linear_ir) const; + void update_loop_info(const lowered::LinearIRCPtr& linear_ir) const; /** * @brief Update Buffer scratchpad size and offsets if needed * Note: `update_loop_info` must be called before * @param linear_ir LinearIR */ - void update_buffer_scratchpad_size(const lowered::LinearIRPtr& linear_ir) const; + void update_buffer_scratchpad_size(const lowered::LinearIRCPtr& linear_ir) const; /** * @brief Calculate data offsets of LinearIR and update these values in RuntimeConfig */ diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index 29d9e066b153af..c01685e6531eb6 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -5,6 +5,7 @@ #include "snippets/generator.hpp" #include "snippets/itt.hpp" +#include "snippets/runtime_configurator.hpp" #include "snippets/lowered/linear_ir.hpp" #include "snippets/lowered/expression.hpp" #include "snippets/op/kernel.hpp" @@ -46,6 +47,7 @@ LoweringResult Generator::generate(lowered::LinearIR& linear_ir, const void* com result.m_saved_emitters.emplace_back(emitter); } result.compiled_snippet = target->get_snippet(); + result.kernel_executor_table = target->get_runtime_configurator()->get_kernel_executor_table(); return result; } diff --git a/src/common/snippets/src/kernel_executor_table.cpp b/src/common/snippets/src/kernel_executor_table.cpp index 964ed736f13dd0..9b43c901f55edb 100644 --- a/src/common/snippets/src/kernel_executor_table.cpp +++ b/src/common/snippets/src/kernel_executor_table.cpp @@ -7,21 +7,13 @@ namespace ov { namespace snippets { -void KernelExecutorTable::replace_key_expression(const snippets::lowered::ExpressionPtr& from, const snippets::lowered::ExpressionPtr& to) { - const auto& found = m_table.find(from); - if (found != m_table.end()) { - OPENVINO_ASSERT(m_table.count(to) == 0, "Attempt to replace a value that is already in the KernelExecutorTable"); - m_table.insert({to, found->second}); - m_table.erase(found); - } -} - void KernelExecutorTable::reset_state(const ExecTableState& state) { OPENVINO_ASSERT(state.size() == m_table.size(), "Invalid state in restore_state: size mismatch"); auto state_it = state.begin(); for (const auto& table_record : m_table) { const auto& state_record = *state_it++; - OPENVINO_ASSERT(table_record.first == state_record.first, "Invalid state in restore_state: expressions mismatch"); + OPENVINO_ASSERT(table_record.first == state_record.first, + "Invalid state in restore_state: expression execution numbers mismatched"); table_record.second->update_by_config(*state_record.second); } } diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 4ede0b58a66cf0..55fd4acb2fa315 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -544,22 +544,21 @@ snippets::Schedule Subgraph::generate(const void* compile_params) const { } auto lowering_result = m_generator->generate(linear_ir, compile_params); - - // Note: Since the code emission is performed on a copy of LIR, but RuntimeConfigurator works with the initial instance, - // we need to replace cloned expression pointers to original ones in the KernelExecutorTable. Ticket: 129772 - const auto& exec_table = m_generator->get_target_machine()->get_runtime_configurator()->get_kernel_executor_table(); - for (const auto& expr : *m_linear_ir) - exec_table->replace_key_expression(expression_map.at(expr.get()), expr); // Some kernel executors might've been registered during code emission. // We need to update them, so appropriate kernels will be compiled. + const auto& exec_table = get_runtime_configurator()->get_kernel_executor_table(); exec_table->update_state(m_linear_ir); return {std::move(lowering_result)}; } -const std::shared_ptr& Subgraph::update_runtime_config() const { +const std::shared_ptr& Subgraph::get_runtime_configurator() const { OPENVINO_ASSERT(m_generator, "Generator has not been inited!"); + return m_generator->get_target_machine()->get_runtime_configurator(); +} + +const std::shared_ptr& Subgraph::update_runtime_config() const { OPENVINO_ASSERT(m_linear_ir, "LoweredLinearIR has not been inited!"); - return m_generator->get_target_machine()->get_runtime_configurator()->get_updated_config(m_linear_ir); + return get_runtime_configurator()->get_updated_config(m_linear_ir); } void Subgraph::print() const { diff --git a/src/common/snippets/src/runtime_configurator.cpp b/src/common/snippets/src/runtime_configurator.cpp index ec1db44f074766..062b3a2d86fbb2 100644 --- a/src/common/snippets/src/runtime_configurator.cpp +++ b/src/common/snippets/src/runtime_configurator.cpp @@ -35,7 +35,7 @@ RuntimeConfigurator::RuntimeConfigurator(std::shared_ptr c) : OPENVINO_ASSERT(m_config, "Runtime config is nullptr!"); } -const std::shared_ptr& RuntimeConfigurator::get_updated_config(const lowered::LinearIRPtr& linear_ir) { +const std::shared_ptr& RuntimeConfigurator::get_updated_config(const lowered::LinearIRCPtr& linear_ir) { // First initialization if (m_io_num == 0) initialization(linear_ir); @@ -44,7 +44,7 @@ const std::shared_ptr& RuntimeConfigurator::get_updated_config(co return m_config; } -void RuntimeConfigurator::initialization(const lowered::LinearIRPtr& linear_ir) { +void RuntimeConfigurator::initialization(const lowered::LinearIRCPtr& linear_ir) { init_data_info(linear_ir); init_tensor_rank(linear_ir); init_buffer_info(linear_ir); @@ -55,7 +55,7 @@ void RuntimeConfigurator::initialization(const lowered::LinearIRPtr& linear_ir) m_config->tile_rank = linear_ir->get_config().m_loop_depth; } -void RuntimeConfigurator::update(const lowered::LinearIRPtr& linear_ir) { +void RuntimeConfigurator::update(const lowered::LinearIRCPtr& linear_ir) { if (linear_ir->is_dynamic()) { update_loop_info(linear_ir); update_buffer_scratchpad_size(linear_ir); @@ -67,11 +67,11 @@ void RuntimeConfigurator::update(const lowered::LinearIRPtr& linear_ir) { update_latest_shapes(); } -void RuntimeConfigurator::init_tensor_rank(const lowered::LinearIRPtr& linear_ir) const { +void RuntimeConfigurator::init_tensor_rank(const lowered::LinearIRCPtr& linear_ir) const { m_config->tensor_rank = linear_ir->get_master_shape().size(); } -void RuntimeConfigurator::init_data_info(const lowered::LinearIRPtr& linear_ir) { +void RuntimeConfigurator::init_data_info(const lowered::LinearIRCPtr& linear_ir) { const auto& parameters = linear_ir->get_parameters(); const auto& results = linear_ir->get_results(); m_in_num = parameters.size(); @@ -113,7 +113,7 @@ void RuntimeConfigurator::init_data_info(const lowered::LinearIRPtr& linear_ir) } } -void RuntimeConfigurator::init_buffer_info(const lowered::LinearIRPtr& linear_ir) { +void RuntimeConfigurator::init_buffer_info(const lowered::LinearIRCPtr& linear_ir) { std::map> dynamic_buffer_clusters, static_buffer_clusters; // All needed checks are in Validate pass @@ -143,7 +143,7 @@ void RuntimeConfigurator::init_buffer_info(const lowered::LinearIRPtr& linear_ir m_dynamic_buffer_clusters = std::move(dynamic_buffer_clusters); } -void RuntimeConfigurator::update_loop_info(const lowered::LinearIRPtr& linear_ir) const { +void RuntimeConfigurator::update_loop_info(const lowered::LinearIRCPtr& linear_ir) const { // Initialized UnifiedLoopInfo struct CurrentUnifiedLoopInfo { size_t current_work_amount = 0; @@ -202,7 +202,7 @@ void RuntimeConfigurator::update_loop_info(const lowered::LinearIRPtr& linear_ir } } -void RuntimeConfigurator::update_buffer_scratchpad_size(const lowered::LinearIRPtr& linear_ir) const { +void RuntimeConfigurator::update_buffer_scratchpad_size(const lowered::LinearIRCPtr& linear_ir) const { const auto& loop_manager = linear_ir->get_loop_manager(); m_config->buffer_scratchpad_size = linear_ir->get_static_buffer_scratchpad_size(); @@ -278,5 +278,10 @@ void RuntimeConfigurator::update_latest_shapes() { } } +void RuntimeConfigurator::set_kernel_executor_table(std::shared_ptr table) const { + OPENVINO_ASSERT(table, "Failed to update Kernel Executo Table: passed table is missed"); + m_config->kernel_executor_table = std::move(table); +} + } // namespace snippets } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp index 925a6d28697d41..1387992792e0a0 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp @@ -14,7 +14,7 @@ namespace intel_cpu { CPURuntimeConfigurator::CPURuntimeConfigurator() : ov::snippets::RuntimeConfigurator(std::make_shared()) { } -void CPURuntimeConfigurator::update(const ov::snippets::lowered::LinearIRPtr& linear_ir) { +void CPURuntimeConfigurator::update(const ov::snippets::lowered::LinearIRCPtr& linear_ir) { if (linear_ir->is_dynamic()) { update_loop_info(linear_ir); update_loop_args(linear_ir); @@ -30,11 +30,11 @@ void CPURuntimeConfigurator::update(const ov::snippets::lowered::LinearIRPtr& li update_latest_shapes(); } -void CPURuntimeConfigurator::init_tensor_rank(const ov::snippets::lowered::LinearIRPtr& linear_ir) const { +void CPURuntimeConfigurator::init_tensor_rank(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const { m_config->tensor_rank = std::max(linear_ir->get_master_shape().size(), rank6D); } -void CPURuntimeConfigurator::update_loop_args(const ov::snippets::lowered::LinearIRPtr& linear_ir) const { +void CPURuntimeConfigurator::update_loop_args(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const { const auto& cpu_config = ov::as_type_ptr(m_config); OPENVINO_ASSERT(cpu_config, "CPURuntimeConfigurator expects CPURuntimeConfig"); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp index f1a21e5982aa1c..93cbb6b598146c 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp @@ -29,17 +29,17 @@ class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator { * @brief Update RuntimeConfig based on LinearIR * @param linear_ir LinearIR */ - void update(const ov::snippets::lowered::LinearIRPtr& linear_ir) override; + void update(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override; /** * @brief Initializes tensor rank of config * @param linear_ir LinearIR */ - void init_tensor_rank(const ov::snippets::lowered::LinearIRPtr& linear_ir) const override; + void init_tensor_rank(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const override; /** * @brief Calculate Loop parameters of Loop emitters and update these values in CPURuntimeConfig * @param linear_ir LinearIR */ - void update_loop_args(const ov::snippets::lowered::LinearIRPtr& linear_ir) const; + void update_loop_args(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const; const size_t rank6D = 6; }; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp index 920f95f0c8bc37..aa917c89dcb016 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp @@ -184,7 +184,7 @@ float BrgemmKernelExecutor::get_beta(const ov::snippets::lowered::LoopManagerPtr return 0; } void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, - const ov::snippets::lowered::LinearIRPtr& linear_ir, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, BrgemmKernelConfig& config) const { const auto& input_pds = expr->get_input_port_descriptors(); const auto& output_pds = expr->get_output_port_descriptors(); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp index b673c61d6d0aef..2549580c1a176c 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp @@ -100,7 +100,7 @@ class BrgemmKernelExecutor : public CPUKernelExecutor compile_kernel(const BrgemmKernelConfig& c) const override; void update_config(const ov::snippets::lowered::ExpressionPtr& expr, - const ov::snippets::lowered::LinearIRPtr& linear_ir, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, BrgemmKernelConfig& config) const override; static float get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, int loop_id, diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index d6d127eb6981e4..86896ad3b4ca5f 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -746,15 +746,34 @@ void Subgraph::prepareParams() { const auto cache = context->getParamsCache(); auto builder = [this, cache](const SubgraphKey& key) -> std::shared_ptr { - const auto& snippet_config = ov::as_type_ptr(subgraph_attrs->snippet->update_runtime_config()); - // Firstly, find the schedule in the cache - const auto code_gen_result = cache->getOrCreate(SubgraphCodeGeneratorKey(subgraph_attrs, getBroadcastingMask(in_shapes)), - [&snippet_config](const SubgraphCodeGeneratorKey& key) -> std::shared_ptr { - return std::make_shared(key.attrs, snippet_config); - }); + const auto& snippet = subgraph_attrs->snippet; if (is_dynamic) { - return std::make_shared(key.attrs, code_gen_result.first, start_offset_in, start_offset_out, snippet_config); + // Dynamic case: + // 1. Generate JIT code if needed + // 2. Update runtime config with dynamic values + // If JIT code has been taken from cache, need to set cached kernel executor table for the configuration + // 3. Create SubgraphDynamicSpecializedExecutor + const auto code_gen_result = cache->getOrCreate(SubgraphCodeGeneratorKey(subgraph_attrs, getBroadcastingMask(in_shapes)), + [](const SubgraphCodeGeneratorKey& key) -> std::shared_ptr { + return std::make_shared(key.attrs, std::make_shared()); + }); + const auto& code_gen = code_gen_result.first; + // [148644] : Update Kernel table from SubgraphCodeGenerator when JIT code was already generated with specific Kernel table + if (code_gen_result.second == CacheEntryBase::LookUpStatus::Hit) { + snippet->get_runtime_configurator()->set_kernel_executor_table(code_gen->get()->lowering_result.kernel_executor_table); + } + const auto& snippet_config = ov::as_type_ptr(snippet->update_runtime_config()); + return std::make_shared(key.attrs, code_gen, start_offset_in, start_offset_out, snippet_config); } else { + // Static case: + // 1. Update runtime config to get static scheduling data (io data offsets, parallel domain) which will be compiled in JIT code + // 2. Generate JIT code with this static data if needed + // 3. Create SubgraphStaticExecutor + const auto& snippet_config = ov::as_type_ptr(snippet->update_runtime_config()); + const auto code_gen_result = cache->getOrCreate(SubgraphCodeGeneratorKey(subgraph_attrs, getBroadcastingMask(in_shapes)), + [&snippet_config](const SubgraphCodeGeneratorKey& key) -> std::shared_ptr { + return std::make_shared(key.attrs, snippet_config); + }); return std::make_shared(key.attrs, code_gen_result.first, start_offset_in, start_offset_out, snippet_config); } }; diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/subgraph_caching.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/subgraph_caching.cpp new file mode 100644 index 00000000000000..f9f17154dcca68 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/subgraph_caching.cpp @@ -0,0 +1,125 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +// Motivation: +// In a dynamic scenario, depending on the input shapes for the current node, +// - we can either generate a new jit kernel or get an existing one from the cache +// - we can either make shape inference or get existing output shapes from the cache +// But the current single layer tests do not allow checking the case when the same kernel can be used for different nodes. +// We check 2 Subgraphs with MatMuls inside to validate Kernel Executor table also + +// ----------- ----------- ----------- ----------- +// |input 0.0| |input 0.1| |input 1.0| |input 1.1| +// ----------- ----------- ----------- ----------- +// | | | | +// ------------------------------------ ------------------------------------ +// | MatMul 0 | | Matmul 1 | +// ------------------------------------ ------------------------------------ +// | | +// ------------------------------------ ------------------------------------ +// | Add 0 | | Add 1 | +// ------------------------------------ ------------------------------------ +// | | +// ---------------------------------------------------------------------------- +// | concat | +// ---------------------------------------------------------------------------- +// | +// -------- +// |output| +// -------- + +#include "snippets/op/subgraph.hpp" +#include "common_test_utils/common_utils.hpp" +#include "common_test_utils/ov_tensor_utils.hpp" +#include "common_test_utils/node_builders/eltwise.hpp" +#include "common_test_utils/node_builders/constant.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "utils/cpu_test_utils.hpp" +#include "internal_properties.hpp" + +namespace ov { +namespace test { +using namespace ov::test::utils; + +typedef std::tuple< + std::vector, // Input Shapes + ElementType // Input precisions +> SubgraphCacheTestParams; + +class SubgraphCacheTest : public testing::WithParamInterface, + virtual public SubgraphBaseTest { +public: + static std::string getTestCaseName(const testing::TestParamInfo &obj) { + std::vector inputShapes; + ElementType inputPrecision; + std::tie(inputShapes, inputPrecision) = obj.param; + + std::ostringstream results; + + for (size_t i = 0; i < inputShapes.size(); i++) { + results << "IS[" << i << "]=" << inputShapes[i]; + } + + results << "InPRC" << "=" << inputPrecision << "_"; + + return results.str(); + } + +protected: + void SetUp() override { + targetDevice = ov::test::utils::DEVICE_CPU; + + std::vector inputShapes; + ElementType inputPrecision; + std::tie(inputShapes, inputPrecision) = this->GetParam(); + + init_input_shapes(inputShapes); + + // Enable Snippets + configuration.insert(ov::intel_cpu::snippets_mode(ov::intel_cpu::SnippetsMode::IGNORE_CALLBACK)); + + ov::ParameterVector paramVec; + for (size_t i = 0; i < inputDynamicShapes.size(); i++) { + paramVec.push_back(std::make_shared(inputPrecision, inputDynamicShapes[i])); + } + + auto matmul0 = std::make_shared(paramVec[0], paramVec[1]); + auto matmul1 = std::make_shared(paramVec[2], paramVec[3]); + + auto const0 = utils::make_constant(matmul0->get_output_element_type(0), ov::Shape{1}); + auto const1 = utils::make_constant(matmul1->get_output_element_type(0), ov::Shape{1}); + + auto add0 = std::make_shared(matmul0, const0); + auto add1 = std::make_shared(matmul1, const1); + + auto concat = std::make_shared(ov::NodeVector{add0, add1}, -1); + function = std::make_shared(concat, paramVec, "Subgraph"); + } +}; + +TEST_P(SubgraphCacheTest, CompareWithRefs) { + run(); + + CPUTestUtils::CheckNumberOfNodesWithType(compiledModel, "MatMul", 0); + CPUTestUtils::CheckNumberOfNodesWithType(compiledModel, "Subgraph", 2); +} + +namespace { + +std::vector inputShapes { + {{1, 2, -1, -1}, {{1, 2, 10, 3}, {1, 2, 10, 3}, {1, 2, 10, 8}, {1, 2, 10, 3}}}, + {{1, 2, -1, -1}, {{1, 2, 3, 12}, {1, 2, 3, 12}, {1, 2, 8, 9}, {1, 2, 3, 12}}}, + {{1, 2, -1, -1}, {{1, 2, 10, 8}, {1, 2, 10, 3}, {1, 2, 10, 3}, {1, 2, 10, 8}}}, + {{1, 2, -1, -1}, {{1, 2, 8, 9}, {1, 2, 3, 12}, {1, 2, 3, 12}, {1, 2, 8, 9}}}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_SubgraphCache, SubgraphCacheTest, + ::testing::Combine( + ::testing::Values(inputShapes), + ::testing::Values(ElementType::f32)), + SubgraphCacheTest::getTestCaseName); + +} // namespace +} // namespace test +} // namespace ov \ No newline at end of file