Skip to content

Commit

Permalink
Ivan's comments applied
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 10, 2023
1 parent 6014440 commit b639a86
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 89 deletions.
8 changes: 0 additions & 8 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ class Brgemm : public MemoryAccess {

bool has_evaluate() const override { return false; }

class ShapeInfer : public IShapeInferSnippets {
protected:
std::vector<std::vector<size_t>> m_io_layouts;
public:
explicit ShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

protected:
ov::element::Type get_output_type() const;
std::vector<ov::PartialShape> get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,12 @@ class HorizonOpShapeInfer : public IShapeInferSnippets {
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

class BrgemmShapeInfer : public IShapeInferSnippets {
std::vector<std::vector<size_t>> m_io_layouts;
public:
explicit BrgemmShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};

} // namespace snippets
} // namespace ov
71 changes: 0 additions & 71 deletions src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,77 +188,6 @@ ov::PartialShape Brgemm::get_output_partial_shape(const std::vector<ov::PartialS
}
return output_shape;
}

Brgemm::ShapeInfer::ShapeInfer(const std::shared_ptr<Node>& n) {
for (const auto& in : n->inputs()) {
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in);
m_io_layouts.push_back(port->get_layout());
}
m_io_layouts.push_back(get_output_layout(n));
}

IShapeInferSnippets::Result Brgemm::ShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 2, "BRGEMM expects 2 input shapes for shape inference");

// Todo: Ideally we should use the layout stored in PortDescriptors. Can we do it?
const auto& arg0_shape = snippets::utils::get_planar_vdims(input_shapes[0].get(), m_io_layouts[0]);
const auto& arg1_shape = snippets::utils::get_planar_vdims(input_shapes[1].get(), m_io_layouts[1]);

size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size();

// temporary shapes to calculate output shape
VectorDims arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape);

// one-dimensional tensors unsqueezing is applied to each input independently.
if (arg0_rank == 1) {
// If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
// by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
// For example {S} will be reshaped to {1, S}.
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1);
arg0_rank = arg0_shape_tmp.size();
}
if (arg1_rank == 1) {
// If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
// by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
// For example {S} will be reshaped to {S, 1}.
arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1);
arg1_rank = arg1_shape_tmp.size();
}

// add 1 to begin to align shape ranks if needed
if (arg0_rank < arg1_rank)
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1);
else if (arg0_rank > arg1_rank)
arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1);

size_t max_rank = arg0_shape_tmp.size();
VectorDims output_shape(max_rank);
for (size_t i = 0; i < max_rank - 2; ++i) {
if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) {
output_shape[i] = arg0_shape_tmp[i];
} else {
if (arg0_shape_tmp[i] == 1 || arg0_shape_tmp[i] == DYNAMIC_DIMENSION)
output_shape[i] = arg1_shape_tmp[i];
else if (arg1_shape_tmp[i] == 1 || arg1_shape_tmp[i] == DYNAMIC_DIMENSION)
output_shape[i] = arg0_shape_tmp[i];
else
OPENVINO_THROW("Incompatible Brgemm batch dimension");
}
}
output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M
output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N

// removing the temporary axes from originally 1D tensors.
if (arg0_shape.size() == 1) {
output_shape.erase(output_shape.begin() + output_shape.size() - 2);
}
if (arg1_shape.size() == 1) {
output_shape.erase(output_shape.begin() + output_shape.size() - 1);
}
output_shape = snippets::utils::get_planar_vdims(output_shape, m_io_layouts[2]);
return {{output_shape}, snippets::ShapeInferStatus::success};
}

} // namespace op
} // namespace snippets
} // namespace ov
72 changes: 72 additions & 0 deletions src/common/snippets/src/shape_inference/shape_infer_instances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//
#include "snippets/shape_inference/shape_infer_instances.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/utils.hpp"
#include "openvino/op/select.hpp"
namespace ov {
namespace snippets {
Expand Down Expand Up @@ -160,5 +161,76 @@ Result HorizonOpShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes
return {{output_shapes}, ShapeInferStatus::success};
}

BrgemmShapeInfer::BrgemmShapeInfer(const std::shared_ptr<Node>& n) {
for (const auto& in : n->inputs()) {
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in);
m_io_layouts.push_back(port->get_layout());
}
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->output(0));
m_io_layouts.push_back(port->get_layout());
}

Result BrgemmShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 2 || input_shapes.size() == 3, "BRGEMM expects 2 or 3 input shapes for shape inference");

// Todo: Ideally we should use the layout stored in PortDescriptors. Can we do it?
const auto& arg0_shape = ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_io_layouts[0]);
const auto& arg1_shape = ov::snippets::utils::get_planar_vdims(input_shapes[1].get(), m_io_layouts[1]);

size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size();

// temporary shapes to calculate output shape
VectorDims arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape);

// one-dimensional tensors unsqueezing is applied to each input independently.
if (arg0_rank == 1) {
// If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
// by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
// For example {S} will be reshaped to {1, S}.
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1);
arg0_rank = arg0_shape_tmp.size();
}
if (arg1_rank == 1) {
// If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
// by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
// For example {S} will be reshaped to {S, 1}.
arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1);
arg1_rank = arg1_shape_tmp.size();
}

// add 1 to begin to align shape ranks if needed
if (arg0_rank < arg1_rank)
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1);
else if (arg0_rank > arg1_rank)
arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1);

size_t max_rank = arg0_shape_tmp.size();
VectorDims output_shape(max_rank);
for (size_t i = 0; i < max_rank - 2; ++i) {
if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) {
output_shape[i] = arg0_shape_tmp[i];
} else {
if (arg0_shape_tmp[i] == 1 || arg0_shape_tmp[i] == DYNAMIC_DIMENSION)
output_shape[i] = arg1_shape_tmp[i];
else if (arg1_shape_tmp[i] == 1 || arg1_shape_tmp[i] == DYNAMIC_DIMENSION)
output_shape[i] = arg0_shape_tmp[i];
else
OPENVINO_THROW("Incompatible Brgemm batch dimension");
}
}
output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M
output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N

// removing the temporary axes from originally 1D tensors.
if (arg0_shape.size() == 1) {
output_shape.erase(output_shape.begin() + output_shape.size() - 2);
}
if (arg1_shape.size() == 1) {
output_shape.erase(output_shape.begin() + output_shape.size() - 1);
}
output_shape = ov::snippets::utils::get_planar_vdims(output_shape, m_io_layouts.back());
return {{output_shape}, snippets::ShapeInferStatus::success};
}

} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry
SHAPE_INFER_PREDEFINED(op::Kernel, EmptyShapeInfer),
SHAPE_INFER_PREDEFINED(op::Nop, EmptyShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(opset1::Select, SelectShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::Brgemm, BrgemmShapeInfer),
// Note that Result has no output PortConnectors, so the shape must be empty
SHAPE_INFER_PREDEFINED(ov::op::v0::Result, EmptyShapeInfer),
//
SHAPE_INFER_OP_SPECIFIC(op::LoadReshape),
SHAPE_INFER_OP_SPECIFIC(op::Brgemm),
SHAPE_INFER_OP_SPECIFIC(op::BroadcastLoad),
SHAPE_INFER_OP_SPECIFIC(op::BroadcastMove),
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ void BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vector<size_t
// So we use port descs from source inputs
const auto element_type = get_input_element_type(0);
const auto& pshape = get_input_partial_shape(0);
// The data always store in planar shape after repacking
const auto planar_pshape = snippets::utils::get_planar_pshape(pshape, layout_input);
// data repacking output
set_output_type(0, element_type, planar_pshape);
// If compensations are needed, they are provided in 2nd output (which is used in BrgemmCPU)
if (is_with_compensations()) {
set_output_type(1, ov::element::f32, planar_pshape);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,5 @@ size_t BrgemmCPU::get_offset_scratch() const {
return get_input_offset(2);
}

BrgemmCPU::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) : Brgemm::ShapeInfer(n) {}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,6 @@ class BrgemmCPU : public snippets::op::Brgemm {

constexpr static size_t SCRATCH_BYTE_SIZE = 32 * 1024;

class ShapeInfer : public Brgemm::ShapeInfer {
public:
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
};


private:
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c);
void compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ ShapeInferPtr CPUShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<InferType>();} }
#define SHAPE_INFER_OP_SPECIFIC(OP) \
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<OP::ShapeInfer>(n);} }
#define SHAPE_INFER_OP_SPECIFIC_EXTERNAL(OP, InferType) \
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<InferType>(n);} }

const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry {
SHAPE_INFER_PREDEFINED(ov::intel_cpu::FusedMulAdd, NumpyBroadcastShapeInfer),
Expand All @@ -36,9 +38,9 @@ const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::spec
SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertTruncation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertSaturation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertTruncation, PassThroughShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(ov::intel_cpu::BrgemmCPU, BrgemmShapeInfer),
//
SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCopyB),
SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCPU),
};
#undef SHAPE_INFER_OP_SPECIFIC
#undef SHAPE_INFER_PREDEFINED
Expand Down

0 comments on commit b639a86

Please sign in to comment.