From c0875bba2c7f7280b6b9cd66069a7772a36a7609 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Thu, 26 Jan 2023 18:50:37 +0400 Subject: [PATCH] [Snippets] Added MHA I8 tokenization --- .../snippets/include/snippets/op/brgemm.hpp | 2 + .../snippets/include/snippets/op/buffer.hpp | 15 +- .../snippets/include/snippets/op/loop.hpp | 1 + .../snippets/include/snippets/op/subgraph.hpp | 9 +- .../snippets/pass/buffer_identification.hpp | 30 ++ .../include/snippets/pass/tokenization.hpp | 9 + src/common/snippets/src/generator.cpp | 1 + src/common/snippets/src/op/brgemm.cpp | 29 +- src/common/snippets/src/op/buffer.cpp | 19 +- src/common/snippets/src/op/loop.cpp | 9 + src/common/snippets/src/op/subgraph.cpp | 96 ++++- .../snippets/src/pass/assign_registers.cpp | 93 +++-- .../src/pass/buffer_identification.cpp | 174 +++++++++ .../snippets/src/pass/collapse_subgraph.cpp | 42 +- .../snippets/src/pass/mha_tokenization.cpp | 231 ++++++++--- src/common/snippets/src/pass/reset_buffer.cpp | 10 +- src/common/snippets/src/pass/tokenization.cpp | 18 + .../tests/src/pass/mha_tokenization.cpp | 13 +- .../src/emitters/jit_snippets_emitters.cpp | 23 +- .../src/emitters/jit_snippets_emitters.hpp | 7 +- .../snippets_mark_skipped.cpp | 12 +- .../intel_cpu/src/transformation_pipeline.cpp | 29 +- .../skip_tests_config.cpp | 3 + .../snippets/matmul.cpp | 46 ++- .../shared_tests_instances/snippets/mha.cpp | 59 ++- .../functional/subgraph_tests/src/mha.cpp | 2 +- .../plugin/shared/include/snippets/matmul.hpp | 15 + .../plugin/shared/include/snippets/mha.hpp | 16 + .../plugin/shared/src/snippets/matmul.cpp | 60 +++ .../plugin/shared/src/snippets/mha.cpp | 82 +++- .../include/subgraph_matmul.hpp | 52 +++ .../include/subgraph_mha.hpp | 127 +++++- .../src/subgraph_matmul.cpp | 60 ++- .../src/subgraph_mha.cpp | 369 ++++++++++++++++-- 34 files changed, 1536 insertions(+), 227 deletions(-) create mode 100644 src/common/snippets/include/snippets/pass/buffer_identification.hpp create mode 100644 src/common/snippets/src/pass/buffer_identification.cpp diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp index 58c70f164799a6..443edb19d8cef6 100644 --- a/src/common/snippets/include/snippets/op/brgemm.hpp +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -27,6 +27,8 @@ class Brgemm : public MemoryAccess { size_t get_offset_b() const { return get_input_offset(1); } size_t get_offset_c() const { return get_output_offset(0); } + static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1); + void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; diff --git a/src/common/snippets/include/snippets/op/buffer.hpp b/src/common/snippets/include/snippets/op/buffer.hpp index 8c6f98ac894e93..14e39b0e9b13be 100644 --- a/src/common/snippets/include/snippets/op/buffer.hpp +++ b/src/common/snippets/include/snippets/op/buffer.hpp @@ -19,15 +19,18 @@ namespace op { * - All buffers in a graph have the same memory pointer. So if we have a few buffers, * each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer * - Buffer should be a single consumer for operation output port + * @param m_type - type of Buffer: IntermediateMemory/NewMemory + * @param m_shape - output allocation shape for Buffer with type NewMemory + * @param m_id - Buffer ID in common Buffer system * @ingroup snippets */ class Buffer : public ngraph::op::Op { public: OPENVINO_OP("Buffer", "SnippetsOpset"); Buffer() = default; - Buffer(const ov::Shape& shape); - Buffer(const ov::Output& arg, const ov::Shape& shape); - Buffer(const ov::Output& arg, int32_t allocation_rank = -1); + Buffer(const ov::Shape& shape, size_t id = 0); + Buffer(const ov::Output& arg, const ov::Shape& shape, size_t id = 0); + Buffer(const ov::Output& arg, int32_t allocation_rank = -1, size_t id = 0); bool visit_attributes(AttributeVisitor& visitor) override; void validate_and_infer_types() override; @@ -38,9 +41,12 @@ class Buffer : public ngraph::op::Op { IntermediateMemory }; + void set_id(size_t id) { m_id = id; } + + size_t get_id() const { return m_id; } + size_t get_byte_size() const; Type get_type() const { return m_type; } ov::Shape get_allocation_shape() const { return m_shape; } - size_t get_byte_size() const; bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; } bool is_new_memory() const { return m_type == Type::NewMemory; } @@ -48,6 +54,7 @@ class Buffer : public ngraph::op::Op { private: Type m_type = Type::IntermediateMemory; ov::Shape m_shape = {}; + size_t m_id = 0; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/loop.hpp b/src/common/snippets/include/snippets/op/loop.hpp index 89cf0abd5173ff..756cc084b6da98 100644 --- a/src/common/snippets/include/snippets/op/loop.hpp +++ b/src/common/snippets/include/snippets/op/loop.hpp @@ -83,6 +83,7 @@ class LoopEnd : public LoopBase { std::vector ptr_increments, std::vector finalization_offsets); LoopEnd() = default; std::shared_ptr get_loop_begin(); + bool visit_attributes(AttributeVisitor& visitor) override; void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& inputs) const override; const std::vector& get_finalization_offsets() const; diff --git a/src/common/snippets/include/snippets/op/subgraph.hpp b/src/common/snippets/include/snippets/op/subgraph.hpp index 46e6633f61b8aa..93e7879b14fa7f 100644 --- a/src/common/snippets/include/snippets/op/subgraph.hpp +++ b/src/common/snippets/include/snippets/op/subgraph.hpp @@ -97,7 +97,6 @@ class Subgraph : public ov::op::util::SubGraphOp { size_t get_buffer_scratchpad_size() const { return m_buffer_scratchpad; } size_t get_virtual_port_count() const { return m_virtual_port_count; } - bool is_buffer_needed() const { return m_buffer_needed; } bool is_quantized() const { return config.m_is_quantized; } bool has_type_relaxed_ops() const { return config.m_has_type_relaxed_ops; } bool has_domain_sensitive_ops() const { return config.m_has_domain_sensitive_ops; } @@ -122,7 +121,6 @@ class Subgraph : public ov::op::util::SubGraphOp { void set_generator(std::shared_ptr generator); void set_tile_rank(size_t newRank) {tileRank = newRank;} void set_virtual_port_count(const size_t count); - void set_buffer_needed(const bool need); void print() const; void print_statistics(bool verbose); @@ -137,8 +135,10 @@ class Subgraph : public ov::op::util::SubGraphOp { // should have explicit Constants even if they're non-scalar (Reshape, Transpose, Broadcast) // This check returns True if Constant op which is input of this op should be inside Subgraph body static auto constant_input_should_be_inside_body(const std::shared_ptr& node) -> bool; - static bool check_broadcast(const std::shared_ptr& node) noexcept; + // Return estimated unqiue buffer count (rating from above). It's needed for tokenization + static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t; + static auto is_domain_sensitive_op(const std::shared_ptr& op) -> bool; private: void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes); @@ -147,12 +147,9 @@ class Subgraph : public ov::op::util::SubGraphOp { void initialize_buffer_scratchpad_size(); // Count of Subgraph virtual ports: // - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition) - // Need Buffer op or not - // - Buffers. All Buffers are considered as one common additional virtual port. So we cannot summarize them as potential non-scalar Constants // NOTE: To avoid overheads in each calculation of this count (for example, in validate_and_type_infer()), // we should MANUALLY calculate it where it needed. size_t m_virtual_port_count = 0; - bool m_buffer_needed = false; size_t m_buffer_scratchpad = 0lu; Shape exec_domain = {}; std::shared_ptr m_generator = nullptr; diff --git a/src/common/snippets/include/snippets/pass/buffer_identification.hpp b/src/common/snippets/include/snippets/pass/buffer_identification.hpp new file mode 100644 index 00000000000000..170fadd79b3688 --- /dev/null +++ b/src/common/snippets/include/snippets/pass/buffer_identification.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace ngraph { +namespace snippets { +namespace pass { + +/** + * @interface BufferIdentification + * @brief The pass set identifiers for Buffers in common Buffer system + * Note: should be called before ResetBuffer() pass to have correct offsets + * @ingroup snippets + */ +class BufferIdentification: public ngraph::pass::FunctionPass { +public: + OPENVINO_RTTI("InsertLoops", "0"); + BufferIdentification() = default; + + bool run_on_model(const std::shared_ptr& m) override; +}; + +} // namespace pass +} // namespace snippets +} // namespace ngraph diff --git a/src/common/snippets/include/snippets/pass/tokenization.hpp b/src/common/snippets/include/snippets/pass/tokenization.hpp index 19b776ec25751d..910289458d83fd 100644 --- a/src/common/snippets/include/snippets/pass/tokenization.hpp +++ b/src/common/snippets/include/snippets/pass/tokenization.hpp @@ -9,6 +9,7 @@ #include "snippets/pass/mha_tokenization.hpp" #include "snippets/pass/collapse_subgraph.hpp" +#include "snippets/op/subgraph.hpp" namespace ngraph { namespace snippets { @@ -19,8 +20,16 @@ namespace pass { SkippedByPlugin - indicate that snippets can't include this node in subgraph. Can be set by Plugin via SetSnippetsNodeType(...). */ enum class SnippetsNodeType : int64_t {NotSet, SkippedByPlugin}; +/* + NotSet - default value returned if the subgraph wasn't marked and snippets can include nodes in this subgraph + Completed - indicate that snippets can't include any nodes in this subgraph. + It's used in separate tokenization pass, for example, tokenization by matcher (MHA Tokenization). + */ +enum class SnippetsSubgraphType : int64_t {NotSet, Completed}; void SetSnippetsNodeType(const std::shared_ptr&, SnippetsNodeType); +void SetSnippetsSubgraphType(const std::shared_ptr&, SnippetsSubgraphType); SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr&); +SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr&); void SetTopologicalOrder(const std::shared_ptr&, int64_t); int64_t GetTopologicalOrder(const std::shared_ptr&); diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index dba0f139fda495..4b040f9be62899 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -10,6 +10,7 @@ #include "snippets/op/subgraph.hpp" #include "snippets/op/kernel.hpp" #include +#include #include #include diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 743653099b8601..a36ae1ad10a9b4 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -17,7 +17,7 @@ Brgemm::Brgemm(const Output& A, const Output& B, set_output_size(1); set_input_offset(offset_a, 0); set_input_offset(offset_b, 1); - set_output_offset(offset_a, 0); + set_output_offset(offset_c, 0); constructor_validate_and_infer_types(); } @@ -45,22 +45,29 @@ std::shared_ptr Brgemm::clone_with_new_inputs(const OutputVector& new_args return std::make_shared(new_args.at(0), new_args.at(1), get_offset_a(), get_offset_b(), get_offset_c()); } -ov::element::Type Brgemm::get_output_type() const { - const auto element_type_a = get_input_element_type(0); - const auto element_type_b = get_input_element_type(1); - const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b); - const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8; - const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b); +ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1) { + const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1); + const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8; + const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1); if (is_f32 || is_bf16) { - return element::f32; + return element::f32; } else if (is_int8) { return element::i32; } else { + return element::undefined; + } +} + +ov::element::Type Brgemm::get_output_type() const { + auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1)); + if (output_type == element::undefined) { throw ngraph_error("BrgemmCPU node has incompatible input element types: " + - element_type_a.get_type_name() + - " and " + - element_type_b.get_type_name()); + get_input_element_type(0).get_type_name() + + " and " + + get_input_element_type(1).get_type_name()); } + + return output_type; } ov::PartialShape Brgemm::get_output_partial_shape(const std::vector& input_shapes) const { diff --git a/src/common/snippets/src/op/buffer.cpp b/src/common/snippets/src/op/buffer.cpp index 8a3963119b832b..5a0ff8d16864cf 100644 --- a/src/common/snippets/src/op/buffer.cpp +++ b/src/common/snippets/src/op/buffer.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2022 Intel Corporation +// Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -12,22 +12,18 @@ using namespace std; using namespace ngraph; -auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t { - return allocation_rank < 0 ? allocation_rank + static_cast(shape_rank) : allocation_rank; -} - -snippets::op::Buffer::Buffer(const ov::Shape& shape) - : Op(), m_type(Type::NewMemory), m_shape(shape) { +snippets::op::Buffer::Buffer(const ov::Shape& shape, size_t id) + : Op(), m_type(Type::NewMemory), m_shape(shape), m_id(id) { constructor_validate_and_infer_types(); } -snippets::op::Buffer::Buffer(const ov::Output& arg, const ov::Shape& shape) - : Op({arg}), m_type(Type::IntermediateMemory), m_shape(shape) { +snippets::op::Buffer::Buffer(const ov::Output& arg, const ov::Shape& shape, size_t id) + : Op({arg}), m_type(Type::IntermediateMemory), m_shape(shape), m_id(id) { constructor_validate_and_infer_types(); } -snippets::op::Buffer::Buffer(const ov::Output& arg, int32_t allocation_rank) - : Op({arg}), m_type(Type::IntermediateMemory) { +snippets::op::Buffer::Buffer(const ov::Output& arg, int32_t allocation_rank, size_t id) + : Op({arg}), m_type(Type::IntermediateMemory), m_id(id) { const auto pshape = arg.get_partial_shape(); OPENVINO_ASSERT(pshape.is_static(), "Buffer supports only static input shape"); const auto shape = pshape.get_shape(); @@ -40,6 +36,7 @@ snippets::op::Buffer::Buffer(const ov::Output& arg, int32_t allocation bool snippets::op::Buffer::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(Buffer_visit_attributes); visitor.on_attribute("allocation_shape", m_shape); + visitor.on_attribute("id", m_id); return true; } diff --git a/src/common/snippets/src/op/loop.cpp b/src/common/snippets/src/op/loop.cpp index c8c704fd350913..c1489d9f92bdcd 100644 --- a/src/common/snippets/src/op/loop.cpp +++ b/src/common/snippets/src/op/loop.cpp @@ -181,6 +181,15 @@ void LoopEnd::validate_and_infer_types() { get_output_descriptor(i).set_tensor_ptr(get_input_descriptor(i).get_output().get_tensor_ptr()); } +bool LoopEnd::visit_attributes(AttributeVisitor& visitor) { + LoopBase::visit_attributes(visitor); + for (size_t i = 0; i < ptr_increments.size(); ++i) { + visitor.on_attribute("ptr_increment_" + std::to_string(i), ptr_increments[i]); + visitor.on_attribute("finalization_offsets_" + std::to_string(i), finalization_offsets[i]); + } + return true; +} + } // namespace op } // namespace snippets } // namespace ngraph \ No newline at end of file diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index f8953745520aff..ef25391bb705ea 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -25,6 +25,7 @@ #include "snippets/pass/reset_buffer.hpp" #include "snippets/pass/insert_buffer.hpp" #include "snippets/pass/loop_fusion.hpp" +#include "snippets/pass/buffer_identification.hpp" #include "snippets/utils.hpp" #include "transformations/common_optimizations/nop_elimination.hpp" @@ -51,8 +52,13 @@ void snippets::op::Subgraph::set_virtual_port_count(const size_t count) { m_virtual_port_count = count; } -void snippets::op::Subgraph::set_buffer_needed(const bool need) { - m_buffer_needed = need; +auto snippets::op::Subgraph::is_domain_sensitive_op(const std::shared_ptr& op) -> bool { + return ov::is_type(op) || + ov::is_type(op) || + ov::is_type(op) || + ov::is_type(op) || + ov::is_type(op) || // Broadcast is domain sensetive op because the output shape depends on + ov::is_type(op); // the both input and broadcast shapes (the both - are inputs of op). Note: is used only in MHA pattern } void snippets::op::Subgraph::init_config() { @@ -62,16 +68,69 @@ void snippets::op::Subgraph::init_config() { ov::is_type(op); config.m_has_type_relaxed_ops = config.m_has_type_relaxed_ops || std::dynamic_pointer_cast(op); - config.m_has_domain_sensitive_ops = config.m_has_domain_sensitive_ops || - ov::is_type(op) || - ov::is_type(op) || - ov::is_type(op) || - ov::is_type(op); + config.m_has_domain_sensitive_ops = config.m_has_domain_sensitive_ops || is_domain_sensitive_op(op); } // Domain sensitive ops are decomposed with explicit Loops. So, we should explicitly insert Loops in Subgraph if it contains these ops config.m_explicit_loop_insertion = config.m_has_domain_sensitive_ops; } +auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t { + // The count of potential unique Buffers - it's hidden virtual ports as well + // We should go through Subgraph and calculate potential non-inplace Buffers count. + // These Buffers can be only around Loops (for example, around MatMul they may be inplace). So we should + // check for element type size of nodes which are used Buffer to get rating from above for uniqe Buffer count. + // The count is estimated because when we calculate this number we have only original graph representation + // and where will be Loops - we can just predict. + // Note: The ops that create Buffers: MatMul, Transpose and Softmax (always FP32) + std::vector used_precision_size; + for (const auto& op : ops) { + if (const auto transpose = ov::as_type_ptr(op)) { + // At the moment Transposes are supported only on Results and Parameters but + // then we should have the different Buffers for Transpose as well (Transpose isn't inplace) + const auto consumers = transpose->get_output_target_inputs(0); + // If after Transpose there is Result it means that there won't be Buffer after Transpose. + // The same case is for Parameter before Transpose + const auto are_prev_or_next_ops = std::none_of(consumers.begin(), consumers.end(), + [](const ov::Input& in) { + return ov::is_type(in.get_node()); + }) || + !ov::is_type(transpose->get_input_node_shared_ptr(0)); + if (are_prev_or_next_ops) { + const auto prc_size = transpose->get_element_type().size(); + if (used_precision_size.empty() || used_precision_size.back() != prc_size) { + used_precision_size.push_back(prc_size); + } + } + } else if (ov::is_type(op) || ov::is_type(op)) { + // Softmax always uses 2 FP32 Buffers + const auto prc_size = ov::element::f32.size(); + if (used_precision_size.empty() || used_precision_size.back() != prc_size) { + used_precision_size.push_back(prc_size); + } + } else if (const auto matmul = ov::as_type_ptr(op)) { + // First input check is enough because MatMul requires the same prc size on inputs + if (!ov::is_type(matmul->get_input_node_shared_ptr(0)) || + !ov::is_type(matmul->get_input_node_shared_ptr(1))) { + const auto prc_size = matmul->get_input_element_type(0).size(); + if (used_precision_size.empty() || used_precision_size.back() != prc_size) { + used_precision_size.push_back(prc_size); + } + } + + const auto consumers = matmul->get_output_target_inputs(0); + if (std::none_of(consumers.begin(), consumers.end(), + [](const ov::Input& in) { return ov::is_type(in.get_node()); })) { + const auto prc_size = matmul->get_element_type().size(); + if (used_precision_size.empty() || used_precision_size.back() != prc_size) { + used_precision_size.push_back(prc_size); + } + } + } + } + + return used_precision_size.size(); +} + snippets::op::Subgraph::Subgraph(const OutputVector& args, std::shared_ptr body) : SubGraphOp(args), m_generator(nullptr) { set_function(body); @@ -195,17 +254,11 @@ auto snippets::op::Subgraph::wrap_node_as_subgraph(const std::shared_ptrget_friendly_name(), body_results, body_parameters); auto subgraph = build_subgraph(node, subgraph_inputs, body); - bool need_buffer = false; size_t hidden_data_count = 0lu; if (auto fq_node = ov::as_type_ptr(node)) { hidden_data_count += utils::get_non_scalar_constant_count_for_fq(fq_node); - // Ops that requires Buffer - } else if (ov::is_type(node) || - ov::is_type(node)) { - need_buffer |= true; } subgraph->set_virtual_port_count(hidden_data_count); - subgraph->set_buffer_needed(need_buffer); for (size_t i = 0; i < body->get_parameters().size(); i++) { body->get_parameters()[i]->set_friendly_name(body_parameters[i]->get_friendly_name()); @@ -491,7 +544,7 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { if (buffer->is_intermediate_memory()) { // Transpose, MatMul and other non-decomposed ops should have different memories on inputs and outputs to avoid data corruption, - // so after them, we should allocate new memory. Other operations (Eltwises, Convert) can be executed inplace inside Loop. + // so after them, we should allocate new memory. Other operations (Eltwises, Convert) can be executed inplace inside Loop only at the moment. OPENVINO_ASSERT(buffer->get_input_size() == 1, "Buffer with intermediate memory must have one parent"); const auto parent = buffer->get_input_node_shared_ptr(0); if (!ov::is_type(parent) || is_transpose_loop(parent)) { @@ -501,6 +554,15 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { continue; } + // If previous allocated memory is less than needed, we have to allocate new + const auto prev_alloc_size = m_buffer_scratchpad - offset; + if (prev_alloc_size < buffer_size) { + offset = m_buffer_scratchpad; + propagate_offset(buffer, offset); + m_buffer_scratchpad += buffer_size; + continue; + } + propagate_offset(buffer, offset); } else { // Single Buffer without input should allocate new memory @@ -581,7 +643,6 @@ void snippets::op::Subgraph::convert_to_snippet_dialect() { m_generator->get_target_machine()->get_lanes(), !config.m_explicit_loop_insertion); if (config.m_has_domain_sensitive_ops) { manager.register_pass(); - manager.register_pass(); } } manager.run_passes(body_ptr()); @@ -630,6 +691,11 @@ snippets::Schedule snippets::op::Subgraph::generate( post_precision.run_passes(body_ptr()); + ov::pass::Manager buffer_m; + buffer_m.register_pass(); + buffer_m.register_pass(); + buffer_m.run_passes(body_ptr()); + // After all passes, when all optimizations are completed and all MemoryAccess ops are inserted, // we can calculate common buffer scratchpad size and propagate offset from Buffer to the corresponding MemoryAccess ops if (config.m_has_domain_sensitive_ops) diff --git a/src/common/snippets/src/pass/assign_registers.cpp b/src/common/snippets/src/pass/assign_registers.cpp index c9af20443b8938..0f2de57a94bc22 100644 --- a/src/common/snippets/src/pass/assign_registers.cpp +++ b/src/common/snippets/src/pass/assign_registers.cpp @@ -13,28 +13,16 @@ #endif namespace { +using Reg = size_t; +using Tensor = std::shared_ptr; + constexpr size_t reg_count = 16lu; using opRegType = ngraph::snippets::Generator::opRegType; -} // namespace - -bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr& f) { - RUN_ON_MODEL_SCOPE(AssignRegisters); - OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::op::AssignRegisters") - using Reg = size_t; - using tensor = std::shared_ptr; - auto ops = f->get_ordered_ops(); - - std::vector>> typed_ops; - for (const auto& op : ops) { - typed_ops.emplace_back(std::make_pair(m_reg_type_mapper(op), op)); - } - size_t counter_vec = 0; - size_t counter_gpr = 0; - std::map regs_vec, regs_gpr; - // Define a set of immune tensors that will be ignored by auto reg allocation => their reg allocation is done manually - std::map manually_assigned_gprs, manually_assigned_vecs; - const auto IS_MANUALLY_ALLOCATED_REG = SIZE_MAX; +auto manual_assigning(const std::shared_ptr& f, + const ov::NodeVector& ops, + std::map& manually_assigned_gprs, + std::map& manually_assigned_vecs) -> void { const auto num_parameters = f->get_parameters().size(); const auto num_results = f->get_results().size(); auto accumulator_reg = 0lu; @@ -42,25 +30,26 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr if (const auto& param = ov::as_type_ptr(op)) { manually_assigned_gprs[op->output(0).get_tensor_ptr()] = static_cast(f->get_parameter_index(param)); - } else if (const auto& result = ov::as_type_ptr(op)) { + } else if (const auto& result = ov::as_type_ptr(op)) { // here we use the fact that Result input & output tensors are identical by construction manually_assigned_gprs[op->output(0).get_tensor_ptr()] = static_cast(f->get_result_index(result) + num_parameters); - } else if (const auto buffer = ov::as_type_ptr(op)) { - // All buffers have one common data pointer + } else if (const auto buffer = ov::as_type_ptr(op)) { + // All buffers with the same ID have one common data pointer + const auto buffer_id = buffer->get_id(); if (buffer->is_intermediate_memory()) { manually_assigned_gprs[op->input(0).get_tensor_ptr()] = - static_cast(num_results + num_parameters); + static_cast(num_results + num_parameters + buffer_id); } manually_assigned_gprs[op->output(0).get_tensor_ptr()] = - static_cast(num_results + num_parameters); - } else if (ov::is_type(op) || ov::is_type(op)) { + static_cast(num_results + num_parameters + buffer_id); + } else if (ov::is_type(op) || ov::is_type(op)) { // Only in SoftmaxDecomposition ReduceMax and ReduceSum use HorizonMax/HorizonSum and VectorBuffer. // We should manually set the one vector register for VectorBuffer and Max/Sum output to simulate a accumulator // TODO [96351]: We should rewrite accumulator pattern using another way const auto input = op->get_input_node_shared_ptr(0); // input - it's accumulator math op: Add or Max for (size_t i = 0; i < input->get_input_size(); ++i) { - if (ov::is_type(input->get_input_node_shared_ptr(i))) { + if (ov::is_type(input->get_input_node_shared_ptr(i))) { manually_assigned_vecs[input->input(i).get_tensor_ptr()] = static_cast(accumulator_reg); } @@ -71,21 +60,49 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr manually_assigned_vecs[op->output(0).get_tensor_ptr()] = static_cast(accumulator_reg); - // If there is Broadcast, it should have the same register as Horizon op - // because it's a result of the accumulator as well - for (auto& out : op->output(0).get_target_inputs()) { - const auto child = out.get_node()->shared_from_this(); - if (ov::is_type(child)) { - manually_assigned_vecs[child->output(0).get_tensor_ptr()] = + auto target_inputs = op->output(0).get_target_inputs(); + auto output = target_inputs.begin()->get_node()->shared_from_this(); + + // All operations `outside loop` after Horizon ops should have the same register to + // avoid using it in the hext Loop + auto iter = output->get_rt_info().find("outside_loop"); + while (iter != output->get_rt_info().end() && iter->second.as()) { + manually_assigned_vecs[output->output(0).get_tensor_ptr()] = static_cast(accumulator_reg); - } + + target_inputs = output->output(0).get_target_inputs(); + output = target_inputs.begin()->get_node()->shared_from_this(); + iter = output->get_rt_info().find("outside_loop"); } + accumulator_reg++; } } +} + +} // namespace + +bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr& f) { + RUN_ON_MODEL_SCOPE(AssignRegisters); + OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::op::AssignRegisters") + auto ops = f->get_ordered_ops(); + + std::vector>> typed_ops; + for (const auto& op : ops) { + typed_ops.emplace_back(std::make_pair(m_reg_type_mapper(op), op)); + } + + size_t counter_vec = 0; + size_t counter_gpr = 0; + std::map regs_vec, regs_gpr; + // Define a set of immune tensors that will be ignored by auto reg allocation => their reg allocation is done manually + std::map manually_assigned_gprs, manually_assigned_vecs; + manual_assigning(f, ops, manually_assigned_gprs, manually_assigned_vecs); + + const auto IS_MANUALLY_ALLOCATED_REG = SIZE_MAX; auto enumerate_out_tensors = [IS_MANUALLY_ALLOCATED_REG] (const std::shared_ptr& op, decltype(regs_vec)& reg_map, - const std::map& manually_assigned_regs, + const std::map& manually_assigned_regs, size_t& counter) { for (const auto& output : op->outputs()) { const auto& t = output.get_tensor_ptr(); @@ -114,7 +131,7 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr std::vector> used_vec(ops.size(), std::set()); std::vector> defined_vec(ops.size(), std::set()); - auto tensor2reg = [IS_MANUALLY_ALLOCATED_REG] (const std::vector& tensors, const std::map& reg_map) { + auto tensor2reg = [IS_MANUALLY_ALLOCATED_REG] (const std::vector& tensors, const std::map& reg_map) { std::set result; for (const auto& t : tensors) { if (reg_map.count(t) == 0) @@ -127,7 +144,7 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr }; for (size_t i = 0; i < typed_ops.size(); i++) { const auto& t_op = typed_ops[i]; - std::vector used_tensors, defined_tensors; + std::vector used_tensors, defined_tensors; for (const auto& in : t_op.second->inputs()) { used_tensors.push_back(in.get_tensor_ptr()); } @@ -279,9 +296,9 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr auto unique2reused_map_vec = linescan_assign_registers(live_intervals_vec, vec_pool); auto unique2reused_map_gpr = linescan_assign_registers(live_intervals_gpr, gpr_pool); - std::map assigned_regs(std::move(manually_assigned_gprs)); + std::map assigned_regs(std::move(manually_assigned_gprs)); assigned_regs.insert(manually_assigned_vecs.begin(), manually_assigned_vecs.end()); - auto register_assigned_regs = [IS_MANUALLY_ALLOCATED_REG, &assigned_regs](const std::map& unique_regs, + auto register_assigned_regs = [IS_MANUALLY_ALLOCATED_REG, &assigned_regs](const std::map& unique_regs, const std::map& unique2reused) { for (const auto& reg : unique_regs) { if (reg.second == IS_MANUALLY_ALLOCATED_REG) diff --git a/src/common/snippets/src/pass/buffer_identification.cpp b/src/common/snippets/src/pass/buffer_identification.cpp new file mode 100644 index 00000000000000..1c7ea6363f480c --- /dev/null +++ b/src/common/snippets/src/pass/buffer_identification.cpp @@ -0,0 +1,174 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "snippets/pass/buffer_identification.hpp" +#include "snippets/snippets_isa.hpp" + +#include + +namespace ngraph { +namespace snippets { +namespace pass { + +namespace { +using BufferSet = std::vector>; + +auto is_intermediate_buffer(const std::shared_ptr& op) -> std::shared_ptr { + const auto buffer = ov::as_type_ptr(op); + return buffer && buffer->is_intermediate_memory() ? buffer : nullptr; +} + + +} // namespace + +auto create_adjacency_matrix(const BufferSet& buffers) -> std::vector { + // The sync point to check for adjency is Loop because only in Loop we increment pointers. + // So if some Buffers in the one Loop have conflict (cannot be inplace: the same ptr increment and finalization offset) + // they are called as adjacent + const auto size = buffers.size(); + std::vector adj(size * size, false); + for (size_t i = 0; i < size; ++i) + adj[i + i * size] = true; + + auto update_adj_matrix = [&](const std::shared_ptr& buffer, size_t buffer_index, + const std::shared_ptr& neighbour_buffer) { + if (neighbour_buffer) { + if (buffer->get_allocation_shape() != neighbour_buffer->get_allocation_shape() || + buffer->get_element_type().size() != neighbour_buffer->get_element_type().size()) { + const auto iter = std::find(buffers.cbegin(), buffers.cend(), neighbour_buffer); + NGRAPH_CHECK(iter != buffers.cend(), "Buffer wasn't find in Buffer system of Subgraph"); + + const size_t adj_idx = std::distance(buffers.cbegin(), iter); + adj[buffer_index + adj_idx * size] = true; + } + } + }; + + for (size_t i = 0; i < buffers.size(); ++i) { + const auto buffer = buffers[i]; + + auto port = buffer->input_value(0).get_index(); + auto parent = buffer->get_input_node_shared_ptr(0); + // We iterate in While cycle to check nested Loops + while (const auto loop_end = ov::as_type_ptr(parent)) { + const auto loop_begin = loop_end->get_loop_begin(); + for (const auto& input_value : loop_begin->input_values()) { + auto loop_in = input_value.get_node_shared_ptr(); + auto port_idx = input_value.get_index(); + while (std::dynamic_pointer_cast(loop_in)) { + const auto source_output = loop_in->input_value(port_idx); + loop_in = source_output.get_node_shared_ptr(); + port_idx = source_output.get_index(); + } + + if (const auto neighbour_buffer = is_intermediate_buffer(loop_in)) { + update_adj_matrix(buffer, i, neighbour_buffer); + } + } + for (const auto& output : loop_end->outputs()) { + // check for first target input is enough for Buffer searching because operations can have only single Buffer per each output port as op + const auto target_inputs = output.get_target_inputs(); + auto consumer_in = *target_inputs.begin(); + auto port_idx = consumer_in.get_index(); + auto consumer = consumer_in.get_node()->shared_from_this(); + while (std::dynamic_pointer_cast(consumer)) { + const auto target_inputs = consumer->get_output_target_inputs(port_idx); + auto consumer_in = *target_inputs.begin(); + port_idx = consumer_in.get_index(); + consumer = consumer_in.get_node()->shared_from_this(); + } + + if (buffer != consumer) { + if (const auto neighbour_buffer = is_intermediate_buffer(consumer)) { + update_adj_matrix(buffer, i, neighbour_buffer); + } + } + } + + parent = parent->get_input_node_shared_ptr(port); + port = parent->input_value(port).get_index(); + } + } + + return adj; +} + +auto coloring(BufferSet& buffers, std::vector& adj) -> std::map { + size_t color = 0; + std::map color_groups; + const auto size = buffers.size(); + for (size_t i = 0; i < size; i++) { + if (!buffers[i]) + continue; + + auto buffer = buffers[i]; + color_groups[color].push_back(buffer); // Add to Color Group + buffers[i] = nullptr; // Remove from graph vertices + + // while Buffer i has not coloured non-neighbours + // (row i contains 0) + while (!std::accumulate(adj.begin() + i * size, adj.begin() + (i + 1) * size, true, std::logical_and())) { + size_t j = i + 1; + bool force_break = false; + for (; j < size; ++j) { + if (adj[j + i * size] && buffers[j]) { + force_break = true; + break; + } + if (!adj[j + i * size] && buffers[j]) + break; + } + + if (force_break || j == size) + break; + + auto neighbour_buffer = buffers[j]; + color_groups[color].push_back(neighbour_buffer); // Add to Color Group + buffers[j] = nullptr; // Remove from graph vertices + std::transform(adj.begin() + i * size, adj.begin() + (i + 1) * size, adj.begin() + j * size, + adj.begin() + i * size, std::logical_or()); + } + + color++; + } + + return color_groups; +} + +bool BufferIdentification::run_on_model(const std::shared_ptr& model) { + RUN_ON_FUNCTION_SCOPE(BufferIdentification); + // Unite Buffers using Graph coloring algorithm. + // Notes: We identify only Buffer with Intermediate memory because Buffers with new memory are used only in Brgemm case + // so these Buffers are always IntermediateBuffer nonadjacent + BufferSet buffers; + + const auto ops = model->get_ordered_ops(); + for (const auto& op : ops) { + if (const auto buffer = is_intermediate_buffer(op)) { + buffers.push_back(buffer); + } + } + + // Creation of Adj matrix + auto adj = create_adjacency_matrix(buffers); + + // Graph coloring algorithm + const auto color_groups = coloring(buffers, adj); + + // FIXME: use const auto& [color, united_buffers] when C++17 is available + for (const auto& pair : color_groups) { + const auto color = pair.first; + const auto united_buffers = pair.second; + for (const auto& buffer : united_buffers) { + buffer->set_id(color); + } + } + + return true; +} + +} // namespace pass +} // namespace snippets +} // namespace ngraph diff --git a/src/common/snippets/src/pass/collapse_subgraph.cpp b/src/common/snippets/src/pass/collapse_subgraph.cpp index af962adaa64432..ebe5d162e7ba0a 100644 --- a/src/common/snippets/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/src/pass/collapse_subgraph.cpp @@ -64,11 +64,21 @@ auto is_supported_op(const std::shared_ptr &n) -> bool { const auto& transpose = as_type_ptr(n); const auto& out_shape = n->get_output_partial_shape(0); if (transpose && out_shape.is_static()) { + const auto parent = transpose->get_input_node_shared_ptr(0); + const auto child = transpose->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); + auto is_brgemm_case = ov::is_type(parent) || ov::is_type(child); + // Check for Transpose parent is MatMul inside Subgraph + if (const auto subgraph = ov::as_type_ptr(parent)) { + const auto body = subgraph->body_ptr(); + const auto subgraph_output = body->get_results()[transpose->input_value(0).get_index()]->get_input_node_shared_ptr(0); + is_brgemm_case = is_brgemm_case || ov::is_type(subgraph_output); + } + const auto& order = as_type_ptr(n->get_input_node_shared_ptr(1)); if (order) { const auto order_value = order->cast_vector(); - return TransposeDecomposition::supported_cases.count(order_value) != 0 || - FuseTransposeBrgemm::supported_cases.count(order_value) != 0; + return (TransposeDecomposition::supported_cases.count(order_value) != 0) || + (is_brgemm_case && FuseTransposeBrgemm::supported_cases.count(order_value) != 0); } } return false; @@ -336,7 +346,7 @@ TokenizeSnippets::TokenizeSnippets() { for (const auto &input_node : ngraph::as_node_vector(input_values)) { if (auto subgraph = ov::as_type_ptr(input_node)) { - if (!clones.count(input_node)) { + if (!clones.count(input_node) && GetSnippetsSubgraphType(subgraph) != SnippetsSubgraphType::Completed) { auto f = subgraph->body().clone(); f->set_friendly_name(subgraph->body_ptr()->get_friendly_name()); clones[input_node] = f; @@ -512,23 +522,23 @@ TokenizeSnippets::TokenizeSnippets() { // To avoid unsupported number of non-scalar Constants in the future (plugin specific limitation) // we should calculate potentional number of non-scalar Constants that will be moved up from body. size_t hidden_data_count = 0; - bool need_buffer = false; if (const auto fq_node = ov::as_type_ptr(node)) { hidden_data_count += ngraph::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node); - // Ops require a Buffer - } else if (ov::is_type(node) || - ov::is_type(node)) { - need_buffer |= true; } ResultVector body_results; std::vector>> subgraph_result_inputs; + ov::NodeVector new_body_ops; for (auto subgraph : input_subgraphs) { // we should summurize additional needed data count (non-scalar Constants and Buffers) from all input subgraphs // because we will collapse them with our node and we should get total count - hidden_data_count += ov::as_type_ptr(subgraph)->get_virtual_port_count(); - need_buffer |= ov::as_type_ptr(subgraph)->is_buffer_needed(); + const auto subgraph_ptr = ov::as_type_ptr(subgraph); + hidden_data_count += subgraph_ptr->get_virtual_port_count(); + if (subgraph_ptr->has_domain_sensitive_ops()) { + const auto ops = subgraph_ptr->body_ptr()->get_ordered_ops(); + new_body_ops.insert(new_body_ops.end(), ops.begin(), ops.end()); + } for (auto output : subgraph->outputs()) { bool first_side_consumer = true; @@ -559,6 +569,10 @@ TokenizeSnippets::TokenizeSnippets() { } } + if (op::Subgraph::is_domain_sensitive_op(node)) { + new_body_ops.push_back(node); + } + for (auto output : node->outputs()) { body_results.push_back(std::make_shared(body_node->output(output.get_index()))); subgraph_result_inputs.push_back(output.get_target_inputs()); @@ -569,13 +583,14 @@ TokenizeSnippets::TokenizeSnippets() { } // todo: move this plugin-specific constraint to the plugin callback - if (body_parameters.size() + body_results.size() + hidden_data_count + static_cast(need_buffer) > 12) { + const auto unique_buffer_count = op::Subgraph::get_estimated_buffer_count(new_body_ops); + if (body_parameters.size() + body_results.size() + hidden_data_count + unique_buffer_count > 12) { const std::string message_reset = "new subgraph is created. Impossible to schedule subgraph with " + std::to_string(body_parameters.size()) + " inputs, " + std::to_string(body_results.size()) + " outputs and " + - std::to_string(hidden_data_count) + " non-scalar constants and " + std::to_string(need_buffer) + "buffers."; + std::to_string(hidden_data_count) + " non-scalar constants and " + std::to_string(unique_buffer_count) + "buffers."; const std::string message_abort = "failed to continue subgraph. Impossible to schedule subgraph with " + std::to_string(body_parameters.size()) + " inputs, " + std::to_string(body_results.size()) + " outputs and " + - std::to_string(hidden_data_count) + " non-scalar constants and " + std::to_string(need_buffer) + "buffers."; + std::to_string(hidden_data_count) + " non-scalar constants and " + std::to_string(unique_buffer_count) + "buffers."; return abort_with_strategy(message_reset, message_abort); } @@ -612,7 +627,6 @@ TokenizeSnippets::TokenizeSnippets() { } subgraph->get_rt_info()["originalLayersNames"] = fusedNames; subgraph->set_virtual_port_count(hidden_data_count); - subgraph->set_buffer_needed(need_buffer); remark(1) << "Replacement (merge) done for: " << subgraph->get_friendly_name() diff --git a/src/common/snippets/src/pass/mha_tokenization.cpp b/src/common/snippets/src/pass/mha_tokenization.cpp index 69a166140b4093..25e2cc397985c0 100644 --- a/src/common/snippets/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/src/pass/mha_tokenization.cpp @@ -7,6 +7,8 @@ #include "snippets/pass/mha_tokenization.hpp" #include "snippets/pass/tokenization.hpp" #include "snippets/op/subgraph.hpp" +#include "snippets/op/brgemm.hpp" +#include "snippets/utils.hpp" #include #include @@ -16,18 +18,19 @@ namespace { auto is_supported_tensor(const ngraph::descriptor::Tensor& t) -> bool { - // TODO: Add support of all supported by common tokenization element types - // return ngraph::snippets::pass::TokenizeSnippets::supported_element_types.count(input.get_element_type()) != 0; - // Also only 4D is supported at the moment - return t.get_element_type() == ngraph::element::f32 && t.get_partial_shape().is_static() && t.get_shape().size() == 4; + // TODO: Add support of non-4D tensors + return t.get_partial_shape().is_static() && t.get_shape().size() == 4; } -// TODO: Add support of FQ, Reshape? -auto is_supported_op(const std::shared_ptr& node) -> bool { - return ngraph::snippets::pass::TokenizeSnippets::AppropriateForSubgraph(node) && - (ngraph::is_type(node) || - ngraph::is_type(node) || - ngraph::is_type(node)); +// TODO: Add support of Reshape? +auto is_supported_intermediate_op(const std::shared_ptr& node) -> bool { + const auto is_intermediate_op = [](const std::shared_ptr& node) { + return ngraph::is_type(node) || + ngraph::is_type(node) || + ngraph::is_type(node) || + ngraph::is_type(node); + }; + return ngraph::snippets::pass::TokenizeSnippets::AppropriateForSubgraph(node) && is_intermediate_op(node); } auto is_valid_transpose(const std::shared_ptr& node, std::vector expected_order) -> bool { @@ -37,9 +40,12 @@ auto is_valid_transpose(const std::shared_ptr& node, return false; return transpose_pattern->cast_vector() == expected_order; }; + auto is_supported_transpose_tensor = [](const ngraph::descriptor::Tensor& t) { + return is_supported_tensor(t) && ngraph::snippets::pass::TokenizeSnippets::supported_element_types.count(t.get_element_type()) != 0; + }; return node && node->get_output_target_inputs(0).size() == 1 && node->get_shape().size() == 4 && - valid_transpose_order(node->get_input_node_shared_ptr(1)) && is_supported_tensor(node->get_input_tensor(0)); + valid_transpose_order(node->get_input_node_shared_ptr(1)) && is_supported_transpose_tensor(node->get_input_tensor(0)); } auto tokenize_broadcast(const std::shared_ptr& interm_op, ov::NodeVector& ordered_ops) -> void { @@ -95,26 +101,78 @@ auto tokenize_reshape_around_softmax(std::shared_ptr& interm_op, ngraph::NodeVector& ordered_ops) -> bool { reshape = ngraph::as_type_ptr(interm_op); if (reshape) { - const auto shape = reshape->get_input_shape(0); - if (shape.back() != reshape->get_output_shape(0).back() || reshape->get_output_target_inputs(0).size() != 1) + const auto in_shape = reshape->get_input_shape(0); + const auto out_shape = reshape->get_output_shape(0); + if (in_shape.back() != out_shape.back() || reshape->get_output_target_inputs(0).size() != 1) return false; ordered_ops.push_back(reshape); interm_op = reshape->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); } return true; -}; +} + +auto get_potential_body_params(const std::shared_ptr& op) -> size_t { + size_t count = 0; + for (size_t i = 1; i < op->get_input_size(); ++i) { + const auto input = op->input_value(i); + const auto parent = input.get_node_shared_ptr(); + const auto constant = ov::as_type_ptr(parent); + if (!(constant && (ngraph::shape_size(input.get_shape()) == 1 || + ov::is_type(op)|| + ngraph::snippets::op::Subgraph::constant_input_should_be_inside_body(op)))) { + count++; + } + } + return count; +} -auto update_intermediate_supported_ops(std::shared_ptr& interm_op, ngraph::NodeVector& ordered_ops) -> bool { - // TODO: Add Reshape, FQ support - while (is_supported_op(interm_op)) { +auto update_intermediate_supported_ops(std::shared_ptr& interm_op, ngraph::NodeVector& ordered_ops, + size_t& hidden_virtual_ports_count, size_t& potential_body_params_count) -> bool { + // TODO: Add Reshape support + while (is_supported_intermediate_op(interm_op)) { // All supported intermediate ops have only one output port - // To verify output element type is enough because all supported intermediate ops have the same output element type as input type - if (interm_op->get_output_target_inputs(0).size() != 1 || !is_supported_tensor(interm_op->get_output_tensor(0))) + if (interm_op->get_output_target_inputs(0).size() != 1) return false; - // Check for supported Broadcast op + // Check for supported ops on branches: Broadcast/Elementwise (for example, dequantize ops) if (interm_op->get_input_size() > 1) { tokenize_broadcast(interm_op, ordered_ops); + + // To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation) + // we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body. + if (const auto fq_node = ov::as_type_ptr(interm_op)) { + hidden_virtual_ports_count += ngraph::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node); + } + + auto is_supported_branch_op = [&ordered_ops](const std::shared_ptr& op) { + return is_supported_intermediate_op(op) && + ngraph::snippets::pass::GetSnippetsNodeType(op) != ngraph::snippets::pass::SnippetsNodeType::SkippedByPlugin && + std::find(ordered_ops.begin(), ordered_ops.end(), op) == ordered_ops.end(); + }; + + for (size_t i = 0; i < interm_op->get_input_size(); ++i) { + const size_t shift = ordered_ops.size(); + auto parent = interm_op->get_input_node_shared_ptr(i); + while (is_supported_branch_op(parent)) { + // All supported ops have only one output port + if (parent->get_output_target_inputs(0).size() != 1) + break; + + // Add node only if there are scalar constants on inputs because of plugin-specific limitation + bool are_weights_scalar = true; + const auto parent_count = parent->get_input_size(); + for (size_t i = 1; i < parent_count; ++i) { + are_weights_scalar = are_weights_scalar && ngraph::shape_size(parent->get_input_shape(i)) == 1; + } + + ordered_ops.insert(ordered_ops.begin() + shift, parent); + // We think that sequence of ops goes through input port 0 + // But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way? + parent = parent->get_input_node_shared_ptr(0); + } + } + + potential_body_params_count += get_potential_body_params(interm_op); } ordered_ops.push_back(interm_op); @@ -135,14 +193,29 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::op::TokenizeMHASnippets") auto& pattern_to_output = m.get_pattern_value_map(); + // Queries + Key + Values = 3 standard inputs of MHA + size_t potential_body_params_count = 3; // After some transformations, a different number of Constants for some operations may be created // than the actual number of Constants during tokenization. // To avoid unsupported number of non-scalar Constants in the future (plugin specific limitation) // we should calculate potential number of non-scalar Constants that will be moved up from body. - // TODO: Need update this variable when FQ will be supported size_t hidden_virtual_ports_count = 0; - // Default value is True because MHA pattern always requires Buffer op - bool need_buffer = true; + // The count of potential unique Buffers - it's hidden virtual ports as well + // We should go through Subgraph and calculate potential non-inplace Buffers count. + // Example: + // Buffer - i32 [32, 128] -> ~ Loop ~ -> Buffer - i8 [32, 128] + // After each Loop iteration we should increment pointers of Buffers: accordingly on 4 byte and 1 byte for scalar case. + // It means that these Buffers cannot be inplace => Each Buffer should have the own register + // For that we can just check the following "branches": + // - Between MatMul0 and MatMul1 - Softmax is sync point. The operations between MatMul0 -> Softmax and Softmax -> MatMul1 + // will be fused into one loop aftet conversion to snippet dialect (Becase it's just FQ, Eltwise nodes) + // - Between MatMul0 and Transpose1 - At the moment operations after Transpose1 cannot be fused in Transpose Loop (to avoid performance regressions). + // But operations after Transpose1 and before MatMul0 will be fused into one loop as well (look at first point) + // Note: If the pass is updated, need to check the new possible branches for potential non-inplace Buffers! + // Default value is 1 because + // - Firstly Softmax always need to have Buffers + // - Secondly Softmax need 2 Buffer but they can be inplace - One virtual port is enough for Softmax + size_t buffer_count = 1; std::string fused_names; ngraph::NodeVector ordered_ops; @@ -166,15 +239,24 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { !is_supported_tensor(matmul0->get_input_tensor(0)) || !is_supported_tensor(matmul0->get_input_tensor(1))) return false; - if (transformation_callback(matmul0)) { + const auto matmul0_prc = op::Brgemm::get_output_type(matmul0->get_input_element_type(0), matmul0->get_input_element_type(1)); + if (matmul0_prc == element::undefined) { return false; } + // Between MatMul0 and Softmax will be the one Loop because of LoopFusing optimization. + // The Loop will have one Buffer with the same shape both on input and output. + // Need to check for precision to get if we need one more register for Buffer + if (matmul0_prc.size() != ov::element::f32.size()) { + if (buffer_count < 2) + buffer_count++; + } + ordered_ops.push_back(matmul0); auto interm_op = matmul0->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); // Add supported operations which are between MatMul0 and Softmax to ordered_ops - if (!update_intermediate_supported_ops(interm_op, ordered_ops)) + if (!update_intermediate_supported_ops(interm_op, ordered_ops, hidden_virtual_ports_count, potential_body_params_count)) return false; std::shared_ptr reshape0 = nullptr; @@ -205,14 +287,26 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { return false; // Add supported operations which are between Softmax and MatMul1 to ordered_ops - if (!update_intermediate_supported_ops(interm_op, ordered_ops)) + if (!update_intermediate_supported_ops(interm_op, ordered_ops, hidden_virtual_ports_count, potential_body_params_count)) return false; const auto matmul1 = ngraph::as_type_ptr(interm_op); if (!matmul1 || matmul1->get_output_target_inputs(0).size() != 1 || matmul1->get_transpose_a() || matmul1->get_transpose_b() || + op::Brgemm::get_output_type(matmul1->get_input_element_type(0), matmul1->get_input_element_type(1)) == element::undefined || !is_supported_tensor(matmul1->get_input_tensor(0)) || !is_supported_tensor(matmul1->get_input_tensor(1))) return false; + if (transformation_callback(matmul1)) { + return false; + } + + // Between Softmax and MatMul1 will be the one Loop because of LoopFusing optimization. + // The Loop will have one Buffer with the same shape both on input and output. + // Need to check for precision to get if we need one more register for Buffer + if (matmul1->get_input_element_type(0).size() != ov::element::f32.size()) { + buffer_count++; + } + /***********************/ /***** Transposes *****/ @@ -224,16 +318,21 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { // so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose1 bool are_weights_scalar = true; auto parent = matmul0->get_input_node_shared_ptr(1); - while (is_supported_op(parent)) { + while (is_supported_intermediate_op(parent)) { // All supported ops have only one output port - // To verify output element type is enough because all supported ops have the same output element type as input type - if (parent->get_output_target_inputs(0).size() != 1 || !is_supported_tensor(parent->get_output_tensor(0))) + if (parent->get_output_target_inputs(0).size() != 1) break; - const auto parent_count = parent->inputs().size(); + const auto parent_count = parent->get_input_size(); for (size_t i = 1; i < parent_count; ++i) { are_weights_scalar = are_weights_scalar && ngraph::shape_size(parent->get_input_shape(i)) == 1; } + // To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation) + // we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body. + if (const auto fq_node = ov::as_type_ptr(parent)) { + hidden_virtual_ports_count += ngraph::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node); + } + potential_body_params_count += get_potential_body_params(parent); ordered_ops.insert(ordered_ops.begin(), parent); // We think that sequence of ops goes through input port 0 // But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way? @@ -261,6 +360,15 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { } } + if (transpose1) { + // Between Transpose1 and MatMul0 will be the one Loop because of LoopFusing optimization. + // The Loop will have one Buffer with the same shape both on input and output. + // Need to check for precision to get if we need one more register for Buffer + if (matmul0->get_input_element_type(1).size() != transpose1->get_output_element_type(0).size()) { + buffer_count++; + } + } + // TODO: Add Reshape Support for all Transposes // Add 3D support for all Transposes const auto transpose0 = ngraph::as_type_ptr(matmul0->get_input_node_shared_ptr(0)); @@ -276,16 +384,41 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { } ordered_ops.push_back(matmul1); + bool are_ops_after_matmul2 = false; auto child = matmul1->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); - // TODO: Add support Eltwises between MatMul1 and Transpose - // status = update_intermediate_supported_ops(child, ordered_ops); - // if (!status) { - // ordered_ops.push_back(child); - // } - - auto transpose3 = ngraph::as_type_ptr(child); - if (is_valid_transpose(transpose3, {0, 2, 1, 3})) { - ordered_ops.push_back(transpose3); + while (is_supported_intermediate_op(child)) { + are_ops_after_matmul2 = true; + // All supported ops have only one output port + if (child->get_output_target_inputs(0).size() != 1) + break; + + // To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation) + // we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body. + if (const auto fq_node = ov::as_type_ptr(child)) { + hidden_virtual_ports_count += ngraph::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node); + } + potential_body_params_count += get_potential_body_params(child); + + // TODO: move this plugin-specific constraint to the plugin callback + // We cannot collapse op to Subgraph if count of potential Parameter and Result count is higher 12 + if (potential_body_params_count + child->get_output_target_inputs(0).size() + hidden_virtual_ports_count + buffer_count > 12) { + break; + } + + ordered_ops.push_back(child); + child = child->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); + } + + // TODO: Add full support of Transpose to cover cases where there are nodes between MatMul2 and Transpose3: + // MatMul2 + // + // Transpose3 + // TODO: Add check for precision of MatMul (we cannot collapse Transpose to I8/BF16 MatMul) + if (!are_ops_after_matmul2) { + auto transpose3 = ngraph::as_type_ptr(child); + if (is_valid_transpose(transpose3, {0, 2, 1, 3})) { + ordered_ops.push_back(transpose3); + } } /**********************/ @@ -294,6 +427,12 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { /* ====== Subgraph creation ======= */ + // TODO: move this plugin-specific constraint to the plugin callback + const auto last_node = ordered_ops.back(); + if (potential_body_params_count + last_node->get_output_size() + hidden_virtual_ports_count + buffer_count > 12) { + return false; + } + ngraph::OutputVector body_inputs, subgraph_inputs; ngraph::ParameterVector body_parameters; ngraph::ResultVector body_results; @@ -304,7 +443,9 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { const auto input = node->input(i); const auto parent = input.get_source_output().get_node_shared_ptr(); const auto constant = ov::as_type_ptr(parent); - if (constant && (ngraph::shape_size(input.get_shape()) == 1 || op::Subgraph::constant_input_should_be_inside_body(node))) { + if (constant && (ngraph::shape_size(input.get_shape()) == 1 || + ov::is_type(node) || + op::Subgraph::constant_input_should_be_inside_body(node))) { // If Constant has one consumer - target node, we add Constant to body_inputs // If Constant has several consumers, we should check that all these consumers are inside Subgraph body // and if all of them are inside body, we can explicitly add Constant to the body_inputs, otherwise we should @@ -347,7 +488,6 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { fused_names += op->get_friendly_name() + ","; } - const auto last_node = ordered_ops.back(); for (const auto& output : last_node->outputs()) { subgraph_result_inputs.push_back(output.get_target_inputs()); } @@ -359,11 +499,6 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { throw ngraph_error("body results and node results size mismatch during subgraph collapse"); } - // todo: move this plugin-specific constraint to the plugin callback - if (body_parameters.size() + body_results.size() + hidden_virtual_ports_count > 12) { - return false; - } - auto body = op::create_body(last_node->get_friendly_name(), body_results, body_parameters); auto subgraph = std::make_shared(subgraph_inputs, body); // Copy runtime info from last node to subgraph - to copy topological order @@ -385,7 +520,9 @@ ngraph::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() { } subgraph->get_rt_info()["originalLayersNames"] = fused_names; subgraph->set_virtual_port_count(hidden_virtual_ports_count); - subgraph->set_buffer_needed(need_buffer); + + // mark the Subgraph as Completed to not allow Snippets to include any nodes into the MHA Subgraph in common Tokenization + SetSnippetsSubgraphType(subgraph, SnippetsSubgraphType::Completed); return true; diff --git a/src/common/snippets/src/pass/reset_buffer.cpp b/src/common/snippets/src/pass/reset_buffer.cpp index 54bdfef03f7f13..30ca6b932c98b0 100644 --- a/src/common/snippets/src/pass/reset_buffer.cpp +++ b/src/common/snippets/src/pass/reset_buffer.cpp @@ -13,15 +13,15 @@ namespace { void normalize_ptr_and_offsets(const ov::NodeVector &io, std::vector &ptr_increments, std::vector &finalization_offsets) { - bool there_is_buffer = false; + std::set buffers; // Iterations are from end because before we correct finalization offsets for Loop outputs (io = inputs + outputs) for (int i = static_cast(io.size()) - 1; i >= 0; --i) { - if (ov::is_type(io[i])) { - if (there_is_buffer) { + if (const auto buffer = ov::as_type_ptr(io[i])) { + if (buffers.count(buffer->get_id()) > 0) { ptr_increments[i] = 0; finalization_offsets[i] = 0; } else { - there_is_buffer = true; + buffers.insert(buffer->get_id()); } } } @@ -86,7 +86,7 @@ ngraph::snippets::pass::ResetBufferState::ResetBufferState() { auto loop_index = 0lu; auto loop = loop_end->input_value(i).get_node_shared_ptr(); auto port_idx = loop_end->input_value(i).get_index(); - while (std::dynamic_pointer_cast(loop)) { + while (ov::is_type(loop)) { const auto source_output = loop->input_value(port_idx); loop = source_output.get_node_shared_ptr(); port_idx = source_output.get_index(); diff --git a/src/common/snippets/src/pass/tokenization.cpp b/src/common/snippets/src/pass/tokenization.cpp index 4744b73b88295e..ec373bd4a82b74 100644 --- a/src/common/snippets/src/pass/tokenization.cpp +++ b/src/common/snippets/src/pass/tokenization.cpp @@ -17,6 +17,13 @@ void SetSnippetsNodeType(const std::shared_ptr &node, SnippetsNodeType nod rt["SnippetsNodeType"] = nodeType; } +void SetSnippetsSubgraphType(const std::shared_ptr &node, SnippetsSubgraphType nodeType) { + if (node) { + auto &rt = node->get_rt_info(); + rt["SnippetsSubgraphType"] = nodeType; + } +} + SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr &node) { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::GetSnippetsNodeType") auto &rt = node->get_rt_info(); @@ -26,6 +33,17 @@ SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr &node) { return rinfo->second.as(); } +SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr &node) { + OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::GetSnippetsSubgraphType") + if (!node) + return SnippetsSubgraphType::NotSet; + auto &rt = node->get_rt_info(); + const auto rinfo = rt.find("SnippetsSubgraphType"); + if (rinfo == rt.end()) + return SnippetsSubgraphType::NotSet; + return rinfo->second.as(); +} + void SetTopologicalOrder(const std::shared_ptr &node, int64_t order) { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::SetTopologicalOrder") auto &rt = node->get_rt_info(); diff --git a/src/common/snippets/tests/src/pass/mha_tokenization.cpp b/src/common/snippets/tests/src/pass/mha_tokenization.cpp index 4c3d967be5f310..b2fafb7f1ec03b 100644 --- a/src/common/snippets/tests/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/tests/src/pass/mha_tokenization.cpp @@ -20,14 +20,23 @@ void TokenizeMHASnippetsTests::run() { } TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA) { - const auto &f = MHAFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}); + const auto &f = MHAFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, + std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32})); function = f.getOriginal(); function_ref = f.getReference(); run(); } TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_MatMul0_Transpose) { - const auto &f = MHAMatMul0TransposeFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}); + const auto &f = MHAMatMul0TransposeFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, + std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32})); + function = f.getOriginal(); + function_ref = f.getReference(); + run(); +} + +TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_int_Matmuls) { + const auto &f = MHAINT8MatMulTypeRelaxedFunction(std::vector{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}); function = f.getOriginal(); function_ref = f.getReference(); run(); diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp index 338cb62dcec39b..a9169d0b52bc7d 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp @@ -139,8 +139,12 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: auto results = model->get_results(); num_inputs = params.size(); num_outputs = results.size(); - is_buffer_needed = std::any_of(ops.begin(), ops.end(), - [](const std::shared_ptr& node) { return ov::is_type(node); } ); + std::set unique_buffers; + for (const auto& op : ops) { + if (const auto buffer = ov::as_type_ptr(op)) + unique_buffers.insert(buffer->get_id()); + } + num_unqiue_buffer = unique_buffers.size(); NodeVector io_nodes; std::copy(params.begin(), params.end(), std::back_inserter(io_nodes)); std::copy(results.begin(), results.end(), std::back_inserter(io_nodes)); @@ -216,15 +220,16 @@ void KernelEmitter::validate_arguments(const std::vector &in, IE_THROW() << "KernelEmitter got invalid number of inputs. Expected 0, got " << in.size(); if (!out.empty()) IE_THROW() << "KernelEmitter got invalid number of outputs. Expected 0, got " << out.size(); - const auto num_params = num_inputs + num_outputs + static_cast(is_buffer_needed); + const auto num_params = num_inputs + num_outputs + num_unqiue_buffer; // The number of used gpr may be >= num_params since LoopBegin+LoopEnd could also use gpr to store work_amount if (data_ptr_regs_idx.size() != num_params) - IE_THROW() << "KernelEmitter: number of inputs and outputs is inconsisnent with the number of allocated registers" + IE_THROW() << "KernelEmitter: number of inputs and outputs is inconsisnent with the number of allocated registers " << num_params << " data_ptr_regs_idx.size() = " << data_ptr_regs_idx.size(); } -void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, bool is_buffer_needed, - const Reg64& reg_indexes, const Reg64& reg_const_params, const std::vector& data_ptr_regs) const { +void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, size_t num_unqiue_buffer, + const Xbyak::Reg64& reg_indexes, const Xbyak::Reg64& reg_const_params, + const std::vector& data_ptr_regs) const { // Note that we don't need offset for the last dim, since it's handled directly by Tile emitter const size_t offset_rank = jcp.master_shape.size() - 1; //const size_t tile_rank = jcp.tile_rank; @@ -287,8 +292,8 @@ void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, boo // Vector "data_ptr_regs" is sorted by abstract regs. // It means that the vector contains the physical registers in order [src, .., src, dst, .., dst, buffer] // So we can initialize buffer register firstly as last value of vector "data_ptr_regs" - if (is_buffer_needed) { - h->mov(data_ptr_regs[num_params], h->ptr[reg_const_params + GET_OFF(buffer_scratchpad_ptr)]); + for (size_t i = 0; i < num_unqiue_buffer; ++i) { + h->mov(data_ptr_regs[num_params + i], h->ptr[reg_const_params + GET_OFF(buffer_scratchpad_ptr)]); } size_t i = 0; for (; i < num_params - last_iter_explicitly; i++) { @@ -319,7 +324,7 @@ void KernelEmitter::emit_impl(const std::vector& in, std::vector data_ptr_regs; transform_idxs_to_regs(data_ptr_regs_idx, data_ptr_regs); - init_data_pointers(num_inputs, num_inputs + num_outputs, is_buffer_needed, reg_indexes, reg_const_params, data_ptr_regs); + init_data_pointers(num_inputs, num_inputs + num_outputs, num_unqiue_buffer, reg_indexes, reg_const_params, data_ptr_regs); for (const auto& c : body) { const auto& emitter = c.first; std::vector in_regs, out_regs; diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp index 0f00eb6f7048b8..9befa79eda2470 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp @@ -83,15 +83,14 @@ class KernelEmitter : public jit_container_emitter { using jit_emitter::emit_code; void validate_arguments(const std::vector &in, const std::vector &out) const override; - void emit_impl(const std::vector& in, - const std::vector& out) const override; - void init_data_pointers(size_t, size_t, bool, const Xbyak::Reg64&, const Xbyak::Reg64&, const std::vector&) const; + void emit_impl(const std::vector& in, const std::vector& out) const override; + void init_data_pointers(size_t, size_t, size_t, const Xbyak::Reg64&, const Xbyak::Reg64&, const std::vector&) const; jit_snippets_compile_args jcp; std::vector gp_regs_pool; size_t num_inputs; size_t num_outputs; - bool is_buffer_needed; + size_t num_unqiue_buffer; // Vector of indices (lenght = input tensor rank) per every input and output that describes in which order // corresponding tensor dimensions are accessed (default: consecutive dense, e.g. 0,1,2,3 for 4D tensor). // Needed to calc i/o offsets. diff --git a/src/plugins/intel_cpu/src/ngraph_transformations/snippets_mark_skipped.cpp b/src/plugins/intel_cpu/src/ngraph_transformations/snippets_mark_skipped.cpp index 221b0145e08f0d..7afcd4f36650a8 100644 --- a/src/plugins/intel_cpu/src/ngraph_transformations/snippets_mark_skipped.cpp +++ b/src/plugins/intel_cpu/src/ngraph_transformations/snippets_mark_skipped.cpp @@ -427,7 +427,7 @@ void MarkSubgraphOpAsSkipped(const std::shared_ptr &node) { bool isSuitableConvert(const std::shared_ptr& node) { if (!ov::is_type(node)) return false; - auto hasResult = [](const std::shared_ptr& node){ + auto hasResult = [](const std::shared_ptr &node) { auto consumers = node->output(0).get_target_inputs(); bool findResult = false; if (consumers.size() == 1) { @@ -449,13 +449,21 @@ bool isSuitableConvert(const std::shared_ptr& node) { return false; } } + +auto is_skipped_op(const std::shared_ptr& op) -> bool { + return ngraph::op::is_constant(op) || + ov::is_type(op) || + ov::is_type(op) || + ov::is_type(op) || + ov::is_type(op); +} } // namespace bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { RUN_ON_MODEL_SCOPE(SnippetsMarkSkipped); int channelAxis = DEFAULT_AXIS; for (auto &node : m->get_ordered_ops()) { - if (ngraph::op::is_constant(node) || ov::is_type(node)) + if (is_skipped_op(node)) continue; if (isSuitableConvolutionParent(node)) { // Initiate fusing chain diff --git a/src/plugins/intel_cpu/src/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformation_pipeline.cpp index 2bedc4d32df2e2..67d3152ec21703 100644 --- a/src/plugins/intel_cpu/src/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformation_pipeline.cpp @@ -558,7 +558,6 @@ void Transformations::MainSnippets(void) { if (snippetsMode == Config::SnippetsMode::Disable || !dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemented only for relevant platforms (avx2+ extensions) return; - ngraph::pass::Manager snippetsManager; snippetsManager.set_per_pass_validation(false); if (snippetsMode != Config::SnippetsMode::IgnoreCallback) @@ -571,10 +570,34 @@ void Transformations::MainSnippets(void) { if (!isMHASupported) { snippetsManager.get_pass_config()->disable(); } + + auto is_supported_matmul = [](const std::shared_ptr& matmul) { + if (!matmul) + return false; + if (matmul->get_input_element_type(1) == ov::element::i8) + return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni); + if (matmul->get_input_element_type(0) == ov::element::bf16 && + matmul->get_input_element_type(1) == ov::element::bf16) + return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16); + return true; + }; + if (snippetsMode != Config::SnippetsMode::IgnoreCallback) { snippetsManager.get_pass_config()->set_callback( - [](const std::shared_ptr& n) -> bool { - const auto pshape = n->get_output_partial_shape(0); + [this, is_supported_matmul](const std::shared_ptr& n) -> bool { + if (this->enableLpt) { + // Tranformation callback is called on MatMul1 + if (!is_supported_matmul(ov::as_type_ptr(n))) + return true; + // Search for MatMul0 + auto parent = n->get_input_node_shared_ptr(0); + while (!ov::is_type(parent)) { + parent = parent->get_input_node_shared_ptr(0); + } + if (!is_supported_matmul(ov::as_type_ptr(parent))) + return true; + } + const auto pshape = n->get_input_partial_shape(0); const auto shape = pshape.get_shape(); const auto parallel_work_amount = std::accumulate(shape.rbegin() + 2, shape.rend(), 1, std::multiplies()); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 5e246017855e49..0d4d0c7522976b 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -210,6 +210,9 @@ std::vector disabledTestPatterns() { if (!InferenceEngine::with_cpu_x86_avx512_core_vnni() && !InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { // MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions retVector.emplace_back(R"(.*Snippets.*MatMulFQ.*)"); + retVector.emplace_back(R"(.*Snippets.*MatMul.*Quantized.*)"); + retVector.emplace_back(R"(.*Snippets.*MHAFQ.*)"); + retVector.emplace_back(R"(.*Snippets.*MHAINT8.*)"); } if (!InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) //TODO: Issue 92895 diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp index 9d792f35264066..920fa3199e5b49 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -19,16 +19,24 @@ std::vector> input_shapes{ {{1, 1, 37, 23}, {1, 2, 23, 33}}, {{1, 16, 384, 64}, {1, 16, 64, 384}} }; + +static inline std::vector> quantized_precisions() { + std::vector> prc = {}; + // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms + if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { + prc.emplace_back(std::vector{element::i8, element::i8}); + prc.emplace_back(std::vector{element::u8, element::i8}); + } + return prc; +} + static inline std::vector> precisions(bool only_fp32 = true) { std::vector> prc = { {element::f32, element::f32}, }; if (!only_fp32) { - // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms - if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { - prc.emplace_back(std::vector{element::i8, element::i8}); - prc.emplace_back(std::vector{element::u8, element::i8}); - } + auto quant = quantized_precisions(); + std::copy(quant.begin(), quant.end(), std::back_inserter(prc)); // In Snippets MatMul BF16 is supported only on bf16/AMX platforms if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) { prc.emplace_back(std::vector{element::bf16, element::bf16}); @@ -63,6 +71,34 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBias, MatMulBias, ::testing::Values(CommonTestUtils::DEVICE_CPU)), MatMul::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBiasQuantized, MatMulBiasQuantized, + ::testing::Combine( + ::testing::ValuesIn(std::vector>{ + std::vector{{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 2, 1, 1}}, + std::vector{{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 2, 69, 49}}}), + ::testing::ValuesIn(quantized_precisions()), + ::testing::Values(1), // Subgraph + ::testing::Values(1), // Tokenized MatMul+Bias + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MatMul::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulsQuantized, MatMulsQuantized, + ::testing::Combine( + ::testing::Values(std::vector{{1, 16, 128, 64}, {1, 16, 64, 128}, {128, 64}}), + ::testing::ValuesIn(quantized_precisions()), + ::testing::Values(3), // Subgraph + Reshape + Subgraph + ::testing::Values(2), // Tokenized [MatMul+FQ+Matmul] and [FQ] + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MatMul::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulsQuantizedSoftmax, MatMulsQuantizedSoftmax, + ::testing::Combine( + ::testing::Values(std::vector{{1, 16, 128, 64}, {1, 16, 64, 128}, {128, 64}}), + ::testing::ValuesIn(quantized_precisions()), + ::testing::Values(3), // Subgraph + Reshape + Subgraph + ::testing::Values(2), // Tokenized [MatMul+FQ+Matmul] and [FQ] + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MatMul::getTestCaseName); } // namespace } // namespace snippets } // namespace test diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 11aeaebdcc21b6..a7a64560ea2aa7 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -4,6 +4,7 @@ #include "snippets/mha.hpp" #include "common_test_utils/test_constants.hpp" +#include "ie_system_conf.h" namespace ov { namespace test { @@ -14,15 +15,27 @@ namespace { const std::vector> inputShapes = { {{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, - {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}}, + {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, {{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}}, {{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}}, }; +static inline std::vector> precisions() { + std::vector> prc = { + {element::f32, element::f32, element::f32, element::f32}, + }; + // In Snippets MatMul BF16 is supported only on bf16/AMX platforms + if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) { + prc.emplace_back(std::vector{element::bf16, element::bf16, element::bf16, element::bf16}); + } + return prc; +} + INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHA, ::testing::Combine( ::testing::ValuesIn(inputShapes), + ::testing::ValuesIn(precisions()), ::testing::ValuesIn({false, true}), ::testing::Values(1), ::testing::Values(1), @@ -38,9 +51,21 @@ const std::vector> inputShapeSelect = { {{2, 52, 6, 102}, {2, 52, 6, 102}, {1, 6, 52, 52}, {1, 6, 1, 1}, {1, 6, 1, 1}, {2, 52, 6, 102}} }; +static inline std::vector> precisionsSelect() { + std::vector> prc = { + {element::f32, element::f32, element::f32, element::f32, element::f32, element::f32}, + }; + // In Snippets MatMul BF16 is supported only on bf16/AMX platforms + if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) { + prc.emplace_back(std::vector{element::bf16, element::bf16, element::bf16, element::bf16, element::bf16, element::bf16}); + } + return prc; +} + INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect, ::testing::Combine( ::testing::ValuesIn(inputShapeSelect), + ::testing::ValuesIn(precisionsSelect()), ::testing::Values(false), // Need to support True for graph builder in tests ::testing::Values(2), // Less + MHA ::testing::Values(2), @@ -54,12 +79,42 @@ const std::vector> inputShapesWOTranspose = { INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOnInputs, ::testing::Combine( ::testing::ValuesIn(inputShapesWOTranspose), - ::testing::ValuesIn({true}), // Need to support False for graph builder in tests + ::testing::Values(std::vector{}), + ::testing::Values(true), // Need to support False for graph builder in tests ::testing::Values(1), ::testing::Values(1), ::testing::Values(CommonTestUtils::DEVICE_CPU)), MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAINT8MatMul, MHAINT8MatMul, + ::testing::Combine( + ::testing::ValuesIn(std::vector>(inputShapes.begin(), inputShapes.begin() + 2)), + ::testing::Values(std::vector{}), + ::testing::Values(false), // The graph doesn't contain Multiply + ::testing::Values(6), // FQx3 on inputs + MHA + Transpose on output + Deq Mul + ::testing::Values(5), // FQx3 on inputs + MHA + Deq Mul + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQAfterMatMul, MHAFQAfterMatMul, + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::Values(std::vector{}), + ::testing::Values(false), // The graph doesn't contain Multiply + ::testing::Values(3), // MHA + Transpose on output + Deq Mul + ::testing::Values(2), // MHA + Deq Mul + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQ, MHAFQ, + ::testing::Combine( + ::testing::Values(std::vector{{1, 64, 12, 64}, {1, 64, 12, 64}, {1, 1, 1, 64}, {1, 64, 12, 64}}), + ::testing::Values(std::vector{}), + ::testing::Values(false), // The graph doesn't contain Multiply + ::testing::Values(7), // Transposex2 + Subgraphsx5 + ::testing::Values(5), // MHA + Deq Mul on output + Deqs on inputs + 2 xFQ on inputs + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MHA::getTestCaseName); } // namespace } // namespace snippets diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp index 4222cb9b97507d..01b457809505ce 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp @@ -553,7 +553,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest, ::testing::ValuesIn(inputPrecisionsQuant), ::testing::ValuesIn(matMulIn0PrecisionsQuant), ::testing::ValuesIn(patternTypesQuant), - ::testing::Values("MHA"), // Snippets don't support Quantized MHA pattern yet + ::testing::Values("MHA"), ::testing::Values(CommonTestUtils::DEVICE_CPU)), MHAQuantTest::getTestCaseName); diff --git a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp index 3e2a0ab015e988..921585f0976418 100644 --- a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp @@ -37,6 +37,21 @@ class MatMulBias : public MatMul { void SetUp() override; }; +class MatMulBiasQuantized : public MatMul { +protected: + void SetUp() override; +}; + +class MatMulsQuantized : public MatMul { +protected: + void SetUp() override; +}; + +class MatMulsQuantizedSoftmax : public MatMul { +protected: + void SetUp() override; +}; + } // namespace snippets } // namespace test } // namespace ov \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/include/snippets/mha.hpp b/src/tests/functional/plugin/shared/include/snippets/mha.hpp index 9f95dcc30acde8..94147f86b19cd8 100644 --- a/src/tests/functional/plugin/shared/include/snippets/mha.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/mha.hpp @@ -12,6 +12,7 @@ namespace snippets { typedef std::tuple< std::vector, // Input shapes + std::vector, // Input Element types bool, // With Multiply size_t, // Expected num nodes size_t, // Expected num subgraphs @@ -42,6 +43,21 @@ class MHAWOTransposeOnInputs : public MHA { void SetUp() override; }; +class MHAINT8MatMul : public MHA { +protected: + void SetUp() override; +}; + +class MHAFQAfterMatMul : public MHA { +protected: + void SetUp() override; +}; + +class MHAFQ : public MHA { +protected: + void SetUp() override; +}; + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp index 06a37e2fd1ffed..10e567292f167a 100644 --- a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp @@ -71,6 +71,48 @@ void MatMulBias::SetUp() { } } +void MatMulBiasQuantized::SetUp() { + std::vector input_shapes; + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::MatMulBiasQuantizedFunction(input_shapes, elem_types); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +void MatMulsQuantized::SetUp() { + std::vector input_shapes; + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::MatMulsQuantizedFunction(input_shapes, elem_types); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +void MatMulsQuantizedSoftmax::SetUp() { + std::vector input_shapes; + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::MatMulsQuantizedSoftmaxFunction(input_shapes, elem_types); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + TEST_P(MatMul, CompareWithRefImpl) { SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); @@ -89,6 +131,24 @@ TEST_P(MatMulBias, CompareWithRefImpl) { validateNumSubgraphs(); } +TEST_P(MatMulBiasQuantized, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + +TEST_P(MatMulsQuantized, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + +TEST_P(MatMulsQuantizedSoftmax, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index cf0075906c058a..bafe21d1281987 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -15,14 +15,17 @@ namespace snippets { std::string MHA::getTestCaseName(testing::TestParamInfo obj) { std::vector inputShapes; + std::vector elem_types; bool withMul; std::string targetDevice; size_t num_nodes, num_subgraphs; - std::tie(inputShapes, withMul, num_nodes, num_subgraphs, targetDevice) = obj.param; + std::tie(inputShapes, elem_types, withMul, num_nodes, num_subgraphs, targetDevice) = obj.param; std::ostringstream result; for (size_t i = 0; i < inputShapes.size(); ++i) result << "IS[" << i << "]=" << CommonTestUtils::partialShape2str({inputShapes[i]}) << "_"; + for (size_t i = 0; i < elem_types.size(); i++) + result << "T[" << i <<"]=" << elem_types[i] << "_"; result << "Mul=" << withMul << "_"; result << "#N=" << num_nodes << "_"; result << "#S=" << num_subgraphs << "_"; @@ -32,11 +35,12 @@ std::string MHA::getTestCaseName(testing::TestParamInfo inputShapes; + std::vector elem_types; bool withMul; - std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::tie(inputShapes, elem_types, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); - auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, withMul); + auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, elem_types, withMul); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { @@ -58,11 +62,12 @@ void MHA::generate_inputs(const std::vector& targetInputStaticSha void MHASelect::SetUp() { std::vector inputShapes; + std::vector elem_types; bool withMul; - std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::tie(inputShapes, elem_types, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); - auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes); + auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes, elem_types); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { @@ -91,8 +96,9 @@ void MHASelect::generate_inputs(const std::vector& targetInputSta void MHAWOTransposeOnInputs::SetUp() { std::vector inputShapes; + std::vector elem_types; bool withMul; - std::tie(inputShapes, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::tie(inputShapes, elem_types, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); auto f = ov::test::snippets::MHAWOTransposeOnInputsFunction(inputDynamicShapes); @@ -104,6 +110,56 @@ void MHAWOTransposeOnInputs::SetUp() { } } +void MHAINT8MatMul::SetUp() { + std::vector inputShapes; + std::vector elem_types; + bool withMul; + std::tie(inputShapes, elem_types, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); + + auto f = ov::test::snippets::MHAINT8MatMulFunction(inputDynamicShapes); + function = f.getOriginal(); + + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } + + // Todo: need to investigate + abs_threshold = 0.3; +} + +void MHAFQAfterMatMul::SetUp() { + std::vector inputShapes; + std::vector elem_types; + bool withMul; + std::tie(inputShapes, elem_types, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); + + auto f = ov::test::snippets::MHAFQAfterMatMulFunction(inputDynamicShapes); + function = f.getOriginal(); + + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +void MHAFQ::SetUp() { + std::vector inputShapes; + std::vector elem_types; + bool withMul; + std::tie(inputShapes, elem_types, withMul, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); + + auto f = ov::test::snippets::MHAFQFunction(inputDynamicShapes); + function = f.getOriginal(); + + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} TEST_P(MHA, CompareWithRefImpl) { run(); @@ -120,6 +176,20 @@ TEST_P(MHAWOTransposeOnInputs, CompareWithRefImpl) { validateNumSubgraphs(); } +TEST_P(MHAINT8MatMul, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +TEST_P(MHAFQAfterMatMul, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} + +TEST_P(MHAFQ, CompareWithRefImpl) { + run(); + validateNumSubgraphs(); +} } // namespace snippets } // namespace test diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp index 15954605e69fdd..755dbf9aa81b4e 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp @@ -70,6 +70,40 @@ class MatMulBiasFunction : public SnippetsFunctionBase { std::vector precisions; }; +// Quantized MatMul +// FQ[I8] +// Add +class MatMulBiasQuantizedFunction : public SnippetsFunctionBase { +public: + explicit MatMulBiasQuantizedFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { + NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + MatMulFunction::verify_precisions(precisions); + } +protected: + std::shared_ptr initOriginal() const override; + + std::vector precisions; +}; + +// Quantized MatMul FQ[I8] +// FQ[U8] Reshape <- To have only one sequence in Subgraph: MatMuL->FQ[U8]->MatMul->FQ[I8] +// \ / +// MatMul +// FQ[I8] +class MatMulsQuantizedFunction : public SnippetsFunctionBase { +public: + explicit MatMulsQuantizedFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { + NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + MatMulFunction::verify_precisions(precisions); + } +protected: + std::shared_ptr initOriginal() const override; + + std::vector precisions; +}; + /// Minimal graph to test MatMul+Transpose combinations. Transpose location is specified via the position argument: /// 0 - before the first MatMul input; 1 - before the second MatMul input; 2 - after the MatMul output. /// Tokenized simply by starting subgraph, @@ -121,6 +155,24 @@ class TransposeMulMatMulBiasFunction : public SnippetsFunctionBase { std::shared_ptr initOriginal() const override; }; +// Quantized MatMul FQ[I8] +// Softmax Reshape <- To have only one sequence in Subgraph: MatMuL->Softmax>FQ[U8]->MatMul->FQ[I8] +// FQ[U8] / +// MatMul +// FQ[I8] +class MatMulsQuantizedSoftmaxFunction : public SnippetsFunctionBase { +public: + explicit MatMulsQuantizedSoftmaxFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { + NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + MatMulFunction::verify_precisions(precisions); + } +protected: + std::shared_ptr initOriginal() const override; + + std::vector precisions; +}; + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp index 309a32e914558a..ba078ced4cf0a3 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp @@ -43,8 +43,8 @@ namespace snippets { */ class MHAFunction : public SnippetsFunctionBase { public: - explicit MHAFunction(const std::vector& inputShapes, bool with_mul = true) - : SnippetsFunctionBase(inputShapes), with_mul(with_mul) { + explicit MHAFunction(const std::vector& inputShapes, const std::vector& precisions, bool with_mul = true) + : SnippetsFunctionBase(inputShapes), with_mul(with_mul), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); } protected: @@ -52,6 +52,7 @@ class MHAFunction : public SnippetsFunctionBase { std::shared_ptr initReference() const override; bool with_mul = true; + std::vector precisions; }; /* Graph: @@ -71,13 +72,15 @@ class MHAFunction : public SnippetsFunctionBase { */ class MHAMatMul0TransposeFunction : public SnippetsFunctionBase { public: - explicit MHAMatMul0TransposeFunction(const std::vector& inputShapes) - : SnippetsFunctionBase(inputShapes) { + explicit MHAMatMul0TransposeFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); } protected: std::shared_ptr initOriginal() const override; std::shared_ptr initReference() const override; + + std::vector precisions; }; /* Graph: @@ -97,11 +100,14 @@ class MHAMatMul0TransposeFunction : public SnippetsFunctionBase { */ class MHASelectFunction : public SnippetsFunctionBase { public: - explicit MHASelectFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { + explicit MHASelectFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 6, "Got invalid number of input shapes"); } protected: std::shared_ptr initOriginal() const override; + + std::vector precisions; }; /* Graph: @@ -126,6 +132,117 @@ class MHAWOTransposeOnInputsFunction : public SnippetsFunctionBase { std::shared_ptr initOriginal() const override; }; +/* Graph: + * Transpose0[0,2,1,3] Transpose1[0,2,3,1] + * \ / + * MatMul0 + * FakeQuantize i8 + * \ / + * Add + * Reshape0 + * Softmax + * Reshape1 Transpose2[0,2,1,3] + * \ / + * MatMul1 + * FakeQuantize i8 + * Transpose3[0,2,1,3] + */ +class MHAFQAfterMatMulFunction : public SnippetsFunctionBase { +public: + explicit MHAFQAfterMatMulFunction(const std::vector& inputShapes) + : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; +}; + +/* Graph: + * FakeQuantize i8 FakeQuantize i8 + * Transpose0[0,2,1,3] Transpose1[0,2,3,1] + * \ / + * MatMul0 + * FakeQuantize i8 + * \ / + * Add + * Reshape0 + * Softmax + * Reshape1 FakeQuantize i8 + * FakeQuantize u8 Transpose2[0,2,1,3] + * \ / + * MatMul1 + * FakeQuantize i8 + * Transpose3[0,2,1,3] + */ +class MHAINT8MatMulFunction : public SnippetsFunctionBase { +public: + explicit MHAINT8MatMulFunction(const std::vector& inputShapes) + : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; +}; + +/* Graph: + * Constant + * FakeQuantize u8 FakeQuantize u8 Convert + * Transpose0[0,2,1,3] Transpose1[0,2,3,1] Multiply + * \ \ / + * \ Multiply + * \ FakeQuantize f32 + * \ / + * MatMul0 + * FakeQuantize f32 FakeQuantize u8 + * \ / + * Add + * Softmax Transpose2[0,2,1,3] + * \ / + * MatMul1 + * FakeQuantize u8 + * Transpose3[0,2,1,3] + * Note: Check a lot of different FQ (the both quantized and floating) - buffers with different size and precision + */ +class MHAFQFunction : public SnippetsFunctionBase { +public: + explicit MHAFQFunction(const std::vector& inputShapes) + : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; +}; + +// Only for tokenization! The graph is after LPT: contains TypeRelaxed ops +/* Graph: + * FakeQuantize i8 FakeQuantize i8 + * Transpose0[0,2,1,3] Transpose1[0,2,3,1] + * \ / + * MatMul0 + * FakeQuantize i8 + * \ / + * Add + * Mul (DeQuantize) + * Reshape0 + * Softmax + * Reshape1 FakeQuantize i8 + * FakeQuantize u8 Transpose2[0,2,1,3] + * \ / + * MatMul1 + * FakeQuantize i8 + * Transpose3[0,2,1,3] + */ +class MHAINT8MatMulTypeRelaxedFunction : public SnippetsFunctionBase { +public: + explicit MHAINT8MatMulTypeRelaxedFunction(const std::vector& inputShapes) + : SnippetsFunctionBase(inputShapes) { + NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + } +protected: + std::shared_ptr initOriginal() const override; + std::shared_ptr initReference() const override; +}; + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp index b213c66eccacc6..d4f1d4aca6232a 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp @@ -91,6 +91,41 @@ std::shared_ptr MatMulBiasFunction::initOriginal() const { auto bias = std::make_shared(matmul, data2); return std::make_shared(NodeVector{bias}, ParameterVector{data0, data1, data2}); } +std::shared_ptr MatMulBiasQuantizedFunction::initOriginal() const { + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + auto data2 = std::make_shared(precision, input_shapes[2]); + auto matmul = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(data1, element::f32).get()); + auto fq2 = ngraph::builder::makeFakeQuantize(matmul, ov::element::f32, 256, {1}, {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + auto bias = std::make_shared(fq2, data2); + return std::make_shared(NodeVector{bias}, ParameterVector{data0, data1, data2}); +} +std::shared_ptr MatMulsQuantizedFunction::initOriginal() const { + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + auto data2 = std::make_shared(precision, input_shapes[2]); + auto matmul0 = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(data1, element::f32).get()); + auto fq0 = ngraph::builder::makeFakeQuantize(matmul0, ov::element::f32, 256, {1}, {0}, {0.820726}, {0}, {0.820726}); + auto fq2 = ngraph::builder::makeFakeQuantize(data2, ov::element::f32, 256, {1}, {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + auto new_shape = std::make_shared(ov::element::u64, ov::Shape{4}, + std::vector{1, 1, input_shapes[2].get_shape()[0], input_shapes[2].get_shape()[1]}); + auto reshape = std::make_shared(fq2, new_shape, false); + auto matmul1 = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(fq0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(reshape, element::f32).get()); + auto fq3 = ngraph::builder::makeFakeQuantize(matmul1, ov::element::f32, 256, {1}, {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + return std::make_shared(NodeVector{fq3}, ParameterVector{data0, data1, data2}); +} std::shared_ptr Transpose0213MatMulFunction::initOriginal() const { auto data0 = std::make_shared(precisions[0], input_shapes[0]); auto data1 = std::make_shared(precisions[1], input_shapes[1]); @@ -169,7 +204,30 @@ std::shared_ptr TransposeMulMatMulBiasFunction::initOriginal() const auto bias = std::make_shared(matmul, data3); return std::make_shared(NodeVector{bias}, ParameterVector{data0, data1, data2, data3}); } +std::shared_ptr MatMulsQuantizedSoftmaxFunction::initOriginal() const { + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + auto data2 = std::make_shared(precision, input_shapes[2]); + auto matmul0 = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(data1, element::f32).get()); + auto softmax = std::make_shared(matmul0, -1); + auto fq0 = ngraph::builder::makeFakeQuantize(softmax, ov::element::f32, 256, {1}, {0}, {0.820726}, {0}, {0.820726}); + auto fq2 = ngraph::builder::makeFakeQuantize(data2, ov::element::f32, 256, {1}, {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + auto new_shape = std::make_shared(ov::element::u64, ov::Shape{4}, + std::vector{1, 1, input_shapes[2].get_shape()[0], input_shapes[2].get_shape()[1]}); + auto reshape = std::make_shared(fq2, new_shape, false); + auto matmul1 = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(fq0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(reshape, element::f32).get()); + auto fq3 = ngraph::builder::makeFakeQuantize(matmul1, ov::element::f32, 256, {1}, {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + return std::make_shared(NodeVector{fq3}, ParameterVector{data0, data1, data2}); +} } // namespace snippets } // namespace test -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp index ac38ea47624eba..532ffafb3d5f24 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp @@ -7,16 +7,18 @@ #include "common_test_utils/data_utils.hpp" #include #include "ngraph_functions/builders.hpp" +#include "ov_ops/type_relaxed.hpp" +#include "lpt_ngraph_functions/common/builders.hpp" namespace ov { namespace test { namespace snippets { std::shared_ptr MHAFunction::initOriginal() const { - auto transpose0Param = std::make_shared(precision, input_shapes[0]); - auto transpose1Param = std::make_shared(precision, input_shapes[1]); - auto addParam = std::make_shared(precision, input_shapes[2]); - auto transpose2Param = std::make_shared(precision, input_shapes[3]); + auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); + auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); + auto addParam = std::make_shared(precisions[2], input_shapes[2]); + auto transpose2Param = std::make_shared(precisions[3], input_shapes[3]); ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; std::vector constantShapes; @@ -51,7 +53,7 @@ std::shared_ptr MHAFunction::initOriginal() const { std::shared_ptr matmul_parent1 = transpose1; if (with_mul) { std::vector mulConstData(ngraph::shape_size(constantShapes[2])); - auto mulConst = ngraph::builder::makeConstant(precision, constantShapes[2], mulConstData, true); + auto mulConst = ngraph::builder::makeConstant(precisions[1], constantShapes[2], mulConstData, true); matmul_parent1 = std::make_shared(transpose1, mulConst); } const auto matMul0 = std::make_shared(transpose0, matmul_parent1, transA, transB); @@ -67,17 +69,17 @@ std::shared_ptr MHAFunction::initOriginal() const { return std::make_shared(results, ngraphParam, "mha"); } std::shared_ptr MHAFunction::initReference() const { - auto data0 = std::make_shared(precision, input_shapes[0]); - auto data1 = std::make_shared(precision, input_shapes[1]); - auto data2 = std::make_shared(precision, input_shapes[2]); - auto data3 = std::make_shared(precision, input_shapes[3]); + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + auto data2 = std::make_shared(precisions[2], input_shapes[2]); + auto data3 = std::make_shared(precisions[3], input_shapes[3]); ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3}; NodeVector subgraph_inputs = {data0, data1, data2, data3}; - auto transpose0Param = std::make_shared(precision, input_shapes[0]); - auto transpose1Param = std::make_shared(precision, input_shapes[1]); - auto addParam = std::make_shared(precision, input_shapes[2]); - auto transpose2Param = std::make_shared(precision, input_shapes[3]); + auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); + auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); + auto addParam = std::make_shared(precisions[2], input_shapes[2]); + auto transpose2Param = std::make_shared(precisions[3], input_shapes[3]); std::vector constantShapes; constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()})); @@ -113,8 +115,8 @@ std::shared_ptr MHAFunction::initReference() const { std::shared_ptr matmul_parent1 = transpose1; if (with_mul) { std::vector mulConstData(ngraph::shape_size(constantShapes[2])); - auto mulConst = ngraph::builder::makeConstant(precision, constantShapes[2], mulConstData, true); - auto mulParam = std::make_shared(precision, mulConst->get_shape()); + auto mulConst = ngraph::builder::makeConstant(precisions[1], constantShapes[2], mulConstData, true); + auto mulParam = std::make_shared(precisions[1], mulConst->get_shape()); matmul_parent1 = std::make_shared(transpose1, mulParam); subgraph_params = {transpose0Param, transpose1Param, mulParam, addParam, transpose2Param}; subgraph_inputs = {data0, data1, mulConst, data2, data3}; @@ -135,10 +137,10 @@ std::shared_ptr MHAFunction::initReference() const { } std::shared_ptr MHAMatMul0TransposeFunction::initOriginal() const { - auto transpose0Param = std::make_shared(precision, input_shapes[0]); - auto transpose1Param = std::make_shared(precision, input_shapes[1]); - auto addParam = std::make_shared(precision, input_shapes[2]); - auto transpose2Param = std::make_shared(precision, input_shapes[3]); + auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); + auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); + auto addParam = std::make_shared(precisions[2], input_shapes[2]); + auto transpose2Param = std::make_shared(precisions[3], input_shapes[3]); ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; std::vector constantShapes; @@ -157,7 +159,7 @@ std::shared_ptr MHAMatMul0TransposeFunction::initOriginal() const { auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[6], order); std::vector mulConstData(1); - auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, mulConstData, true); + auto mulConst = ngraph::builder::makeConstant(precisions[1], ov::Shape{1}, mulConstData, true); std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), @@ -188,16 +190,16 @@ std::shared_ptr MHAMatMul0TransposeFunction::initOriginal() const { return std::make_shared(results, ngraphParam, "mha"); } std::shared_ptr MHAMatMul0TransposeFunction::initReference() const { - auto data0 = std::make_shared(precision, input_shapes[0]); - auto data1 = std::make_shared(precision, input_shapes[1]); - auto data2 = std::make_shared(precision, input_shapes[2]); - auto data3 = std::make_shared(precision, input_shapes[3]); + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + auto data2 = std::make_shared(precisions[2], input_shapes[2]); + auto data3 = std::make_shared(precisions[3], input_shapes[3]); ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3}; - auto transpose0Param = std::make_shared(precision, input_shapes[0]); - auto transpose1Param = std::make_shared(precision, input_shapes[1]); - auto addParam = std::make_shared(precision, input_shapes[2]); - auto transpose2Param = std::make_shared(precision, input_shapes[3]); + auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); + auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); + auto addParam = std::make_shared(precisions[2], input_shapes[2]); + auto transpose2Param = std::make_shared(precisions[3], input_shapes[3]); std::vector constantShapes; constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()})); @@ -214,7 +216,7 @@ std::shared_ptr MHAMatMul0TransposeFunction::initReference() const { auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[6], std::vector{0, 2, 1, 3}); std::vector mulConstData(1); - auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, mulConstData, true); + auto mulConst = ngraph::builder::makeConstant(precisions[1], ov::Shape{1}, mulConstData, true); ngraph::ParameterVector subgraphParams = {transpose0Param, transpose1Param, addParam, transpose2Param}; std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * @@ -250,12 +252,12 @@ std::shared_ptr MHAMatMul0TransposeFunction::initReference() const { } std::shared_ptr MHASelectFunction::initOriginal() const { - auto transpose0Param = std::make_shared(precision, input_shapes[0]); - auto transpose1Param = std::make_shared(precision, input_shapes[1]); - auto addParam = std::make_shared(precision, input_shapes[2]); - auto less0Param = std::make_shared(precision, input_shapes[3]); - auto less1Param = std::make_shared(precision, input_shapes[4]); - auto transpose2Param = std::make_shared(precision, input_shapes[5]); + auto transpose0Param = std::make_shared(precisions[0], input_shapes[0]); + auto transpose1Param = std::make_shared(precisions[1], input_shapes[1]); + auto addParam = std::make_shared(precisions[2], input_shapes[2]); + auto less0Param = std::make_shared(precisions[3], input_shapes[3]); + auto less1Param = std::make_shared(precisions[4], input_shapes[4]); + auto transpose2Param = std::make_shared(precisions[5], input_shapes[5]); ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, less0Param, less1Param, transpose2Param}; std::vector constantShapes; @@ -288,7 +290,7 @@ std::shared_ptr MHASelectFunction::initOriginal() const { static_cast(input_shapes[0].get_shape()[1])}; auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[4], reshape1ConstData); // Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN - auto selectConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, std::vector{1}); + auto selectConst = ngraph::builder::makeConstant(precisions[2], ov::Shape{1}, std::vector{1}); float transA = false; float transB = false; @@ -343,6 +345,301 @@ std::shared_ptr MHAWOTransposeOnInputsFunction::initOriginal() const return std::make_shared(results, ngraphParam, "mha"); } +std::shared_ptr MHAFQAfterMatMulFunction::initOriginal() const { + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + const auto shape_rank = input_shapes[0].get_shape().size(); + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + + std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * + input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(input_shapes[0].get_shape()[0]), + static_cast(input_shapes[0].get_shape()[2]), + static_cast(input_shapes[0].get_shape()[1]), + static_cast(input_shapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(transpose0Param, transpose0Const); + const auto transpose1 = std::make_shared(transpose1Param, transpose1Const); + const auto matMul0 = std::make_shared(transpose0, transpose1, transA, transB); + auto fq0 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + const auto add = std::make_shared(fq0, addParam); + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + const auto transpose2 = std::make_shared(transpose2Param, transpose2Const); + const auto matMul1 = std::make_shared(reshape1, transpose2, transA, transB); + auto fq1 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + const auto transpose3 = std::make_shared(fq1, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} +std::shared_ptr MHAINT8MatMulFunction::initOriginal() const { + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + const auto shape_rank = input_shapes[0].get_shape().size(); + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + + std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * + input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(input_shapes[0].get_shape()[0]), + static_cast(input_shapes[0].get_shape()[2]), + static_cast(input_shapes[0].get_shape()[1]), + static_cast(input_shapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData); + + auto fq0 = ngraph::builder::makeFakeQuantize(transpose0Param, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + auto fq1 = ngraph::builder::makeFakeQuantize(transpose1Param, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + auto fq2 = ngraph::builder::makeFakeQuantize(transpose2Param, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(fq0, transpose0Const); + const auto transpose1 = std::make_shared(fq1, transpose1Const); + const auto matMul0 = std::make_shared(transpose0, transpose1, transA, transB); + auto fq3 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + const auto add = std::make_shared(fq3, addParam); + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + auto fq4 = ngraph::builder::makeFakeQuantize(reshape1, ov::element::f32, 256, {1}, + {0}, {0.820726}, {0}, {0.820726}); + const auto transpose2 = std::make_shared(fq2, transpose2Const); + const auto matMul1 = std::make_shared(fq4, transpose2, transA, transB); + auto fq5 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1}, + {-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294}); + const auto transpose3 = std::make_shared(fq5, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} +std::shared_ptr MHAFQFunction::initOriginal() const { + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + const auto shape_rank = input_shapes[0].get_shape().size(); + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + + const auto fq0 = ngraph::builder::makeFakeQuantize(transpose0Param, ov::element::f32, 256, {1}, + {-5.217694}, {6.661877}, {-5.217694}, {6.661877}); + const auto fq1 = ngraph::builder::makeFakeQuantize(transpose1Param, ov::element::f32, 256, {1}, + {-6.40245}, {6.45286}, {-6.40245}, {6.45286}); + const auto fq_add = ngraph::builder::makeFakeQuantize(addParam, ov::element::f32, 256, {1}, + {-1000}, {0}, {-1000}, {0}); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(fq0, transpose0Const); + const auto transpose1 = std::make_shared(fq1, transpose1Const); + const auto transpose2 = std::make_shared(transpose2Param, transpose2Const); + const auto mul_const = ngraph::builder::makeConstant(ov::element::i8, ov::Shape{1}, std::vector{127}); + const auto convert = std::make_shared(mul_const, ov::element::f32); + const auto mul_deq_const = ngraph::builder::makeConstant(ov::element::f32, ov::Shape{1}, std::vector{0.00098425}); + const auto mul_deq = std::make_shared(convert, mul_deq_const); + const auto mul = std::make_shared(transpose1, mul_deq); + auto fq1_1 = ngraph::builder::makeFakeQuantize(mul, ov::element::f32, 256, {1}, + {-0.8003067}, {0.8066083}, {-0.8003067}, {0.8066083}); + const auto matMul0 = std::make_shared(transpose0, fq1_1, transA, transB); + auto fq2 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1}, + {-14.50351}, {17.65645}, {-14.50351}, {17.65645}); + const auto add = std::make_shared(fq2, fq_add); + const auto softMax = std::make_shared(add, 3); + const auto matMul1 = std::make_shared(softMax, transpose2, transA, transB); + auto fq3 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1}, + {-1.895786}, {2.0028071}, {-1.895786}, {2.0028071}); + const auto transpose3 = std::make_shared(fq3, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} +std::shared_ptr MHAINT8MatMulTypeRelaxedFunction::initOriginal() const { + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + const auto shape_rank = input_shapes[0].get_shape().size(); + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + + std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * + input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(input_shapes[0].get_shape()[0]), + static_cast(input_shapes[0].get_shape()[2]), + static_cast(input_shapes[0].get_shape()[1]), + static_cast(input_shapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData); + + const auto fq_signed_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {-36912.66015625}, {36624.28125}, {-128}, {127}, ov::element::i8); + const auto fq0 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose0Param, ov::element::i8, fq_signed_params); + const auto fq1 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose1Param, ov::element::i8, fq_signed_params); + const auto fq2 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose2Param, ov::element::i8, fq_signed_params); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(fq0, transpose0Const); + const auto transpose1 = std::make_shared(fq1, transpose1Const); + const auto matMul0 = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(transpose0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(transpose1, element::f32).get(), transA, transB); + + const auto fq3 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::i8, fq_signed_params); + const auto add = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(fq3, element::f32).get(), + ov::op::TemporaryReplaceOutputType(addParam, element::f32).get()); + const auto deq = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{0.1122}); + const auto deq_mul = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(add, element::f32).get(), + ov::op::TemporaryReplaceOutputType(deq, element::f32).get()); + + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + + const auto fq_unsigned_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {0}, {0.245}, {0}, {255}, ov::element::u8); + const auto fq4 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::u8, fq_unsigned_params); + + const auto transpose2 = std::make_shared(fq2, transpose2Const); + const auto matMul1 = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(fq4, element::f32).get(), + ov::op::TemporaryReplaceOutputType(transpose2, element::f32).get(), transA, transB); + const auto fq5 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::i8, fq_signed_params); + const auto transpose3 = std::make_shared(fq5, transpose3Const); + + ngraph::ResultVector results{std::make_shared(transpose3)}; + return std::make_shared(results, ngraphParam, "mha"); +} +std::shared_ptr MHAINT8MatMulTypeRelaxedFunction::initReference() const { + auto data0 = std::make_shared(precision, input_shapes[0]); + auto data1 = std::make_shared(precision, input_shapes[1]); + auto data2 = std::make_shared(precision, input_shapes[2]); + auto data3 = std::make_shared(precision, input_shapes[3]); + ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3}; + + const auto fq_signed_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {-36912.66015625}, {36624.28125}, {-128}, {127}, ov::element::i8); + const auto fq0 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data0, ov::element::i8, fq_signed_params); + const auto fq1 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data1, ov::element::i8, fq_signed_params); + const auto fq2 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data3, ov::element::i8, fq_signed_params); + NodeVector subgraph_inputs = {fq0, fq1, data2, fq2}; + + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto transpose2Param = std::make_shared(precision, input_shapes[3]); + ov::ParameterVector subgraph_params = {transpose0Param, transpose1Param, addParam, transpose2Param}; + + const auto shape_rank = input_shapes[0].get_shape().size(); + auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 3, 1}); + auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector{0, 2, 1, 3}); + + std::vector reshape0ConstData = {static_cast(input_shapes[0].get_shape()[0] * + input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]), + -1}; + auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData); + + std::vector reshape1ConstData = {static_cast(input_shapes[0].get_shape()[0]), + static_cast(input_shapes[0].get_shape()[2]), + static_cast(input_shapes[0].get_shape()[1]), + static_cast(input_shapes[0].get_shape()[1])}; + auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData); + + float transA = false; + float transB = false; + const auto transpose0 = std::make_shared(transpose0Param, transpose0Const); + const auto transpose1 = std::make_shared(transpose1Param, transpose1Const); + const auto matMul0 = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(transpose0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(transpose1, element::f32).get(), transA, transB); + + const auto fq3 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::i8, fq_signed_params); + const auto add = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(fq3, element::f32).get(), + ov::op::TemporaryReplaceOutputType(addParam, element::f32).get()); + const auto deq = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{0.1122}); + const auto deq_mul = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(add, element::f32).get(), + ov::op::TemporaryReplaceOutputType(deq, element::f32).get()); + + const auto reshape0 = std::make_shared(add, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + + const auto fq_unsigned_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {0}, {0.245}, {0}, {255}, ov::element::u8); + const auto fq4 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::u8, fq_unsigned_params); + + const auto transpose2 = std::make_shared(transpose2Param, transpose2Const); + const auto matMul1 = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(fq4, element::f32).get(), + ov::op::TemporaryReplaceOutputType(transpose2, element::f32).get(), transA, transB); + const auto fq5 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::i8, fq_signed_params); + + auto subgraph = std::make_shared(subgraph_inputs, + std::make_shared(NodeVector{fq5}, subgraph_params)); + // TODO: At the moment Snippets don't support explicitly Transpose. + // So we cannot collapse Transpose into Subgraph if there are ops between MatMul2 and Transpose3 + auto transpose3 = std::make_shared(subgraph, transpose3Const); + + return std::make_shared(NodeVector{transpose3}, ngraphParams); +} + } // namespace snippets } // namespace test } // namespace ov