Skip to content

Commit

Permalink
Avoid failures in case of copyB absense
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 11, 2024
1 parent 0989f0d commit 5ff35d7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,10 @@ void BrgemmCPU::custom_constructor_validate_and_infer_types(std::vector<size_t>
INTERNAL_OP_SCOPE(BrgemmCPU_constructor_validate_and_infer_types);
validate_inputs();

// During ctor call, BrgemmCPU doesn't know his port descriptors.
// So we use port descs from source inputs
const auto brgemm_copy = with_repacking(m_type) ? get_brgemm_copy() : nullptr;
// This shape inference can use get_input_partial_shape(1) in all cases
const auto planar_input_shapes =
std::vector<ov::PartialShape>{ snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_a),
brgemm_copy ? snippets::utils::get_planar_pshape(brgemm_copy->input(0))
: snippets::utils::get_planar_pshape(get_input_partial_shape(1), layout_b) };
snippets::utils::get_planar_pshape(get_input_partial_shape(1), layout_b) };
auto output_shape = infer_output_partial_shape(planar_input_shapes);
set_output_type(0, get_output_type(), snippets::utils::get_planar_pshape(output_shape, layout_c));

Expand Down Expand Up @@ -141,6 +138,8 @@ std::shared_ptr<BrgemmCopyB> BrgemmCPU::get_brgemm_copy() const {
return brgemm_copy_b;
}
}
std::cout << "[ INFO ] get_brgemm_copy didn't find copy_B\n";
return nullptr;
OPENVINO_THROW("BrgemmCopyB hasn't been found!");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ enum class BRGEMM_TYPE {
STAND_ALONE, // No extra requirements, used for f32|f32
WITH_AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad
WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations
REPACKING_ONLY // u8|i8 or bf16|bf16 (non-AMX system) - needs BrgemmCopyB on second input for data repacking
REPACKING_ONLY, // low precision or some specific f32 cases - needs BrgemmCopyB on second input for data repacking
};

dnnl::impl::cpu::x64::cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ bool BrgemmCPUBlocking::mark_blocking_loops(LinearIR& linear_ir,
brgemm_expr->get_input_port_descriptor(1)->set_subtensor({k_block, n_block});
brgemm_expr->get_output_port_descriptor(0)->set_subtensor({m_block, n_block});

const auto copy_b_expr = linear_ir.get_expr_by_node(brgemm->get_brgemm_copy());
auto copy_b = brgemm->get_brgemm_copy();
if (!copy_b)
return true;
const auto copy_b_expr = linear_ir.get_expr_by_node(copy_b);
copy_b_expr->get_input_port_descriptor(0)->set_subtensor({get_full_dim_value(), get_full_dim_value()});
copy_b_expr->get_output_port_descriptor(0)->set_subtensor({get_full_dim_value(), get_full_dim_value()});
if (with_compensations(type)) {
Expand Down

0 comments on commit 5ff35d7

Please sign in to comment.