diff --git a/src/common/snippets/include/snippets/lowered/pass/pass.hpp b/src/common/snippets/include/snippets/lowered/pass/pass.hpp index 5833b695b0bba8..446f96d30a27cf 100644 --- a/src/common/snippets/include/snippets/lowered/pass/pass.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/pass.hpp @@ -20,7 +20,7 @@ namespace pass { * @brief Base class for transformations on linear IR * @ingroup snippets */ -class PassBase { +class PassBase : public std::enable_shared_from_this { public: PassBase() = default; virtual ~PassBase() = default; diff --git a/src/common/snippets/include/snippets/utils/utils.hpp b/src/common/snippets/include/snippets/utils/utils.hpp index e9a7e7d1ca9523..22665f41827ce3 100644 --- a/src/common/snippets/include/snippets/utils/utils.hpp +++ b/src/common/snippets/include/snippets/utils/utils.hpp @@ -124,6 +124,11 @@ std::string vector2str(const std::vector& values) { bool broadcast_merge_dim(size_t& dst, const size_t& d1, const size_t& d2); +// If one of the dims is dynamic, return the other dim (might also be dynamic) +// If both dims are static, they must be equal - this is the difference from the utility above +// Can be used in SpecificLoopIterationHandlers +bool merge_dynamic_dim(size_t& dst, const size_t& d1, const size_t& d2); + VectorDims pshape_to_vdims(const PartialShape&); ov::PartialShape vdims_to_pshape(const VectorDims&); diff --git a/src/common/snippets/src/lowered/expression.cpp b/src/common/snippets/src/lowered/expression.cpp index 8b30872e0b1328..3c4391da3a7250 100644 --- a/src/common/snippets/src/lowered/expression.cpp +++ b/src/common/snippets/src/lowered/expression.cpp @@ -175,11 +175,11 @@ ExpressionPtr Expression::clone_with_new_inputs(const ExpressionMap& expr_map, } ExpressionPort Expression::get_input_port(size_t i) { - return ExpressionPort(this->shared_from_this(), ExpressionPort::Type::Input, i); + return ExpressionPort(shared_from_this(), ExpressionPort::Type::Input, i); } ExpressionPort Expression::get_output_port(size_t i) { - return ExpressionPort(this->shared_from_this(), ExpressionPort::Type::Output, i); + return ExpressionPort(shared_from_this(), ExpressionPort::Type::Output, i); } std::vector Expression::get_input_ports() { diff --git a/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp b/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp index d689b183456bc1..090175c7466be3 100644 --- a/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp +++ b/src/common/snippets/src/lowered/pass/brgemm_blocking.cpp @@ -102,8 +102,10 @@ std::tuple BrgemmBlockingBase::get_blocking_params(const const auto& m = *++out_preordered_dims.rbegin(); const auto& n = *out_preordered_dims.rbegin(); - const auto& k = *in_0_planar_dims.rbegin(); - OPENVINO_ASSERT(k == *++in_1_planar_dims.rbegin(), "Brgemm input descriptors have different K dimension value."); + const auto& k0 = *in_0_planar_dims.rbegin(); + const auto& k1 = *++in_1_planar_dims.rbegin(); + size_t k = 0; + OPENVINO_ASSERT(utils::merge_dynamic_dim(k, k0, k1), "Brgemm input descriptors have incompatible K dimension value."); // Ticket: 113745 // TODO: extend block size selection heuristics diff --git a/src/common/snippets/src/lowered/pass/iter_handler.cpp b/src/common/snippets/src/lowered/pass/iter_handler.cpp index 3e035628df476f..9c9ab69351c7bd 100644 --- a/src/common/snippets/src/lowered/pass/iter_handler.cpp +++ b/src/common/snippets/src/lowered/pass/iter_handler.cpp @@ -49,13 +49,13 @@ bool UpdateMemoryAccessCounts::run(LinearIR& linear_ir, LinearIR::constExprIt be } std::shared_ptr UpdateMemoryAccessCounts::merge(const std::shared_ptr& other) { - const auto merged_pass = std::make_shared(m_count); - if (other == nullptr) - return merged_pass; + if (!other) + return shared_from_this(); const auto casted_pass = ov::as_type_ptr(other); - if (!casted_pass || m_count != casted_pass->m_count) + size_t merged_count; + if (!casted_pass || !ov::snippets::utils::merge_dynamic_dim(merged_count, m_count, casted_pass->m_count)) return nullptr; - return merged_pass; + return std::make_shared(merged_count); } SetFillOffset::SetFillOffset(size_t offset) : RangedPass(), m_offset(offset) {} @@ -71,13 +71,13 @@ bool SetFillOffset::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Linear } std::shared_ptr SetFillOffset::merge(const std::shared_ptr& other) { - const auto merged_pass = std::make_shared(m_offset); - if (other == nullptr) - return merged_pass; + if (!other) + return shared_from_this(); const auto casted_pass = ov::as_type_ptr(other); - if (!casted_pass || m_offset != casted_pass->m_offset) + size_t merged_offset; + if (!casted_pass || !ov::snippets::utils::merge_dynamic_dim(merged_offset, m_offset, casted_pass->m_offset)) return nullptr; - return merged_pass; + return std::make_shared(merged_offset); } bool SetLoopIncrementOne::run(LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) { diff --git a/src/common/snippets/src/lowered/pass/propagate_subtensors.cpp b/src/common/snippets/src/lowered/pass/propagate_subtensors.cpp index c89274a728c4c9..53fb344f4d9e8a 100644 --- a/src/common/snippets/src/lowered/pass/propagate_subtensors.cpp +++ b/src/common/snippets/src/lowered/pass/propagate_subtensors.cpp @@ -175,13 +175,13 @@ bool UpdateSubtensors::run(LinearIR& linear_ir, LinearIR::constExprIt begin, Lin } std::shared_ptr UpdateSubtensors::merge(const std::shared_ptr& other) { - const auto merged_pass = std::make_shared(m_tail_size); - if (other == nullptr) - return merged_pass; + if (!other) + return shared_from_this(); const auto casted_pass = ov::as_type_ptr(other); - if (!casted_pass || m_tail_size != casted_pass->m_tail_size) + size_t merged_size; + if (!casted_pass || !ov::snippets::utils::merge_dynamic_dim(merged_size, m_tail_size, casted_pass->m_tail_size)) return nullptr; - return merged_pass; + return std::make_shared(merged_size); } } // namespace pass diff --git a/src/common/snippets/src/pass/mha_tokenization.cpp b/src/common/snippets/src/pass/mha_tokenization.cpp index 483cdcd8564265..d1d8526534f62d 100644 --- a/src/common/snippets/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/src/pass/mha_tokenization.cpp @@ -472,9 +472,25 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken // TODO [75567]: move this plugin-specific constraint to the plugin callback const auto last_node = ordered_ops.back(); - if (potential_body_params_count + last_node->get_output_size() + hidden_virtual_ports_count + uniqie_buffer_reg_group_count > 11) { + const auto io_count = potential_body_params_count + last_node->get_output_size() + hidden_virtual_ports_count; + const auto data_count = io_count + uniqie_buffer_reg_group_count; + auto available_regs = config.get_data_ptr_gpr_count(); + // [150148, 150149] Currently Snippets don't have mechanism of spilling registers on stack. + // Due to this limitation we have to skip tokenization of some subgraphs + // if we need more registers than we have on the target machine. + // `config.get_data_ptr_gpr_count()` provides available data registers count (including parameters, results and buffers) + // after excluding 2 registers for work amounts. + // However, MHA Subgraph has `SplitLoops` optimization which adds outermost blocked Loop by M. This Loop requires + // the separate own register for `work_amount` also. Thus, we have to decrement `available_regs` count in MHA case. + // Need to notice that in general we have enough count of available registers. + // But in rare cases (when there are a lot of parameters/results, the heuristic value of their number is `5`) + // the count of available registers might be not enough and we have to not tokenize these subgraphs. + // So only for these rare cases we decrement `available_regs` value. + if (io_count > 5) + available_regs--; + + if (data_count > available_regs) return false; - } // If backend doesn't enable dynamic MHA tokenization, return false if (!config.is_dynamic_mha_token_enabled()) { diff --git a/src/common/snippets/src/utils/utils.cpp b/src/common/snippets/src/utils/utils.cpp index cd64958207d958..e7381fe6754758 100644 --- a/src/common/snippets/src/utils/utils.cpp +++ b/src/common/snippets/src/utils/utils.cpp @@ -103,10 +103,21 @@ auto get_non_scalar_constant_count_for_fq(const std::shared_ptr + + +namespace ov { +namespace test { +namespace snippets { + +// D1, D2, Result +using BroadcastMergeDimParams = std::tuple; + +class BroadcastMergeDimTest : public testing::TestWithParam { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + +protected: + void SetUp() override; + BroadcastMergeDimParams m_dims = {}; +}; + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/common/snippets/tests/src/utils/broadcast_merge_dim.cpp b/src/common/snippets/tests/src/utils/broadcast_merge_dim.cpp new file mode 100644 index 00000000000000..52ed6116822a9d --- /dev/null +++ b/src/common/snippets/tests/src/utils/broadcast_merge_dim.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "utils/broadcast_dim_merge.hpp" + +#include "common_test_utils/ov_test_utils.hpp" +#include "snippets/utils/utils.hpp" + +namespace ov { +namespace test { +namespace snippets { + +std::string BroadcastMergeDimTest::getTestCaseName(testing::TestParamInfo obj) { + BroadcastMergeDimParams params = obj.param; + std::ostringstream result; + result << "D0=" << ov::snippets::utils::value2str(std::get<0>(params)) << "_"; + result << "D1=" << ov::snippets::utils::value2str(std::get<1>(params)) << "_"; + result << "DST=" << ov::snippets::utils::value2str(std::get<2>(params)); + return result.str(); +} + +void BroadcastMergeDimTest::SetUp() { + m_dims = this->GetParam(); +} + +TEST_P(BroadcastMergeDimTest, BrodcastMergeDim) { + size_t d1, d2, dst, result; + std::tie(d1, d2, dst) = this->m_dims; + ASSERT_TRUE(ov::snippets::utils::broadcast_merge_dim(result, d1, d2)); + ASSERT_EQ(result, dst); +} + +namespace BrodcastMergeDimInstantiation { + +constexpr size_t dynamic = ov::snippets::utils::get_dynamic_value(); + +const std::vector dimension_cases = { + {10, 10, 10}, + {10, 1, 10}, + {1, 10, 10}, + {dynamic, 10, 10}, + {10, dynamic, 10}, + {dynamic, dynamic, dynamic}, + {dynamic, 1, dynamic}, + {1, dynamic, dynamic}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_BrodcastMergeDim, BroadcastMergeDimTest, + ::testing::ValuesIn(dimension_cases), + BroadcastMergeDimTest::getTestCaseName); + +} // namespace BrodcastMergeDimInstantiation +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/node.h b/src/plugins/intel_cpu/src/node.h index 2095907b860508..f3eb606bf3e322 100644 --- a/src/plugins/intel_cpu/src/node.h +++ b/src/plugins/intel_cpu/src/node.h @@ -773,7 +773,7 @@ class Node { NameFromType(getType())); } - MemoryPtr getScratchPadMem(const DnnlMemoryDescPtr& desc) { + MemoryPtr getScratchPadMem(const MemoryDescPtr& desc) { if (!scratchpadMem || !scratchpadMem->getDesc().isCompatible(*desc)) { scratchpadMem = context->getScratchPad(curNumaNode)->createScratchPadMem(desc); } diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 25ea10ce805622..b676f54e27c2d0 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -72,16 +72,15 @@ class SubgraphStaticExecutor : public Subgraph::SubgraphExecutor { const std::shared_ptr& snippet, const std::vector& start_offset_in, const std::vector& start_offset_out, - const std::shared_ptr& snippet_config) - : SubgraphExecutor(snippet_attrs, snippet, start_offset_in, start_offset_out) { - init_runtime_params(snippet_config); - } + const std::shared_ptr& snippet_config, + const BufferScratchpadAllocator& allocator) + : SubgraphExecutor(snippet_attrs, snippet, start_offset_in, start_offset_out, snippet_config, allocator) {} void exec(const std::vector& inMemPtrs, const std::vector& outMemPtrs) override { const auto& callable = m_schedule->get_callable(); - auto initializer = [&](jit_snippets_call_args& call_args) { - init_call_args(call_args, inMemPtrs, outMemPtrs); + auto initializer = [&](jit_snippets_call_args& call_args, size_t ithr) { + init_call_args(call_args, inMemPtrs, outMemPtrs, ithr); }; auto caller = [&](jit_snippets_call_args& call_args, const size_t* indexes) { callable(&call_args, indexes); @@ -97,17 +96,15 @@ class SubgraphStaticExecutor : public Subgraph::SubgraphExecutor { protected: typedef void (*kernel)(const void*, const void*); - inline void init_call_args(jit_snippets_call_args& call_args, const std::vector& srcMemPtrs, const std::vector& dstMemPtrs) { + inline void init_call_args(jit_snippets_call_args& call_args, const std::vector& srcMemPtrs, + const std::vector& dstMemPtrs, size_t ithr) { for (size_t i = 0; i < srcMemPtrs.size(); i++) call_args.src_ptrs[i] = srcMemPtrs[i]->getDataAs() + m_start_offset_in[i]; for (size_t i = 0; i < dstMemPtrs.size(); i++) call_args.dst_ptrs[i] = dstMemPtrs[i]->getDataAs() + m_start_offset_out[i]; - if (m_buffer_scratchpad_size > 0) { - call_args.buffer_scratchpad_ptr = - reinterpret_cast(m_buffer_scratchpad.data()) + parallel_get_thread_num() * m_buffer_scratchpad_size; - } + update_scratchpad_ptr(call_args.buffer_scratchpad_ptr, ithr); } }; @@ -118,9 +115,13 @@ class SubgraphDynamicSpecializedExecutor : public Subgraph::SubgraphExecutor { const std::shared_ptr& snippet, const std::vector& start_offset_in, const std::vector& start_offset_out, - const std::shared_ptr& snippet_config) - : SubgraphExecutor(snippet_attrs, snippet, start_offset_in, start_offset_out) { - init_runtime_params(snippet_config); + const std::shared_ptr& snippet_config, + const BufferScratchpadAllocator& allocator) + : SubgraphExecutor(snippet_attrs, snippet, start_offset_in, start_offset_out, snippet_config, allocator) { + buffer_offsets = snippet_config->buffer_cluster_offsets; + data_offsets = snippet_config->io_data_offsets; + loop_args = snippet_config->loop_args; + reset_exec_table_state = snippet_config->kernel_executor_table->get_state_reset(); } void exec(const std::vector& inMemPtrs, const std::vector& outMemPtrs) override { @@ -137,8 +138,8 @@ class SubgraphDynamicSpecializedExecutor : public Subgraph::SubgraphExecutor { std::vector dst_ptrs; init_original_ptrs(inMemPtrs, outMemPtrs, src_ptrs, dst_ptrs); - auto initializer = [&](jit_snippets_call_args& call_args) { - init_call_args(call_args); + auto initializer = [&](jit_snippets_call_args& call_args, size_t ithr) { + init_call_args(call_args, ithr); }; auto caller = [&](jit_snippets_call_args& call_args, const size_t* indexes) { update_ptrs(call_args, src_ptrs, dst_ptrs, indexes); @@ -155,13 +156,11 @@ class SubgraphDynamicSpecializedExecutor : public Subgraph::SubgraphExecutor { protected: typedef void (*dynamic_kernel)(const void *); - inline void init_call_args(jit_snippets_call_args& call_args) { + inline void init_call_args(jit_snippets_call_args& call_args, size_t ithr) { call_args.register_loops(loop_args); std::copy(buffer_offsets.cbegin(), buffer_offsets.cend(), call_args.buffer_offsets); - if (m_buffer_scratchpad_size > 0) - call_args.buffer_scratchpad_ptr = - reinterpret_cast(m_buffer_scratchpad.data()) + parallel_get_thread_num() * m_buffer_scratchpad_size; + update_scratchpad_ptr(call_args.buffer_scratchpad_ptr, ithr); } inline void init_original_ptrs(const std::vector& srcMemPtrs, const std::vector& dstMemPtrs, @@ -196,14 +195,6 @@ class SubgraphDynamicSpecializedExecutor : public Subgraph::SubgraphExecutor { } } - void init_runtime_params(const std::shared_ptr& snippet_config) override { - SubgraphExecutor::init_runtime_params(snippet_config); - buffer_offsets = snippet_config->buffer_cluster_offsets; - data_offsets = snippet_config->io_data_offsets; - loop_args = snippet_config->loop_args; - reset_exec_table_state = snippet_config->kernel_executor_table->get_state_reset(); - }; - std::vector buffer_offsets = {}; std::vector> data_offsets = {}; std::vector loop_args = {}; @@ -757,10 +748,15 @@ void Subgraph::optimizeIR() { } void Subgraph::prepareParams() { - const auto cache = context->getParamsCache(); + const auto& cache = context->getParamsCache(); - auto builder = [this, cache](const SubgraphKey& key) -> std::shared_ptr { + auto builder = [this, &cache](const SubgraphKey& key) -> std::shared_ptr { const auto& snippet = subgraph_attrs->snippet; + + SubgraphExecutor::BufferScratchpadAllocator allocator = [this](size_t size) { + return getScratchPadMem(std::make_shared(ov::element::u8, intel_cpu::Shape{size})); + }; + if (is_dynamic) { // Dynamic case: // 1. Generate JIT code if needed @@ -777,7 +773,7 @@ void Subgraph::prepareParams() { snippet->get_runtime_configurator()->set_kernel_executor_table(code_gen->get()->lowering_result.kernel_executor_table); } const auto& snippet_config = ov::as_type_ptr(snippet->update_runtime_config()); - return std::make_shared(key.attrs, code_gen, start_offset_in, start_offset_out, snippet_config); + return std::make_shared(key.attrs, code_gen, start_offset_in, start_offset_out, snippet_config, allocator); } else { // Static case: // 1. Update runtime config to get static scheduling data (io data offsets, parallel domain) which will be compiled in JIT code @@ -788,7 +784,7 @@ void Subgraph::prepareParams() { [&snippet_config](const SubgraphCodeGeneratorKey& key) -> std::shared_ptr { return std::make_shared(key.attrs, snippet_config); }); - return std::make_shared(key.attrs, code_gen_result.first, start_offset_in, start_offset_out, snippet_config); + return std::make_shared(key.attrs, code_gen_result.first, start_offset_in, start_offset_out, snippet_config, allocator); } }; @@ -875,26 +871,25 @@ Subgraph::SubgraphCodeGenerator::SubgraphCodeGenerator(const std::shared_ptr& snippet_attrs, const std::shared_ptr& snippet, const std::vector& start_offset_in, - const std::vector& start_offset_out) + const std::vector& start_offset_out, + const std::shared_ptr& snippet_config, + const BufferScratchpadAllocator& allocator) : m_schedule(snippet->get()), m_start_offset_in(start_offset_in), m_start_offset_out(start_offset_out) { OPENVINO_ASSERT(m_schedule, "Schedule is empty!"); -#if defined(__linux__) && defined(OPENVINO_ARCH_X86_64) && defined(SNIPPETS_DEBUG_CAPS) - const auto target = std::dynamic_pointer_cast(snippet_attrs->snippet->get_generator()->get_target_machine()); - enabled_segfault_detector = target && target->debug_config.enable_segfault_detector; -#endif -} - -void Subgraph::SubgraphExecutor::init_runtime_params(const std::shared_ptr& snippet_config) { OPENVINO_ASSERT(snippet_config, "Runtime Config is empty!"); - - m_buffer_scratchpad_size = snippet_config->buffer_scratchpad_size; - OPENVINO_ASSERT(!ov::snippets::utils::is_dynamic_value(m_buffer_scratchpad_size), "Undefined buffer scratchpad size!"); - m_buffer_scratchpad.resize(m_buffer_scratchpad_size * parallel_get_max_threads(), 0); - init_parallel_domain(snippet_config, m_parallel_exec_domain); m_harness_work_amount = std::accumulate(m_parallel_exec_domain.cbegin(), m_parallel_exec_domain.cend(), size_t(1), std::multiplies()); m_nthreads = std::min(parallel_get_max_threads(), static_cast(m_harness_work_amount)); + + m_buffer_scratchpad_size = snippet_config->buffer_scratchpad_size; + OPENVINO_ASSERT(!ov::snippets::utils::is_dynamic_value(m_buffer_scratchpad_size), "Undefined buffer scratchpad size!"); + m_buffer_scratchpad = allocator(static_cast(m_nthreads) * m_buffer_scratchpad_size); + +#if defined(__linux__) && defined(OPENVINO_ARCH_X86_64) && defined(SNIPPETS_DEBUG_CAPS) + const auto target = std::dynamic_pointer_cast(snippet_attrs->snippet->get_generator()->get_target_machine()); + enabled_segfault_detector = target && target->debug_config.enable_segfault_detector; +#endif } #if defined(__linux__) && defined(OPENVINO_ARCH_X86_64) && defined(SNIPPETS_DEBUG_CAPS) @@ -914,7 +909,7 @@ void Subgraph::SubgraphExecutor::segfault_detector() { } #endif -void Subgraph::SubgraphExecutor::parallel_for6d(const std::function& initializer, +void Subgraph::SubgraphExecutor::parallel_for6d(const std::function& initializer, const std::function& caller) { const auto& dom = m_parallel_exec_domain; @@ -924,7 +919,7 @@ void Subgraph::SubgraphExecutor::parallel_for6d(const std::function& initializer, +void Subgraph::SubgraphExecutor::parallel_forNd(const std::function& initializer, const std::function& caller) { const auto& dom = m_parallel_exec_domain; @@ -948,7 +943,7 @@ void Subgraph::SubgraphExecutor::parallel_forNd(const std::function; + SubgraphExecutor(const std::shared_ptr& snippet_attrs, const std::shared_ptr& snippet, const std::vector& start_offset_in, - const std::vector& start_offset_out); + const std::vector& start_offset_out, + const std::shared_ptr& snippet_config, + const BufferScratchpadAllocator& allocator); virtual ~SubgraphExecutor() = default; virtual void exec(const std::vector& inMemPtrs, const std::vector& outMemPtrs) = 0; protected: - void parallel_for6d(const std::function& initializer, + void parallel_for6d(const std::function& initializer, const std::function& caller); - void parallel_forNd(const std::function& initializer, + void parallel_forNd(const std::function& initializer, const std::function& caller); - virtual void init_runtime_params(const std::shared_ptr& snippet_config); + inline void update_scratchpad_ptr(void*& scratchpad_ptr, size_t ithr) const { + if (m_buffer_scratchpad_size > 0) + scratchpad_ptr = m_buffer_scratchpad->getDataAs() + ithr * m_buffer_scratchpad_size; + } std::shared_ptr m_schedule; // Holds index of output used as in execution domain @@ -142,7 +149,7 @@ class Subgraph::SubgraphExecutor { size_t m_harness_work_amount = 0; // Buffer scratchpad - std::vector m_buffer_scratchpad = {}; + MemoryPtr m_buffer_scratchpad = nullptr; size_t m_buffer_scratchpad_size = 0; const size_t rank6D = 6; diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 98da1be4c74876..43862022462ada 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -893,17 +893,16 @@ void Transformations::MainSnippets(void) { // ARM has 32 gprs. After excluding 2 registers for work amounts, 1 register for runtime parameters, 1 platform register, // 3 registers for temporary use, and 2 stack related registers, it has 23 remaining registers. size_t data_ptr_gpr_count = 23; + bool is_dynamic_mha_token_enabled = false; #else // X64 has 16 gprs. After excluding 2 registers for work amounts, 1 register for runtime parameters, // and 2 stack related registers, it has 11 remaining registers. size_t data_ptr_gpr_count = 11; + bool is_dynamic_mha_token_enabled = true; #endif // The optimization "SplitDimensionM" depends on target machine (thread count). // To avoid uncontrolled behavior in tests, we disabled the optimization when there is Config::SnippetsMode::IgnoreCallback bool split_m_dimension = !ignoreCallback; - // [113198] Add dynamic Subgraph with MHA pattern inside execution support - // To enable dynamic MHA in tests, this flag is on when there is Config::SnippetsMode::IgnoreCallback - bool is_dynamic_mha_token_enabled = ignoreCallback; // [122706] Some 3D MHA Patterns have perf regressions when Transpose op is tokenized std::set mha_supported_transpose_ranks = { 4 }; snippets::pass::SnippetsTokenization::Config tokenization_config(concurrency, data_ptr_gpr_count, split_m_dimension,