diff --git a/src/common/snippets/include/snippets/pass/mha_tokenization.hpp b/src/common/snippets/include/snippets/pass/mha_tokenization.hpp index 60afc27a149750..7c161e8447e9b8 100644 --- a/src/common/snippets/include/snippets/pass/mha_tokenization.hpp +++ b/src/common/snippets/include/snippets/pass/mha_tokenization.hpp @@ -19,6 +19,7 @@ namespace pass { */ class TokenizeMHASnippets: public ngraph::pass::MatcherPass { public: + OPENVINO_RTTI("TokenizeMHASnippets", "0"); TokenizeMHASnippets(); }; diff --git a/src/common/snippets/src/pass/convert_constants.cpp b/src/common/snippets/src/pass/convert_constants.cpp index 3c2e8cee2a7a6f..37cf0f85266434 100644 --- a/src/common/snippets/src/pass/convert_constants.cpp +++ b/src/common/snippets/src/pass/convert_constants.cpp @@ -24,7 +24,9 @@ ngraph::snippets::pass::ConvertConstantsToScalars::ConvertConstantsToScalars() { return false; // Note that all Constants {1,1,1,1} are converted to Scalar {1} here // This is needed to simplify shape inference, otherwise {1,1,1,1} Constants can increase output rank - auto scalar = std::make_shared(ov::op::v0::Constant(*constant, ov::Shape{1})); + // Also some operations support only scalar shapes, so we need separate scalars and shape [1] + const auto shape = constant->get_output_shape(0).size() == 0 ? ov::Shape{} : ov::Shape{1}; + auto scalar = std::make_shared(ov::op::v0::Constant(*constant, shape)); scalar->set_friendly_name(constant->get_friendly_name()); ngraph::copy_runtime_info(constant, scalar); ngraph::replace_node(constant, scalar); diff --git a/src/common/snippets/src/pass/mha_tokenization.cpp b/src/common/snippets/src/pass/mha_tokenization.cpp index 8de5f29eee9b45..69a166140b4093 100644 --- a/src/common/snippets/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/src/pass/mha_tokenization.cpp @@ -241,18 +241,24 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { } auto transpose1 = ngraph::as_type_ptr(parent); - if (!matmul0->get_transpose_b() && is_valid_transpose(transpose1, {0, 2, 3, 1})) { - ordered_ops.insert(ordered_ops.begin(), transpose1); - } else if (matmul0->get_transpose_b() && is_valid_transpose(transpose1, {0, 2, 1, 3})) { - // We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order - // only if these ops have scalar shapes on other inputs. - // There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false). - // We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching - if (are_weights_scalar) { - ordered_ops.insert(ordered_ops.begin(), transpose1); + if (matmul0->get_transpose_b()) { + if (is_valid_transpose(transpose1, {0, 2, 1, 3})) { + // We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order + // only if these ops have scalar shapes on other inputs. + // There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false). + // We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching + if (are_weights_scalar) { + ordered_ops.insert(ordered_ops.begin(), transpose1); + } else { + return false; + } } else { return false; } + } else { + if (is_valid_transpose(transpose1, {0, 2, 3, 1})) { + ordered_ops.insert(ordered_ops.begin(), transpose1); + } } // TODO: Add Reshape Support for all Transposes diff --git a/src/common/snippets/src/pass/tokenization.cpp b/src/common/snippets/src/pass/tokenization.cpp index 0b5e3636dcec77..4744b73b88295e 100644 --- a/src/common/snippets/src/pass/tokenization.cpp +++ b/src/common/snippets/src/pass/tokenization.cpp @@ -53,7 +53,7 @@ bool EnumerateNodes::run_on_model(const std::shared_ptr &m) { bool SnippetsTokenization::run_on_model(const std::shared_ptr& m) { RUN_ON_FUNCTION_SCOPE(SnippetsTokenization); - ngraph::pass::Manager manager; + ngraph::pass::Manager manager(get_pass_config()); manager.set_per_pass_validation(false); manager.register_pass(); diff --git a/src/plugins/intel_cpu/src/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformation_pipeline.cpp index b44e6142a8a42f..0f331dced13544 100644 --- a/src/plugins/intel_cpu/src/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformation_pipeline.cpp @@ -544,10 +544,9 @@ void Transformations::PostLpt() { }); // Float MHA is supported by snippets now - auto postLPTPassConfig = postLPTPassManager.get_pass_config(); if (!enableBF16) { - postLPTPassConfig->disable(); - postLPTPassConfig->disable(); + postLPTPassManager.get_pass_config()->disable(); + postLPTPassManager.get_pass_config()->disable(); } // Execute before snippets. Otherwise FQ will be converted to Subgraph @@ -567,8 +566,7 @@ void Transformations::MainSnippets(void) { if (enableBF16) { // TODO: Need to add BF16 support for MHA in Snippets - const auto snippetsConfig = snippetsManager.get_pass_config(); - snippetsConfig->disable(); + snippetsManager.get_pass_config()->disable(); } if (snippetsMode != Config::SnippetsMode::IgnoreCallback) { snippetsManager.get_pass_config()->set_callback( @@ -590,9 +588,8 @@ void Transformations::MainSnippets(void) { // - parallelism support on JIT level const auto needed_num_of_threads = 12lu; const auto l2_cache_size = dnnl::utils::get_cache_size(2, true); - const auto is_unsupported_parallel_work_amount = IMPLICATION( - parallel_get_num_threads() / 2 > parallel_work_amount, - parallel_work_amount < needed_num_of_threads); + const auto is_unsupported_parallel_work_amount = parallel_get_num_threads() / 2 > parallel_work_amount && + parallel_work_amount < needed_num_of_threads; const auto is_unsupported_kernel_work_amount = kernel_buffer_size > l2_cache_size; return is_unsupported_parallel_work_amount || is_unsupported_kernel_work_amount; });