diff --git a/src/common/snippets/include/snippets/pass/tokenization.hpp b/src/common/snippets/include/snippets/pass/tokenization.hpp index 24efcceec71a24..0bf78e5a0cd662 100644 --- a/src/common/snippets/include/snippets/pass/tokenization.hpp +++ b/src/common/snippets/include/snippets/pass/tokenization.hpp @@ -66,10 +66,10 @@ class SnippetsTokenization : public ov::pass::ModelPass { */ struct Config { Config(size_t concurrency, size_t data_ptr_gpr_count, bool split_m_dimension, bool enable_transpose_on_output, - bool dyn_mha_token, std::set mha_transpose_ranks) + bool dyn_mha_token, std::set mha_transpose_ranks, bool enable_bf16 = false) : m_concurrency(concurrency), m_data_ptr_gpr_count(data_ptr_gpr_count), m_split_m_dimension(split_m_dimension), m_mha_token_enable_transpose_on_output(enable_transpose_on_output), m_is_dynamic_mha_token_enabled(dyn_mha_token), - m_mha_supported_transpose_ranks(std::move(mha_transpose_ranks)) { + m_mha_supported_transpose_ranks(std::move(mha_transpose_ranks)), m_enable_bf16(enable_bf16) { OPENVINO_ASSERT(concurrency > 0, "Concurrency should be greater than 0"); OPENVINO_ASSERT(data_ptr_gpr_count > 0, "data_ptr_gpr_count should be greater than 0"); } @@ -102,6 +102,10 @@ class SnippetsTokenization : public ov::pass::ModelPass { return m_mha_supported_transpose_ranks; } + bool is_bf16_enabled() const { + return m_enable_bf16; + } + private: size_t m_concurrency = 0; // The number of gpr that can be used as data pointers for data nodes (Parameter (and non-Scalar Constants), @@ -121,6 +125,7 @@ class SnippetsTokenization : public ov::pass::ModelPass { // Note that in general Snippets support Transpose of any ranks. // But at the moment Transpose is used only in MHA pattern where 3D and 4D tensors are supported. std::set m_mha_supported_transpose_ranks = { 3, 4 }; + bool m_enable_bf16 = false; }; OPENVINO_RTTI("SnippetsTokenization", "0"); diff --git a/src/common/snippets/src/pass/mha_tokenization.cpp b/src/common/snippets/src/pass/mha_tokenization.cpp index beb465ab3a3fbe..b3883c7feb9b37 100644 --- a/src/common/snippets/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/src/pass/mha_tokenization.cpp @@ -355,7 +355,16 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken // We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching const auto is_transposed_b_0 = matmul0->get_transpose_b(); bool has_matmul0_has_ops_on_input = false; - while (is_supported_intermediate_op(parent)) { + + // Note: this is a temporary WA, avoiding matmul B input tokenization in the cases when CPU . + // It will be removed when plugin specific SubgraphPass will be implemented. + auto check_matmul_b_input_tokenization = [&config](const std::shared_ptr& matmul) { + return matmul->get_input_element_type(0) == matmul->get_input_element_type(1) && + ((matmul->get_input_element_type(0) == ov::element::f32 && !config.is_bf16_enabled()) || + (matmul->get_input_element_type(0) == ov::element::i8)); + }; + const bool support_mm0_b_input_tokenization = check_matmul_b_input_tokenization(matmul0); + while (support_mm0_b_input_tokenization && is_supported_intermediate_op(parent)) { // All supported ops have only one output port if (parent->get_output_target_inputs(0).size() != 1) break; @@ -404,12 +413,16 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken } }; - const auto transpose1 = ov::as_type_ptr(parent); const auto transpose0 = ov::as_type_ptr(matmul0->get_input_node_shared_ptr(0)); - const auto transpose2 = ov::as_type_ptr(matmul1->get_input_node_shared_ptr(1)); - tokenize_transpose(transpose1, is_transposed_b_0, get_decomposed_transpose_order(pattern_rank), ordered_ops.begin()); tokenize_transpose(transpose0, matmul0->get_transpose_a(), get_fusion_transpose_order(pattern_rank), ordered_ops.begin()); - tokenize_transpose(transpose2, matmul1->get_transpose_b(), get_fusion_transpose_order(pattern_rank), ordered_ops.end()); + if (support_mm0_b_input_tokenization) { + const auto transpose1 = ov::as_type_ptr(parent); + tokenize_transpose(transpose1, is_transposed_b_0, get_decomposed_transpose_order(pattern_rank), ordered_ops.begin()); + } + if (check_matmul_b_input_tokenization(matmul1)) { + const auto transpose2 = ov::as_type_ptr(matmul1->get_input_node_shared_ptr(1)); + tokenize_transpose(transpose2, matmul1->get_transpose_b(), get_fusion_transpose_order(pattern_rank), ordered_ops.end()); + } ordered_ops.push_back(matmul1); bool are_ops_after_matmul1 = false; diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 9dd1da2d471e5a..9afae7dd86a8cf 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -955,7 +955,7 @@ void Transformations::MainSnippets(void) { std::set mha_supported_transpose_ranks = { 4 }; snippets::pass::SnippetsTokenization::Config tokenization_config(concurrency, data_ptr_gpr_count, split_m_dimension, mha_token_enable_transpose_on_output, is_dynamic_mha_token_enabled, - mha_supported_transpose_ranks); + mha_supported_transpose_ranks, config.inferencePrecision == ov::element::bf16); ov::pass::Manager snippetsManager("CPU:Snippets"); snippetsManager.set_per_pass_validation(false); diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp index 8517612a348f68..62b7a3390879e1 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp @@ -3,9 +3,9 @@ // #include "common_test_utils/common_utils.hpp" -#include "common_test_utils/ov_tensor_utils.hpp" #include "common_test_utils/node_builders/constant.hpp" #include "common_test_utils/node_builders/fake_quantize.hpp" +#include "common_test_utils/ov_tensor_utils.hpp" #include "internal_properties.hpp" #include "shared_test_classes/base/ov_subgraph.hpp" #include "utils/cpu_test_utils.hpp" @@ -666,15 +666,27 @@ std::vector> matMulIn0PrecisionsQuant = { {ElementType::i8, ElementType::u8}, }; -INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant_Pattern0, +INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant_Pattern0_i8i8, MHAQuantTest, ::testing::Combine(::testing::ValuesIn(static_shapes_to_test_representation(inputShapesQuant)), ::testing::ValuesIn(inputPrecisionsQuant), - ::testing::ValuesIn(matMulIn0PrecisionsQuant), + ::testing::Values(std::vector{ElementType::i8, ElementType::i8}), ::testing::Values(0), ::testing::Values(ExpectedNodes{ {"Subgraph", 5}, // FQs on inputs x 3 + MHA + Deq Mul - {"Transpose", 1}}), // Transpose between MHA and Deq Mul + {"Transpose", 1}}), // Transpose between MHA and Deq Mul + Extracted transpose on B input of 2nd MM + ::testing::Values(ov::test::utils::DEVICE_CPU)), + MHAQuantTest::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant_Pattern0_i8u8, + MHAQuantTest, + ::testing::Combine(::testing::ValuesIn(static_shapes_to_test_representation(inputShapesQuant)), + ::testing::ValuesIn(inputPrecisionsQuant), + ::testing::Values(std::vector{ElementType::i8, ElementType::u8}), + ::testing::Values(0), + ::testing::Values(ExpectedNodes{ + {"Subgraph", 5}, // FQs on inputs x 3 + MHA + Deq Mul + {"Transpose", 2}}), // Transpose between MHA and Deq Mul + Extracted transpose on B input of 2nd MM ::testing::Values(ov::test::utils::DEVICE_CPU)), MHAQuantTest::getTestCaseName); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_quantized.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_quantized.cpp index 0c731b74565863..0a12e0a36a3621 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_quantized.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_quantized.cpp @@ -48,7 +48,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(ov::element::f32), ::testing::Values(false), // The graph doesn't contain Multiply ::testing::Values(MHA::default_thread_count), - ::testing::Values(6), // FQx3 on inputs + MHA + Transpose on output + Deq Mul + ::testing::Values(7), // FQx3 on inputs + MHA + Transpose on output + Transpose on Matmul's B input + Deq Mul ::testing::Values(5), // FQx3 on inputs + MHA + Deq Mul ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::Values(CPUTestUtils::empty_plugin_config)),