diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp index aa917c89dcb016..e46de866990005 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp @@ -224,9 +224,12 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression // Quick validation check: Should we check that port is really Brgemm port? // If BrgemmCopyB in the Loop by M -> first input port will be BrgemmCopyB with `incremented=false` // to avoid extra checks, we validate only first input port - OPENVINO_ASSERT(in_ports.size() > 1 && in_ports.front().is_incremented && in_ports.front().dim_idx == 1 && - out_ports.size() == 1 && out_ports.front().is_incremented && out_ports.front().dim_idx == 1, - "Incorrect Loop by Brgemm dimension N"); + // Note: We check `is_incremented` attribute only for not incremented ports because + // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization + auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { return p.dim_idx == 1; }; + OPENVINO_ASSERT(in_ports.size() > 1 && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) && + out_ports.size() == 1 && check_port(out_ports.back()), + "Incorrect Loop by Brgemm dimension M"); M = current_expanded_loop_info->get_increment(); input_pds[0]->set_subtensor_dim(1, M); output_pds[0]->set_subtensor_dim(1, M); @@ -240,8 +243,11 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression const auto& in_ports = current_expanded_loop_info->get_input_ports(); const auto& out_ports = current_expanded_loop_info->get_output_ports(); // Quick validation check: Should we check that port is really Brgemm port? - OPENVINO_ASSERT(in_ports.size() == 2 && !in_ports.front().is_incremented && in_ports.back().is_incremented && in_ports.back().dim_idx == 0 && - out_ports.size() == 1 && out_ports.front().is_incremented && out_ports.front().dim_idx == 0, + // Note: We check `is_incremented` attribute only for not incremented ports because + // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization + auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { return p.dim_idx == 0; }; + OPENVINO_ASSERT(in_ports.size() == 2 && !in_ports.front().is_incremented && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) && + out_ports.size() == 1 && check_port(out_ports.back()), "Incorrect Loop by Brgemm dimension N"); N = current_expanded_loop_info->get_increment(); input_pds[1]->set_subtensor_dim(0, N); @@ -261,8 +267,9 @@ void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::Expression const auto& in_ports = current_expanded_loop_info->get_input_ports(); const auto& out_ports = current_expanded_loop_info->get_output_ports(); // Quick validation check: Should we check that port is really Brgemm port? - OPENVINO_ASSERT(in_ports.size() == 2 && in_ports.front().is_incremented && in_ports.front().dim_idx == 0 && - in_ports.back().is_incremented && in_ports.back().dim_idx == 1 && + // Note: We check `is_incremented` attribute only for not incremented ports because + // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization + OPENVINO_ASSERT(in_ports.size() == 2 && in_ports.front().dim_idx == 0 && in_ports.back().dim_idx == 1 && out_ports.size() == 1 && !out_ports.front().is_incremented, "Incorrect Loop by Brgemm dimension K"); K = current_expanded_loop_info->get_increment(); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index c5ee28757b0e04..380967a44f5fa6 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -25,7 +25,8 @@ const auto& inputShapes_4D = STATIC_SHAPES( const auto& inputShapes_3D = STATIC_SHAPES( {{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}}, - {{68, 6, 92}, {68, 6, 92}, {1, 68, 68}, {68, 6, 92}}); + {{68, 6, 92}, {68, 6, 92}, {1, 68, 68}, {68, 6, 92}}, + {{16, 2, 92}, {68, 2, 92}, {1, 16, 68}, {68, 2, 92}}); static inline bool is_bf16_supported() { return ov::with_cpu_x86_bfloat16() || ov::with_cpu_x86_avx512_core_amx_bf16();