From 0ac860895cdc75171c3f7af821ea84791d86e723 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Fri, 19 May 2023 10:32:44 +0400 Subject: [PATCH] Applied comments by Dmitry --- .../include/snippets/lowered/linear_ir.hpp | 36 +++++----- .../lowered/pass/allocate_buffers.hpp | 7 +- .../lowered/pass/insert_load_store.hpp | 1 + .../lowered/pass/vector_to_scalar.hpp | 48 ------------- .../snippets/lowered/port_descriptor.hpp | 4 +- .../snippets/include/snippets/op/subgraph.hpp | 2 +- src/common/snippets/src/generator.cpp | 2 +- .../snippets/src/lowered/expression.cpp | 4 +- src/common/snippets/src/lowered/linear_ir.cpp | 34 ++++----- .../src/lowered/pass/insert_buffers.cpp | 4 +- .../src/lowered/pass/insert_load_store.cpp | 18 +++-- .../load_movebroadcast_to_broadcastload.cpp | 2 +- .../lowered/pass/softmax_decomposition.cpp | 1 + .../src/lowered/pass/vector_to_scalar.cpp | 50 ------------- .../snippets/src/lowered/port_descriptor.cpp | 18 ++--- src/common/snippets/src/op/brgemm.cpp | 6 +- src/common/snippets/src/op/subgraph.cpp | 18 ++--- .../snippets/src/pass/collapse_subgraph.cpp | 5 +- .../src/pass/fuse_transpose_brgemm.cpp | 8 +-- .../snippets/src/pass/matmul_to_brgemm.cpp | 4 +- .../snippets/src/pass/set_softmax_ports.cpp | 4 +- .../src/pass/transpose_decomposition.cpp | 9 +-- src/common/snippets/src/utils.cpp | 6 +- .../snippets/tests/src/lowering_utils.cpp | 2 +- .../emitters/x64/jit_snippets_emitters.cpp | 33 ++++----- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 6 +- .../snippets/x64/op/brgemm_copy_b.cpp | 2 +- .../snippets/x64/op/brgemm_cpu.cpp | 12 ++-- .../x64/pass/brgemm_to_brgemm_cpu.cpp | 10 +-- .../lowered/fuse_load_store_and_convert.cpp | 4 +- .../plugin/shared/include/snippets/matmul.hpp | 14 ++-- .../plugin/shared/include/snippets/mha.hpp | 10 +-- .../plugin/shared/src/snippets/matmul.cpp | 71 ++++--------------- .../plugin/shared/src/snippets/mha.cpp | 68 ++++-------------- .../src/subgraph_lowered.cpp | 22 +++--- 35 files changed, 191 insertions(+), 354 deletions(-) delete mode 100644 src/common/snippets/include/snippets/lowered/pass/vector_to_scalar.hpp delete mode 100644 src/common/snippets/src/lowered/pass/vector_to_scalar.cpp diff --git a/src/common/snippets/include/snippets/lowered/linear_ir.hpp b/src/common/snippets/include/snippets/lowered/linear_ir.hpp index d725332566b546..ac42ce731bacaa 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir.hpp @@ -15,7 +15,7 @@ namespace lowered { class Config { public: // True if the lowered Emitters need to be accessed during runtime. Normally they're destroyed after code emission. - bool m_save_lowered_code = false; + bool m_save_expressions = false; // True if we should check runtime info for nodes to call specific needed transformations bool m_need_fill_tail_register = false; size_t m_loop_depth = 1; @@ -36,8 +36,8 @@ class LinearIR { static LinearIR::container deep_copy_range(LinearIR::container::const_iterator begin, LinearIR::container::const_iterator end); - const container& get_ops() const {return m_lowered_ops; } - const io_container& get_IO_ops() const {return m_io_lowered_ops; } + const container& get_ops() const {return m_expressions; } + const io_container& get_IO_ops() const {return m_io_expressions; } Config get_config() {return m_config; } const ExpressionPtr& get_expr_by_node(const std::shared_ptr& n) const; @@ -52,24 +52,24 @@ class LinearIR { */ void move(constExprIt from, constExprIt to); - bool empty() const noexcept {return m_lowered_ops.empty(); } + bool empty() const noexcept {return m_expressions.empty(); } void debug_print(bool tds_as_pointers = false) const; - container::reference back() noexcept {return m_lowered_ops.back();} - container::const_reference back() const noexcept {return m_lowered_ops.back();} - container::reference front() noexcept {return m_lowered_ops.front();} - container::const_reference front() const noexcept {return m_lowered_ops.front();} + container::reference back() noexcept {return m_expressions.back();} + container::const_reference back() const noexcept {return m_expressions.back();} + container::reference front() noexcept {return m_expressions.front();} + container::const_reference front() const noexcept {return m_expressions.front();} - exprIt begin() noexcept {return m_lowered_ops.begin();} - exprIt end() noexcept {return m_lowered_ops.end();} + exprIt begin() noexcept {return m_expressions.begin();} + exprIt end() noexcept {return m_expressions.end();} constExprIt begin() const noexcept {return cbegin();} constExprIt end() const noexcept {return cend();} - constExprIt cbegin() const noexcept {return m_lowered_ops.cbegin();} - constExprIt cend() const noexcept {return m_lowered_ops.cend();} - container::reverse_iterator rbegin() noexcept {return m_lowered_ops.rbegin();} - container::reverse_iterator rend() noexcept {return m_lowered_ops.rend();} - container::const_reverse_iterator crbegin() const noexcept {return m_lowered_ops.crbegin();} - container::const_reverse_iterator crend() const noexcept {return m_lowered_ops.crend();} + constExprIt cbegin() const noexcept {return m_expressions.cbegin();} + constExprIt cend() const noexcept {return m_expressions.cend();} + container::reverse_iterator rbegin() noexcept {return m_expressions.rbegin();} + container::reverse_iterator rend() noexcept {return m_expressions.rend();} + container::const_reverse_iterator crbegin() const noexcept {return m_expressions.crbegin();} + container::const_reverse_iterator crend() const noexcept {return m_expressions.crend();} exprIt insert(constExprIt pos, const ov::NodeVector& nodes); exprIt insert(constExprIt pos, const std::shared_ptr& n); @@ -97,9 +97,9 @@ class LinearIR { void register_expression(const ExpressionPtr& expr, bool io_allowed = false); void unregister_expression(const ExpressionPtr& expr); - container m_lowered_ops{}; + container m_expressions{}; std::unordered_map, std::shared_ptr> m_node2expression_map; - io_container m_io_lowered_ops; + io_container m_io_expressions; Config m_config{}; LoopManagerPtr m_loop_manager = nullptr; }; diff --git a/src/common/snippets/include/snippets/lowered/pass/allocate_buffers.hpp b/src/common/snippets/include/snippets/lowered/pass/allocate_buffers.hpp index c4b7530b951857..dd25b5872f5379 100644 --- a/src/common/snippets/include/snippets/lowered/pass/allocate_buffers.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/allocate_buffers.hpp @@ -14,7 +14,12 @@ namespace pass { /** * @interface AllocateBuffers - * @brief The pass calculation common size of buffer scratchpad and propagates Buffer offsets to connected MemoryAccess operations. + * @brief The pass calculates common size of buffer scratchpad and propagates Buffer offsets to connected MemoryAccess operations. + * Notes: + * - The pass implicitly regulates InPlace processing for some Buffers when it's possible. + * The pass don't allocate new memory for InPlace Buffers, we propagate the same offsets for them. + * - The pass should be splitted into two passes: ProcessInplace (markup of Buffers which can use the same memory) + * and AllocateBuffer (allocate memory for Buffers using MemorySolver which can optimally reuse memory). * @ingroup snippets */ diff --git a/src/common/snippets/include/snippets/lowered/pass/insert_load_store.hpp b/src/common/snippets/include/snippets/lowered/pass/insert_load_store.hpp index a5e489393aaed1..6b87b8dfa6b5fe 100644 --- a/src/common/snippets/include/snippets/lowered/pass/insert_load_store.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/insert_load_store.hpp @@ -33,6 +33,7 @@ class InsertLoadStore : public Pass { const ExpressionPort& actual_port, const std::vector& target_ports, bool is_entry = true); void update_loop(const LinearIR::LoopManager::LoopInfoPtr& loop_info, const ExpressionPort& actual_port, const std::vector& target_ports, bool is_entry = true); + size_t get_count(const PortDescriptorPtr& port_desc) const; size_t m_vector_size; }; diff --git a/src/common/snippets/include/snippets/lowered/pass/vector_to_scalar.hpp b/src/common/snippets/include/snippets/lowered/pass/vector_to_scalar.hpp deleted file mode 100644 index 4815c9fe524dd0..00000000000000 --- a/src/common/snippets/include/snippets/lowered/pass/vector_to_scalar.hpp +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (C) 2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "pass.hpp" - -namespace ov { -namespace snippets { -namespace lowered { -namespace pass { - -/** - * @interface SetScalarCountForLoadStore - * @brief Set count `1` for Load and Store to represent as ScalarLoad / ScalarStore - * The pass is used to change element count to loading to "1" to load or store scalar value - * Used for tail generation - * @ingroup snippets - */ - -// Note that, BrodacastMove is typically inserted right after the Load. Such cases are typical for -// simple subgraphs where one of the ov::op's inputs is broadcasted to match the larger one. However, BroadcastMove -// could also be inserted after the ov::op, if the op input don't need broadcasting, but the output does -// (for example, to match the larger output of a child node). In such cases, Loads (and Stores) should be replaced -// with ScalarLoads (ScalarStores) to avoid invalid read in vector Loop. Graph example: -// Parameter_0 Parameter_1 Parameter_2 -// [1,2,5,16] [1,2,5,1] [1,2,5,1] -// Load BroadcastLoad Load* Scalar -// Add Subtract -// \___________ ___________BroadcastMove -// \ / -// Multiply -// Store -// Result -// Note: Load* should be replaced with ScalarLoad in this example to avoid invalid read in vector Loop. - -class SetScalarCountForLoadStore : public Pass { -public: - explicit SetScalarCountForLoadStore(); - OPENVINO_RTTI("SetScalarCountForLoadStore", "Pass") - bool run(lowered::LinearIR& linear_ir) override; -}; - -} // namespace pass -} // namespace lowered -} // namespace snippets -} // namespace ov diff --git a/src/common/snippets/include/snippets/lowered/port_descriptor.hpp b/src/common/snippets/include/snippets/lowered/port_descriptor.hpp index 94a3a5fc526718..4d3100dd56182c 100644 --- a/src/common/snippets/include/snippets/lowered/port_descriptor.hpp +++ b/src/common/snippets/include/snippets/lowered/port_descriptor.hpp @@ -62,11 +62,11 @@ class PortDescriptor { std::vector m_layout{}; /// \brief Minimal tensor size that could be processed in one call std::vector m_subtensor_shape{}; - /// \brief The corresponding abstract register + /// \brief The corresponding abstract/physical register size_t m_reg = 0; }; -class PortManager { +class PortDescriptorUtils { public: static void set_port_descriptor_ptr(const ov::Input& n, const PortDescriptorPtr& desc); static void set_port_descriptor_ptr(const ov::Output& n, const PortDescriptorPtr& desc); diff --git a/src/common/snippets/include/snippets/op/subgraph.hpp b/src/common/snippets/include/snippets/op/subgraph.hpp index abea7a0a379ce0..9d63bcba1367a6 100644 --- a/src/common/snippets/include/snippets/op/subgraph.hpp +++ b/src/common/snippets/include/snippets/op/subgraph.hpp @@ -145,7 +145,7 @@ class Subgraph : public ov::op::util::SubGraphOp { private: void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes); void data_flow_transformations(ov::pass::Manager& pre_common, ov::pass::Manager& post_common, ov::pass::Manager& post_precision); - void control_flow_transformations(lowered::LinearIR& linear_ir, lowered::pass::PassPipeline& target_pipeline, const lowered::Config& config); + void control_flow_transformations(lowered::LinearIR& linear_ir, lowered::pass::PassPipeline& target_pipeline); void init_config(); // Count of Subgraph virtual ports: // - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition) diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index 8737911a7a8ce8..56747783303869 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -46,7 +46,7 @@ Generator::LoweringResult Generator::generate(lowered::LinearIR& linear_ir, cons // todo: we save lowered to access compiled brgemm kernels on execution time (normally lowered is destructed by then) // remove this when kernel caching is implemented. Don't forget to make generate const method. - if (config.m_save_lowered_code) + if (config.m_save_expressions) lowered_saved = linear_ir; return { target->get_snippet() }; diff --git a/src/common/snippets/src/lowered/expression.cpp b/src/common/snippets/src/lowered/expression.cpp index 49089b04459fea..c10fd08598cba8 100644 --- a/src/common/snippets/src/lowered/expression.cpp +++ b/src/common/snippets/src/lowered/expression.cpp @@ -20,10 +20,10 @@ Expression::Expression(const std::shared_ptr& n) : m_source_node{n}, m_emi m_input_port_descriptors.reserve(n->get_input_size()); m_output_port_descriptors.reserve(n->get_output_size()); for (const auto& input : n->inputs()) { - m_input_port_descriptors.push_back(PortManager::get_port_descriptor_ptr(input)); + m_input_port_descriptors.push_back(PortDescriptorUtils::get_port_descriptor_ptr(input)); } for (const auto& output : n->outputs()) { - m_output_port_descriptors.push_back(PortManager::get_port_descriptor_ptr(output)); + m_output_port_descriptors.push_back(PortDescriptorUtils::get_port_descriptor_ptr(output)); } } diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index 4fb370876d4dd1..0bc22204a54425 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -19,10 +19,10 @@ namespace snippets { namespace lowered { LinearIR::LinearIR(const std::shared_ptr& model, Config config) - : m_io_lowered_ops{}, m_config{std::move(config)}, m_loop_manager(std::make_shared()) { - constExprIt last_param = m_lowered_ops.end(); + : m_io_expressions{}, m_config{std::move(config)}, m_loop_manager(std::make_shared()) { + constExprIt last_param = m_expressions.end(); for (const auto& n : get_ordered_ops(model)) { - constExprIt insertion_pos = m_lowered_ops.end(); + constExprIt insertion_pos = m_expressions.end(); const auto expr = create_expression(n, model); // Scalar should be on the Linear IR beginning after Parameters to have valid expression order after Loop passes. @@ -33,10 +33,10 @@ LinearIR::LinearIR(const std::shared_ptr& model, Config config) } register_expression(expr, true); - const auto& it = m_lowered_ops.insert(insertion_pos, expr); + const auto& it = m_expressions.insert(insertion_pos, expr); if (const auto io_expr = std::dynamic_pointer_cast(expr)) { - m_io_lowered_ops.push_back(io_expr); + m_io_expressions.push_back(io_expr); if (ov::is_type(n)) last_param = it; } @@ -71,7 +71,7 @@ void LinearIR::serialize(const std::string& xml, const std::string& bin) { first_node->set_friendly_name("Start"); first_node->get_rt_info()["execTimeMcs"] = 0; std::shared_ptr body_node = first_node; - for (const auto& expr : m_lowered_ops) { + for (const auto& expr : m_expressions) { body_node = std::make_shared(body_node, expr); } auto last_node = std::make_shared(body_node); @@ -116,7 +116,7 @@ void LinearIR::debug_print(bool tds_as_pointers) const { std::map td2int; int td_counter = 0; int counter = 0; - for (const auto& expr : m_lowered_ops) { + for (const auto& expr : m_expressions) { const auto& node = expr->get_node(); std::cerr << counter++ << " : " << node->get_friendly_name() << " : "; @@ -148,7 +148,7 @@ void LinearIR::debug_print(bool tds_as_pointers) const { } void LinearIR::init_emitters(const std::shared_ptr& target) { - for (auto& expr : m_lowered_ops) { + for (auto& expr : m_expressions) { if (!expr->get_emitter()) expr->init_emitter(target); } @@ -206,12 +206,12 @@ void LinearIR::unregister_expression(const ExpressionPtr& expr) { LinearIR::exprIt LinearIR::insert(constExprIt pos, container::value_type&& value) { register_expression(value); - return m_lowered_ops.insert(pos, value); + return m_expressions.insert(pos, value); } LinearIR::exprIt LinearIR::insert(constExprIt pos, const container::value_type& value) { register_expression(value); - return m_lowered_ops.insert(pos, value); + return m_expressions.insert(pos, value); } LinearIR::exprIt LinearIR::insert(constExprIt pos, exprIt begin, exprIt end) { @@ -223,15 +223,15 @@ LinearIR::exprIt LinearIR::insert(constExprIt pos, exprIt begin, exprIt end) { LinearIR::exprIt LinearIR::insert(constExprIt pos, constExprIt begin, constExprIt end) { for (auto b = begin; b != end; b++) register_expression(*b); - return m_lowered_ops.insert(pos, begin, end); + return m_expressions.insert(pos, begin, end); } LinearIR::exprIt LinearIR::insert(LinearIR::constExprIt pos, const NodeVector& nodes) { - auto ret = m_lowered_ops.end(); + auto ret = m_expressions.end(); for (const auto& n : nodes) { const auto& expr = create_expression(n); register_expression(expr); - ret = m_lowered_ops.insert(pos, expr); + ret = m_expressions.insert(pos, expr); } // Need to return iterator to the first of the inserted values return std::prev(ret, static_cast(nodes.size())); @@ -240,22 +240,22 @@ LinearIR::exprIt LinearIR::insert(LinearIR::constExprIt pos, const NodeVector& n LinearIR::exprIt LinearIR::insert(LinearIR::constExprIt pos, const std::shared_ptr& n) { const auto& expr = create_expression(n); register_expression(expr); - return m_lowered_ops.insert(pos, expr); + return m_expressions.insert(pos, expr); } LinearIR::exprIt LinearIR::erase(LinearIR::exprIt pos) { unregister_expression(*pos); - return m_lowered_ops.erase(pos); + return m_expressions.erase(pos); } LinearIR::exprIt LinearIR::erase(LinearIR::constExprIt pos) { unregister_expression(*pos); - return m_lowered_ops.erase(pos); + return m_expressions.erase(pos); } void LinearIR::move(LinearIR::constExprIt from, LinearIR::constExprIt to) { // Instead of `insert()` + `erase()`, we use `splice()` for the same list - m_lowered_ops.splice(to, m_lowered_ops, from); + m_expressions.splice(to, m_expressions, from); } }// namespace lowered diff --git a/src/common/snippets/src/lowered/pass/insert_buffers.cpp b/src/common/snippets/src/lowered/pass/insert_buffers.cpp index 830903887f4d4d..1da65bd31f7036 100644 --- a/src/common/snippets/src/lowered/pass/insert_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/insert_buffers.cpp @@ -103,7 +103,7 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt // Need to insert between 2nd and 4th Loops - after 2nd Loop const auto pos = insertion_position(linear_ir, loop_manager, parent_expr, expr); const auto buffer = std::make_shared(parent->output(parent_port), m_buffer_allocation_rank); - PortManager::set_port_descriptor_ptr(buffer->output(0), parent_expr_output.get_descriptor_ptr()->clone()); + PortDescriptorUtils::set_port_descriptor_ptr(buffer->output(0), parent_expr_output.get_descriptor_ptr()->clone()); // Output tensor is automatically filled from PortDescriptor const auto buffer_expr = linear_ir.create_expression(buffer, {input_tensor}); linear_ir.insert(pos, buffer_expr); @@ -178,7 +178,7 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt const auto pos = insertion_position(linear_ir, loop_manager, expr, (*potential_consumers.begin()).get_expr()); auto buffer = std::make_shared(node->output(port), m_buffer_allocation_rank); - PortManager::set_port_descriptor_ptr(buffer->output(0), exit_point.get_descriptor_ptr()->clone()); + PortDescriptorUtils::set_port_descriptor_ptr(buffer->output(0), exit_point.get_descriptor_ptr()->clone()); // We cannot insert Node output tensor on Buffer output because not all consumers of Node needs Buffer // Example: // Add diff --git a/src/common/snippets/src/lowered/pass/insert_load_store.cpp b/src/common/snippets/src/lowered/pass/insert_load_store.cpp index 5e25bcfc314f32..ac025646c19cb6 100644 --- a/src/common/snippets/src/lowered/pass/insert_load_store.cpp +++ b/src/common/snippets/src/lowered/pass/insert_load_store.cpp @@ -50,6 +50,16 @@ void InsertLoadStore::update_loop(const LinearIR::LoopManager::LoopInfoPtr& loop ports.insert(port_it, target_ports.cbegin(), target_ports.cend()); } +size_t InsertLoadStore::get_count(const PortDescriptorPtr& port_desc) const { + const auto layout = port_desc->get_layout(); + const auto shape = port_desc->get_shape(); + // Find last dimension by layout + const auto last_dim_idx = std::find(layout.begin(), layout.end(), layout.size() - 1); + OPENVINO_ASSERT(last_dim_idx != layout.end(), "Load/Store expression have incorrect layout"); + const auto dim = shape[*last_dim_idx]; + return dim == 1 ? 1 : m_vector_size; +} + bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExprIt& data_expr_it) { const auto& loop_manager = linear_ir.get_loop_manager(); const auto& data_expr = *data_expr_it; @@ -71,8 +81,8 @@ bool InsertLoadStore::insert_load(LinearIR& linear_ir, const LinearIR::constExpr const auto inner_loop = get_inner_loop_id(loop_ids); OPENVINO_ASSERT(inner_loop != Expression::LOOP_NULL_ID, "Loop hasn't been found!"); - const auto load = std::make_shared(data_node->output(0), m_vector_size); - PortManager::set_port_descriptor_ptr(load->output(0), consumer_input.get_descriptor_ptr()->clone()); + const auto load = std::make_shared(data_node->output(0), get_count(data_expr->get_output_port_descriptor(0))); + PortDescriptorUtils::set_port_descriptor_ptr(load->output(0), consumer_input.get_descriptor_ptr()->clone()); const auto load_expr = linear_ir.create_expression(load, {output_tensor}); linear_ir.insert(std::find(data_expr_it, linear_ir.cend(), consumer_expr), load_expr); linear_ir.replace_input(consumer_input, load_expr->get_output_tensor(0)); @@ -106,8 +116,8 @@ bool InsertLoadStore::insert_store(LinearIR& linear_ir, const LinearIR::constExp const auto inner_loop = get_inner_loop_id(loop_ids); OPENVINO_ASSERT(inner_loop != Expression::LOOP_NULL_ID, "Loop hasn't been found!"); - const auto store = std::make_shared(parent->output(port), m_vector_size); - PortManager::set_port_descriptor_ptr(store->output(0), parent_output.get_descriptor_ptr()->clone()); + const auto store = std::make_shared(parent->output(port), get_count(data_expr->get_input_port_descriptor(0))); + PortDescriptorUtils::set_port_descriptor_ptr(store->output(0), parent_output.get_descriptor_ptr()->clone()); const auto store_expr = linear_ir.create_expression(store, {input_tensor}); const auto& reverse_insertion_pos = std::find(std::reverse_iterator(data_expr_it), linear_ir.crend(), parent_expr); const auto& insertion_pos = reverse_insertion_pos.base(); diff --git a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp index 22b3338c208df5..7d3f95380ba7fe 100644 --- a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp +++ b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp @@ -45,7 +45,7 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir) { const auto& outshape = move_broadcast->get_output_partial_shape(0); const auto broadcastload = std::make_shared(load->input_value(0), outshape, load->get_offset()); const auto move_consumers = expr->get_output_tensor(0)->get_consumers(); - PortManager::set_port_descriptor_ptr(broadcastload->output(0), expr->get_output_port(0).get_descriptor_ptr()->clone()); + PortDescriptorUtils::set_port_descriptor_ptr(broadcastload->output(0), expr->get_output_port(0).get_descriptor_ptr()->clone()); const auto broadcastload_expr = linear_ir.create_expression(broadcastload, { parent_expr->get_input_tensor(0) }); const auto mv_expr_it = expr_it; const auto insertion_pos = std::next(expr_it); diff --git a/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp b/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp index 9749977e3726c8..f1b5117e75da4b 100644 --- a/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp +++ b/src/common/snippets/src/lowered/pass/softmax_decomposition.cpp @@ -142,6 +142,7 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) { // For tail loop we should fill input of Max by float min and // input of Sum by zero to avoid math incorrect calculations + // TODO [111383]: It should be covered via general pipeline (for example, via analyze in InsertTailLoop?) max.second->input(0).get_rt_info()["set_fill"] = uint32_t(0xff7fffff); sum.second->input(0).get_rt_info()["set_fill"] = uint32_t(0x00000000); modified = true; diff --git a/src/common/snippets/src/lowered/pass/vector_to_scalar.cpp b/src/common/snippets/src/lowered/pass/vector_to_scalar.cpp deleted file mode 100644 index 8d776bad51108f..00000000000000 --- a/src/common/snippets/src/lowered/pass/vector_to_scalar.cpp +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (C) 2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "snippets/lowered/pass/vector_to_scalar.hpp" - -#include "snippets/snippets_isa.hpp" -#include "snippets/itt.hpp" - - -namespace ov { -namespace snippets { -namespace lowered { -namespace pass { - -SetScalarCountForLoadStore::SetScalarCountForLoadStore() {} - -bool SetScalarCountForLoadStore::run(LinearIR& linear_ir) { - OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SetScalarCountForLoadStore") - bool modified = false; - for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) { - const auto& expr = *expr_it; - const auto& op = expr->get_node(); - const auto load = ov::as_type_ptr(op); - const auto store = ov::as_type_ptr(op); - if (load || store) { - const auto& layout = load ? expr->get_input_port_descriptor(0)->get_layout() - : expr->get_output_port_descriptor(0)->get_layout(); - const auto& tensor_shape = load ? expr->get_input_port_descriptor(0)->get_shape() - : expr->get_output_port_descriptor(0)->get_shape(); - // Find last dimension by layout - const auto last_dim_idx = std::find(layout.begin(), layout.end(), layout.size() - 1); - OPENVINO_ASSERT(last_dim_idx != layout.end(), "Load/Store expression have incorrect layout"); - const auto dim = tensor_shape[*last_dim_idx]; - if (dim == 1) { - modified |= true; - if (load) load->set_count(1lu); - if (store) store->set_count(1lu); - } - } - } - return modified; -} - - - -} // namespace pass -} // namespace lowered -} // namespace snippets -} // namespace ov diff --git a/src/common/snippets/src/lowered/port_descriptor.cpp b/src/common/snippets/src/lowered/port_descriptor.cpp index ba838e8a068c60..719f77e7a56fb5 100644 --- a/src/common/snippets/src/lowered/port_descriptor.cpp +++ b/src/common/snippets/src/lowered/port_descriptor.cpp @@ -60,7 +60,9 @@ bool operator==(const PortDescriptor& lhs, const PortDescriptor& rhs) { lhs.m_subtensor_shape == rhs.m_subtensor_shape; } -void PortManager::init_default(std::vector& in_descs, std::vector& out_descs, const std::shared_ptr& node) { +void PortDescriptorUtils::init_default(std::vector& in_descs, + std::vector& out_descs, + const std::shared_ptr& node) { in_descs.resize(node->get_input_size()); out_descs.resize(node->get_output_size()); for (size_t i = 0; i < node->get_input_size(); ++i) { @@ -71,7 +73,7 @@ void PortManager::init_default(std::vector& in_descs, std::ve } } -void PortManager::set_port_descriptor_ptr(const ov::Input& in, const PortDescriptorPtr& desc) { +void PortDescriptorUtils::set_port_descriptor_ptr(const ov::Input& in, const PortDescriptorPtr& desc) { const auto& node = in.get_node()->shared_from_this(); auto& rt_info = node->get_rt_info(); const auto& key = PortDescriptorVectorAttribute::get_type_info_static(); @@ -89,7 +91,7 @@ void PortManager::set_port_descriptor_ptr(const ov::Input& in, const P } } -void PortManager::set_port_descriptor_ptr(const ov::Output& out, const PortDescriptorPtr& desc) { +void PortDescriptorUtils::set_port_descriptor_ptr(const ov::Output& out, const PortDescriptorPtr& desc) { const auto& node = out.get_node_shared_ptr(); auto& rt_info = node->get_rt_info(); const auto& key = PortDescriptorVectorAttribute::get_type_info_static(); @@ -107,10 +109,10 @@ void PortManager::set_port_descriptor_ptr(const ov::Output& out, const } } -PortDescriptorPtr PortManager::get_port_descriptor_ptr(const ov::Input& in) { +PortDescriptorPtr PortDescriptorUtils::get_port_descriptor_ptr(const ov::Input& in) { return get_port_descriptor_ptr(ov::Input(in.get_node(), in.get_index())); } -PortDescriptorPtr PortManager::get_port_descriptor_ptr(const ov::Input& in) { +PortDescriptorPtr PortDescriptorUtils::get_port_descriptor_ptr(const ov::Input& in) { const auto& node = in.get_node(); auto& rt_info = node->get_rt_info(); const auto& key = PortDescriptorVectorAttribute::get_type_info_static(); @@ -124,10 +126,10 @@ PortDescriptorPtr PortManager::get_port_descriptor_ptr(const ov::Input& out) { +PortDescriptorPtr PortDescriptorUtils::get_port_descriptor_ptr(const Output& out) { return get_port_descriptor_ptr(ov::Output(out.get_node(), out.get_index())); } -PortDescriptorPtr PortManager::get_port_descriptor_ptr(const Output& out) { +PortDescriptorPtr PortDescriptorUtils::get_port_descriptor_ptr(const Output& out) { const auto& node = out.get_node(); const auto& rt_info = node->get_rt_info(); const auto& key = PortDescriptorVectorAttribute::get_type_info_static(); @@ -141,7 +143,7 @@ PortDescriptorPtr PortManager::get_port_descriptor_ptr(const Output& node) { +void PortDescriptorUtils::clean(const std::shared_ptr& node) { auto& rt_info = node->get_rt_info(); rt_info.erase(PortDescriptorVectorAttribute::get_type_info_static()); } diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 4206d93568b76d..e02e0699a80b53 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -57,9 +57,9 @@ std::shared_ptr Brgemm::clone_with_new_inputs(const OutputVector& new_args check_new_args_count(this, new_args); return std::make_shared(new_args.at(0), new_args.at(1), get_offset_a(), get_offset_b(), get_offset_c(), - lowered::PortManager::get_port_descriptor_ptr(input(0))->get_layout(), - lowered::PortManager::get_port_descriptor_ptr(input(1))->get_layout(), - lowered::PortManager::get_port_descriptor_ptr(output(0))->get_layout()); + lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(), + lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(), + lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout()); } ov::element::Type Brgemm::get_output_type() const { diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 91c68fd37ac7d6..feb52579a9243c 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -29,7 +29,6 @@ #include "snippets/lowered/pass/init_loops.hpp" #include "snippets/lowered/pass/insert_buffers.hpp" #include "snippets/lowered/pass/insert_load_store.hpp" -#include "snippets/lowered/pass/vector_to_scalar.hpp" #include "snippets/lowered/pass/load_movebroadcast_to_broadcastload.hpp" #include "snippets/lowered/pass/allocate_buffers.hpp" #include "snippets/lowered/pass/propagate_layout.hpp" @@ -40,7 +39,6 @@ #include "snippets/lowered/pass/clean_repeated_ptr_shifts.hpp" #include "snippets/lowered/pass/identify_buffers.hpp" -#include "transformations/common_optimizations/nop_elimination.hpp" #include "transformations/utils/utils.hpp" #include @@ -513,14 +511,12 @@ void snippets::op::Subgraph::data_flow_transformations(ov::pass::Manager& pre_co } void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir, - lowered::pass::PassPipeline& target_pipeline, - const lowered::Config& config) { + lowered::pass::PassPipeline& target_pipeline) { INTERNAL_OP_SCOPE(Subgraph); OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::control_flow_transformations") - linear_ir = lowered::LinearIR(body_ptr(), config); const size_t vector_size = get_generator()->get_target_machine()->get_lanes(); - const int32_t buffer_allocation_rank = static_cast(config.m_loop_depth); + const int32_t buffer_allocation_rank = static_cast(linear_ir.get_config().m_loop_depth); // Note: The pass InitLoops uses LoopInfo that contains entry and exit points of the corresponding Loop. // To avoid the Loop information corruption, we should call the passes with Load/Store work @@ -532,7 +528,6 @@ void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& lin common_pipeline.register_pass(); common_pipeline.register_pass(buffer_allocation_rank); common_pipeline.register_pass(vector_size); - common_pipeline.register_pass(); common_pipeline.register_pass(); common_pipeline.register_pass(); common_pipeline.register_pass(); @@ -589,14 +584,15 @@ snippets::Schedule snippets::op::Subgraph::generate( OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::generate") NGRAPH_CHECK(m_generator != nullptr, "generate is called while generator is not set"); - lowered::LinearIR linear_ir; + data_flow_transformations(pre_common, post_common, post_precision); + lowered::Config lowering_config; - lowering_config.m_save_lowered_code = config.m_has_domain_sensitive_ops; + lowering_config.m_save_expressions = config.m_has_domain_sensitive_ops; lowering_config.m_need_fill_tail_register = config.m_has_domain_sensitive_ops; lowering_config.m_loop_depth = tileRank; - data_flow_transformations(pre_common, post_common, post_precision); - control_flow_transformations(linear_ir, target_lowered_pipeline, lowering_config); + lowered::LinearIR linear_ir = lowered::LinearIR(body_ptr(), lowering_config); + control_flow_transformations(linear_ir, target_lowered_pipeline); // actual code emission const auto& lowering_result = m_generator->generate(linear_ir, lowering_config, compile_params); diff --git a/src/common/snippets/src/pass/collapse_subgraph.cpp b/src/common/snippets/src/pass/collapse_subgraph.cpp index 43d87f57433e27..27bc8cd02d06e3 100644 --- a/src/common/snippets/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/src/pass/collapse_subgraph.cpp @@ -578,7 +578,10 @@ TokenizeSnippets::TokenizeSnippets() { OPENVINO_THROW("body results and node results size mismatch during subgraph collaps"); } - // todo: move this plugin-specific constraint to the plugin callback + // The each data node (Parameter (and non-Scalar Constants), Result, Buffers with the same ID) requires the own unique GPR. + // At the moment, CPU Plugin has limitation for GPR registers: there are only 12 available registers. + // This limitation will be resolved once generator supports gprs spills [75622]. + // TODO [75567]: move this plugin-specific constraint to the plugin callback const auto unique_buffer_count = op::Subgraph::get_estimated_buffer_count(new_body_ops); if (body_parameters.size() + body_results.size() + hidden_data_count + unique_buffer_count > 12) { const std::string message_reset = "new subgraph is created. Impossible to schedule subgraph with " + diff --git a/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp b/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp index 672181064aeffa..24a4141916e189 100644 --- a/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp +++ b/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp @@ -26,8 +26,8 @@ bool FuseTransposeBrgemm::is_supported_transpose(const Output& transpose_p // if Transpose in and out layout is not empty => something was already fused on this port auto default_layout = std::vector(transpose_port.get_shape().size()); std::iota(default_layout.begin(), default_layout.end(), 0);// NCHW layout by default - if (lowered::PortManager::get_port_descriptor_ptr(transpose_port)->get_layout() != default_layout || - lowered::PortManager::get_port_descriptor_ptr(transpose_node->input_value(0))->get_layout() != default_layout) + if (lowered::PortDescriptorUtils::get_port_descriptor_ptr(transpose_port)->get_layout() != default_layout || + lowered::PortDescriptorUtils::get_port_descriptor_ptr(transpose_node->input_value(0))->get_layout() != default_layout) return false; const auto& transpose_order = constant->cast_vector(); // todo: this limitation is due to the fact that offsets are calculated in Kernel, and the only way @@ -65,7 +65,7 @@ FuseTransposeBrgemm::FuseTransposeBrgemm() { const auto& brgemm_out = brgemm->output(0); const auto& transpose_out = m.get_match_value(); const auto& const_order = ov::as_type_ptr(transpose_out.get_node_shared_ptr()->get_input_node_shared_ptr(1)); - const auto& original_port = ov::snippets::lowered::PortManager::get_port_descriptor_ptr(brgemm_out); + const auto& original_port = ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(brgemm_out); original_port->set_shape(transpose_out.get_shape()); original_port->set_layout(const_order->cast_vector()); for (const auto& in : transpose_out.get_target_inputs()) @@ -79,7 +79,7 @@ FuseTransposeBrgemm::FuseTransposeBrgemm() { const auto& transpose = as_type_ptr(in_value.get_node_shared_ptr()); const auto& const_order = ov::as_type_ptr(transpose->get_input_node_shared_ptr(1)); brgemm->set_argument(i, transpose->input_value(0)); - const auto& original_port = ov::snippets::lowered::PortManager::get_port_descriptor_ptr(in); + const auto& original_port = ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(in); original_port->set_shape(transpose->get_input_shape(0)); original_port->set_layout(const_order->cast_vector()); } diff --git a/src/common/snippets/src/pass/matmul_to_brgemm.cpp b/src/common/snippets/src/pass/matmul_to_brgemm.cpp index dff0fe0689f828..ba2d8c6311abe6 100644 --- a/src/common/snippets/src/pass/matmul_to_brgemm.cpp +++ b/src/common/snippets/src/pass/matmul_to_brgemm.cpp @@ -22,11 +22,11 @@ void MatMulToBrgemm::init_ports(const std::shared_ptr& brgemm) const for (const auto& input : brgemm->inputs()) { const auto tensor = input.get_shape(); const auto subtensor = get_subtensor(tensor); - lowered::PortManager::set_port_descriptor_ptr(input, std::make_shared(tensor, subtensor)); + lowered::PortDescriptorUtils::set_port_descriptor_ptr(input, std::make_shared(tensor, subtensor)); } const auto tensor = brgemm->get_output_shape(0); const auto subtensor = get_subtensor(tensor); - lowered::PortManager::set_port_descriptor_ptr(brgemm->output(0), std::make_shared(tensor, subtensor)); + lowered::PortDescriptorUtils::set_port_descriptor_ptr(brgemm->output(0), std::make_shared(tensor, subtensor)); } MatMulToBrgemm::MatMulToBrgemm() { diff --git a/src/common/snippets/src/pass/set_softmax_ports.cpp b/src/common/snippets/src/pass/set_softmax_ports.cpp index edf28dd40d81d3..1651a6d6217495 100644 --- a/src/common/snippets/src/pass/set_softmax_ports.cpp +++ b/src/common/snippets/src/pass/set_softmax_ports.cpp @@ -47,8 +47,8 @@ ov::snippets::pass::SetSoftmaxPorts::SetSoftmaxPorts() { for (size_t i = axis; i < rank; ++i) subtensor[i] = lowered::PortDescriptor::ServiceDimensions::FULL_DIM; - lowered::PortManager::set_port_descriptor_ptr(root->input(0), std::make_shared(root->input(0), subtensor)); - lowered::PortManager::set_port_descriptor_ptr(root->output(0), std::make_shared(root->output(0), subtensor)); + lowered::PortDescriptorUtils::set_port_descriptor_ptr(root->input(0), std::make_shared(root->input(0), subtensor)); + lowered::PortDescriptorUtils::set_port_descriptor_ptr(root->output(0), std::make_shared(root->output(0), subtensor)); return true; }; diff --git a/src/common/snippets/src/pass/transpose_decomposition.cpp b/src/common/snippets/src/pass/transpose_decomposition.cpp index 24331bcddcf31f..bb581105a7523a 100644 --- a/src/common/snippets/src/pass/transpose_decomposition.cpp +++ b/src/common/snippets/src/pass/transpose_decomposition.cpp @@ -12,6 +12,7 @@ namespace ov { namespace snippets { namespace pass { +using namespace lowered; const std::set> TransposeDecomposition::supported_cases = {{0, 2, 3, 1}}; @@ -48,10 +49,10 @@ TransposeDecomposition::TransposeDecomposition() { auto load = std::make_shared(data_input, subtensor[0], 0, layout); auto store = std::make_shared(load, subtensor[0]); - lowered::PortManager::set_port_descriptor_ptr(load->input(0), std::make_shared(load->get_input_shape(0), subtensor, layout)); - lowered::PortManager::set_port_descriptor_ptr(load->output(0), std::make_shared(load->get_output_shape(0), subtensor)); - lowered::PortManager::set_port_descriptor_ptr(store->input(0), std::make_shared(store->get_input_shape(0), subtensor)); - lowered::PortManager::set_port_descriptor_ptr(store->output(0), std::make_shared(store->get_output_shape(0), subtensor)); + PortDescriptorUtils::set_port_descriptor_ptr(load->input(0), std::make_shared(load->get_input_shape(0), subtensor, layout)); + PortDescriptorUtils::set_port_descriptor_ptr(load->output(0), std::make_shared(load->get_output_shape(0), subtensor)); + PortDescriptorUtils::set_port_descriptor_ptr(store->input(0), std::make_shared(store->get_input_shape(0), subtensor)); + PortDescriptorUtils::set_port_descriptor_ptr(store->output(0), std::make_shared(store->get_output_shape(0), subtensor)); for (auto& input : transpose->output(0).get_target_inputs()) { input.replace_source_output(store->output(0)); diff --git a/src/common/snippets/src/utils.cpp b/src/common/snippets/src/utils.cpp index 5e5e0ec125a6b0..02ec54af2d8dbe 100644 --- a/src/common/snippets/src/utils.cpp +++ b/src/common/snippets/src/utils.cpp @@ -88,18 +88,18 @@ ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const } ov::PartialShape get_port_planar_shape(const Input& in) { - const auto& port = lowered::PortManager::get_port_descriptor_ptr(in); + const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in); return utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout()); } ov::PartialShape get_port_planar_shape(const Output& out) { - const auto& port = lowered::PortManager::get_port_descriptor_ptr(out); + const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(out); return utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout()); } void safe_copy_runtime_info(const std::shared_ptr& from, const std::shared_ptr& to) { ov::copy_runtime_info(from, to); - lowered::PortManager::clean(to); + lowered::PortDescriptorUtils::clean(to); } } // namespace utils diff --git a/src/common/snippets/tests/src/lowering_utils.cpp b/src/common/snippets/tests/src/lowering_utils.cpp index ca42012f1ae00f..ba3a4f91d43e33 100644 --- a/src/common/snippets/tests/src/lowering_utils.cpp +++ b/src/common/snippets/tests/src/lowering_utils.cpp @@ -62,7 +62,7 @@ void LoweringTests::SetUp() { void LoweringTests::TearDown() { ASSERT_TRUE(function); - auto cloned_function = ov::clone_model(*function); + auto cloned_function = function->clone(); if (!function_ref) { function_ref = cloned_function; } diff --git a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp index 3b1b97abdba86f..dd01900b52b086 100644 --- a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp @@ -6,22 +6,17 @@ #include -#include "snippets/lowered/expression.hpp" -#include "snippets/op/subgraph.hpp" #include "snippets/snippets_isa.hpp" -#include "snippets/utils.hpp" +#include "snippets/lowered/expression.hpp" +#include "snippets/lowered/tensor.hpp" #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" #include "transformations/snippets/x64/op//brgemm_cpu.hpp" -#include "snippets/snippets_isa.hpp" -#include "snippets/op/subgraph.hpp" -#include "snippets/lowered/tensor.hpp" using namespace InferenceEngine; -using ov::snippets::op::Subgraph; -using ov::snippets::AllocatedEmitter; using namespace Xbyak; using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64; +using ov::snippets::AllocatedEmitter; using ov::snippets::lowered::Expression; using ov::snippets::lowered::IOExpression; using ov::snippets::lowered::ExpressionPtr; @@ -68,10 +63,10 @@ void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, return physical_regs; }; - for (const auto& lowered_code : expressions) { - const auto& emitter = lowered_code->get_emitter(); + for (const auto& expression : expressions) { + const auto& emitter = expression->get_emitter(); std::vector in_abstract_regs, out_abstract_regs; - std::tie(in_abstract_regs, out_abstract_regs) = lowered_code->get_reg_info(); + std::tie(in_abstract_regs, out_abstract_regs) = expression->get_reg_info(); std::vector in_physical_regs, out_physical_regs; switch (std::dynamic_pointer_cast(emitter)->get_in_out_type()) { case gpr_to_gpr: @@ -96,8 +91,8 @@ void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, default: IE_THROW() << "Unhandled in_out type"; } - lowered_code->set_reg_info({in_physical_regs, out_physical_regs}); - if (auto container = std::dynamic_pointer_cast(lowered_code->get_emitter())) + expression->set_reg_info({in_physical_regs, out_physical_regs}); + if (auto container = std::dynamic_pointer_cast(expression->get_emitter())) container->map_abstract_registers(gpr_map_pool, vec_map_pool, expressions); } } @@ -310,10 +305,10 @@ void KernelEmitter::emit_impl(const std::vector& in, transform_idxs_to_regs(data_ptr_regs_idx, data_ptr_regs); init_data_pointers(num_inputs, num_inputs + num_outputs, num_unique_buffer, reg_indexes, reg_const_params, data_ptr_regs); - for (const auto& lowered_code : body) { - const auto& emitter = lowered_code->get_emitter(); + for (const auto& expression : body) { + const auto& emitter = expression->get_emitter(); std::vector in_regs, out_regs; - std::tie(in_regs, out_regs) = lowered_code->get_reg_info(); + std::tie(in_regs, out_regs) = expression->get_reg_info(); emitter->emit_code(in_regs, out_regs, vec_regs_pool, gp_regs_pool); } h->postamble(); @@ -745,10 +740,10 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: std::vector> brgemm_inputs = {brgemm_node->input(0), brgemm_copy ? brgemm_copy->input(0) : brgemm_node->input(1)}; for (const auto& input : brgemm_inputs) { - init_scheduling_params(snippets::lowered::PortManager::get_port_descriptor_ptr(input)->get_layout(), + init_scheduling_params(snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input)->get_layout(), input.get_shape()); } - init_scheduling_params(snippets::lowered::PortManager::get_port_descriptor_ptr(brgemm_node->output(0))->get_layout(), + init_scheduling_params(snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(brgemm_node->output(0))->get_layout(), brgemm_node->output(0).get_shape()); const auto& A_shape = brgemm_node->get_input_shape(0); @@ -1105,7 +1100,7 @@ BrgemmCopyBEmitter::BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, d if (m_with_comp) m_comp_offset = brgemm_repack->get_offset_compensations(); - const auto& layout = snippets::lowered::PortManager::get_port_descriptor_ptr(brgemm_repack->input(0))->get_layout(); + const auto& layout = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(brgemm_repack->input(0))->get_layout(); const auto& original_shape = brgemm_repack->get_input_shape(0); auto transposed_shape = original_shape; size_t leading_dimension = *(original_shape.rbegin()); diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index e0a494b78a9f69..d7ed25d06e4075 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -564,14 +564,14 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) { CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::RemoveConverts); CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::MulAddToFMA); - ov::snippets::lowered::pass::PassPipeline target_specific_pipeline; - CPU_REGISTER_PASS_X64(target_specific_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert); + ov::snippets::lowered::pass::PassPipeline control_flow_pipeline; + CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert); schedule = snippet->generate( pre_dialect, post_dialect, post_precision, - target_specific_pipeline, + control_flow_pipeline, reinterpret_cast(jcp)); } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp index 3916946af027ea..07ff18b167c8f5 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp @@ -79,7 +79,7 @@ std::shared_ptr intel_cpu::BrgemmCopyB::clone_with_new_inputs(const Output get_offset_in(), get_offset_out(), is_with_compensations() ? get_offset_compensations() : 0, - snippets::lowered::PortManager::get_port_descriptor_ptr(input(0))->get_layout()); + snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout()); } size_t intel_cpu::BrgemmCopyB::get_offset_compensations() const { diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp index 6ae0d428fa4473..1a378616819293 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp @@ -110,15 +110,15 @@ std::shared_ptr BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a if (!is_with_scratchpad()) { new_node = std::make_shared(new_args.at(0), new_args.at(1), m_type, get_offset_a(), get_offset_b(), get_offset_c(), - snippets::lowered::PortManager::get_port_descriptor_ptr(input(0))->get_layout(), - snippets::lowered::PortManager::get_port_descriptor_ptr(input(1))->get_layout(), - snippets::lowered::PortManager::get_port_descriptor_ptr(output(0))->get_layout()); + snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(), + snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(), + snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout()); } else { new_node = std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_type, get_offset_a(), get_offset_b(), get_offset_scratch(), get_offset_c(), - snippets::lowered::PortManager::get_port_descriptor_ptr(input(0))->get_layout(), - snippets::lowered::PortManager::get_port_descriptor_ptr(input(1))->get_layout(), - snippets::lowered::PortManager::get_port_descriptor_ptr(output(0))->get_layout()); + snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(), + snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(), + snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout()); } return new_node; } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp index e0414bc9a6c67c..0c492498af6ff3 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp @@ -33,11 +33,11 @@ std::vector make_subtensor(const ov::Shape& tensor) { template void set_full_port_desc(const T& port) { const auto& shape = port.get_shape(); - PortManager::set_port_descriptor_ptr(port, std::make_shared(shape, make_subtensor(shape))); + PortDescriptorUtils::set_port_descriptor_ptr(port, std::make_shared(shape, make_subtensor(shape))); } template void set_port_desc(const T& port, Args... params) { - PortManager::set_port_descriptor_ptr(port, std::make_shared(params...)); + PortDescriptorUtils::set_port_descriptor_ptr(port, std::make_shared(params...)); } } // namespace @@ -58,9 +58,9 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { return false; } - const auto& brgemm_in0_desc = PortManager::get_port_descriptor_ptr(brgemm->input(0)); - const auto& brgemm_in1_desc = PortManager::get_port_descriptor_ptr(brgemm->input(1)); - const auto& brgemm_out_desc = PortManager::get_port_descriptor_ptr(brgemm->output(0)); + const auto& brgemm_in0_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->input(0)); + const auto& brgemm_in1_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->input(1)); + const auto& brgemm_out_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->output(0)); const auto dimsMatMulIn0 = snippets::utils::get_port_planar_shape(brgemm->input_value(0)).get_shape(); const auto dimsMatMulIn1 = snippets::utils::get_port_planar_shape(brgemm->input_value(1)).get_shape(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.cpp index f5053df738db1a..ed93ea754b0a45 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.cpp @@ -46,7 +46,7 @@ bool ov::intel_cpu::pass::FuseLoadStoreConvert::fuse_load_convert(snippets::lowe const auto out_port = convert_expr->get_output_port(0); const auto convert_consumers = out_port.get_connected_ports(); - snippets::lowered::PortManager::set_port_descriptor_ptr(load_convert->output(0), out_port.get_descriptor_ptr()->clone()); + snippets::lowered::PortDescriptorUtils::set_port_descriptor_ptr(load_convert->output(0), out_port.get_descriptor_ptr()->clone()); const auto load_convert_expr = linear_ir.create_expression(load_convert, { load_expr->get_input_tensor(0) }); const auto convert_expr_it = convert_it; const auto insertion_pos = std::next(convert_it); @@ -91,7 +91,7 @@ bool ov::intel_cpu::pass::FuseLoadStoreConvert::fuse_store_convert(snippets::low const auto out_port = store_expr->get_output_port(0); const auto store_consumers = out_port.get_connected_ports(); - snippets::lowered::PortManager::set_port_descriptor_ptr(store_convert->output(0), out_port.get_descriptor_ptr()->clone()); + snippets::lowered::PortDescriptorUtils::set_port_descriptor_ptr(store_convert->output(0), out_port.get_descriptor_ptr()->clone()); const auto store_convert_expr = linear_ir.create_expression(store_convert, { input_td }); const auto convert_expr_it = convert_it; const auto insertion_pos = std::next(convert_it); diff --git a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp index 921585f0976418..d4139b11de07a9 100644 --- a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp @@ -19,37 +19,39 @@ typedef std::tuple< > MatMulParams; class MatMul : public testing::WithParamInterface, - virtual public ov::test::SnippetsTestsCommon { + virtual public ov::test::SnippetsTestsCommon { public: static std::string getTestCaseName(testing::TestParamInfo obj); protected: void SetUp() override; + + virtual void init_subgraph(const std::vector& inputShapes, const std::vector& types); }; class MatMulFQ : public MatMul { protected: - void SetUp() override; + void init_subgraph(const std::vector& inputShapes, const std::vector& types) override; }; class MatMulBias : public MatMul { protected: - void SetUp() override; + void init_subgraph(const std::vector& inputShapes, const std::vector& types) override; }; class MatMulBiasQuantized : public MatMul { protected: - void SetUp() override; + void init_subgraph(const std::vector& inputShapes, const std::vector& types) override; }; class MatMulsQuantized : public MatMul { protected: - void SetUp() override; + void init_subgraph(const std::vector& inputShapes, const std::vector& types) override; }; class MatMulsQuantizedSoftmax : public MatMul { protected: - void SetUp() override; + void init_subgraph(const std::vector& inputShapes, const std::vector& types) override; }; } // namespace snippets diff --git a/src/tests/functional/plugin/shared/include/snippets/mha.hpp b/src/tests/functional/plugin/shared/include/snippets/mha.hpp index 7794c9b286d312..8c15adbc8c3fc4 100644 --- a/src/tests/functional/plugin/shared/include/snippets/mha.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/mha.hpp @@ -29,22 +29,24 @@ class MHA : public testing::WithParamInterface, void SetUp() override; void generate_inputs(const std::vector& targetInputStaticShapes) override; + virtual void init_subgraph(); + + bool m_with_mul = false; }; class MHASelect : public MHA { protected: - void SetUp() override; - void generate_inputs(const std::vector& targetInputStaticShapes) override; + void init_subgraph() override; }; class MHAWOTransposeOnInputs : public MHA { protected: - void SetUp() override; + void init_subgraph() override; }; class MHAWOTranspose : public MHA { - void SetUp() override; + void init_subgraph() override; }; } // namespace snippets diff --git a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp index 10e567292f167a..6ef643e3efeee0 100644 --- a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp @@ -35,82 +35,41 @@ void MatMul::SetUp() { std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - auto f = ov::test::snippets::MatMulFunction(input_shapes, elem_types); - function = f.getOriginal(); + init_subgraph(input_shapes, elem_types); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); } } -void MatMulFQ::SetUp() { - std::vector input_shapes; - std::vector elem_types; - std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - - auto f = ov::test::snippets::FQMatMulFunction(input_shapes); +void MatMul::init_subgraph(const std::vector& inputShapes, const std::vector& types) { + auto f = ov::test::snippets::MatMulFunction(inputShapes, types); function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } } -void MatMulBias::SetUp() { - std::vector input_shapes; - std::vector elem_types; - std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - - auto f = ov::test::snippets::MatMulBiasFunction(input_shapes, elem_types); +void MatMulFQ::init_subgraph(const std::vector& inputShapes, const std::vector& types) { + auto f = ov::test::snippets::FQMatMulFunction(inputShapes); function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } } -void MatMulBiasQuantized::SetUp() { - std::vector input_shapes; - std::vector elem_types; - std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - - auto f = ov::test::snippets::MatMulBiasQuantizedFunction(input_shapes, elem_types); +void MatMulBias::init_subgraph(const std::vector& inputShapes, const std::vector& types) { + auto f = ov::test::snippets::MatMulBiasFunction(inputShapes, types); function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } } -void MatMulsQuantized::SetUp() { - std::vector input_shapes; - std::vector elem_types; - std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - - auto f = ov::test::snippets::MatMulsQuantizedFunction(input_shapes, elem_types); +void MatMulBiasQuantized::init_subgraph(const std::vector& inputShapes, const std::vector& types) { + auto f = ov::test::snippets::MatMulBiasQuantizedFunction(inputShapes, types); function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } } -void MatMulsQuantizedSoftmax::SetUp() { - std::vector input_shapes; - std::vector elem_types; - std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); +void MatMulsQuantized::init_subgraph(const std::vector& inputShapes, const std::vector& types) { + auto f = ov::test::snippets::MatMulsQuantizedFunction(inputShapes, types); + function = f.getOriginal(); +} - auto f = ov::test::snippets::MatMulsQuantizedSoftmaxFunction(input_shapes, elem_types); +void MatMulsQuantizedSoftmax::init_subgraph(const std::vector& inputShapes, const std::vector& types) { + auto f = ov::test::snippets::MatMulsQuantizedSoftmaxFunction(inputShapes, types); function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } } TEST_P(MatMul, CompareWithRefImpl) { diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index 7e2b7be9642fcc..2f5d17dbd8159a 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -43,23 +43,23 @@ std::string MHA::getTestCaseName(testing::TestParamInfo inputShapes; - bool withMul; ov::element::Type prc; std::map additionalConfig; - std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); + std::tie(inputShapes, m_with_mul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); - auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, withMul); - function = f.getOriginal(); + init_subgraph(); configuration.insert(additionalConfig.begin(), additionalConfig.end()); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + if (additionalConfig.empty() && !configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); } setInferenceType(prc); inType = outType = prc; + if (prc == ov::element::bf16) + abs_threshold = 0.3; } void MHA::generate_inputs(const std::vector& targetInputStaticShapes) { @@ -73,25 +73,9 @@ void MHA::generate_inputs(const std::vector& targetInputStaticSha } } -void MHASelect::SetUp() { - std::vector inputShapes; - bool withMul; - ov::element::Type prc; - std::map additionalConfig; - std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); - - auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes); +void MHA::init_subgraph() { + auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_with_mul); function = f.getOriginal(); - - configuration.insert(additionalConfig.begin(), additionalConfig.end()); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } - - setInferenceType(prc); - inType = outType = prc; } void MHASelect::generate_inputs(const std::vector& targetInputStaticShapes) { @@ -112,47 +96,21 @@ void MHASelect::generate_inputs(const std::vector& targetInputSta } } -void MHAWOTransposeOnInputs::SetUp() { - std::vector inputShapes; - bool withMul; - ov::element::Type prc; - std::map additionalConfig; - std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); +void MHASelect::init_subgraph() { + auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes); + function = f.getOriginal(); +} +void MHAWOTransposeOnInputs::init_subgraph() { auto f = ov::test::snippets::MHAWOTransposeOnInputsFunction(inputDynamicShapes); function = f.getOriginal(); - - configuration.insert(additionalConfig.begin(), additionalConfig.end()); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } - - setInferenceType(prc); - inType = outType = prc; } -void MHAWOTranspose::SetUp() { - std::vector inputShapes; - bool withMul; - ov::element::Type prc; - std::map additionalConfig; - std::tie(inputShapes, withMul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(inputShapes)); - +void MHAWOTranspose::init_subgraph() { auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes); function = f.getOriginal(); - - configuration.insert(additionalConfig.begin(), additionalConfig.end()); - - setInferenceType(prc); - inType = outType = prc; - if (prc == ov::element::bf16) - abs_threshold = 0.3; } - TEST_P(MHA, CompareWithRefImpl) { SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp index c6485da75acd22..6d9bb3e93f1cb9 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp @@ -80,33 +80,33 @@ std::shared_ptr Transpose0213MatMulLoweredFunction::initLowered() con // Note: validity of transpose_position values is checked in Transpose0213MatMulSinhFunction constructor if (transpose_position < 2) { const auto& anchor = data[transpose_position]->output(0); - const auto& td = ov::snippets::lowered::PortManager::get_port_descriptor_ptr(anchor); + const auto& td = ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(anchor); const auto& tensor = td->get_shape(); const auto& subtensor = td->get_subtensor(); } auto matmul = std::make_shared(data[0], data[1], 0, 0, 0, transpose_position == 0 ? layout : std::vector{}, - transpose_position == 1 ? layout : std::vector{}, - transpose_position == 2 ? layout : std::vector{}); + transpose_position == 1 ? layout : std::vector{}, + transpose_position == 2 ? layout : std::vector{}); auto result = std::make_shared(matmul); if (transpose_position == 2) { const auto& anchor = matmul->output(0); - const auto& td = ov::snippets::lowered::PortManager::get_port_descriptor_ptr(anchor); + const auto& td = ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(anchor); const auto& tensor = td->get_shape(); const auto& subtensor = td->get_subtensor(); - ov::snippets::lowered::PortManager::set_port_descriptor_ptr(anchor, + ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor_ptr(anchor, std::make_shared(tensor, - subtensor, - layout)); + subtensor, + layout)); } if (transpose_position < 2) { const auto& anchor = data[transpose_position]->output(0); - const auto& td = ov::snippets::lowered::PortManager::get_port_descriptor_ptr(anchor); + const auto& td = ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(anchor); const auto& tensor = td->get_shape(); const auto& subtensor = td->get_subtensor(); - ov::snippets::lowered::PortManager::set_port_descriptor_ptr(matmul->input(transpose_position), + ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor_ptr(matmul->input(transpose_position), std::make_shared(tensor, - subtensor, - layout)); + subtensor, + layout)); } matmul->validate_and_infer_types(); return std::make_shared(NodeVector{matmul}, data);