From 0fb5b64c10398f3e9c2d80c00b7538f43629fb08 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Tue, 19 Nov 2024 10:23:47 +0100 Subject: [PATCH] Correct Transpose tokenization in tests --- src/common/snippets/src/pass/collapse_subgraph.cpp | 5 ++++- .../tests/functional/shared_tests_instances/snippets/mha.cpp | 2 +- .../shared_tests_instances/snippets/mha_with_dyn_mul.cpp | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/common/snippets/src/pass/collapse_subgraph.cpp b/src/common/snippets/src/pass/collapse_subgraph.cpp index 0f0cc225173479..6348f89598523d 100644 --- a/src/common/snippets/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/src/pass/collapse_subgraph.cpp @@ -51,9 +51,12 @@ auto is_supported_op(const std::shared_ptr &n) -> bool { const auto parent = transpose->get_input_node_shared_ptr(0); const auto child = transpose->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); auto is_brgemm_case = ov::is_type(parent) || ov::is_type(child); + auto decomposition_case = true; // Check for Transpose parent is MatMul inside Subgraph if (const auto subgraph = ov::as_type_ptr(parent)) { if (GetSnippetsSubgraphType(subgraph) != SnippetsSubgraphType::Completed) { + // Transpose decomposition is supported only for Transpose nodes right after Subgraph's parameters + decomposition_case = false; const auto body = subgraph->body_ptr(); const auto subgraph_output = body->get_results()[transpose->input_value(0).get_index()]->get_input_node_shared_ptr(0); is_brgemm_case = is_brgemm_case || ov::is_type(subgraph_output); @@ -63,7 +66,7 @@ auto is_supported_op(const std::shared_ptr &n) -> bool { const auto& order = as_type_ptr(n->get_input_node_shared_ptr(1)); if (order) { const auto order_value = order->cast_vector(); - return (TransposeDecomposition::is_supported_transpose_order(order_value)) || + return (decomposition_case && TransposeDecomposition::is_supported_transpose_order(order_value)) || (is_brgemm_case && FuseTransposeBrgemm::is_supported_transpose_order(order_value)); } } diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 63f5176684ccc1..45bb055d086910 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -131,7 +131,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, ::testing::Values(ov::element::bf16), ::testing::ValuesIn({false}), ::testing::Values(MHA::default_thread_count), - ::testing::Values(7), + ::testing::Values(9), ::testing::Values(6), ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)), diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_with_dyn_mul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_with_dyn_mul.cpp index 7876d737af2281..ccd23dd6833f98 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_with_dyn_mul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha_with_dyn_mul.cpp @@ -56,7 +56,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(precision_f32(5)), ::testing::Values(ov::element::bf16), ::testing::Values(MHA::default_thread_count), - ::testing::Values(8), // MHA + 1 Transpose on output + 6 Converts around + ::testing::Values(10), // MHA + 1 Transpose on output + 6 Converts around + 2 Transposes on Matmul's B inputs ::testing::Values(7), // MHA + 6 Converts around ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::Values(CPUTestUtils::empty_plugin_config)),