diff --git a/src/common/snippets/include/snippets/lowered/loop_manager.hpp b/src/common/snippets/include/snippets/lowered/loop_manager.hpp index 2b71e8605ab393..e0e13b56da2c59 100644 --- a/src/common/snippets/include/snippets/lowered/loop_manager.hpp +++ b/src/common/snippets/include/snippets/lowered/loop_manager.hpp @@ -22,7 +22,16 @@ class LinearIR::LoopManager { struct LoopPort { LoopPort() = default; LoopPort(const ExpressionPort& port, bool is_incremented = true, size_t dim_idx = 0) - : expr_port(std::make_shared(port)), is_incremented(is_incremented), dim_idx(dim_idx) {} + : expr_port(std::make_shared(port)), + is_incremented(is_incremented), + dim_idx(dim_idx) { + OPENVINO_ASSERT(dim_idx < port.get_descriptor_ptr()->get_shape().size(), + "LoopPort dim_idx (", + dim_idx, + ") must be less than the corresponding expression port shape rank (", + port.get_descriptor_ptr()->get_shape().size(), + ")"); + } friend bool operator==(const LoopPort& lhs, const LoopPort& rhs); friend bool operator!=(const LoopPort& lhs, const LoopPort& rhs); @@ -93,12 +102,12 @@ class LinearIR::LoopManager { const std::vector& entries, const std::vector& exits) { const auto loop_info = std::make_shared(work_amount, work_amount_increment, entries, exits); - for (auto& entry : loop_info->entry_points) { - entry.dim_idx = dim_idx; - } - for (auto& exit : loop_info->exit_points) { - exit.dim_idx = dim_idx; - } + auto set_common_dim_idx = [dim_idx](std::vector& ports) { + for (auto& port : ports) + port.dim_idx = dim_idx; + }; + set_common_dim_idx(loop_info->entry_points); + set_common_dim_idx(loop_info->exit_points); const auto loop_id = this->add_loop_info(loop_info); for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) { insert_loop_id(*expr_it, loop_id); diff --git a/src/common/snippets/src/lowered/loop_manager.cpp b/src/common/snippets/src/lowered/loop_manager.cpp index e3a0d2aec40250..6e0ac6090e840a 100644 --- a/src/common/snippets/src/lowered/loop_manager.cpp +++ b/src/common/snippets/src/lowered/loop_manager.cpp @@ -28,8 +28,7 @@ LinearIR::LoopManager::LoopInfo::LoopInfo(size_t work_amount, size_t increment, } size_t LinearIR::LoopManager::LoopInfo::get_dim_idx() const { - if (entry_points.empty()) - return SIZE_MAX; + OPENVINO_ASSERT(!entry_points.empty(), "Loop info must have at least one entry point"); auto equal_dim_idxes = [&](const LinearIR::LoopManager::LoopPort& p) { return p.dim_idx == entry_points[0].dim_idx; }; diff --git a/src/common/snippets/src/lowered/pass/identify_buffers.cpp b/src/common/snippets/src/lowered/pass/identify_buffers.cpp index 02aabc93ead6ac..dc23da6be6d9f5 100644 --- a/src/common/snippets/src/lowered/pass/identify_buffers.cpp +++ b/src/common/snippets/src/lowered/pass/identify_buffers.cpp @@ -51,39 +51,8 @@ std::vector IdentifyBuffers::create_adjacency_matrix(const LinearIR& linea } }; - auto is_buffer = [](const ExpressionPort& port) { - return ov::is_type(port.get_expr()->get_node()); - }; - for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) { const auto &expr = *expr_it; - if (const auto brgemm = ov::as_type_ptr(expr->get_node())) { - const auto consumers = expr->get_output_port_connector(0)->get_consumers(); - - auto buffer_it = std::find_if(consumers.begin(), consumers.end(), is_buffer); - if (buffer_it == consumers.end()) - continue; - OPENVINO_ASSERT(std::count_if(consumers.begin(), consumers.end(), is_buffer) == 1, "Brgemm mustn't have more than 1 consumer buffer"); - - std::vector> adjacency_buffers; - adjacency_buffers.push_back(ov::as_type_ptr(buffer_it->get_expr()->get_node())); - - for (const auto& input_connector : expr->get_input_port_connectors()) { - const auto parent_node = input_connector->get_source().get_expr()->get_node(); - if (const auto neighbour_buffer = ov::as_type_ptr(parent_node)) { - adjacency_buffers.push_back(neighbour_buffer); - } - } - for (auto buffer_it = adjacency_buffers.begin(); buffer_it != adjacency_buffers.end(); ++buffer_it) { - for (auto neighbour_it = std::next(buffer_it); neighbour_it != adjacency_buffers.end(); ++neighbour_it) { - const auto buffer_idx = get_buffer_idx(*buffer_it); - const auto neighbour_idx = get_buffer_idx(*neighbour_it); - adj[index(size, neighbour_idx, buffer_idx)] = adj[index(size, buffer_idx, neighbour_idx)] = true; - } - } - continue; - } - const auto& loop_end = ov::as_type_ptr(expr->get_node()); if (!loop_end) continue; diff --git a/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp b/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp index f864dc65f8184d..d9be4627073d1a 100644 --- a/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp +++ b/src/common/snippets/src/lowered/pass/insert_tail_loop.cpp @@ -337,7 +337,7 @@ bool InsertTailLoop::run(LinearIR& linear_ir) { continue; const auto loop_info = loop_manager->get_loop_info(loop_end->get_id()); - if (loop_info->fst_iter_handler != nullptr) { + if (loop_info->fst_iter_handler) { modified |= loop_info->fst_iter_handler(linear_ir, expr_it); continue; } diff --git a/src/common/snippets/src/lowered/pass/optimize_loop_single_evaluation.cpp b/src/common/snippets/src/lowered/pass/optimize_loop_single_evaluation.cpp index 4244c09c7e658c..317eb32f7ab1fe 100644 --- a/src/common/snippets/src/lowered/pass/optimize_loop_single_evaluation.cpp +++ b/src/common/snippets/src/lowered/pass/optimize_loop_single_evaluation.cpp @@ -19,8 +19,8 @@ bool OptimizeLoopSingleEvaluation::run(LinearIR& linear_ir) { return false; bool is_modified = false; - for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) { - if (auto loop_end = ov::as_type_ptr(expr_it->get()->get_node())) { + for (const auto& expr : linear_ir) { + if (auto loop_end = ov::as_type_ptr(expr->get_node())) { // *1* solo vector/tail loop + empty outer loop // => skip increments (both counter & ptr) : set evaluate_once flag // *2* solo vector/tail loop + non-empty outer loop diff --git a/src/common/snippets/tests/src/lowering_utils.cpp b/src/common/snippets/tests/src/lowering_utils.cpp index fd0aca4042e81f..848e7e81536355 100644 --- a/src/common/snippets/tests/src/lowering_utils.cpp +++ b/src/common/snippets/tests/src/lowering_utils.cpp @@ -107,13 +107,13 @@ std::shared_ptr const ov::snippets::lowered::pass::PassPipeline& lowered_pre_common, const ov::snippets::lowered::pass::PassPipeline& lowered_post_common, const std::shared_ptr& generator, - const std::shared_ptr& factory) { + const std::shared_ptr& shape_infer_factory) { auto subgraph = getTokenizedSubgraph(f); subgraph->set_generator(generator == nullptr ? std::make_shared() : generator); subgraph->set_master_shape(master_shape); subgraph->set_tile_rank(2); // Note: lowered_pipeline would have no effect on subgraph body, since it's applied on linear IR - subgraph->generate(backend_passes, lowered_pre_common, lowered_post_common, factory); + subgraph->generate(backend_passes, lowered_pre_common, lowered_post_common, shape_infer_factory); return subgraph; } diff --git a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp index cf89f4c58937a1..0246a986890809 100644 --- a/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/x64/jit_snippets_emitters.cpp @@ -803,20 +803,8 @@ std::set> BrgemmEmitter::get_supported_precisions(con } void BrgemmEmitter::validate_arguments(const std::vector &in, const std::vector &out) const { - std::set unique_ids{in[0], in[1], out[0]}; - size_t unique_ids_count = 3; - auto add_reg_to_unique_ids = [&](const size_t reg_number) { - unique_ids.insert(reg_number); - unique_ids_count++; - }; - - if (m_with_scratch) { - if (in.size() != 3) - IE_THROW() << "BRGEMM Emitter expects 3 inputs if there are compensations/wsp"; - add_reg_to_unique_ids(in[2]); - } - if (unique_ids.size() != unique_ids_count) { - IE_THROW() << "BRGEMM Emitter expects that all input/output registers are unique"; + if (m_with_scratch && in.size() != 3) { + IE_THROW() << "BRGEMM Emitter expects 3 inputs if there are compensations/wsp"; } } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp index 643b5d74fc963b..193d2ce808f002 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp @@ -57,32 +57,31 @@ void BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vectorget_shape()); - const auto& element_type = get_input_element_type(0); const auto& planar_pshape = snippets::utils::get_planar_pshape(shape, port->get_layout()); set_output_type(0, element_type, planar_pshape); if (is_with_compensations()) { set_output_type(1, ov::element::f32, planar_pshape); } - validate(planar_pshape, element_type); } -void BrgemmCopyB::validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type) { +void BrgemmCopyB::validate_element_type(const ov::element::Type& element_type) { OPENVINO_ASSERT(one_of(element_type, element::bf16, element::i8), "BrgemmCopyB doesn't support element type" + element_type.get_type_name()); } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp index 9274ad026e5f01..f803e5d55fcb8d 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp @@ -67,7 +67,7 @@ class BrgemmCopyB : public snippets::op::MemoryAccess { private: void custom_constructor_validate_and_infer_types(std::vector layout_input = {}); - void validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type); + void validate_element_type(const ov::element::Type& element_type); void compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n); Type m_type = Type::OnlyRepacking; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp index 224c1de826677a..76c69a831af276 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp @@ -16,7 +16,7 @@ namespace intel_cpu { BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Type type, const size_t offset_a, const size_t offset_b, const size_t offset_c, std::vector layout_a, std::vector layout_b, std::vector layout_c, - const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) + const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n, const float beta) : Brgemm(), m_type(type) { // We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call set_arguments({A, B}); @@ -32,8 +32,8 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Type ty BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output& scratch, const Type type, const size_t offset_a, const size_t offset_b, const size_t offset_scratch, const size_t offset_c, std::vector layout_a, std::vector layout_b, std::vector layout_c, - const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) - : Brgemm(), m_type(type) { + const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n, const float beta) + : Brgemm(), m_type(type), m_beta(beta) { set_arguments({A, B, scratch}); set_output_size(1); ctor_initialize(std::set{0, 1, 2}, std::set{0}); @@ -48,8 +48,8 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output< BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Type type, const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c, std::vector layout_a, std::vector layout_b, std::vector layout_c, - const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) - : Brgemm(), m_type(type) { + const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n, const float beta) + : Brgemm(), m_type(type), m_beta(beta) { set_arguments({A, B}); set_output_size(1); m_input_ports = {{0, desc_a}, {1, desc_b}}; @@ -61,8 +61,8 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Type ty BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output& scratch, const Type type, const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_scratch, const PortDescriptor& desc_c, std::vector layout_a, std::vector layout_b, std::vector layout_c, - const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) - : Brgemm(), m_type(type) { + const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n, const float beta) + : Brgemm(), m_type(type), m_beta(beta) { set_arguments({A, B, scratch}); set_output_size(1); m_input_ports = {{0, desc_a}, {1, desc_b}, {2, desc_scratch}}; @@ -136,22 +136,20 @@ std::shared_ptr BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a check_new_args_count(this, new_args); std::shared_ptr brgemm; if (!is_with_scratchpad()) { - brgemm = std::make_shared(new_args.at(0), new_args.at(1), m_type, + return std::make_shared(new_args.at(0), new_args.at(1), m_type, get_input_port_descriptor(0), get_input_port_descriptor(1), get_output_port_descriptor(0), snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(), snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(), snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout(), - m_M_blk, m_K_blk, m_N_blk); + m_M_blk, m_K_blk, m_N_blk, m_beta); } else { - brgemm = std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_type, + return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_type, get_input_port_descriptor(0), get_input_port_descriptor(1), get_input_port_descriptor(2), get_output_port_descriptor(0), snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(), snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(), snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout(), - m_M_blk, m_K_blk, m_N_blk); + m_M_blk, m_K_blk, m_N_blk, m_beta); } - brgemm->set_beta(get_beta()); - return brgemm; } std::shared_ptr BrgemmCPU::get_brgemm_copy() const { diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp index 8f483d8d4e7733..1ea2418f995463 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp @@ -32,19 +32,19 @@ class BrgemmCPU : public snippets::op::Brgemm { BrgemmCPU(const Output& A, const Output& B, const Type type, const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_c = 0, std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}, - const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0); + const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0, const float beta = 0.f); BrgemmCPU(const Output& A, const Output& B, const Output& scratch, const Type type, const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_scratch = 0, const size_t offset_c = 0, std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}, - const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0); + const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0, const float beta = 0.f); BrgemmCPU(const Output& A, const Output& B, const Type type, const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c, std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}, - const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0); + const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0, const float beta = 0.f); BrgemmCPU(const Output& A, const Output& B, const Output& scratch, const Type type, const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_scratch, const PortDescriptor& desc_c, std::vector layout_a = {}, std::vector layout_b = {}, std::vector layout_c = {}, - const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0); + const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0, const float beta = 0.f); BrgemmCPU() = default; void validate_and_infer_types() override; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp index e419f01074ebe9..797afeb3e4c9e7 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp @@ -7,6 +7,7 @@ #include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "snippets/itt.hpp" +#include "snippets/utils.hpp" #include "snippets/lowered/linear_ir.hpp" #include "snippets/lowered/loop_manager.hpp" #include "snippets/lowered/pass/insert_tail_loop.hpp" @@ -31,7 +32,7 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) { const auto& loop_manager = linear_ir.get_loop_manager(); auto blocking_loop_exists = [&](const ExpressionPtr& brgemm_expr, const std::shared_ptr& brgemm) { auto check_port = [&](const LoopPort& p) { - return p.expr_port->get_expr() == brgemm_expr && (p.dim_idx == 0 || p.dim_idx == 1); + return p.expr_port->get_expr() == brgemm_expr && ov::snippets::utils::one_of(p.dim_idx, 0ul, 1ul); }; const auto& loop_ids = brgemm_expr->get_loop_ids(); @@ -74,7 +75,8 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) { *(input_0_subtensor.rbegin() + 1) = block_size_m; *(output_subtensor.rbegin() + 1) = block_size_m; - std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true), LoopPort(brgemm_expr->get_input_port(1), false)}; + std::vector entries{LoopPort(brgemm_expr->get_input_port(0), true), + LoopPort(brgemm_expr->get_input_port(1), false)}; if (brgemm->is_with_scratchpad()) entries.emplace_back(brgemm_expr->get_input_port(2), false); std::vector exits{LoopPort(brgemm_expr->get_output_port(0), true)}; @@ -98,8 +100,11 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) { std::vector entries{LoopPort(brgemm_expr->get_input_port(0), false), LoopPort(brgemm_expr->get_input_port(1), true)}; - if (brgemm->is_with_scratchpad()) - entries.emplace_back(brgemm_expr->get_input_port(2), true); + if (brgemm->is_with_scratchpad()) { + // The second input of Brgemm for AMX case is scratch buffer so it mustn't be incremented + const bool is_incremented = brgemm->is_with_compensations() ? true : false; + entries.emplace_back(brgemm_expr->get_input_port(2), is_incremented); + } std::vector exits{LoopPort(brgemm_expr->get_output_port(0), true)}; loop_manager->mark_loop(expr_it, std::next(expr_it), n, block_size_n, 0, entries, exits); } @@ -129,6 +134,7 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) { auto first_iter_handler = [](LinearIR& linear_ir, LinearIR::constExprIt expr_it) { const auto loop_end = ov::as_type_ptr(expr_it->get()->get_node()); + OPENVINO_ASSERT(loop_end, "First loop iteraton handler must be called on LoopEnd expression"); const auto loop_id = loop_end->get_id(); const auto& loop_manager = linear_ir.get_loop_manager(); const auto& loop_info = loop_manager->get_loop_info(loop_id);