From 5ff35d7529034b3a9fbf61df6cd49e2a717f3ef8 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Fri, 11 Oct 2024 11:02:57 +0200 Subject: [PATCH] Avoid failures in case of copyB absense --- .../src/transformations/snippets/x64/op/brgemm_cpu.cpp | 9 ++++----- .../src/transformations/snippets/x64/op/brgemm_utils.hpp | 2 +- .../snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp | 5 ++++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp index b40bd88f31726b..942ebc62bc9b17 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp @@ -68,13 +68,10 @@ void BrgemmCPU::custom_constructor_validate_and_infer_types(std::vector INTERNAL_OP_SCOPE(BrgemmCPU_constructor_validate_and_infer_types); validate_inputs(); - // During ctor call, BrgemmCPU doesn't know his port descriptors. - // So we use port descs from source inputs - const auto brgemm_copy = with_repacking(m_type) ? get_brgemm_copy() : nullptr; + // This shape inference can use get_input_partial_shape(1) in all cases const auto planar_input_shapes = std::vector{ snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_a), - brgemm_copy ? snippets::utils::get_planar_pshape(brgemm_copy->input(0)) - : snippets::utils::get_planar_pshape(get_input_partial_shape(1), layout_b) }; + snippets::utils::get_planar_pshape(get_input_partial_shape(1), layout_b) }; auto output_shape = infer_output_partial_shape(planar_input_shapes); set_output_type(0, get_output_type(), snippets::utils::get_planar_pshape(output_shape, layout_c)); @@ -141,6 +138,8 @@ std::shared_ptr BrgemmCPU::get_brgemm_copy() const { return brgemm_copy_b; } } + std::cout << "[ INFO ] get_brgemm_copy didn't find copy_B\n"; + return nullptr; OPENVINO_THROW("BrgemmCopyB hasn't been found!"); } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp index bc627c59920c4b..370c6a11a6b93d 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp @@ -17,7 +17,7 @@ enum class BRGEMM_TYPE { STAND_ALONE, // No extra requirements, used for f32|f32 WITH_AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations - REPACKING_ONLY // u8|i8 or bf16|bf16 (non-AMX system) - needs BrgemmCopyB on second input for data repacking + REPACKING_ONLY, // low precision or some specific f32 cases - needs BrgemmCopyB on second input for data repacking }; dnnl::impl::cpu::x64::cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp index 9b9de6001cdcf0..3fef80b357e9ba 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_cpu_blocking.cpp @@ -94,7 +94,10 @@ bool BrgemmCPUBlocking::mark_blocking_loops(LinearIR& linear_ir, brgemm_expr->get_input_port_descriptor(1)->set_subtensor({k_block, n_block}); brgemm_expr->get_output_port_descriptor(0)->set_subtensor({m_block, n_block}); - const auto copy_b_expr = linear_ir.get_expr_by_node(brgemm->get_brgemm_copy()); + auto copy_b = brgemm->get_brgemm_copy(); + if (!copy_b) + return true; + const auto copy_b_expr = linear_ir.get_expr_by_node(copy_b); copy_b_expr->get_input_port_descriptor(0)->set_subtensor({get_full_dim_value(), get_full_dim_value()}); copy_b_expr->get_output_port_descriptor(0)->set_subtensor({get_full_dim_value(), get_full_dim_value()}); if (with_compensations(type)) {