Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets] Added support of BF16/I8/U8 for MatMul #15063

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions src/common/snippets/include/snippets/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,6 @@ class TargetMachine {
*/
virtual size_t get_lanes() const = 0;

/**
* @interface opRegType
* @brief Register type of operations
* Note that currently there are 4 types of ops:
* gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd etc)
* gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc.
* vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc.
*/
enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec};
/**
* @brief gets register type by op type
* @return register type
*/
opRegType get_op_reg_type(const std::shared_ptr<Node>& op) const;

/**
* @brief called by generator to all the emitter for a target machine
* @return a map by node's type info with callbacks to create an instance of emitter for corresponding operation type
Expand All @@ -78,12 +63,6 @@ class TargetMachine {
virtual ~TargetMachine() = default;

protected:
/**
* @brief gets register type by specific plugin op type
* @return register type
*/
virtual opRegType get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const = 0;

std::map<const ngraph::DiscreteTypeInfo, std::function<std::shared_ptr<Emitter>(std::shared_ptr<ngraph::Node>)>> jitters;
};

Expand Down Expand Up @@ -164,7 +143,29 @@ class Generator {
*/
std::shared_ptr<const TargetMachine> get_target_machine() const;

/**
* @interface opRegType
* @brief Register type of operations
* Note that currently there are 4 types of ops:
* gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd etc)
* gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc.
* vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc.
*/
enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec};
/**
* @brief gets register type by op type
* TODO: Should be static attribute of emitters
* @return register type
*/
opRegType get_op_reg_type(const std::shared_ptr<Node>& op) const;

protected:
/**
* @brief gets register type by specific plugin op type
* @return register type
*/
virtual opRegType get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const;

std::shared_ptr<TargetMachine> target;
// todo: we need to save lowered code to access compiled brgemm kernels on execution time (normally lowered is destructed by then).
// This is temporary solution, remove this when kernel caching is implemented. Don't forget to make generate const method.
Expand Down
15 changes: 4 additions & 11 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,14 @@ namespace op {
class Brgemm : public MemoryAccess {
public:
OPENVINO_OP("Brgemm", "SnippetsOpset", MemoryAccess);
Brgemm(const Output<Node>& A, const Output<Node>& B, bool transposed_a = false, bool transposed_b = false,
Brgemm(const Output<Node>& A, const Output<Node>& B,
const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu);
Brgemm() = default;

bool transposed_a() const { return m_transposed_a; }
bool transposed_b() const { return m_transposed_b; }
size_t get_offset_a() const { return get_input_offset(0); }
size_t get_offset_b() const { return get_input_offset(1); }
size_t get_offset_c() const { return get_output_offset(0); }

size_t get_offset_a() const { return get_input_port_descriptor(0).m_offset; }
size_t get_offset_b() const { return get_input_port_descriptor(1).m_offset; }
size_t get_offset_c() const { return get_output_port_descriptor(0).m_offset; }

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

Expand All @@ -39,9 +35,6 @@ class Brgemm : public MemoryAccess {
protected:
ov::element::Type get_output_type() const;
ov::PartialShape get_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const;

bool m_transposed_a;
bool m_transposed_b;
};

} // namespace op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BroadcastLoad : public MemoryAccess {
BroadcastLoad(const Output<Node>& x, ov::PartialShape output_shape, size_t offset = 0lu);
BroadcastLoad() = default;

size_t get_offset() const { return get_input_port_descriptor(0).m_offset; }
size_t get_offset() const { return get_input_offset(0); }

bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
Expand Down
71 changes: 18 additions & 53 deletions src/common/snippets/include/snippets/op/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ namespace op {
/**
* @interface Buffer
* @brief This is a base class for memory storage.
* If Buffer has a parent, the operation is for intermediate data storage - Intermediate type.
IvanNovoselov marked this conversation as resolved.
Show resolved Hide resolved
* Otherwise, the operation is for allocation of new empty memory with shape `m_shape` - Empty type
* Notes:
* - All buffers in a graph have the same memory pointer. So if we have a few buffers,
* each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer
Expand All @@ -22,67 +24,30 @@ namespace op {
class Buffer : public ngraph::op::Op {
public:
OPENVINO_OP("Buffer", "SnippetsOpset");

size_t get_byte_size() const;
virtual ov::PartialShape get_allocation_shape() const = 0;

protected:
Buffer() = default;
};

/**
* @interface AllocationBuffer
* @brief The operation is for allocation of new empty memory. The operation has one parent that is equal to allocation shape
* - m_element_type - element type of memory
* @ingroup snippets
*/
class AllocationBuffer : public Buffer {
public:
OPENVINO_OP("AllocationBuffer", "SnippetsOpset", Buffer);

AllocationBuffer() = default;
AllocationBuffer(const ov::Output<ov::Node>& shape, const ov::element::Type element_type);

ov::PartialShape get_allocation_shape() const override;
Buffer(const ov::Shape& shape);
Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape);
Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1);

bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

protected:
ov::element::Type m_element_type;
};

/**
* @interface IntermediateBuffer
* @brief The operation is for intermediate data storage.
* If Buffer has only one parent, the Buffer will allocate a full memory with input shape of Buffer.
* If Buffer has second parent as well, the Buffer will allocate memory with shape that is equal to values from second input but
* saves the input shape for shape inference and input element type.
* For example,
* Parameter [5, 3, 128] Constant [2] (with values {3, 128})
* \ /
* Buffer with allocated memory 3x128 size
* |
* Result [5, 3, 128]
* @ingroup snippets
*/
class IntermediateBuffer : public Buffer {
public:
OPENVINO_OP("IntermediateBuffer", "SnippetsOpset", Buffer);

IntermediateBuffer() = default;
IntermediateBuffer(const ov::Output<ov::Node>& x);
IntermediateBuffer(const ov::Output<ov::Node>& x, const ov::Output<ov::Node>& shape);
enum Type {
NewMemory,
IntermediateMemory
};

ov::PartialShape get_allocation_shape() const override;
Type get_type() const { return m_type; }
ov::Shape get_allocation_shape() const { return m_shape; }
size_t get_byte_size() const;

bool visit_attributes(AttributeVisitor& visitor) override { return true; }
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; }
bool is_new_memory() const { return m_type == Type::NewMemory; }

static std::shared_ptr<ov::Node> create_shape_constant(const ov::PartialShape& shape, size_t allocation_rank);
static std::shared_ptr<ov::Node> create_shape_constant(const ov::PartialShape& shape);
private:
Type m_type = Type::IntermediateMemory;
ov::Shape m_shape = {};
};

} // namespace op
Expand Down
10 changes: 8 additions & 2 deletions src/common/snippets/include/snippets/op/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ class Load : public MemoryAccess {
Load(const Output<Node>& x, const size_t count = 1lu, const size_t offset = 0lu);
Load() = default;

size_t get_offset() const { return get_input_port_descriptor(0).m_offset; }
size_t get_count() const { return get_input_port_descriptor(0).m_count; }
size_t get_offset() const { return get_input_offset(0); }
size_t get_count() const { return get_input_count(0); }

void set_offset(size_t offset) { set_input_offset(offset, 0); }
void set_count(size_t count) { set_input_count(count, 0); }

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
Expand All @@ -45,6 +48,9 @@ class LoadReshape : public Load {
LoadReshape(const Output<Node>& x, size_t count = 1lu, const size_t offset = 0lu, std::vector<size_t> order = {});
LoadReshape() = default;

void set_offset(size_t offset) { set_output_offset(offset, 0); }
void set_count(size_t count) { set_output_count(count, 0); }

bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
Expand Down
65 changes: 36 additions & 29 deletions src/common/snippets/include/snippets/op/memory_access.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,6 @@ namespace ngraph {
namespace snippets {
namespace op {

class MemoryAccess;

/**
* @interface PortDescriptor
* @brief This class describes port of MemoryAccess operation
* @param m_count - count of elements to load/store
* @param m_offset - starting index of elements to load/store
* @param m_index - port index
* @ingroup snippets
*/

struct PortDescriptor {
PortDescriptor(size_t count, size_t offset) : m_count(count), m_offset(offset) {}
PortDescriptor() = default;

size_t m_count = 0lu;
size_t m_offset = 0lu;
size_t m_index = 0lu;

private:
PortDescriptor(size_t count, size_t offset, size_t index) : m_count(count), m_offset(offset), m_index(index) {}

friend class MemoryAccess;
};

/**
* @interface MemoryAccess
* @brief This is a base class for memory access operations (like Load and Store).
Expand All @@ -48,14 +23,46 @@ class MemoryAccess : public ngraph::op::Op {
public:
OPENVINO_OP("MemoryAccess", "SnippetsOpset");

/**
* @interface PortDescriptor
* @brief This class describes port of MemoryAccess operation
* @param m_count - count of elements to load/store
* @param m_offset - starting index of elements to load/store
* @param m_index - port index
* @ingroup snippets
*/
struct PortDescriptor {
PortDescriptor(size_t count, size_t offset) : m_count(count), m_offset(offset) {}
PortDescriptor() = default;

size_t m_count = 0lu;
size_t m_offset = 0lu;
size_t m_index = 0lu;

private:
PortDescriptor(size_t count, size_t offset, size_t index) : m_count(count), m_offset(offset), m_index(index) {}

friend class MemoryAccess;
};

void set_input_port_descriptor(const PortDescriptor& desc, const size_t i);
void set_output_port_descriptor(const PortDescriptor& desc, const size_t i);
PortDescriptor get_input_port_descriptor(const size_t i) const;
PortDescriptor get_output_port_descriptor(const size_t i) const;
PortDescriptor& get_input_port_descriptor(const size_t i);
PortDescriptor& get_output_port_descriptor(const size_t i);
const PortDescriptor& get_input_port_descriptor(const size_t i) const;
const PortDescriptor& get_output_port_descriptor(const size_t i) const;

void set_input_count(size_t count, size_t idx);
void set_output_count(size_t count, size_t idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It makes sense to have a default idx = 0 value
  2. We used to have only one count for both inputs and outputs, so I'm a little confused here now. If I want to set count for Load should I use set_input_count or set_output_count? And if only set_input_count is legal for such operations, what happens if I set_output_count? I've seen that you created set_count for Load and Store, and I think this is a right direction, but still set_input_count could be called on load, which is a bit confusing. I don't insist, but should we make a separate class for single-port operations maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Nice idea. let's do it, thanks
  2. Load works with data on input and with vector register on output. So you should work with set_input_count. And opposite for Store. I thought that it seems like logical things. We can discuss it

void set_input_offset(size_t offset, size_t idx);
void set_output_offset(size_t offset, size_t idx);

size_t get_input_count(size_t idx) const;
size_t get_output_count(size_t idx) const;
size_t get_input_offset(size_t idx) const;
size_t get_output_offset(size_t idx) const;


bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;

protected:
explicit MemoryAccess(const OutputVector& arguments);
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/include/snippets/op/store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class Store : public MemoryAccess {
Store(const Output<Node>& x, const size_t count = 1lu, const size_t offset = 0lu);
Store() = default;

size_t get_offset() const { return get_output_port_descriptor(0).m_offset; }
size_t get_count() const { return get_output_port_descriptor(0).m_count; }
size_t get_offset() const { return get_output_offset(0); }
size_t get_count() const { return get_output_count(0); }

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ namespace pass {
*/
class AssignRegisters : public ngraph::pass::FunctionPass {
public:
explicit AssignRegisters(const std::shared_ptr<const TargetMachine>& target_machine) : m_target_machine(target_machine) {
explicit AssignRegisters(const std::function<Generator::opRegType(const std::shared_ptr<Node>& op)>& mapper) : m_reg_type_mapper(mapper) {
set_property(ngraph::pass::PassProperty::REQUIRE_STATIC_SHAPE, true);
}
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;

private:
std::shared_ptr<const TargetMachine> m_target_machine = nullptr;
std::function<Generator::opRegType(const std::shared_ptr<Node>& op)> m_reg_type_mapper;
};

} // namespace pass
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/include/snippets/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ ov::PartialShape get_port_planar_shape(const Output<Node>& out);
ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector<size_t>& layout);
std::vector<size_t> get_node_output_layout(const std::shared_ptr<Node>& node);
std::vector<size_t> get_node_output_layout(const Node* node);
void set_output_layout(const ov::Output<Node>& port, const std::shared_ptr<opset1::Transpose>& node);
void set_output_layout(const ov::Output<Node>& port, const std::vector<size_t>& layout);
void set_transpose_output_layout(const ov::Output<Node>& port, const std::shared_ptr<opset1::Transpose>& node);
void set_transpose_output_layout(const ov::Output<Node>& port, const std::vector<size_t>& layout);

inline ov::Dimension get_inner_dim(const ov::PartialShape &shape) { return *(shape.rbegin()); }
inline ov::Dimension get_outer_dim(const ov::PartialShape &shape) { return *(shape.rbegin() + 1); }
Expand Down
Loading