Skip to content

Commit

Permalink
BrgemCopyB: always planar layout on output
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 10, 2023
1 parent a71874e commit 6014440
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ void BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vector<size_t
// So we use port descs from source inputs
const auto element_type = get_input_element_type(0);
const auto& pshape = get_input_partial_shape(0);
set_output_type(0, element_type, pshape);
const auto planar_pshape = snippets::utils::get_planar_pshape(pshape, layout_input);
set_output_type(0, element_type, planar_pshape);
if (is_with_compensations()) {
set_output_type(1, ov::element::f32, pshape);
set_output_type(1, ov::element::f32, planar_pshape);
}
const auto planar_pshape = snippets::utils::get_planar_pshape(pshape, layout_input);
validate(planar_pshape, element_type);
}

Expand All @@ -71,11 +71,11 @@ void BrgemmCopyB::validate_and_infer_types() {
const auto port = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0));
const auto shape = ov::Shape(port->get_shape());
const auto& element_type = get_input_element_type(0);
set_output_type(0, element_type, shape);
const auto& planar_pshape = snippets::utils::get_planar_pshape(shape, port->get_layout());
set_output_type(0, element_type, planar_pshape);
if (is_with_compensations()) {
set_output_type(1, ov::element::f32, shape);
set_output_type(1, ov::element::f32, planar_pshape);
}
const auto& planar_pshape = snippets::utils::get_planar_pshape(shape, port->get_layout());
validate(planar_pshape, element_type);
}

Expand All @@ -90,27 +90,26 @@ void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k,
m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape.rbegin();
}

ov::Shape intel_cpu::BrgemmCopyB::get_data_repacking_shape(const ov::PartialShape& planar_shape) const {
const size_t N = planar_shape.rbegin()->get_length();
const size_t K = (planar_shape.rbegin() + 1)->get_length();
ov::Shape intel_cpu::BrgemmCopyB::get_data_repacking_shape(const ov::snippets::VectorDims& planar_dims) const {
const auto& N = *planar_dims.rbegin();
const auto& K = *(planar_dims.rbegin() + 1);
return ov::Shape{rnd_up(K, m_brgemmVNNIFactor), rnd_up(N, m_N_blk)};
}

ov::Shape intel_cpu::BrgemmCopyB::get_compensation_shape(const ov::PartialShape& planar_shape) const {
const size_t N = planar_shape.rbegin()->get_length();
ov::Shape intel_cpu::BrgemmCopyB::get_compensation_shape(const ov::snippets::VectorDims& planar_dims) const {
const auto& N = *planar_dims.rbegin();
return ov::Shape{rnd_up(N, m_N_blk)};
}

std::shared_ptr<Node> intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs);
check_new_args_count(this, new_args);
auto clone = std::make_shared<BrgemmCopyB>(new_args.at(0), m_src_type, m_type,
get_input_port_descriptor(0),
get_output_port_descriptor(0),
is_with_compensations() ? get_output_port_descriptor(1) : PortDescriptor{},
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(),
m_K_blk, m_N_blk);
return clone;
return std::make_shared<BrgemmCopyB>(new_args.at(0), m_src_type, m_type,
get_input_port_descriptor(0),
get_output_port_descriptor(0),
is_with_compensations() ? get_output_port_descriptor(1) : PortDescriptor{},
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(),
m_K_blk, m_N_blk);
}

size_t BrgemmCopyB::get_offset_compensations() const {
Expand All @@ -119,5 +118,19 @@ size_t BrgemmCopyB::get_offset_compensations() const {
return get_output_offset(1);
}

BrgemmCopyB::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) {
const auto& brg_copyb = ov::as_type_ptr<BrgemmCopyB>(n);
OPENVINO_ASSERT(brg_copyb, "Got invalid node in BrgemmCopyB::ShapeInfer");
m_layout = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->input(0))->get_layout();
m_num_outs = brg_copyb->get_output_size();
}

ov::snippets::IShapeInferSnippets::Result BrgemmCopyB::ShapeInfer::infer(const std::vector<ov::snippets::VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 1, "Got unexpected number of input shapes");
const auto planar_shape = ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_layout);
std::vector<ov::snippets::VectorDims> new_shapes(m_num_outs, planar_shape);
return {new_shapes, ov::snippets::ShapeInferStatus::success};
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include "snippets/op/memory_access.hpp"
#include "snippets/shape_types.hpp"
#include <snippets/shape_inference/shape_inference.hpp>

namespace ov {
Expand Down Expand Up @@ -43,8 +44,8 @@ class BrgemmCopyB : public snippets::op::MemoryAccess {
void set_k_block_size(size_t block_size) { m_K_blk = block_size; }
void set_n_block_size(size_t block_size) { m_N_blk = block_size; }

ov::Shape get_data_repacking_shape(const ov::PartialShape& planar_shape) const;
ov::Shape get_compensation_shape(const ov::PartialShape& planar_shape) const;
ov::Shape get_data_repacking_shape(const ov::snippets::VectorDims& planar_dims) const;
ov::Shape get_compensation_shape(const ov::snippets::VectorDims& planar_dims) const;

Type get_type() const { return m_type; }
size_t get_brgemm_vnni_factor() const { return m_brgemmVNNIFactor; }
Expand All @@ -56,6 +57,14 @@ class BrgemmCopyB : public snippets::op::MemoryAccess {
bool has_evaluate() const override { return false; }
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

class ShapeInfer : public snippets::IShapeInferSnippets {
std::vector<size_t> m_layout{};
size_t m_num_outs = 1;
public:
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
Result infer(const std::vector<snippets::VectorDimsRef>& input_shapes) override;
};

private:
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_input = {});
void validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ bool ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape::run(snippets::lowered::Lin
if (auto copy_b = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyB>(expr->get_node())) {
const auto buffer = get_buffer_from_output(expr, 0);
const auto& out_desc = expr->get_output_port_descriptor(0);
const auto planar_pshape = ov::PartialShape(ov::snippets::utils::get_planar_vdims(out_desc->get_shape(), out_desc->get_layout()));
buffer->set_allocation_shape(copy_b->get_data_repacking_shape(planar_pshape));
buffer->set_allocation_shape(copy_b->get_data_repacking_shape(out_desc->get_shape()));
if (copy_b->is_with_compensations()) {
const auto compensations_buffer = get_buffer_from_output(expr, 1);
compensations_buffer->set_allocation_shape(copy_b->get_compensation_shape(planar_pshape));
compensations_buffer->set_allocation_shape(copy_b->get_compensation_shape(out_desc->get_shape()));
}
modified = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::spec
SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertTruncation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertSaturation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertTruncation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::BrgemmCopyB, PassThroughShapeInfer),
//
SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCopyB),
SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCPU),
};
#undef SHAPE_INFER_OP_SPECIFIC
Expand Down

0 comments on commit 6014440

Please sign in to comment.