From eb3e6a65eb372990a446d0c70120f3d5b967373e Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Thu, 8 Jun 2023 12:05:14 +0400 Subject: [PATCH] [Snippets] Add support of MHA Tokenization for different precisions (#15647) --- .../snippets/include/snippets/op/brgemm.hpp | 2 + .../snippets/include/snippets/op/buffer.hpp | 6 +- .../snippets/include/snippets/op/subgraph.hpp | 1 - .../snippets/pass/mha_tokenization.hpp | 26 +- .../include/snippets/pass/tokenization.hpp | 29 +- .../src/lowered/pass/allocate_buffers.cpp | 4 +- .../src/lowered/pass/assign_registers.cpp | 6 +- src/common/snippets/src/op/brgemm.cpp | 27 +- src/common/snippets/src/op/buffer.cpp | 20 +- src/common/snippets/src/op/subgraph.cpp | 38 +- .../snippets/src/pass/collapse_subgraph.cpp | 27 +- .../snippets/src/pass/mha_tokenization.cpp | 237 +++++++---- src/common/snippets/src/pass/tokenization.cpp | 22 +- .../snippets/tests/src/lowering_utils.cpp | 1 + .../tests/src/pass/collapse_subgraph.cpp | 1 + .../tests/src/pass/mha_tokenization.cpp | 14 +- .../emitters/x64/jit_snippets_emitters.cpp | 17 +- .../emitters/x64/jit_snippets_emitters.hpp | 4 +- .../x64/pass/snippets_mark_skipped.cpp | 10 +- .../transformation_pipeline.cpp | 64 ++- .../skip_tests_config.cpp | 4 + .../snippets/matmul.cpp | 27 +- .../shared_tests_instances/snippets/mha.cpp | 192 +++++++-- .../functional/subgraph_tests/src/mha.cpp | 2 +- .../plugin/shared/include/snippets/mha.hpp | 20 +- .../plugin/shared/src/snippets/mha.cpp | 52 ++- .../include/subgraph_matmul.hpp | 15 +- .../include/subgraph_mha.hpp | 136 ++++++- .../src/subgraph_matmul.cpp | 2 +- .../src/subgraph_mha.cpp | 376 ++++++++++++++++-- 30 files changed, 1105 insertions(+), 277 deletions(-) diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp index d83f18c69c98eb..a037f6669df083 100644 --- a/src/common/snippets/include/snippets/op/brgemm.hpp +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -28,6 +28,8 @@ class Brgemm : public MemoryAccess { size_t get_offset_b() const { return get_input_offset(1); } size_t get_offset_c() const { return get_output_offset(0); } + static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1); + void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; diff --git a/src/common/snippets/include/snippets/op/buffer.hpp b/src/common/snippets/include/snippets/op/buffer.hpp index 6b8ec7b5fd31b7..7a644644dd7417 100644 --- a/src/common/snippets/include/snippets/op/buffer.hpp +++ b/src/common/snippets/include/snippets/op/buffer.hpp @@ -29,7 +29,7 @@ class Buffer : public ov::op::Op { public: OPENVINO_OP("Buffer", "SnippetsOpset"); Buffer() = default; - Buffer(const ov::Shape& shape, size_t id = 0); + Buffer(const ov::Shape& shape, ov::element::Type element_type = ov::element::u8, size_t id = 0); Buffer(const ov::Output& arg, const ov::Shape& shape, size_t id = 0); Buffer(const ov::Output& arg, int32_t allocation_rank = -1, size_t id = 0); @@ -48,9 +48,10 @@ class Buffer : public ov::op::Op { int64_t get_offset() const { return m_offset; } void set_id(size_t id) { m_id = id; } void set_offset(int64_t offset) { m_offset = offset; } - size_t get_byte_size() const; + void set_element_type(ov::element::Type element_type); + bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; } bool is_new_memory() const { return m_type == Type::NewMemory; } @@ -59,6 +60,7 @@ class Buffer : public ov::op::Op { ov::Shape m_shape = {}; int64_t m_offset = 0; size_t m_id = 0; // Default ID - 0. All Buffers are from the same set + ov::element::Type m_element_type = ov::element::u8; // u8 - default 1 byte }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/subgraph.hpp b/src/common/snippets/include/snippets/op/subgraph.hpp index 9d63bcba1367a6..615facf773f149 100644 --- a/src/common/snippets/include/snippets/op/subgraph.hpp +++ b/src/common/snippets/include/snippets/op/subgraph.hpp @@ -136,7 +136,6 @@ class Subgraph : public ov::op::util::SubGraphOp { // should have explicit Constants even if they're non-scalar (Reshape, Transpose, Broadcast) // This check returns True if Constant op which is input of this op should be inside Subgraph body static auto constant_input_should_be_inside_body(const std::shared_ptr& node) -> bool; - static bool check_broadcast(const std::shared_ptr& node) noexcept; // Return estimated unique buffer count (upper bound). It's needed for tokenization static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t; diff --git a/src/common/snippets/include/snippets/pass/mha_tokenization.hpp b/src/common/snippets/include/snippets/pass/mha_tokenization.hpp index c1a1700b1da7eb..eaf57881d316a5 100644 --- a/src/common/snippets/include/snippets/pass/mha_tokenization.hpp +++ b/src/common/snippets/include/snippets/pass/mha_tokenization.hpp @@ -6,6 +6,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pattern/matcher.hpp" +#include "snippets/pass/tokenization.hpp" namespace ov { namespace snippets { @@ -14,13 +15,34 @@ namespace pass { /** * @interface TokenizeMHASnippets * @brief The pass tokenizes MHA-pattern into Subgraph - * TODO: Write pattern + * Pattern: Transpose1 + * | + * Transpose0 [Eltwise, Select] + * \ / + * MatMul0 + * | + * [Eltwise, Select, Reshape] + * | + * Softmax + * | + * [Eltwise, Select, Reshape] Transpose2 + * \ / + * MatMul1 + * | + * [Eltwise, Select, Transpose3] + * Notes: + * - Transposes can be missed + * - Transpose0, Transpose2 and Transpose3 may have only [0,2,1,3] order + * - Transpose1 may have only [0,2,3,1] order + * - [...] means any count of different nodes from list. But: + * * Reshapes can be only explicitly around Softmax (Reshape -> Softmax -> Reshape) + * * After MatMul1 may be only Transpose3 or any count of Eltwise, Select ops. * @ingroup snippets */ class TokenizeMHASnippets: public ov::pass::MatcherPass { public: OPENVINO_RTTI("TokenizeMHASnippets", "0"); - TokenizeMHASnippets(); + TokenizeMHASnippets(const SnippetsTokenization::Config& config = {}); }; } // namespace pass diff --git a/src/common/snippets/include/snippets/pass/tokenization.hpp b/src/common/snippets/include/snippets/pass/tokenization.hpp index a1f4bb4f2f8d6e..151e49bb00d0a5 100644 --- a/src/common/snippets/include/snippets/pass/tokenization.hpp +++ b/src/common/snippets/include/snippets/pass/tokenization.hpp @@ -7,8 +7,7 @@ #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pattern/matcher.hpp" -#include "snippets/pass/mha_tokenization.hpp" -#include "snippets/pass/collapse_subgraph.hpp" +#include "snippets/op/subgraph.hpp" namespace ov { namespace snippets { @@ -19,8 +18,16 @@ namespace pass { SkippedByPlugin - indicate that snippets can't include this node in subgraph. Can be set by Plugin via SetSnippetsNodeType(...). */ enum class SnippetsNodeType : int64_t {NotSet, SkippedByPlugin}; +/* + NotSet - default value returned if the subgraph wasn't marked and snippets can include nodes in this subgraph + Completed - indicate that snippets can't include any nodes in this subgraph. + It's used in separate tokenization pass, for example, tokenization by matcher (MHA Tokenization). + */ +enum class SnippetsSubgraphType : int64_t {NotSet, Completed}; void SetSnippetsNodeType(const std::shared_ptr&, SnippetsNodeType); +void SetSnippetsSubgraphType(const std::shared_ptr&, SnippetsSubgraphType); SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr&); +SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr&); void SetTopologicalOrder(const std::shared_ptr&, int64_t); int64_t GetTopologicalOrder(const std::shared_ptr&); @@ -48,8 +55,26 @@ class EnumerateNodes : public ov::pass::ModelPass { */ class SnippetsTokenization : public ov::pass::ModelPass { public: + /** + * @interface Config + * @brief Allow to adjust tokenization passes + * @ingroup snippets + */ + struct Config { + Config(bool enable_transpose = true) : mha_token_enable_transpose(enable_transpose) {} + + // False if all Transposes aren't tokenized in MHA Tokenization. + // Otherwise, they may be fused into Subgraph if possible + // TODO [106921]: Remove please when the ticket 106921 is implemented + bool mha_token_enable_transpose = true; + }; + OPENVINO_RTTI("SnippetsTokenization", "0"); + SnippetsTokenization(const Config& config) : m_config(config) {} bool run_on_model(const std::shared_ptr& m) override; + +private: + Config m_config{}; }; diff --git a/src/common/snippets/src/lowered/pass/allocate_buffers.cpp b/src/common/snippets/src/lowered/pass/allocate_buffers.cpp index 12608d6013f768..6be51814112889 100644 --- a/src/common/snippets/src/lowered/pass/allocate_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/allocate_buffers.cpp @@ -124,8 +124,8 @@ bool AllocateBuffers::run(LinearIR& linear_ir) { const auto current_allocated_memory_size = m_buffer_scratchpad_size - offset; if (buffer_size > current_allocated_memory_size) { - m_buffer_scratchpad_size += (buffer_size - current_allocated_memory_size); - // Note: we don't update offset because we just add memory to needed size + allocate(buffer, expr, buffer_size); + continue; } propagate_offset(linear_ir, *expr_it, offset); allocated_buffers.insert(expr); diff --git a/src/common/snippets/src/lowered/pass/assign_registers.cpp b/src/common/snippets/src/lowered/pass/assign_registers.cpp index 293d80437ce1d1..bdfec24b7e2b8c 100644 --- a/src/common/snippets/src/lowered/pass/assign_registers.cpp +++ b/src/common/snippets/src/lowered/pass/assign_registers.cpp @@ -100,9 +100,9 @@ bool AssignRegisters::run(LinearIR& linear_ir) { // Otherwise WIN build fails with "IS_MANUALLY_ALLOCATED_REG cannot be implicitly captured because no default capture mode has been specified" // the same problem with all the other lambdas in this file auto enumerate_out_tensors = [=] (const ExpressionPtr& expr, - decltype(regs_vec)& reg_map, - const std::map& manually_assigned_regs, - size_t& counter) { + decltype(regs_vec)& reg_map, + const std::map& manually_assigned_regs, + size_t& counter) { for (const auto& out_tensor : expr->get_output_port_connectors()) { // Note that some ops might have identical input&output tensors (Result and Tile* for ex.) // so we have to check that the tensor has not been enumerated already diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index e02e0699a80b53..88952ef7fcfe6f 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -62,22 +62,29 @@ std::shared_ptr Brgemm::clone_with_new_inputs(const OutputVector& new_args lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout()); } -ov::element::Type Brgemm::get_output_type() const { - const auto element_type_a = get_input_element_type(0); - const auto element_type_b = get_input_element_type(1); - const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b); - const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8; - const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b); +ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1) { + const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1); + const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8; + const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1); if (is_f32 || is_bf16) { - return element::f32; + return element::f32; } else if (is_int8) { return element::i32; } else { + return element::undefined; + } +} + +ov::element::Type Brgemm::get_output_type() const { + auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1)); + if (output_type == element::undefined) { OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " + - element_type_a.get_type_name() + - " and " + - element_type_b.get_type_name()); + get_input_element_type(0).get_type_name() + + " and " + + get_input_element_type(1).get_type_name()); } + + return output_type; } std::vector Brgemm::get_planar_input_shapes(const std::vector>& inputs) const { diff --git a/src/common/snippets/src/op/buffer.cpp b/src/common/snippets/src/op/buffer.cpp index c1cecddd86228d..8b703fa0c29a16 100644 --- a/src/common/snippets/src/op/buffer.cpp +++ b/src/common/snippets/src/op/buffer.cpp @@ -14,8 +14,8 @@ namespace snippets { namespace op { -Buffer::Buffer(const ov::Shape& shape, size_t id) - : Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id) { +Buffer::Buffer(const ov::Shape& shape, ov::element::Type element_type, size_t id) + : Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id), m_element_type(std::move(element_type)) { constructor_validate_and_infer_types(); } @@ -40,26 +40,25 @@ bool Buffer::visit_attributes(AttributeVisitor& visitor) { visitor.on_attribute("allocation_shape", m_shape); visitor.on_attribute("offset", m_offset); visitor.on_attribute("id", m_id); + visitor.on_attribute("element_type", m_element_type); return true; } void Buffer::validate_and_infer_types() { INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types); - ov::element::Type output_type; ov::Shape output_shape; if (m_type == Type::NewMemory) { OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory must to not have arguments!"); output_shape = m_shape; - output_type = ov::element::u8; // 1Byte } else if (m_type == Type::IntermediateMemory) { const auto& input_shape = get_input_partial_shape(0); OPENVINO_ASSERT(input_shape.is_static(), "Buffer supports only static input shape"); - output_type = get_input_element_type(0); + m_element_type = get_input_element_type(0); output_shape = input_shape.get_shape(); } else { OPENVINO_THROW("Buffer supports only the following types: NewMemory and IntermediateMemory"); } - set_output_type(0, output_type, output_shape); + set_output_type(0, m_element_type, output_shape); } std::shared_ptr Buffer::clone_with_new_inputs(const OutputVector& new_args) const { @@ -67,7 +66,7 @@ std::shared_ptr Buffer::clone_with_new_inputs(const OutputVector& new_args check_new_args_count(this, new_args); std::shared_ptr new_buffer = nullptr; if (m_type == Type::NewMemory) { - new_buffer = std::make_shared(m_shape, m_id); + new_buffer = std::make_shared(m_shape, m_element_type, m_id); } else if (m_type == Type::IntermediateMemory) { new_buffer = std::make_shared(new_args.at(0), m_shape, m_id); } else { @@ -82,6 +81,13 @@ size_t Buffer::get_byte_size() const { return ov::shape_size(shape) * get_element_type().size(); } +void Buffer::set_element_type(ov::element::Type element_type) { + OPENVINO_ASSERT(is_new_memory(), "Only Buffer with NewMemory can change his output precision!"); + m_element_type = std::move(element_type); + // Apply the change + validate_and_infer_types(); +} + } // namespace op } // namespace snippets } // namespace ov diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 1c221601ae0ebe..bfea82eac08d5e 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -14,7 +14,6 @@ #include "snippets/pass/convert_constants.hpp" #include "snippets/pass/convert_power_to_powerstatic.hpp" #include "snippets/pass/transpose_decomposition.hpp" -#include "snippets/pass/transform_convert.hpp" #include "snippets/pass/matmul_to_brgemm.hpp" #include "snippets/pass/fuse_transpose_brgemm.hpp" #include "snippets/pass/set_softmax_ports.hpp" @@ -75,12 +74,11 @@ auto snippets::op::Subgraph::is_domain_sensitive_op(const std::shared_ptrget_ops(); for (const auto& op : ops) { - config.m_is_quantized = config.m_is_quantized || - ov::is_type(op); - config.m_has_domain_sensitive_ops = config.m_has_domain_sensitive_ops || - is_domain_sensitive_op(op); + update(config.m_is_quantized, ov::is_type(op)); + update(config.m_has_domain_sensitive_ops, is_domain_sensitive_op(op)); } } @@ -93,6 +91,13 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op // and where will be Loops - we can just predict. // Note: The ops that create Buffers: MatMul, Transpose and Softmax (always FP32) std::vector used_precision_size; + + auto push_prc_size = [&used_precision_size](size_t precision_size) { + if (used_precision_size.empty() || used_precision_size.back() != precision_size) { + used_precision_size.push_back(precision_size); + } + }; + for (const auto& op : ops) { if (const auto transpose = ov::as_type_ptr(op)) { // At the moment Transposes are supported only on Results and Parameters but @@ -106,34 +111,23 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op }) || !ov::is_type(transpose->get_input_node_shared_ptr(0)); if (are_prev_or_next_ops) { - const auto prc_size = transpose->get_element_type().size(); - if (used_precision_size.empty() || used_precision_size.back() != prc_size) { - used_precision_size.push_back(prc_size); - } + push_prc_size(transpose->get_element_type().size()); } } else if (ov::is_type(op) || ov::is_type(op)) { - // Softmax always uses 2 FP32 Buffers - const auto prc_size = ov::element::f32.size(); - if (used_precision_size.empty() || used_precision_size.back() != prc_size) { - used_precision_size.push_back(prc_size); - } + // Softmax always uses 2 FP32 Buffers after decomposition. + // They are inplace and the same so we can push precision size only once + push_prc_size(ov::element::f32.size()); } else if (const auto matmul = ov::as_type_ptr(op)) { // First input check is enough because MatMul requires the same prc size on inputs if (!ov::is_type(matmul->get_input_node_shared_ptr(0)) || !ov::is_type(matmul->get_input_node_shared_ptr(1))) { - const auto prc_size = matmul->get_input_element_type(0).size(); - if (used_precision_size.empty() || used_precision_size.back() != prc_size) { - used_precision_size.push_back(prc_size); - } + push_prc_size(matmul->get_input_element_type(0).size()); } const auto consumers = matmul->get_output_target_inputs(0); if (std::none_of(consumers.begin(), consumers.end(), [](const ov::Input& in) { return ov::is_type(in.get_node()); })) { - const auto prc_size = matmul->get_element_type().size(); - if (used_precision_size.empty() || used_precision_size.back() != prc_size) { - used_precision_size.push_back(prc_size); - } + push_prc_size(matmul->get_element_type().size()); } } } diff --git a/src/common/snippets/src/pass/collapse_subgraph.cpp b/src/common/snippets/src/pass/collapse_subgraph.cpp index 27bc8cd02d06e3..acb5ccac513d39 100644 --- a/src/common/snippets/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/src/pass/collapse_subgraph.cpp @@ -63,11 +63,21 @@ auto is_supported_op(const std::shared_ptr &n) -> bool { const auto& transpose = as_type_ptr(n); const auto& out_shape = n->get_output_partial_shape(0); if (transpose && out_shape.is_static()) { + const auto parent = transpose->get_input_node_shared_ptr(0); + const auto child = transpose->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); + auto is_brgemm_case = ov::is_type(parent) || ov::is_type(child); + // Check for Transpose parent is MatMul inside Subgraph + if (const auto subgraph = ov::as_type_ptr(parent)) { + const auto body = subgraph->body_ptr(); + const auto subgraph_output = body->get_results()[transpose->input_value(0).get_index()]->get_input_node_shared_ptr(0); + is_brgemm_case = is_brgemm_case || ov::is_type(subgraph_output); + } + const auto& order = as_type_ptr(n->get_input_node_shared_ptr(1)); if (order) { const auto order_value = order->cast_vector(); - return TransposeDecomposition::supported_cases.count(order_value) != 0 || - FuseTransposeBrgemm::supported_cases.count(order_value) != 0; + return (TransposeDecomposition::supported_cases.count(order_value) != 0) || + (is_brgemm_case && FuseTransposeBrgemm::supported_cases.count(order_value) != 0); } } return false; @@ -337,7 +347,7 @@ TokenizeSnippets::TokenizeSnippets() { for (const auto& input_node : ov::as_node_vector(input_values)) { if (auto subgraph = ov::as_type_ptr(input_node)) { - if (!clones.count(input_node)) { + if (!clones.count(input_node) && GetSnippetsSubgraphType(subgraph) != SnippetsSubgraphType::Completed) { auto f = subgraph->body().clone(); f->set_friendly_name(subgraph->body_ptr()->get_friendly_name()); clones[input_node] = f; @@ -524,15 +534,18 @@ TokenizeSnippets::TokenizeSnippets() { ResultVector body_results; std::vector>> subgraph_result_inputs; - ov::NodeVector new_body_ops; + ov::NodeVector ops_for_buffer_count; for (auto subgraph : input_subgraphs) { // we should summurize additional needed data count (non-scalar Constants and Buffers) from all input subgraphs // because we will collapse them with our node and we should get total count const auto subgraph_ptr = ov::as_type_ptr(subgraph); hidden_data_count += subgraph_ptr->get_virtual_port_count(); + // Buffers can be existed only in Subgraphs with domain sensetive ops which + // requires intermediate memory for data repacking + // To avoid load time regressions, we verify only these Subgraph with domain sensetive ops if (subgraph_ptr->has_domain_sensitive_ops()) { const auto ops = subgraph_ptr->body_ptr()->get_ordered_ops(); - new_body_ops.insert(new_body_ops.end(), ops.begin(), ops.end()); + ops_for_buffer_count.insert(ops_for_buffer_count.end(), ops.begin(), ops.end()); } for (auto output : subgraph->outputs()) { @@ -566,7 +579,7 @@ TokenizeSnippets::TokenizeSnippets() { } if (op::Subgraph::is_domain_sensitive_op(node)) { - new_body_ops.push_back(node); + ops_for_buffer_count.push_back(node); } for (auto output : node->outputs()) { @@ -582,7 +595,7 @@ TokenizeSnippets::TokenizeSnippets() { // At the moment, CPU Plugin has limitation for GPR registers: there are only 12 available registers. // This limitation will be resolved once generator supports gprs spills [75622]. // TODO [75567]: move this plugin-specific constraint to the plugin callback - const auto unique_buffer_count = op::Subgraph::get_estimated_buffer_count(new_body_ops); + const auto unique_buffer_count = op::Subgraph::get_estimated_buffer_count(ops_for_buffer_count); if (body_parameters.size() + body_results.size() + hidden_data_count + unique_buffer_count > 12) { const std::string message_reset = "new subgraph is created. Impossible to schedule subgraph with " + std::to_string(body_parameters.size()) + " inputs, " + std::to_string(body_results.size()) + " outputs and " + diff --git a/src/common/snippets/src/pass/mha_tokenization.cpp b/src/common/snippets/src/pass/mha_tokenization.cpp index fbca7b47b62d31..864341dc417e53 100644 --- a/src/common/snippets/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/src/pass/mha_tokenization.cpp @@ -6,9 +6,10 @@ #include "snippets/itt.hpp" -#include "snippets/utils.hpp" -#include "snippets/pass/tokenization.hpp" +#include "snippets/pass/collapse_subgraph.hpp" #include "snippets/op/subgraph.hpp" +#include "snippets/op/brgemm.hpp" +#include "snippets/utils.hpp" #include "openvino/core/rt_info.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" @@ -17,17 +18,14 @@ namespace { auto is_supported_tensor(const ov::descriptor::Tensor& t) -> bool { - // TODO: Add support of all supported by common tokenization element types - // return ov::snippets::pass::TokenizeSnippets::supported_element_types.count(input.get_element_type()) != 0; - return t.get_element_type() == ngraph::element::f32 && - t.get_partial_shape().is_static() && ov::snippets::utils::one_of(t.get_shape().size(), 3lu, 4lu); + return t.get_partial_shape().is_static() && ov::snippets::utils::one_of(t.get_shape().size(), 3lu, 4lu); } -// TODO: Add support of FQ, Reshape? auto is_supported_intermediate_op(const std::shared_ptr& node) -> bool { const auto is_intermediate_op = [](const std::shared_ptr& node) { return ov::is_type(node) || ov::is_type(node) || + ov::is_type(node) || ov::is_type(node); }; return is_intermediate_op(node) && ov::snippets::pass::TokenizeSnippets::AppropriateForSubgraph(node); @@ -40,9 +38,12 @@ auto is_valid_transpose(const std::shared_ptr& node, std: return false; return transpose_pattern->cast_vector() == expected_order; }; + auto is_supported_transpose_tensor = [](const ov::descriptor::Tensor& t) { + return is_supported_tensor(t) && ov::snippets::pass::TokenizeSnippets::supported_element_types.count(t.get_element_type()) != 0; + }; return node && node->get_output_target_inputs(0).size() == 1 && node->get_shape().size() == 4 && - valid_transpose_order(node->get_input_node_shared_ptr(1)) && is_supported_tensor(node->get_input_tensor(0)); + valid_transpose_order(node->get_input_node_shared_ptr(1)) && is_supported_transpose_tensor(node->get_input_tensor(0)); } auto tokenize_broadcast(const std::shared_ptr& interm_op, ov::NodeVector& ordered_ops) -> void { @@ -98,14 +99,15 @@ auto tokenize_reshape_around_softmax(std::shared_ptr& interm_op, ov::NodeVector& ordered_ops) -> bool { reshape = ov::as_type_ptr(interm_op); if (reshape) { - const auto shape = reshape->get_input_shape(0); - if (shape.back() != reshape->get_output_shape(0).back() || reshape->get_output_target_inputs(0).size() != 1) + const auto in_shape = reshape->get_input_shape(0); + const auto out_shape = reshape->get_output_shape(0); + if (in_shape.back() != out_shape.back() || reshape->get_output_target_inputs(0).size() != 1) return false; ordered_ops.push_back(reshape); interm_op = reshape->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); } return true; -}; +} auto get_potential_body_params(const std::shared_ptr& op) -> size_t { size_t count = 0; @@ -124,43 +126,50 @@ auto get_potential_body_params(const std::shared_ptr& op) -> size_t { auto update_intermediate_supported_ops(std::shared_ptr& interm_op, ov::NodeVector& ordered_ops, size_t& hidden_virtual_ports_count, size_t& potential_body_params_count) -> bool { - // TODO: Add Reshape, FQ support while (is_supported_intermediate_op(interm_op)) { // All supported intermediate ops have only one output port - // To verify output element type is enough because all supported intermediate ops have the same output element type as input type - if (interm_op->get_output_target_inputs(0).size() != 1 || !is_supported_tensor(interm_op->get_output_tensor(0))) + if (interm_op->get_output_target_inputs(0).size() != 1) return false; - // Check for supported Broadcast op + // Check for supported ops on branches: Broadcast/Elementwise (for example, dequantize ops) if (interm_op->get_input_size() > 1) { tokenize_broadcast(interm_op, ordered_ops); - } - - auto is_supported_branch_op = [&ordered_ops](const std::shared_ptr& op) { - return is_supported_intermediate_op(op) && - ov::snippets::pass::GetSnippetsNodeType(op) != ov::snippets::pass::SnippetsNodeType::SkippedByPlugin && - std::find(ordered_ops.begin(), ordered_ops.end(), op) == ordered_ops.end(); - }; - for (size_t i = 0; i < interm_op->get_input_size(); ++i) { - const size_t shift = ordered_ops.size(); - auto parent = interm_op->get_input_node_shared_ptr(i); - while (is_supported_branch_op(parent)) { - // All supported ops have only one output port - if (parent->get_output_target_inputs(0).size() != 1) - break; + // To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation) + // we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body. + if (const auto fq_node = ov::as_type_ptr(interm_op)) { + hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node); + } - // Add node only if there are scalar constants on inputs because of plugin-specific limitation - bool are_weights_scalar = true; - const auto parent_count = parent->get_input_size(); - for (size_t i = 1; i < parent_count; ++i) { - are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1; + auto is_supported_branch_op = [&ordered_ops](const std::shared_ptr& op) { + return is_supported_intermediate_op(op) && + ov::snippets::pass::GetSnippetsNodeType(op) != ov::snippets::pass::SnippetsNodeType::SkippedByPlugin && + std::find(ordered_ops.begin(), ordered_ops.end(), op) == ordered_ops.end(); + }; + + for (size_t i = 0; i < interm_op->get_input_size(); ++i) { + const size_t shift = ordered_ops.size(); + auto parent = interm_op->get_input_node_shared_ptr(i); + while (is_supported_branch_op(parent)) { + // All supported ops have only one output port + if (parent->get_output_target_inputs(0).size() != 1) + break; + + // Add node only if there are scalar constants on inputs because of plugin-specific limitation + bool are_weights_scalar = true; + const auto parent_count = parent->get_input_size(); + for (size_t i = 1; i < parent_count; ++i) { + are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1; + } + if (!are_weights_scalar) + break; + + ordered_ops.insert(ordered_ops.begin() + shift, parent); + // TODO [107731]: We think that sequence of ops goes through input port 0 + // But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way? + if (parent->get_input_size() > 0) + parent = parent->get_input_node_shared_ptr(0); } - - ordered_ops.insert(ordered_ops.begin() + shift, parent); - // We think that sequence of ops goes through input port 0 - // But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way? - parent = parent->get_input_node_shared_ptr(0); } } @@ -173,7 +182,7 @@ auto update_intermediate_supported_ops(std::shared_ptr& interm_op, ov: }; } // namespace -ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { +ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsTokenization::Config& config) { MATCHER_SCOPE(TokenizeMHASnippets); auto m_matmul0 = std::make_shared(ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()), @@ -184,14 +193,13 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::TokenizeMHASnippets") auto& pattern_to_output = m.get_pattern_value_map(); + // Queries + Key + Values = 3 standard inputs of MHA + size_t potential_body_params_count = 3; // After some transformations, a different number of Constants for some operations may be created // than the actual number of Constants during tokenization. // To avoid unsupported number of non-scalar Constants in the future (plugin specific limitation) // we should calculate potential number of non-scalar Constants that will be moved up from body. - // TODO: Need update this variable when FQ will be supported size_t hidden_virtual_ports_count = 0; - // Queries + Key + Values = 3 standard inputs of MHA - size_t potential_body_params_count = 3; // The count of potential unique Buffers - it's hidden virtual ports as well // We should go through Subgraph and calculate potential non-inplace Buffers count. // Example: @@ -231,10 +239,20 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { !is_supported_tensor(matmul0->get_input_tensor(0)) || !is_supported_tensor(matmul0->get_input_tensor(1))) return false; - if (transformation_callback(matmul0)) { + const auto matmul0_prc = op::Brgemm::get_output_type(matmul0->get_input_element_type(0), + matmul0->get_input_element_type(1)); + if (matmul0_prc == element::undefined) { return false; } + // Between MatMul0 and Softmax will be the one Loop because of LoopFusing optimization. + // The Loop will have one Buffer with the same shape both on input and output. + // Need to check for precision to get if we need one more register for Buffer + if (matmul0_prc.size() != ov::element::f32.size()) { + if (buffer_count < 2) + buffer_count++; + } + ordered_ops.push_back(matmul0); auto interm_op = matmul0->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); @@ -276,10 +294,28 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { return false; const auto matmul1 = ov::as_type_ptr(interm_op); - if (!matmul1 || matmul1->get_output_target_inputs(0).size() != 1 || matmul1->get_transpose_a() || matmul1->get_transpose_b() || - !is_supported_tensor(matmul1->get_input_tensor(0)) || !is_supported_tensor(matmul1->get_input_tensor(1))) + if (!matmul1 || matmul1->get_output_target_inputs(0).size() != 1 || + matmul1->get_transpose_a() || matmul1->get_transpose_b()) + return false; + + const auto matmul1_out_type = op::Brgemm::get_output_type(matmul1->get_input_element_type(0), + matmul1->get_input_element_type(1)); + if (matmul1_out_type == element::undefined || + !is_supported_tensor(matmul1->get_input_tensor(0)) || + !is_supported_tensor(matmul1->get_input_tensor(1))) return false; + if (transformation_callback(matmul0)) { + return false; + } + + // Between Softmax and MatMul1 will be the one Loop because of LoopFusing optimization. + // The Loop will have one Buffer with the same shape both on input and output. + // Need to check for precision to get if we need one more register for Buffer + if (matmul1->get_input_element_type(0).size() != ov::element::f32.size()) { + buffer_count++; + } + /***********************/ /***** Transposes *****/ @@ -287,29 +323,51 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { * We can add them into Subgraph body */ + auto tokenize_transpose = [config](const std::shared_ptr& node) -> std::shared_ptr { + return config.mha_token_enable_transpose ? ov::as_type_ptr(node) + : nullptr; + }; + // First input branch of MatMul0 should be executed before second input branch of MatMul0, // so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose1 bool are_weights_scalar = true; + // We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order (or without this Transpose1) + // only if these ops have scalar shapes on other inputs. + // There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false). + // We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching + const auto is_transposed_b_0 = matmul0->get_transpose_b(); auto parent = matmul0->get_input_node_shared_ptr(1); while (is_supported_intermediate_op(parent)) { // All supported ops have only one output port - // To verify output element type is enough because all supported ops have the same output element type as input type - if (parent->get_output_target_inputs(0).size() != 1 || !is_supported_tensor(parent->get_output_tensor(0))) + if (parent->get_output_target_inputs(0).size() != 1) break; - const auto parent_count = parent->inputs().size(); - for (size_t i = 1; i < parent_count; ++i) { - are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1; + // Only if MatMul0 has transposed_b, we have to tokenize scalar ops + // to move explicit Transpose from MatMul0 input_1 to Parameter of Subgraph body + if (is_transposed_b_0) { + const auto parent_count = parent->get_input_size(); + bool are_weights_scalar = true; + for (size_t i = 1; i < parent_count; ++i) { + are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1; + } + if (!are_weights_scalar) { + break; + } + } + + // To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation) + // we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body. + if (const auto fq_node = ov::as_type_ptr(parent)) { + hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node); } potential_body_params_count += get_potential_body_params(parent); ordered_ops.insert(ordered_ops.begin(), parent); - // We think that sequence of ops goes through input port 0 - // But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way? + // TODO [107731] To go always through 0-th port - is it safe? parent = parent->get_input_node_shared_ptr(0); } - auto transpose1 = ov::as_type_ptr(parent); - if (matmul0->get_transpose_b()) { + const auto transpose1 = tokenize_transpose(parent); + if (is_transposed_b_0) { if (is_valid_transpose(transpose1, {0, 2, 1, 3})) { // We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order // only if these ops have scalar shapes on other inputs. @@ -329,31 +387,63 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { } } - // TODO: Add Reshape Support for all Transposes - // Add 3D support for all Transposes - const auto transpose0 = ov::as_type_ptr(matmul0->get_input_node_shared_ptr(0)); + if (transpose1) { + // Between Transpose1 and MatMul0 will be the one Loop because of LoopFusing optimization. + // The Loop will have one Buffer with the same shape both on input and output. + // Need to check for precision to get if we need one more register for Buffer + if (matmul0->get_input_element_type(1).size() != transpose1->get_output_element_type(0).size()) { + buffer_count++; + } + } + + const auto transpose0 = tokenize_transpose(matmul0->get_input_node_shared_ptr(0)); if (is_valid_transpose(transpose0, {0, 2, 1, 3})) { ordered_ops.insert(ordered_ops.begin(), transpose0); - } else if (matmul0->get_transpose_b()) { + } else if (matmul0->get_transpose_a()) { return false; } - const auto transpose2 = ov::as_type_ptr(matmul1->get_input_node_shared_ptr(1)); + const auto transpose2 = tokenize_transpose(matmul1->get_input_node_shared_ptr(1)); if (is_valid_transpose(transpose2, {0, 2, 1, 3})) { ordered_ops.push_back(transpose2); } ordered_ops.push_back(matmul1); + bool are_ops_after_matmul1 = false; auto child = matmul1->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); - // TODO: Add support Eltwises between MatMul1 and Transpose - // status = update_intermediate_supported_ops(child, ordered_ops); - // if (!status) { - // ordered_ops.push_back(child); - // } - - auto transpose3 = ov::as_type_ptr(child); - if (is_valid_transpose(transpose3, {0, 2, 1, 3})) { - ordered_ops.push_back(transpose3); + while (is_supported_intermediate_op(child)) { + are_ops_after_matmul1 = true; + // All supported ops have only one output port + if (child->get_output_target_inputs(0).size() != 1) + break; + + // To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation) + // we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body. + if (const auto fq_node = ov::as_type_ptr(child)) { + hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node); + } + potential_body_params_count += get_potential_body_params(child); + + // TODO [75567]: move this plugin-specific constraint to the plugin callback + // We cannot collapse op to Subgraph if count of potential Parameter and Result count is higher 12 + if (potential_body_params_count + child->get_output_target_inputs(0).size() + hidden_virtual_ports_count + buffer_count > 12) { + break; + } + + ordered_ops.push_back(child); + child = child->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); + } + + // At the moment Snippets don't support nodes between MatMul1 and Transpose3 due to Loop and strided calculations limitations + // MatMul1 + // + // Transpose3 + if (!are_ops_after_matmul1) { + auto transpose3 = tokenize_transpose(child); + if (is_valid_transpose(transpose3, {0, 2, 1, 3}) && + transpose3->get_input_element_type(0) == matmul1_out_type) { // To avoid Convert between MatMul1 and Transpose3 + ordered_ops.push_back(transpose3); + } } /**********************/ @@ -362,7 +452,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { /* ====== Subgraph creation ======= */ - // TODO: move this plugin-specific constraint to the plugin callback + // TODO [75567]: move this plugin-specific constraint to the plugin callback const auto last_node = ordered_ops.back(); if (potential_body_params_count + last_node->get_output_size() + hidden_virtual_ports_count + buffer_count > 12) { return false; @@ -378,7 +468,9 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { const auto input = node->input(i); const auto parent = input.get_source_output().get_node_shared_ptr(); const auto constant = ov::as_type_ptr(parent); - if (constant && (ov::shape_size(input.get_shape()) == 1 || op::Subgraph::constant_input_should_be_inside_body(node))) { + if (constant && (ov::shape_size(input.get_shape()) == 1 || + ov::is_type(node) || + op::Subgraph::constant_input_should_be_inside_body(node))) { // If Constant has one consumer - target node, we add Constant to body_inputs // If Constant has several consumers, we should check that all these consumers are inside Subgraph body // and if all of them are inside body, we can explicitly add Constant to the body_inputs, otherwise we should @@ -454,6 +546,9 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { subgraph->get_rt_info()["originalLayersNames"] = fused_names; subgraph->set_virtual_port_count(hidden_virtual_ports_count); + // mark the Subgraph as Completed to not allow Snippets to include any nodes into the MHA Subgraph in common Tokenization + SetSnippetsSubgraphType(subgraph, SnippetsSubgraphType::Completed); + return true; /* ================================ */ diff --git a/src/common/snippets/src/pass/tokenization.cpp b/src/common/snippets/src/pass/tokenization.cpp index bdf684ef3fba6c..13346efabef091 100644 --- a/src/common/snippets/src/pass/tokenization.cpp +++ b/src/common/snippets/src/pass/tokenization.cpp @@ -7,6 +7,8 @@ #include "snippets/pass/tokenization.hpp" #include "snippets/pass/common_optimizations.hpp" #include "openvino/pass/manager.hpp" +#include "snippets/pass/mha_tokenization.hpp" +#include "snippets/pass/collapse_subgraph.hpp" namespace ov { @@ -18,6 +20,13 @@ void SetSnippetsNodeType(const std::shared_ptr &node, SnippetsNodeType nod rt["SnippetsNodeType"] = nodeType; } +void SetSnippetsSubgraphType(const std::shared_ptr &node, SnippetsSubgraphType nodeType) { + if (node) { + auto &rt = node->get_rt_info(); + rt["SnippetsSubgraphType"] = nodeType; + } +} + SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr &node) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::GetSnippetsNodeType") auto& rt = node->get_rt_info(); @@ -27,6 +36,17 @@ SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr &node) { return rinfo->second.as(); } +SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr &node) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::GetSnippetsSubgraphType") + if (!node) + return SnippetsSubgraphType::NotSet; + auto &rt = node->get_rt_info(); + const auto rinfo = rt.find("SnippetsSubgraphType"); + if (rinfo == rt.end()) + return SnippetsSubgraphType::NotSet; + return rinfo->second.as(); +} + void SetTopologicalOrder(const std::shared_ptr &node, int64_t order) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SetTopologicalOrder") auto& rt = node->get_rt_info(); @@ -58,7 +78,7 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr& m) { manager.set_per_pass_validation(false); manager.register_pass(); - manager.register_pass(); + manager.register_pass(m_config); manager.register_pass(); manager.register_pass(); manager.run_passes(m); diff --git a/src/common/snippets/tests/src/lowering_utils.cpp b/src/common/snippets/tests/src/lowering_utils.cpp index ba3a4f91d43e33..2fc7868b90182f 100644 --- a/src/common/snippets/tests/src/lowering_utils.cpp +++ b/src/common/snippets/tests/src/lowering_utils.cpp @@ -5,6 +5,7 @@ #include #include "lowering_utils.hpp" #include "snippets/pass/tokenization.hpp" +#include "snippets/pass/collapse_subgraph.hpp" namespace ov { diff --git a/src/common/snippets/tests/src/pass/collapse_subgraph.cpp b/src/common/snippets/tests/src/pass/collapse_subgraph.cpp index 48ce19052827c8..356f76c9eeffee 100644 --- a/src/common/snippets/tests/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/tests/src/pass/collapse_subgraph.cpp @@ -8,6 +8,7 @@ #include #include #include "snippets/pass/tokenization.hpp" +#include "snippets/pass/collapse_subgraph.hpp" namespace ov { namespace test { diff --git a/src/common/snippets/tests/src/pass/mha_tokenization.cpp b/src/common/snippets/tests/src/pass/mha_tokenization.cpp index 19e1453c463825..68956a2a626105 100644 --- a/src/common/snippets/tests/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/tests/src/pass/mha_tokenization.cpp @@ -6,6 +6,7 @@ #include #include #include "snippets/pass/tokenization.hpp" +#include "snippets/pass/mha_tokenization.hpp" #include "snippets/pass/explicit_transpose_matmul_inputs.hpp" namespace ov { @@ -20,14 +21,23 @@ void TokenizeMHASnippetsTests::run() { } TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA) { - const auto& f = MHAFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}); + const auto &f = MHAFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, + std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32})); function = f.getOriginal(); function_ref = f.getReference(); run(); } TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_MatMul0_Transpose) { - const auto& f = MHAMatMul0TransposeFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}); + const auto &f = MHAMatMul0TransposeFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, + std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32})); + function = f.getOriginal(); + function_ref = f.getReference(); + run(); +} + +TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_int_Matmuls) { + const auto &f = MHAINT8MatMulTypeRelaxedFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}); function = f.getOriginal(); function_ref = f.getReference(); run(); diff --git a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp index 6c4717699b47f2..c339b72cfd17f5 100644 --- a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp @@ -172,7 +172,7 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: general_exprs.emplace_back(expr); } } - num_unique_buffer = unique_buffers.size(); + num_unique_buffers = unique_buffers.size(); // Note that we can't use reg_indexes_idx or reg_const_params_idx to store data pointers because these two // regs are used to calculate offsets for the data pointers @@ -198,15 +198,16 @@ void KernelEmitter::validate_arguments(const std::vector &in, IE_THROW() << "KernelEmitter got invalid number of inputs. Expected 0, got " << in.size(); if (!out.empty()) IE_THROW() << "KernelEmitter got invalid number of outputs. Expected 0, got " << out.size(); - const auto num_params = num_inputs + num_outputs + num_unique_buffer; + const auto num_params = num_inputs + num_outputs + num_unique_buffers; // The number of used gpr may be >= num_params since LoopBegin+LoopEnd could also use gpr to store work_amount if (data_ptr_regs_idx.size() != num_params) - IE_THROW() << "KernelEmitter: number of inputs and outputs is inconsisnent with the number of allocated registers" + IE_THROW() << "KernelEmitter: number of inputs and outputs is inconsistent with the number of allocated registers " << num_params << " data_ptr_regs_idx.size() = " << data_ptr_regs_idx.size(); } -void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, size_t num_buffer, - const Reg64& reg_indexes, const Reg64& reg_const_params, const std::vector& data_ptr_regs) const { +void KernelEmitter::init_data_pointers(const Xbyak::Reg64& reg_indexes, const Xbyak::Reg64& reg_const_params, + const std::vector& data_ptr_regs) const { + const auto num_params = num_inputs + num_outputs; // Note that we don't need offset for the last dim, since it's handled directly by Tile emitter const size_t offset_rank = jcp.master_shape.size() - 1; std::vector> data_offsets(num_params, std::vector{}); @@ -267,7 +268,9 @@ void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, siz // Vector "data_ptr_regs" is sorted by abstract regs. // It means that the vector contains the physical registers in order [src, .., src, dst, .., dst, buffer] // So we can initialize buffer register firstly as last value of vector "data_ptr_regs" - for (size_t i = 0; i < num_buffer; ++i) { + // NOTE: Snippets Buffer Scratchpad has the common data pointer for all Buffers (even with different ID). + // The accessing memory is covered by correct offsets in each Buffer and the corresponding MemoryAccess ops + for (size_t i = 0; i < num_unique_buffers; ++i) { h->mov(data_ptr_regs[num_params + i], h->ptr[reg_const_params + GET_OFF(buffer_scratchpad_ptr)]); } size_t i = 0; @@ -299,7 +302,7 @@ void KernelEmitter::emit_impl(const std::vector& in, std::vector data_ptr_regs; transform_idxs_to_regs(data_ptr_regs_idx, data_ptr_regs); - init_data_pointers(num_inputs, num_inputs + num_outputs, num_unique_buffer, reg_indexes, reg_const_params, data_ptr_regs); + init_data_pointers(reg_indexes, reg_const_params, data_ptr_regs); for (const auto& expression : body) { const auto& emitter = expression->get_emitter(); std::vector in_regs, out_regs; diff --git a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.hpp index f41340a9223fc7..cc4ba3a55f830d 100644 --- a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.hpp @@ -87,13 +87,13 @@ class KernelEmitter : public jit_container_emitter { const std::vector &out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; - void init_data_pointers(size_t, size_t, size_t, const Xbyak::Reg64&, const Xbyak::Reg64&, const std::vector&) const; + void init_data_pointers(const Xbyak::Reg64&, const Xbyak::Reg64&, const std::vector&) const; jit_snippets_compile_args jcp; std::vector gp_regs_pool; size_t num_inputs; size_t num_outputs; - size_t num_unique_buffer; + size_t num_unique_buffers; // Vector of indices (lenght = input tensor rank) per every input and output that describes in which order // corresponding tensor dimensions are accessed (default: consecutive dense, e.g. 0,1,2,3 for 4D tensor). // Needed to calc i/o offsets. diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp index 6eca0a514533d2..1ad4cb62b834bc 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/snippets_mark_skipped.cpp @@ -427,7 +427,7 @@ void MarkSubgraphOpAsSkipped(const std::shared_ptr &node) { bool isSuitableConvert(const std::shared_ptr& node) { if (!ov::is_type(node)) return false; - auto hasResult = [](const std::shared_ptr& node){ + auto hasResult = [](const std::shared_ptr& node) { auto consumers = node->output(0).get_target_inputs(); bool findResult = false; if (consumers.size() == 1) { @@ -449,13 +449,19 @@ bool isSuitableConvert(const std::shared_ptr& node) { return false; } } + +auto is_skipped_op(const std::shared_ptr& op) -> bool { + return ov::is_type(op) || + ov::is_type(op) || + ov::is_type(op); +} } // namespace bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { RUN_ON_MODEL_SCOPE(SnippetsMarkSkipped); int channelAxis = DEFAULT_AXIS; for (auto &node : m->get_ordered_ops()) { - if (ov::is_type(node) || ov::is_type(node)) + if (is_skipped_op(node)) continue; if (isSuitableConvolutionParent(node)) { // Initiate fusing chain diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 5b3ff344266108..49970a8fb646d9 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -108,6 +108,8 @@ // Snippets #include "snippets/pass/tokenization.hpp" +#include "snippets/pass/mha_tokenization.hpp" +#include "snippets/pass/collapse_subgraph.hpp" #include "snippets/pass/common_optimizations.hpp" // Misc @@ -616,22 +618,58 @@ void Transformations::MainSnippets(void) { !dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemented only for relevant platforms (avx2+ extensions) return; + // At the moment Snippets supports Transposes in MHA pattern only in FP32 case since + // - ConvertSaturation[BF16->FP32] will be inserted after Parameters and before Transposes in canonicalization stage + // - ConvertSaturation[FP32->BF16] will be inserted after Transposes and before Brgemm in precision propagation stage + // Because of that Transposes won't be fused into Brgemm + // TODO [111813]: Need to update this pipeline to avoid Converts between Transposes and Brgemm on inputs + ov::snippets::pass::SnippetsTokenization::Config tokenization_config; + tokenization_config.mha_token_enable_transpose = !enableBF16; + ngraph::pass::Manager snippetsManager; snippetsManager.set_per_pass_validation(false); if (snippetsMode != Config::SnippetsMode::IgnoreCallback) CPU_REGISTER_PASS_X64(snippetsManager, SnippetsMarkSkipped, enableBF16); - CPU_REGISTER_PASS_X64(snippetsManager, snippets::pass::SnippetsTokenization); + CPU_REGISTER_PASS_X64(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config); + // Tokenize MHA in quantized model or with BF16 only in tests. + // TODO [106921]: Please enable the tokenization when the ticket 106921 with blocking support for BRGEMM will be implemented + const bool onlyFloatSupported = snippetsMode != Config::SnippetsMode::IgnoreCallback; const bool isMHASupported = - !enableBF16 && // TODO: Need to add BF16 support for MHA in Snippets + IMPLICATION(enableBF16, !onlyFloatSupported) && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); // MHA has BRGEMM that is supported only on AVX512 platforms if (!isMHASupported) { CPU_DISABLE_PASS_X64(snippetsManager, snippets::pass::TokenizeMHASnippets); } + +#if defined(OPENVINO_ARCH_X86_64) + auto is_supported_matmul = [onlyFloatSupported](const std::shared_ptr& n) { + const auto matmul = ov::as_type_ptr(n); + if (!matmul) + return false; + if (matmul->get_input_element_type(1) == ov::element::i8) + return !onlyFloatSupported && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni); + if (matmul->get_input_element_type(0) == ov::element::bf16 && + matmul->get_input_element_type(1) == ov::element::bf16) + return !onlyFloatSupported && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16); + return true; + }; +#endif // OPENVINO_ARCH_X86_64 + if (snippetsMode != Config::SnippetsMode::IgnoreCallback) { CPU_SET_CALLBACK_X64(snippetsManager, - [](const std::shared_ptr& n) -> bool { - const auto pshape = n->get_output_partial_shape(0); + [&](const std::shared_ptr& n) -> bool { + // Tranformation callback is called on MatMul0 + if (!is_supported_matmul(n)) + return true; + // Search for MatMul1 + auto child = n->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); + while (!ov::is_type(child)) { + child = child->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); + } + if (!is_supported_matmul(child)) + return true; + const auto pshape = child->get_input_partial_shape(0); const auto shape = pshape.get_shape(); const auto parallel_work_amount = std::accumulate(shape.rbegin() + 2, shape.rend(), 1, std::multiplies()); @@ -662,18 +700,18 @@ void Transformations::MainSnippets(void) { // todo: general tokenization flow is not currently supported for these operations. // they can be tokenized only as a part of complex patterns const bool is_disabled_tokenization = (ov::is_type(n) || - ov::is_type(n) || - ov::is_type(n) || - ov::is_type(n) || - ov::is_type(n) || - ov::is_type(n)); + ov::is_type(n) || + ov::is_type(n) || + ov::is_type(n) || + ov::is_type(n) || + ov::is_type(n)); const auto& inputs = n->inputs(); // todo: clarify whether we can evaluate snippets on const paths const bool has_only_const_inputs = std::all_of(inputs.begin(), inputs.end(), - [](const ov::Input& in) { - return ov::is_type( - in.get_source_output().get_node_shared_ptr()); - }); + [](const ov::Input& in) { + return ov::is_type( + in.get_source_output().get_node_shared_ptr()); + }); // todo: clarify whether we can evaluate snippets on inputs with larger ranks auto rank_is_too_large = [](const ov::descriptor::Tensor& t) { // callback is called has_supported_in_out(), so it's safe to assume that the shapes are static diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 104d44dd79fc0e..363cb45861b13a 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -246,6 +246,9 @@ std::vector disabledTestPatterns() { if (!InferenceEngine::with_cpu_x86_avx512_core_vnni() && !InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { // MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions retVector.emplace_back(R"(.*Snippets.*MatMulFQ.*)"); + retVector.emplace_back(R"(.*Snippets.*MatMul.*Quantized.*)"); + retVector.emplace_back(R"(.*Snippets.*MHAFQ.*)"); + retVector.emplace_back(R"(.*Snippets.*MHAINT8.*)"); } if (!InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) //TODO: Issue 92895 @@ -254,6 +257,7 @@ std::vector disabledTestPatterns() { if (!InferenceEngine::with_cpu_x86_avx512_core_amx_bf16() && !InferenceEngine::with_cpu_x86_bfloat16()) { // ignored for not supported bf16 platforms retVector.emplace_back(R"(.*smoke_Snippets_EnforcePrecision_bf16.*)"); + retVector.emplace_back(R"(.*smoke_Snippets_MHAWOTransposeEnforceBF16.*)"); } return retVector; diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp index 59807c50c9df9b..ad94c9b6284790 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -19,16 +19,24 @@ std::vector> input_shapes{ {{1, 1, 37, 23}, {1, 2, 23, 33}}, {{1, 16, 384, 64}, {1, 16, 64, 384}} }; + +static inline std::vector> quantized_precisions() { + std::vector> prc = {}; + // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms + if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { + prc.emplace_back(std::vector{element::i8, element::i8}); + prc.emplace_back(std::vector{element::u8, element::i8}); + } + return prc; +} + static inline std::vector> precisions(bool only_fp32 = true) { std::vector> prc = { {element::f32, element::f32}, }; if (!only_fp32) { - // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms - if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { - prc.emplace_back(std::vector{element::i8, element::i8}); - prc.emplace_back(std::vector{element::u8, element::i8}); - } + auto quant = quantized_precisions(); + std::copy(quant.begin(), quant.end(), std::back_inserter(prc)); // In Snippets MatMul BF16 is supported only on bf16/AMX platforms if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) { prc.emplace_back(std::vector{element::bf16, element::bf16}); @@ -36,15 +44,6 @@ static inline std::vector> precisions(bool only_fp32 } return prc; } -static inline std::vector> quantized_precisions() { - std::vector> prc = {}; - // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms - if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { - prc.emplace_back(std::vector{element::i8, element::i8}); - prc.emplace_back(std::vector{element::u8, element::i8}); - } - return prc; -} INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul, ::testing::Combine( diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 96baa1f9fb6ed5..bd5d1473240ad8 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -4,7 +4,9 @@ #include "snippets/mha.hpp" #include "common_test_utils/test_constants.hpp" +#include "test_utils/cpu_test_utils.hpp" #include "ie_plugin_config.hpp" +#include "ie_system_conf.h" namespace ov { namespace test { @@ -15,22 +17,52 @@ namespace { const std::vector> inputShapes = { {{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, - {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}}, + {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, {{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}}, {{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}}, }; +static inline bool is_bf16_supported() { + return InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16(); +} + +static inline std::vector> precision_f32(size_t count) { + std::vector> prc; + prc.emplace_back(std::vector(count, element::f32)); + return prc; +} + +static inline std::vector> precision_bf16(size_t count) { + std::vector> prc; + if (is_bf16_supported()) + prc.emplace_back(std::vector(count, element::bf16)); + return prc; +} + INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHA, - ::testing::Combine( - ::testing::ValuesIn(inputShapes), - ::testing::ValuesIn({false, true}), - ::testing::Values(ov::element::f32), - ::testing::Values(1), - ::testing::Values(1), - ::testing::Values(CommonTestUtils::DEVICE_CPU), - ::testing::Values(std::map{})), - MHA::getTestCaseName); + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn({false, true}), + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHA, + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::ValuesIn(precision_bf16(4)), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn({false, true}), + ::testing::Values(7), // MHA + 5 Converts + 1 Transpose on output + ::testing::Values(6), // MHA + 5 Converts on inputs and output + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), + MHA::getTestCaseName); const std::vector> inputShapeSelect = { // without broadcast @@ -44,64 +76,142 @@ const std::vector> inputShapeSelect = { INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect, ::testing::Combine( ::testing::ValuesIn(inputShapeSelect), - ::testing::Values(false), // Need to support True for graph builder in tests + ::testing::ValuesIn(precision_f32(6)), ::testing::Values(ov::element::f32), + ::testing::Values(false), // Need to support True for graph builder in tests ::testing::Values(2), // Less + MHA ::testing::Values(2), ::testing::Values(CommonTestUtils::DEVICE_CPU), - ::testing::Values(std::map{})), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), MHA::getTestCaseName); +const std::vector> inputShapesWOTranspose_4D = { + {{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}}, + {{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}} +}; +const std::vector> inputShapesWOTranspose_3D = { + {{12, 197, 64}, {12, 64, 197}, {12, 197, 64}}, + {{12, 128, 100}, {12, 100, 128}, {12, 128, 100}} +}; -static std::vector> inputShapesWOTranspose(bool supports_3d = false) { - std::vector> shapes = { - {{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}}, - {{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}} - }; - if (supports_3d) { - std::vector> shapes_3d = { - {{12, 197, 64}, {12, 64, 197}, {12, 197, 64}}, - {{12, 128, 100}, {12, 100, 128}, {12, 128, 100}} - }; - shapes.insert(shapes.end(), shapes_3d.begin(), shapes_3d.end()); - } - return shapes; -} +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs_4D, MHAWOTransposeOnInputs, + ::testing::Combine( + ::testing::ValuesIn(inputShapesWOTranspose_4D), + ::testing::Values(std::vector{}), + ::testing::Values(ov::element::f32), + ::testing::Values(true), // Need to support False for graph builder in tests + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), + MHA::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOnInputs, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose_4D, MHAWOTranspose, ::testing::Combine( - ::testing::ValuesIn(inputShapesWOTranspose()), + ::testing::ValuesIn(inputShapesWOTranspose_4D), + ::testing::ValuesIn(precision_f32(3)), + ::testing::Values(ov::element::f32), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(1), + ::testing::Values(1), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose_3D, MHAWOTranspose, + ::testing::Combine( + ::testing::ValuesIn(inputShapesWOTranspose_3D), + ::testing::ValuesIn(precision_f32(3)), ::testing::Values(ov::element::f32), + ::testing::ValuesIn({true}), // Need to support False for graph builder in tests ::testing::Values(1), ::testing::Values(1), ::testing::Values(CommonTestUtils::DEVICE_CPU), - ::testing::Values(std::map{})), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), MHA::getTestCaseName); -const std::map cpuBF16PluginConfig = { { InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16, - InferenceEngine::PluginConfigParams::YES } }; +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeBF16_4D, MHAWOTranspose, + ::testing::Combine( + ::testing::ValuesIn(inputShapesWOTranspose_4D), + ::testing::ValuesIn(precision_bf16(3)), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(5), // MHA + 4 extra Converts on inputs and output + ::testing::Values(5), // MHA + 4 extra Converts on inputs and output + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), + MHA::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeBF16_3D, MHAWOTranspose, ::testing::Combine( - ::testing::ValuesIn(inputShapesWOTranspose(true)), + ::testing::ValuesIn(inputShapesWOTranspose_3D), + ::testing::ValuesIn(precision_bf16(3)), + ::testing::Values(ov::element::f32), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(5), // MHA + 4 extra Converts on inputs and output + ::testing::Values(5), // MHA + 4 extra Converts on inputs and output + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeEnforceBF16_4D, MHAWOTranspose, + ::testing::Combine( + ::testing::ValuesIn(inputShapesWOTranspose_4D), + ::testing::ValuesIn(precision_f32(3)), ::testing::Values(ov::element::bf16), - ::testing::Values(3), - ::testing::Values(0), // CPU plugin doesn't support MHA pattern via Snippets on bf16 + ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(5), // MHA + 4 extra Converts on inputs and output + ::testing::Values(5), // MHA + 4 extra Converts on inputs and output ::testing::Values(CommonTestUtils::DEVICE_CPU), - ::testing::Values(cpuBF16PluginConfig)), + ::testing::Values(CPUTestUtils::cpuBF16PluginConfig)), MHA::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose, MHAWOTranspose, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeEnforceBF16_3D, MHAWOTranspose, ::testing::Combine( - ::testing::ValuesIn(inputShapesWOTranspose(true)), + ::testing::ValuesIn(inputShapesWOTranspose_3D), + ::testing::ValuesIn(precision_f32(3)), + ::testing::Values(ov::element::bf16), ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(5), // MHA + 4 extra Converts on inputs and output + ::testing::Values(5), // MHA + 4 extra Converts on inputs and output + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpuBF16PluginConfig)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAINT8MatMul, MHAINT8MatMul, + ::testing::Combine( + ::testing::ValuesIn(std::vector>(inputShapes.begin(), inputShapes.begin() + 2)), + ::testing::Values(std::vector{}), ::testing::Values(ov::element::f32), - ::testing::Values(1), - ::testing::Values(1), + ::testing::Values(false), // The graph doesn't contain Multiply + ::testing::Values(6), // FQx3 on inputs + MHA + Transpose on output + Deq Mul + ::testing::Values(5), // FQx3 on inputs + MHA + Deq Mul + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQAfterMatMul, MHAFQAfterMatMul, + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::Values(std::vector{}), + ::testing::Values(ov::element::f32), + ::testing::Values(false), // The graph doesn't contain Multiply + ::testing::Values(3), // MHA + Transpose on output + Deq Mul + ::testing::Values(2), // MHA + Deq Mul + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQ, MHAFQ, + ::testing::Combine( + ::testing::Values(std::vector{{1, 64, 12, 64}, {1, 64, 12, 64}, {1, 1, 1, 64}, {1, 64, 12, 64}}), + ::testing::Values(std::vector{}), + ::testing::Values(ov::element::f32), + ::testing::Values(false), // The graph doesn't contain Multiply + ::testing::Values(7), // Transposex2 + Subgraphsx5 + ::testing::Values(5), // MHA + Deq Mul on output + Deqs on inputs + 2 xFQ on inputs ::testing::Values(CommonTestUtils::DEVICE_CPU), - ::testing::Values(std::map{})), + ::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)), MHA::getTestCaseName); diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp index 91ae0af5fa5663..4f7a10056ccb39 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp @@ -575,7 +575,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest, ::testing::ValuesIn(inputPrecisionsQuant), ::testing::ValuesIn(matMulIn0PrecisionsQuant), ::testing::ValuesIn(patternTypesQuant), - ::testing::Values("MHA"), // Snippets don't support Quantized MHA pattern yet + ::testing::Values("MHA"), ::testing::Values(CommonTestUtils::DEVICE_CPU)), MHAQuantTest::getTestCaseName); diff --git a/src/tests/functional/plugin/shared/include/snippets/mha.hpp b/src/tests/functional/plugin/shared/include/snippets/mha.hpp index 8c15adbc8c3fc4..10fb316a8f7624 100644 --- a/src/tests/functional/plugin/shared/include/snippets/mha.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/mha.hpp @@ -12,8 +12,9 @@ namespace snippets { typedef std::tuple< std::vector, // Input shapes - bool, // With Multiply + std::vector, // Input Element types ov::element::Type, // Inference precision + bool, // With Multiply size_t, // Expected num nodes size_t, // Expected num subgraphs std::string, // Target Device @@ -32,6 +33,7 @@ class MHA : public testing::WithParamInterface, virtual void init_subgraph(); bool m_with_mul = false; + std::vector m_input_types; }; class MHASelect : public MHA { @@ -46,6 +48,22 @@ class MHAWOTransposeOnInputs : public MHA { }; class MHAWOTranspose : public MHA { +protected: + void init_subgraph() override; +}; + +class MHAINT8MatMul : public MHA { +protected: + void init_subgraph() override; +}; + +class MHAFQAfterMatMul : public MHA { +protected: + void init_subgraph() override; +}; + +class MHAFQ : public MHA { +protected: void init_subgraph() override; }; diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index 2f5d17dbd8159a..877217d85ee1a4 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -15,16 +15,19 @@ namespace snippets { std::string MHA::getTestCaseName(testing::TestParamInfo obj) { std::vector inputShapes; - bool withMul; + std::vector elem_types; ov::element::Type prc; + bool withMul; std::string targetDevice; size_t num_nodes, num_subgraphs; std::map additionalConfig; - std::tie(inputShapes, withMul, prc, num_nodes, num_subgraphs, targetDevice, additionalConfig) = obj.param; + std::tie(inputShapes, elem_types, prc, withMul, num_nodes, num_subgraphs, targetDevice, additionalConfig) = obj.param; std::ostringstream result; for (size_t i = 0; i < inputShapes.size(); ++i) result << "IS[" << i << "]=" << CommonTestUtils::partialShape2str({inputShapes[i]}) << "_"; + for (size_t i = 0; i < elem_types.size(); i++) + result << "T[" << i <<"]=" << elem_types[i] << "_"; result << "Mul=" << withMul << "_"; result << "PRC=" << prc << "_"; result << "#N=" << num_nodes << "_"; @@ -45,13 +48,13 @@ void MHA::SetUp() { std::vector inputShapes; ov::element::Type prc; std::map additionalConfig; - std::tie(inputShapes, m_with_mul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); + std::tie(inputShapes, m_input_types, prc, m_with_mul, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); init_subgraph(); configuration.insert(additionalConfig.begin(), additionalConfig.end()); - if (additionalConfig.empty() && !configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); } @@ -59,7 +62,7 @@ void MHA::SetUp() { setInferenceType(prc); inType = outType = prc; if (prc == ov::element::bf16) - abs_threshold = 0.3; + rel_threshold = 0.05f; } void MHA::generate_inputs(const std::vector& targetInputStaticShapes) { @@ -68,13 +71,13 @@ void MHA::generate_inputs(const std::vector& targetInputStaticSha for (int i = 0; i < model_inputs.size(); ++i) { const auto& model_input = model_inputs[i]; ov::Tensor tensor; - tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(model_input.get_element_type(), targetInputStaticShapes[i], 1.0f, 0.5f); + tensor = ov::test::utils::create_and_fill_tensor(model_input.get_element_type(), model_input.get_shape(), 2, -1, 256); inputs.insert({model_input.get_node_shared_ptr(), tensor}); } } void MHA::init_subgraph() { - auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_with_mul); + auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_input_types, m_with_mul); function = f.getOriginal(); } @@ -90,14 +93,14 @@ void MHASelect::generate_inputs(const std::vector& targetInputSta tensor = ov::test::utils::create_and_fill_tensor(model_input.get_element_type(), model_input.get_shape(), 5 + seed, -2, 10, seed); seed++; } else { - tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(model_input.get_element_type(), model_input.get_shape(), 1.0f, 0.5f); + tensor = ov::test::utils::create_and_fill_tensor(model_input.get_element_type(), model_input.get_shape(), 2, -1, 256); } inputs.insert({node_input, tensor}); } } void MHASelect::init_subgraph() { - auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes); + auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes, m_input_types); function = f.getOriginal(); } @@ -107,7 +110,22 @@ void MHAWOTransposeOnInputs::init_subgraph() { } void MHAWOTranspose::init_subgraph() { - auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes); + auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes, m_input_types); + function = f.getOriginal(); +} + +void MHAINT8MatMul::init_subgraph() { + auto f = ov::test::snippets::MHAINT8MatMulFunction(inputDynamicShapes); + function = f.getOriginal(); +} + +void MHAFQAfterMatMul::init_subgraph() { + auto f = ov::test::snippets::MHAFQAfterMatMulFunction(inputDynamicShapes); + function = f.getOriginal(); +} + +void MHAFQ::init_subgraph() { + auto f = ov::test::snippets::MHAFQFunction(inputDynamicShapes); function = f.getOriginal(); } @@ -134,6 +152,20 @@ TEST_P(MHAWOTranspose, CompareWithRefImpl) { validateNumSubgraphs(); } +TEST_P(MHAINT8MatMul, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +TEST_P(MHAFQAfterMatMul, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +TEST_P(MHAFQ, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} } // namespace snippets } // namespace test diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp index 55799aa0cecf84..e5019a39ba8df2 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp @@ -26,9 +26,9 @@ class MatMulFunction : public SnippetsFunctionBase { explicit MatMulFunction(const std::vector& inputShapes, const std::vector& precisions) : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes"); - verify_precisions(precisions); + validate_precisions(precisions); } - static void verify_precisions(const std::vector& precisions) { + static void validate_precisions(const std::vector& precisions) { NGRAPH_CHECK(precisions.size() == 2, "Got invalid number of input element types"); const bool is_f32 = ov::snippets::utils::everyone_is(element::f32, precisions[0], precisions[1]); const bool is_int8 = ov::snippets::utils::one_of(precisions[0], element::i8, element::u8) && precisions[1] == element::i8; @@ -62,7 +62,7 @@ class MatMulBiasFunction : public SnippetsFunctionBase { explicit MatMulBiasFunction(const std::vector& inputShapes, const std::vector& precisions) : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); - MatMulFunction::verify_precisions(precisions); + MatMulFunction::validate_precisions(precisions); } protected: std::shared_ptr initOriginal() const override; @@ -70,7 +70,6 @@ class MatMulBiasFunction : public SnippetsFunctionBase { std::vector precisions; }; - // Quantized MatMul // FQ[I8] // Add @@ -79,7 +78,7 @@ class MatMulBiasQuantizedFunction : public SnippetsFunctionBase { explicit MatMulBiasQuantizedFunction(const std::vector& inputShapes, const std::vector& precisions) : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); - MatMulFunction::verify_precisions(precisions); + MatMulFunction::validate_precisions(precisions); } protected: std::shared_ptr initOriginal() const override; @@ -97,7 +96,7 @@ class MatMulsQuantizedFunction : public SnippetsFunctionBase { explicit MatMulsQuantizedFunction(const std::vector& inputShapes, const std::vector& precisions) : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); - MatMulFunction::verify_precisions(precisions); + MatMulFunction::validate_precisions(precisions); } protected: std::shared_ptr initOriginal() const override; @@ -121,7 +120,7 @@ class Transpose0213MatMulFunction : public SnippetsFunctionBase { NGRAPH_CHECK(input_shapes[0].rank().get_length() == 4 && input_shapes[1].rank().get_length() == 4, "Only rank 4 input shapes are supported by this test"); NGRAPH_CHECK(transpose_position >=0 && transpose_position <= 2, "Got invalid transpose position"); - MatMulFunction::verify_precisions(precisions); + MatMulFunction::validate_precisions(precisions); } protected: std::shared_ptr initOriginal() const override; @@ -166,7 +165,7 @@ class MatMulsQuantizedSoftmaxFunction : public SnippetsFunctionBase { explicit MatMulsQuantizedSoftmaxFunction(const std::vector& inputShapes, const std::vector& precisions) : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); - MatMulFunction::verify_precisions(precisions); + MatMulFunction::validate_precisions(precisions); } protected: std::shared_ptr initOriginal() const override; diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp index 5f4ceebf59991f..499b4a4b0f18bd 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp @@ -43,15 +43,17 @@ namespace snippets { */ class MHAFunction : public SnippetsFunctionBase { public: - explicit MHAFunction(const std::vector& inputShapes, bool with_mul = true) - : SnippetsFunctionBase(inputShapes), with_mul(with_mul) { + explicit MHAFunction(const std::vector& inputShapes, const std::vector& precisions, bool with_mul = true) + : SnippetsFunctionBase(inputShapes), with_mul(with_mul), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions"); } protected: std::shared_ptr initOriginal() const override; std::shared_ptr initReference() const override; bool with_mul = true; + std::vector precisions; }; /* Graph: @@ -71,13 +73,16 @@ class MHAFunction : public SnippetsFunctionBase { */ class MHAMatMul0TransposeFunction : public SnippetsFunctionBase { public: - explicit MHAMatMul0TransposeFunction(const std::vector& inputShapes) - : SnippetsFunctionBase(inputShapes) { + explicit MHAMatMul0TransposeFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions"); } protected: std::shared_ptr initOriginal() const override; std::shared_ptr initReference() const override; + + std::vector precisions; }; /* Graph: @@ -97,11 +102,15 @@ class MHAMatMul0TransposeFunction : public SnippetsFunctionBase { */ class MHASelectFunction : public SnippetsFunctionBase { public: - explicit MHASelectFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { + explicit MHASelectFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 6, "Got invalid number of input shapes"); + NGRAPH_CHECK(precisions.size() == 6, "Got invalid number of input precisions"); } protected: std::shared_ptr initOriginal() const override; + + std::vector precisions; }; /* Graph: @@ -137,13 +146,128 @@ class MHAWOTransposeOnInputsFunction : public SnippetsFunctionBase { */ class MHAWOTransposeFunction : public SnippetsFunctionBase { public: - explicit MHAWOTransposeFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { + explicit MHAWOTransposeFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + NGRAPH_CHECK(precisions.size() == 3, "Got invalid number of input precisions"); + } +protected: + std::shared_ptr initOriginal() const override; + + std::vector precisions; +}; + +/* Graph: + * Transpose0[0,2,1,3] Transpose1[0,2,3,1] + * \ / + * MatMul0 + * FakeQuantize i8 + * \ / + * Add + * Reshape0 + * Softmax + * Reshape1 Transpose2[0,2,1,3] + * \ / + * MatMul1 + * FakeQuantize i8 + * Transpose3[0,2,1,3] + */ +class MHAFQAfterMatMulFunction : public SnippetsFunctionBase { +public: + explicit MHAFQAfterMatMulFunction(const std::vector& inputShapes) + : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; +}; + +/* Graph: + * FakeQuantize i8 FakeQuantize i8 + * Transpose0[0,2,1,3] Transpose1[0,2,3,1] + * \ / + * MatMul0 + * FakeQuantize i8 + * \ / + * Add + * Reshape0 + * Softmax + * Reshape1 FakeQuantize i8 + * FakeQuantize u8 Transpose2[0,2,1,3] + * \ / + * MatMul1 + * FakeQuantize i8 + * Transpose3[0,2,1,3] + */ +class MHAINT8MatMulFunction : public SnippetsFunctionBase { +public: + explicit MHAINT8MatMulFunction(const std::vector& inputShapes) + : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; +}; + +/* Graph: + * Constant + * FakeQuantize u8 FakeQuantize u8 Convert + * Transpose0[0,2,1,3] Transpose1[0,2,3,1] Multiply + * \ \ / + * \ Multiply + * \ FakeQuantize f32 + * \ / + * MatMul0 + * FakeQuantize f32 FakeQuantize u8 + * \ / + * Add + * Softmax Transpose2[0,2,1,3] + * \ / + * MatMul1 + * FakeQuantize u8 + * Transpose3[0,2,1,3] + * Note: Check a lot of different FQ (the both quantized and floating) - buffers with different size and precision + */ +class MHAFQFunction : public SnippetsFunctionBase { +public: + explicit MHAFQFunction(const std::vector& inputShapes) + : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); } protected: std::shared_ptr initOriginal() const override; }; +// Only for tokenization! The graph is after LPT: contains TypeRelaxed ops +/* Graph: + * FakeQuantize i8 FakeQuantize i8 + * Transpose0[0,2,1,3] Transpose1[0,2,3,1] + * \ / + * MatMul0 + * FakeQuantize i8 + * \ / + * Add + * Mul (DeQuantize) + * Reshape0 + * Softmax + * Reshape1 FakeQuantize i8 + * FakeQuantize u8 Transpose2[0,2,1,3] + * \ / + * MatMul1 + * FakeQuantize i8 + * Transpose3[0,2,1,3] + */ +class MHAINT8MatMulTypeRelaxedFunction : public SnippetsFunctionBase { +public: + explicit MHAINT8MatMulTypeRelaxedFunction(const std::vector& inputShapes) + : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; + std::shared_ptr initReference() const override; +}; + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp index c5086525ec1e52..01d0758a160fc4 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp @@ -230,4 +230,4 @@ std::shared_ptr MatMulsQuantizedSoftmaxFunction::initOriginal() const } // namespace snippets } // namespace test -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp index 0f64f9cf5616d5..9d7860d2d52c3c 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp @@ -7,16 +7,18 @@ #include "common_test_utils/data_utils.hpp" #include #include "ngraph_functions/builders.hpp" +#include "ov_ops/type_relaxed.hpp" +#include "lpt_ngraph_functions/common/builders.hpp" namespace ov { namespace test { namespace snippets { std::shared_ptr MHAFunction::initOriginal() const { - auto transpose0Param = std::make_shared(precision, input_shapes[0]); - auto transpose1Param = std::make_shared(precision, input_shapes[1]); - auto addParam = std::make_shared(precision, input_shapes[2]); - auto transpose2Param = std::make_shared(precision, input_shapes[3]); + auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); + auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); + auto addParam = std::make_shared(precisions[2], input_shapes[2]); + auto transpose2Param = std::make_shared(precisions[3], input_shapes[3]); ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; std::vector constantShapes; @@ -51,7 +53,7 @@ std::shared_ptr MHAFunction::initOriginal() const { std::shared_ptr matmul_parent1 = transpose1; if (with_mul) { std::vector mulConstData(ngraph::shape_size(constantShapes[2])); - auto mulConst = ngraph::builder::makeConstant(precision, constantShapes[2], mulConstData, true); + auto mulConst = ngraph::builder::makeConstant(precisions[1], constantShapes[2], mulConstData, true); matmul_parent1 = std::make_shared(transpose1, mulConst); } const auto matMul0 = std::make_shared(transpose0, matmul_parent1, transA, transB); @@ -67,17 +69,17 @@ std::shared_ptr MHAFunction::initOriginal() const { return std::make_shared(results, ngraphParam, "mha"); } std::shared_ptr MHAFunction::initReference() const { - auto data0 = std::make_shared(precision, input_shapes[0]); - auto data1 = std::make_shared(precision, input_shapes[1]); - auto data2 = std::make_shared(precision, input_shapes[2]); - auto data3 = std::make_shared(precision, input_shapes[3]); + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + auto data2 = std::make_shared(precisions[2], input_shapes[2]); + auto data3 = std::make_shared(precisions[3], input_shapes[3]); ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3}; NodeVector subgraph_inputs = {data0, data1, data2, data3}; - auto transpose0Param = std::make_shared(precision, input_shapes[0]); - auto transpose1Param = std::make_shared(precision, input_shapes[1]); - auto addParam = std::make_shared(precision, input_shapes[2]); - auto transpose2Param = std::make_shared(precision, input_shapes[3]); + auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); + auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); + auto addParam = std::make_shared(precisions[2], input_shapes[2]); + auto transpose2Param = std::make_shared(precisions[3], input_shapes[3]); std::vector constantShapes; constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()})); @@ -113,8 +115,8 @@ std::shared_ptr MHAFunction::initReference() const { std::shared_ptr matmul_parent1 = transpose1; if (with_mul) { std::vector mulConstData(ngraph::shape_size(constantShapes[2])); - auto mulConst = ngraph::builder::makeConstant(precision, constantShapes[2], mulConstData, true); - auto mulParam = std::make_shared(precision, mulConst->get_shape()); + auto mulConst = ngraph::builder::makeConstant(precisions[1], constantShapes[2], mulConstData, true); + auto mulParam = std::make_shared(precisions[1], mulConst->get_shape()); matmul_parent1 = std::make_shared(transpose1, mulParam); subgraph_params = {transpose0Param, transpose1Param, mulParam, addParam, transpose2Param}; subgraph_inputs = {data0, data1, mulConst, data2, data3}; @@ -135,10 +137,10 @@ std::shared_ptr MHAFunction::initReference() const { } std::shared_ptr MHAMatMul0TransposeFunction::initOriginal() const { - auto transpose0Param = std::make_shared(precision, input_shapes[0]); - auto transpose1Param = std::make_shared(precision, input_shapes[1]); - auto addParam = std::make_shared(precision, input_shapes[2]); - auto transpose2Param = std::make_shared(precision, input_shapes[3]); + auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); + auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); + auto addParam = std::make_shared(precisions[2], input_shapes[2]); + auto transpose2Param = std::make_shared(precisions[3], input_shapes[3]); ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; std::vector constantShapes; @@ -157,7 +159,7 @@ std::shared_ptr MHAMatMul0TransposeFunction::initOriginal() const { auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[6], order); std::vector mulConstData(1); - auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, mulConstData, true); + auto mulConst = ngraph::builder::makeConstant(precisions[1], ov::Shape{1}, mulConstData, true); std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), @@ -188,16 +190,16 @@ std::shared_ptr MHAMatMul0TransposeFunction::initOriginal() const { return std::make_shared(results, ngraphParam, "mha"); } std::shared_ptr MHAMatMul0TransposeFunction::initReference() const { - auto data0 = std::make_shared(precision, input_shapes[0]); - auto data1 = std::make_shared(precision, input_shapes[1]); - auto data2 = std::make_shared(precision, input_shapes[2]); - auto data3 = std::make_shared(precision, input_shapes[3]); + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + auto data2 = std::make_shared(precisions[2], input_shapes[2]); + auto data3 = std::make_shared(precisions[3], input_shapes[3]); ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3}; - auto transpose0Param = std::make_shared(precision, input_shapes[0]); - auto transpose1Param = std::make_shared(precision, input_shapes[1]); - auto addParam = std::make_shared(precision, input_shapes[2]); - auto transpose2Param = std::make_shared(precision, input_shapes[3]); + auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); + auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); + auto addParam = std::make_shared(precisions[2], input_shapes[2]); + auto transpose2Param = std::make_shared(precisions[3], input_shapes[3]); std::vector constantShapes; constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()})); @@ -214,7 +216,7 @@ std::shared_ptr MHAMatMul0TransposeFunction::initReference() const { auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[6], std::vector{0, 2, 1, 3}); std::vector mulConstData(1); - auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, mulConstData, true); + auto mulConst = ngraph::builder::makeConstant(precisions[1], ov::Shape{1}, mulConstData, true); ngraph::ParameterVector subgraphParams = {transpose0Param, transpose1Param, addParam, transpose2Param}; std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * @@ -250,12 +252,12 @@ std::shared_ptr MHAMatMul0TransposeFunction::initReference() const { } std::shared_ptr MHASelectFunction::initOriginal() const { - auto transpose0Param = std::make_shared(precision, input_shapes[0]); - auto transpose1Param = std::make_shared(precision, input_shapes[1]); - auto addParam = std::make_shared(precision, input_shapes[2]); - auto less0Param = std::make_shared(precision, input_shapes[3]); - auto less1Param = std::make_shared(precision, input_shapes[4]); - auto transpose2Param = std::make_shared(precision, input_shapes[5]); + auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); + auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); + auto addParam = std::make_shared(precisions[2], input_shapes[2]); + auto less0Param = std::make_shared(precisions[3], input_shapes[3]); + auto less1Param = std::make_shared(precisions[4], input_shapes[4]); + auto transpose2Param = std::make_shared(precisions[5], input_shapes[5]); ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, less0Param, less1Param, transpose2Param}; std::vector constantShapes; @@ -288,7 +290,7 @@ std::shared_ptr MHASelectFunction::initOriginal() const { static_cast(input_shapes[0].get_shape()[1])}; auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[4], reshape1ConstData); // Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN - auto selectConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, std::vector{1}); + auto selectConst = ngraph::builder::makeConstant(precisions[2], ov::Shape{1}, std::vector{1}); float transA = false; float transB = false; @@ -344,9 +346,9 @@ std::shared_ptr MHAWOTransposeOnInputsFunction::initOriginal() const } std::shared_ptr MHAWOTransposeFunction::initOriginal() const { - auto param0 = std::make_shared(precision, input_shapes[0]); - auto param1 = std::make_shared(precision, input_shapes[1]); - auto param2 = std::make_shared(precision, input_shapes[2]); + auto param0 = std::make_shared(precisions[0], input_shapes[0]); + auto param1 = std::make_shared(precisions[1], input_shapes[1]); + auto param2 = std::make_shared(precisions[2], input_shapes[2]); ngraph::ParameterVector ngraphParam = {param0, param1, param2}; float transA = false; @@ -359,6 +361,302 @@ std::shared_ptr MHAWOTransposeFunction::initOriginal() const { return std::make_shared(results, ngraphParam, "mha"); } + +std::shared_ptr MHAFQAfterMatMulFunction::initOriginal() const { + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + const auto shape_rank = input_shapes[0].get_shape().size(); + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + + std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * + input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(input_shapes[0].get_shape()[0]), + static_cast(input_shapes[0].get_shape()[2]), + static_cast(input_shapes[0].get_shape()[1]), + static_cast(input_shapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(transpose0Param, transpose0Const); + const auto transpose1 = std::make_shared(transpose1Param, transpose1Const); + const auto matMul0 = std::make_shared(transpose0, transpose1, transA, transB); + auto fq0 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + const auto add = std::make_shared(fq0, addParam); + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + const auto transpose2 = std::make_shared(transpose2Param, transpose2Const); + const auto matMul1 = std::make_shared(reshape1, transpose2, transA, transB); + auto fq1 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + const auto transpose3 = std::make_shared(fq1, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} +std::shared_ptr MHAINT8MatMulFunction::initOriginal() const { + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + const auto shape_rank = input_shapes[0].get_shape().size(); + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + + std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * + input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(input_shapes[0].get_shape()[0]), + static_cast(input_shapes[0].get_shape()[2]), + static_cast(input_shapes[0].get_shape()[1]), + static_cast(input_shapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData); + + auto fq0 = ngraph::builder::makeFakeQuantize(transpose0Param, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + auto fq1 = ngraph::builder::makeFakeQuantize(transpose1Param, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + auto fq2 = ngraph::builder::makeFakeQuantize(transpose2Param, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(fq0, transpose0Const); + const auto transpose1 = std::make_shared(fq1, transpose1Const); + const auto matMul0 = std::make_shared(transpose0, transpose1, transA, transB); + auto fq3 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + const auto add = std::make_shared(fq3, addParam); + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + auto fq4 = ngraph::builder::makeFakeQuantize(reshape1, ov::element::f32, 256, {1}, + {0}, {0.820726}, {0}, {0.820726}); + const auto transpose2 = std::make_shared(fq2, transpose2Const); + const auto matMul1 = std::make_shared(fq4, transpose2, transA, transB); + auto fq5 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + const auto transpose3 = std::make_shared(fq5, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} +std::shared_ptr MHAFQFunction::initOriginal() const { + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + const auto shape_rank = input_shapes[0].get_shape().size(); + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + + const auto fq0 = ngraph::builder::makeFakeQuantize(transpose0Param, ov::element::f32, 256, {1}, + {-5.217694}, {6.661877}, {-5.217694}, {6.661877}); + const auto fq1 = ngraph::builder::makeFakeQuantize(transpose1Param, ov::element::f32, 256, {1}, + {-6.40245}, {6.45286}, {-6.40245}, {6.45286}); + const auto fq_add = ngraph::builder::makeFakeQuantize(addParam, ov::element::f32, 256, {1}, + {-1000}, {0}, {-1000}, {0}); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(fq0, transpose0Const); + const auto transpose1 = std::make_shared(fq1, transpose1Const); + const auto transpose2 = std::make_shared(transpose2Param, transpose2Const); + const auto mul_const = ngraph::builder::makeConstant(ov::element::i8, ov::Shape{1}, std::vector{127}); + const auto convert = std::make_shared(mul_const, ov::element::f32); + const auto mul_deq_const = ngraph::builder::makeConstant(ov::element::f32, ov::Shape{1}, std::vector{0.00098425}); + const auto mul_deq = std::make_shared(convert, mul_deq_const); + const auto mul = std::make_shared(transpose1, mul_deq); + auto fq1_1 = ngraph::builder::makeFakeQuantize(mul, ov::element::f32, 256, {1}, + {-0.8003067}, {0.8066083}, {-0.8003067}, {0.8066083}); + const auto matMul0 = std::make_shared(transpose0, fq1_1, transA, transB); + auto fq2 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1}, + {-14.50351}, {17.65645}, {-14.50351}, {17.65645}); + const auto add = std::make_shared(fq2, fq_add); + const auto softMax = std::make_shared(add, 3); + const auto matMul1 = std::make_shared(softMax, transpose2, transA, transB); + auto fq3 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1}, + {-1.895786}, {2.0028071}, {-1.895786}, {2.0028071}); + const auto transpose3 = std::make_shared(fq3, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} +std::shared_ptr MHAINT8MatMulTypeRelaxedFunction::initOriginal() const { + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + const auto shape_rank = input_shapes[0].get_shape().size(); + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + + std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * + input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(input_shapes[0].get_shape()[0]), + static_cast(input_shapes[0].get_shape()[2]), + static_cast(input_shapes[0].get_shape()[1]), + static_cast(input_shapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData); + + const auto fq_signed_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {-36912.66015625}, {36624.28125}, {-128}, {127}, ov::element::i8); + const auto fq0 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose0Param, ov::element::i8, fq_signed_params); + const auto fq1 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose1Param, ov::element::i8, fq_signed_params); + const auto fq2 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose2Param, ov::element::i8, fq_signed_params); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(fq0, transpose0Const); + const auto transpose1 = std::make_shared(fq1, transpose1Const); + const auto matMul0 = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(transpose0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(transpose1, element::f32).get(), transA, transB); + + const auto fq3 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::i8, fq_signed_params); + const auto add = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(fq3, element::f32).get(), + ov::op::TemporaryReplaceOutputType(addParam, element::f32).get()); + const auto deq = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{0.1122}); + const auto deq_mul = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(add, element::f32).get(), + ov::op::TemporaryReplaceOutputType(deq, element::f32).get()); + + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + + const auto fq_unsigned_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {0}, {0.245}, {0}, {255}, ov::element::u8); + const auto fq4 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::u8, fq_unsigned_params); + + const auto transpose2 = std::make_shared(fq2, transpose2Const); + const auto matMul1 = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(fq4, element::f32).get(), + ov::op::TemporaryReplaceOutputType(transpose2, element::f32).get(), transA, transB); + const auto fq5 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::i8, fq_signed_params); + const auto transpose3 = std::make_shared(fq5, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} +std::shared_ptr MHAINT8MatMulTypeRelaxedFunction::initReference() const { + auto data0 = std::make_shared(precision, input_shapes[0]); + auto data1 = std::make_shared(precision, input_shapes[1]); + auto data2 = std::make_shared(precision, input_shapes[2]); + auto data3 = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3}; + + const auto fq_signed_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {-36912.66015625}, {36624.28125}, {-128}, {127}, ov::element::i8); + const auto fq0 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data0, ov::element::i8, fq_signed_params); + const auto fq1 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data1, ov::element::i8, fq_signed_params); + const auto fq2 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data3, ov::element::i8, fq_signed_params); + NodeVector subgraph_inputs = {fq0, fq1, data2, fq2}; + + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ov::ParameterVector subgraph_params = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + const auto shape_rank = input_shapes[0].get_shape().size(); + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + + std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * + input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(input_shapes[0].get_shape()[0]), + static_cast(input_shapes[0].get_shape()[2]), + static_cast(input_shapes[0].get_shape()[1]), + static_cast(input_shapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(transpose0Param, transpose0Const); + const auto transpose1 = std::make_shared(transpose1Param, transpose1Const); + const auto matMul0 = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(transpose0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(transpose1, element::f32).get(), transA, transB); + + const auto fq3 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::i8, fq_signed_params); + const auto add = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(fq3, element::f32).get(), + ov::op::TemporaryReplaceOutputType(addParam, element::f32).get()); + const auto deq = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{0.1122}); + const auto deq_mul = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(add, element::f32).get(), + ov::op::TemporaryReplaceOutputType(deq, element::f32).get()); + + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + + const auto fq_unsigned_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {0}, {0.245}, {0}, {255}, ov::element::u8); + const auto fq4 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::u8, fq_unsigned_params); + + const auto transpose2 = std::make_shared(transpose2Param, transpose2Const); + const auto matMul1 = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(fq4, element::f32).get(), + ov::op::TemporaryReplaceOutputType(transpose2, element::f32).get(), transA, transB); + const auto fq5 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::i8, fq_signed_params); + + auto subgraph = std::make_shared(subgraph_inputs, + std::make_shared(NodeVector{fq5}, subgraph_params)); + // TODO: At the moment Snippets don't support explicitly Transpose. + // So we cannot collapse Transpose into Subgraph if there are ops between MatMul2 and Transpose3 + auto transpose3 = std::make_shared(subgraph, transpose3Const); + + return std::make_shared(NodeVector{transpose3}, ngraphParams); +} + } // namespace snippets } // namespace test } // namespace ov