diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp index bbff2c1d07a91a..465d400795bd7e 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp @@ -58,11 +58,11 @@ void BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vectorget_shape()); const auto& element_type = get_input_element_type(0); - set_output_type(0, element_type, shape); + const auto& planar_pshape = snippets::utils::get_planar_pshape(shape, port->get_layout()); + set_output_type(0, element_type, planar_pshape); if (is_with_compensations()) { - set_output_type(1, ov::element::f32, shape); + set_output_type(1, ov::element::f32, planar_pshape); } - const auto& planar_pshape = snippets::utils::get_planar_pshape(shape, port->get_layout()); validate(planar_pshape, element_type); } @@ -90,27 +90,26 @@ void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k, m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape.rbegin(); } -ov::Shape intel_cpu::BrgemmCopyB::get_data_repacking_shape(const ov::PartialShape& planar_shape) const { - const size_t N = planar_shape.rbegin()->get_length(); - const size_t K = (planar_shape.rbegin() + 1)->get_length(); +ov::Shape intel_cpu::BrgemmCopyB::get_data_repacking_shape(const ov::snippets::VectorDims& planar_dims) const { + const auto& N = *planar_dims.rbegin(); + const auto& K = *(planar_dims.rbegin() + 1); return ov::Shape{rnd_up(K, m_brgemmVNNIFactor), rnd_up(N, m_N_blk)}; } -ov::Shape intel_cpu::BrgemmCopyB::get_compensation_shape(const ov::PartialShape& planar_shape) const { - const size_t N = planar_shape.rbegin()->get_length(); +ov::Shape intel_cpu::BrgemmCopyB::get_compensation_shape(const ov::snippets::VectorDims& planar_dims) const { + const auto& N = *planar_dims.rbegin(); return ov::Shape{rnd_up(N, m_N_blk)}; } std::shared_ptr intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs); check_new_args_count(this, new_args); - auto clone = std::make_shared(new_args.at(0), m_src_type, m_type, - get_input_port_descriptor(0), - get_output_port_descriptor(0), - is_with_compensations() ? get_output_port_descriptor(1) : PortDescriptor{}, - snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(), - m_K_blk, m_N_blk); - return clone; + return std::make_shared(new_args.at(0), m_src_type, m_type, + get_input_port_descriptor(0), + get_output_port_descriptor(0), + is_with_compensations() ? get_output_port_descriptor(1) : PortDescriptor{}, + snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(), + m_K_blk, m_N_blk); } size_t BrgemmCopyB::get_offset_compensations() const { @@ -119,5 +118,19 @@ size_t BrgemmCopyB::get_offset_compensations() const { return get_output_offset(1); } +BrgemmCopyB::ShapeInfer::ShapeInfer(const std::shared_ptr& n) { + const auto& brg_copyb = ov::as_type_ptr(n); + OPENVINO_ASSERT(brg_copyb, "Got invalid node in BrgemmCopyB::ShapeInfer"); + m_layout = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->input(0))->get_layout(); + m_num_outs = brg_copyb->get_output_size(); +} + +ov::snippets::IShapeInferSnippets::Result BrgemmCopyB::ShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(input_shapes.size() == 1, "Got unexpected number of input shapes"); + const auto planar_shape = ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_layout); + std::vector new_shapes(m_num_outs, planar_shape); + return {new_shapes, ov::snippets::ShapeInferStatus::success}; +} + } // namespace intel_cpu } // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp index 77f78f246545ae..9274ad026e5f01 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp @@ -5,6 +5,7 @@ #pragma once #include "snippets/op/memory_access.hpp" +#include "snippets/shape_types.hpp" #include namespace ov { @@ -43,8 +44,8 @@ class BrgemmCopyB : public snippets::op::MemoryAccess { void set_k_block_size(size_t block_size) { m_K_blk = block_size; } void set_n_block_size(size_t block_size) { m_N_blk = block_size; } - ov::Shape get_data_repacking_shape(const ov::PartialShape& planar_shape) const; - ov::Shape get_compensation_shape(const ov::PartialShape& planar_shape) const; + ov::Shape get_data_repacking_shape(const ov::snippets::VectorDims& planar_dims) const; + ov::Shape get_compensation_shape(const ov::snippets::VectorDims& planar_dims) const; Type get_type() const { return m_type; } size_t get_brgemm_vnni_factor() const { return m_brgemmVNNIFactor; } @@ -56,6 +57,14 @@ class BrgemmCopyB : public snippets::op::MemoryAccess { bool has_evaluate() const override { return false; } std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + class ShapeInfer : public snippets::IShapeInferSnippets { + std::vector m_layout{}; + size_t m_num_outs = 1; + public: + explicit ShapeInfer(const std::shared_ptr& n); + Result infer(const std::vector& input_shapes) override; + }; + private: void custom_constructor_validate_and_infer_types(std::vector layout_input = {}); void validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.cpp index 96997bf803f9f4..91bec8aee60d4a 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.cpp @@ -26,11 +26,10 @@ bool ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape::run(snippets::lowered::Lin if (auto copy_b = ov::as_type_ptr(expr->get_node())) { const auto buffer = get_buffer_from_output(expr, 0); const auto& out_desc = expr->get_output_port_descriptor(0); - const auto planar_pshape = ov::PartialShape(ov::snippets::utils::get_planar_vdims(out_desc->get_shape(), out_desc->get_layout())); - buffer->set_allocation_shape(copy_b->get_data_repacking_shape(planar_pshape)); + buffer->set_allocation_shape(copy_b->get_data_repacking_shape(out_desc->get_shape())); if (copy_b->is_with_compensations()) { const auto compensations_buffer = get_buffer_from_output(expr, 1); - compensations_buffer->set_allocation_shape(copy_b->get_compensation_shape(planar_pshape)); + compensations_buffer->set_allocation_shape(copy_b->get_compensation_shape(out_desc->get_shape())); } modified = true; } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp index 891ac89bb3c9ed..d09f3f218e67d9 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp @@ -36,8 +36,8 @@ const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::spec SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertTruncation, PassThroughShapeInfer), SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertSaturation, PassThroughShapeInfer), SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertTruncation, PassThroughShapeInfer), - SHAPE_INFER_PREDEFINED(ov::intel_cpu::BrgemmCopyB, PassThroughShapeInfer), // + SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCopyB), SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCPU), }; #undef SHAPE_INFER_OP_SPECIFIC