Skip to content

Commit

Permalink
[Snippets] Added support of BF16/I8/U8 for MatMul (#15063)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova authored Mar 28, 2023
1 parent 253e4eb commit 38c924a
Show file tree
Hide file tree
Showing 60 changed files with 1,952 additions and 589 deletions.
2 changes: 1 addition & 1 deletion src/common/snippets/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ie_faster_build(${TARGET_NAME}
)

target_link_libraries(${TARGET_NAME} PUBLIC openvino::runtime
PRIVATE ngraph_reference ov_shape_inference openvino::runtime::dev)
PRIVATE ngraph_reference openvino::runtime::dev)

target_include_directories(${TARGET_NAME} PUBLIC $<BUILD_INTERFACE:${PUBLIC_HEADERS_DIR}>
PRIVATE $<BUILD_INTERFACE:${SHAPE_INFER_INCLUDE_DIR}>)
Expand Down
23 changes: 22 additions & 1 deletion src/common/snippets/include/snippets/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class TargetMachine {
*/
virtual size_t get_lanes() const = 0;


/**
* @brief called by generator to all the emitter for a target machine
* @return a map by node's type info with callbacks to create an instance of emitter for corresponding operation type
Expand Down Expand Up @@ -155,7 +154,29 @@ class Generator {
*/
std::shared_ptr<const TargetMachine> get_target_machine() const;

/**
* @interface opRegType
* @brief Register type of operations
* Note that currently there are 4 types of ops:
* gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd etc)
* gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc.
* vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc.
*/
enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec};
/**
* @brief gets register type by op type
* TODO: Should be static attribute of emitters
* @return register type
*/
opRegType get_op_reg_type(const std::shared_ptr<Node>& op) const;

protected:
/**
* @brief gets register type by specific plugin op type
* @return register type
*/
virtual opRegType get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const;

std::shared_ptr<TargetMachine> target;
// todo: we need to save lowered code to access compiled brgemm kernels on execution time (normally lowered is destructed by then).
// This is temporary solution, remove this when kernel caching is implemented. Don't forget to make generate const method.
Expand Down
29 changes: 12 additions & 17 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include "ngraph/op/op.hpp"
#include "ngraph/op/matmul.hpp"
#include "memory_access.hpp"

namespace ngraph {
namespace snippets {
Expand All @@ -16,30 +16,25 @@ namespace op {
* @brief Brgemm is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows
* @ingroup snippets
*/
class Brgemm : public ngraph::op::v0::MatMul {
class Brgemm : public MemoryAccess {
public:
OPENVINO_OP("Brgemm", "SnippetsOpset", ngraph::op::v0::MatMul);
Brgemm(const Output<Node>& A, const Output<Node>& B, const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu);
OPENVINO_OP("Brgemm", "SnippetsOpset", MemoryAccess);
Brgemm(const Output<Node>& A, const Output<Node>& B,
const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu);
Brgemm() = default;

bool visit_attributes(AttributeVisitor& visitor) override;
size_t get_offset_a() const { return get_input_offset(0); }
size_t get_offset_b() const { return get_input_offset(1); }
size_t get_offset_c() const { return get_output_offset(0); }

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

bool has_evaluate() const override { return false; }

size_t get_offset_a() const { return m_offset_a; }
size_t get_offset_b() const { return m_offset_b; }
size_t get_offset_c() const { return m_offset_c; }

void set_offset_a(const size_t offset) { m_offset_a = offset; }
void set_offset_b(const size_t offset) { m_offset_b = offset; }
void set_offset_c(const size_t offset) { m_offset_c = offset; }

private:
size_t m_offset_a = 0lu; // offset for first input
size_t m_offset_b = 0lu; // offset for second input
size_t m_offset_c = 0lu; // offset for output
protected:
ov::element::Type get_output_type() const;
ov::PartialShape get_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const;
};

} // namespace op
Expand Down
11 changes: 5 additions & 6 deletions src/common/snippets/include/snippets/op/broadcastload.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include <snippets/op/broadcastmove.hpp>
#include <snippets/op/memory_access.hpp>

#include "ngraph/op/op.hpp"

Expand All @@ -17,22 +17,21 @@ namespace op {
* @brief Is generated for broadcasting by least varying dimension for non-blocked cases and the second varying dimension for blocked
* @ingroup snippets
*/
class BroadcastLoad : public BroadcastMove {
class BroadcastLoad : public MemoryAccess {
public:
OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ngraph::snippets::op::BroadcastMove);
OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ngraph::snippets::op::MemoryAccess);

BroadcastLoad(const Output<Node>& x, ov::PartialShape output_shape, size_t offset = 0lu);
BroadcastLoad() = default;

size_t get_offset() const { return m_offset; }
void set_offset(const size_t offset) { m_offset = offset; }
size_t get_offset() const { return get_input_offset(0); }

bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;

private:
size_t m_offset = 0lu;
ov::PartialShape output_shape;
};

} // namespace op
Expand Down
32 changes: 20 additions & 12 deletions src/common/snippets/include/snippets/op/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ namespace op {

/**
* @interface Buffer
* @brief The operation is for intermediate data storage
* - m_allocation_rank - rank of shape for memory allocation: shape[shape_rank - normalize(m_allocation_rank) : shape_rank].
* It's needed to allocate needed memory size that depends on Tile rank, for example.
* Default value is -1 (full shape)
* @brief This is a base class for memory storage.
* If Buffer has a parent, the operation is for intermediate data storage - IntermediateMemory type.
* Otherwise, the operation is for allocation of new empty memory with shape `m_shape` - NewMemory type
* Notes:
* - All buffers in a graph have the same memory pointer. So if we have a few buffers,
* each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer
Expand All @@ -25,21 +24,30 @@ namespace op {
class Buffer : public ngraph::op::Op {
public:
OPENVINO_OP("Buffer", "SnippetsOpset");

Buffer(const Output<Node>& x, const int32_t allocation_rank = -1);
Buffer() = default;
Buffer(const ov::Shape& shape);
Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape);
Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1);

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

int32_t get_allocation_rank() const { return m_allocation_rank; }
void set_allocation_rank(int32_t rank) { m_allocation_rank = rank; }
enum Type {
NewMemory,
IntermediateMemory
};

Type get_type() const { return m_type; }
ov::Shape get_allocation_shape() const { return m_shape; }
size_t get_byte_size() const;

bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; }
bool is_new_memory() const { return m_type == Type::NewMemory; }

private:
int32_t m_allocation_rank = -1;
Type m_type = Type::IntermediateMemory;
ov::Shape m_shape = {};
};

} // namespace op
Expand Down
12 changes: 11 additions & 1 deletion src/common/snippets/include/snippets/op/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@ namespace op {
*/
class Load : public MemoryAccess {
public:
OPENVINO_OP("Load", "SnippetsOpset");
OPENVINO_OP("Load", "SnippetsOpset", MemoryAccess);

Load(const Output<Node>& x, const size_t count = 1lu, const size_t offset = 0lu);
Load() = default;

size_t get_offset() const { return get_input_offset(0); }
size_t get_count() const { return get_input_count(0); }

void set_offset(size_t offset) { set_input_offset(offset, 0); }
void set_count(size_t count) { set_input_count(count, 0); }

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};

Expand All @@ -41,6 +48,9 @@ class LoadReshape : public Load {
LoadReshape(const Output<Node>& x, size_t count = 1lu, const size_t offset = 0lu, std::vector<size_t> order = {});
LoadReshape() = default;

void set_offset(size_t offset) { set_output_offset(offset, 0); }
void set_count(size_t count) { set_output_count(count, 0); }

bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
Expand Down
60 changes: 48 additions & 12 deletions src/common/snippets/include/snippets/op/memory_access.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -13,28 +13,64 @@ namespace op {
/**
* @interface MemoryAccess
* @brief This is a base class for memory access operations (like Load and Store).
* It provides universal set/get interface to manipulate the number
* of elements accessed during one operation call ("count").
* Default "count" value is "1" - it means to load/store one element
* It provides universal interface to manipulate with memory: load/store.
* @param m_input_ports - vector of input descriptors: variables of PortDescriptor class
* @param m_output_ports - vector of output descriptors: variables of PortDescriptor class
* @ingroup snippets
*/

class MemoryAccess : public ngraph::op::Op {
public:
OPENVINO_OP("MemoryAccess", "SnippetsOpset");

size_t get_count() const;
size_t get_offset() const;
void set_count(const size_t count);
void set_offset(const size_t offset);
/**
* @interface PortDescriptor
* @brief This class describes port of MemoryAccess operation
* @param m_count - count of elements to load/store
* @param m_offset - starting index of elements to load/store
* @param m_index - port index
* @ingroup snippets
*/
struct PortDescriptor {
PortDescriptor(size_t count, size_t offset) : count(count), offset(offset) {}
PortDescriptor() = default;

size_t count = 0lu;
size_t offset = 0lu;
size_t index = 0lu;

private:
PortDescriptor(size_t count, size_t offset, size_t index) : count(count), offset(offset), index(index) {}

friend class MemoryAccess;
};

void set_input_count(size_t count, size_t idx = 0);
void set_output_count(size_t count, size_t idx = 0);
void set_input_offset(size_t offset, size_t idx = 0);
void set_output_offset(size_t offset, size_t idx = 0);

size_t get_input_count(size_t idx = 0) const;
size_t get_output_count(size_t idx = 0) const;
size_t get_input_offset(size_t idx = 0) const;
size_t get_output_offset(size_t idx = 0) const;

size_t get_input_port_count() const { return m_input_ports.size(); }
size_t get_output_port_count() const { return m_output_ports.size(); }

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;

protected:
explicit MemoryAccess(const Output<Node>& x, size_t count = 1lu, size_t offset = 0lu);
explicit MemoryAccess(const OutputVector& arguments, size_t input_count = 0, size_t output_count = 0);
MemoryAccess() = default;
size_t m_count = 0lu;
size_t m_offset = 0lu;

void set_input_port_descriptor(const PortDescriptor& desc, const size_t i);
void set_output_port_descriptor(const PortDescriptor& desc, const size_t i);
const PortDescriptor& get_input_port_descriptor(const size_t i) const;
const PortDescriptor& get_output_port_descriptor(const size_t i) const;

std::vector<PortDescriptor> m_input_ports;
std::vector<PortDescriptor> m_output_ports;
};

} // namespace op
Expand Down
9 changes: 8 additions & 1 deletion src/common/snippets/include/snippets/op/store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@ namespace op {
*/
class Store : public MemoryAccess {
public:
OPENVINO_OP("Store", "SnippetsOpset");
OPENVINO_OP("Store", "SnippetsOpset", MemoryAccess);

Store(const Output<Node>& x, const size_t count = 1lu, const size_t offset = 0lu);
Store() = default;

size_t get_offset() const { return get_output_offset(0); }
size_t get_count() const { return get_output_count(0); }

void set_offset(size_t offset) { set_output_offset(offset, 0); }
void set_count(size_t count) { set_output_count(count, 0); }

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <ngraph/pass/pass.hpp>

#include "snippets/generator.hpp"

namespace ngraph {
namespace snippets {
namespace pass {
Expand All @@ -18,10 +20,13 @@ namespace pass {
*/
class AssignRegisters : public ngraph::pass::FunctionPass {
public:
explicit AssignRegisters() {
explicit AssignRegisters(const std::function<Generator::opRegType(const std::shared_ptr<Node>& op)>& mapper) : m_reg_type_mapper(mapper) {
set_property(ngraph::pass::PassProperty::REQUIRE_STATIC_SHAPE, true);
}
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;

private:
std::function<Generator::opRegType(const std::shared_ptr<Node>& op)> m_reg_type_mapper;
};

} // namespace pass
Expand Down
21 changes: 21 additions & 0 deletions src/common/snippets/include/snippets/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,31 @@ ov::PartialShape get_port_planar_shape(const Output<Node>& out);
ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector<size_t>& layout);
std::vector<size_t> get_node_output_layout(const std::shared_ptr<Node>& node);
std::vector<size_t> get_node_output_layout(const Node* node);
void set_transpose_output_layout(const ov::Output<Node>& port, const std::shared_ptr<opset1::Transpose>& node);
void set_output_layout(const ov::Output<Node>& port, const std::vector<size_t>& layout);

inline ov::Dimension get_inner_dim(const ov::PartialShape &shape) { return *(shape.rbegin()); }
inline ov::Dimension get_outer_dim(const ov::PartialShape &shape) { return *(shape.rbegin() + 1); }

inline auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t {
return allocation_rank < 0 ? allocation_rank + static_cast<int32_t>(shape_rank) + 1 : allocation_rank;
}

template <typename T, typename P>
constexpr bool one_of(T val, P item) { return val == item; }

template <typename T, typename P, typename... Args>
constexpr bool one_of(T val, P item, Args... item_others) {
return val == item || one_of(val, item_others...);
}

template <typename T, typename P>
constexpr bool everyone_is(T val, P item) { return val == item; }

template <typename T, typename P, typename... Args>
constexpr bool everyone_is(T val, P item, Args... item_others) {
return val == item && everyone_is(val, item_others...);
}
} // namespace utils
} // namespace snippets
} // namespace ngraph
Loading

0 comments on commit 38c924a

Please sign in to comment.