Skip to content

Commit

Permalink
[Snippets] Added support of I8/U8/BF16 for MatMul
Browse files Browse the repository at this point in the history
Rewrote MemoryAccess
  • Loading branch information
a-sidorova committed Jan 12, 2023
1 parent c33b941 commit d9503fd
Show file tree
Hide file tree
Showing 55 changed files with 1,612 additions and 502 deletions.
20 changes: 20 additions & 0 deletions src/common/snippets/include/snippets/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@ class TargetMachine {
*/
virtual size_t get_lanes() const = 0;

/**
* @interface opRegType
* @brief Register type of operations
* Note that currently there are 4 types of ops:
* gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd etc)
* gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc.
* vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc.
*/
enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec};
/**
* @brief gets register type by op type
* @return register type
*/
opRegType get_op_reg_type(const std::shared_ptr<Node>& op) const;

/**
* @brief called by generator to all the emitter for a target machine
Expand All @@ -64,6 +78,12 @@ class TargetMachine {
virtual ~TargetMachine() = default;

protected:
/**
* @brief gets register type by specific plugin op type
* @return register type
*/
virtual opRegType get_specific_op_reg_type(const std::shared_ptr<ov::Node>& op) const = 0;

std::map<const ngraph::DiscreteTypeInfo, std::function<std::shared_ptr<Emitter>(std::shared_ptr<ngraph::Node>)>> jitters;
};

Expand Down
28 changes: 13 additions & 15 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include "ngraph/op/op.hpp"
#include "ngraph/op/matmul.hpp"
#include "memory_access.hpp"

namespace ngraph {
namespace snippets {
Expand All @@ -16,30 +16,28 @@ namespace op {
* @brief Brgemm is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows
* @ingroup snippets
*/
class Brgemm : public ngraph::op::v0::MatMul {
class Brgemm : public MemoryAccess {
public:
OPENVINO_OP("Brgemm", "SnippetsOpset", ngraph::op::v0::MatMul);
Brgemm(const Output<Node>& A, const Output<Node>& B, const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu);
OPENVINO_OP("Brgemm", "SnippetsOpset", MemoryAccess);
Brgemm(const Output<Node>& A, const Output<Node>& B, bool transposed_a = false, bool transposed_b = false,
const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu);
Brgemm() = default;

bool transposed_a() const { return m_transposed_a; }
bool transposed_b() const { return m_transposed_b; }

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

bool has_evaluate() const override { return false; }

size_t get_offset_a() const { return m_offset_a; }
size_t get_offset_b() const { return m_offset_b; }
size_t get_offset_c() const { return m_offset_c; }

void set_offset_a(const size_t offset) { m_offset_a = offset; }
void set_offset_b(const size_t offset) { m_offset_b = offset; }
void set_offset_c(const size_t offset) { m_offset_c = offset; }
protected:
ov::element::Type get_output_type() const;
ov::PartialShape get_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const;

private:
size_t m_offset_a = 0lu; // offset for first input
size_t m_offset_b = 0lu; // offset for second input
size_t m_offset_c = 0lu; // offset for output
bool m_transposed_a;
bool m_transposed_b;
};

} // namespace op
Expand Down
11 changes: 4 additions & 7 deletions src/common/snippets/include/snippets/op/broadcastload.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include <snippets/op/broadcastmove.hpp>
#include <snippets/op/memory_access.hpp>

#include "ngraph/op/op.hpp"

Expand All @@ -17,22 +17,19 @@ namespace op {
* @brief Is generated for broadcasting by least varying dimension for non-blocked cases and the second varying dimension for blocked
* @ingroup snippets
*/
class BroadcastLoad : public BroadcastMove {
class BroadcastLoad : public MemoryAccess {
public:
OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ngraph::snippets::op::BroadcastMove);
OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ngraph::snippets::op::MemoryAccess);

BroadcastLoad(const Output<Node>& x, ov::PartialShape output_shape, size_t offset = 0lu);
BroadcastLoad() = default;

size_t get_offset() const { return m_offset; }
void set_offset(const size_t offset) { m_offset = offset; }

bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;

private:
size_t m_offset = 0lu;
ov::PartialShape output_shape;
};

} // namespace op
Expand Down
7 changes: 7 additions & 0 deletions src/common/snippets/include/snippets/op/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ namespace op {
* - m_allocation_rank - rank of shape for memory allocation: shape[shape_rank - normalize(m_allocation_rank) : shape_rank].
* It's needed to allocate needed memory size that depends on Tile rank, for example.
* Default value is -1 (full shape)
* - m_static_shape - static shape that describes Buffer size in cases when Buffer doesn't have parent node.
* - m_element_type - element type in cases when Buffer doesn't have parent node.
* - m_single - True if Buffer doesn't have parent node else False
* Notes:
* - All buffers in a graph have the same memory pointer. So if we have a few buffers,
* each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer
Expand All @@ -27,6 +30,7 @@ class Buffer : public ngraph::op::Op {
OPENVINO_OP("Buffer", "SnippetsOpset");

Buffer(const Output<Node>& x, const int32_t allocation_rank = -1);
Buffer(const ov::Shape shape, const ov::element::Type element_type, int32_t allocation_rank = -1);
Buffer() = default;

int32_t get_allocation_rank() const { return m_allocation_rank; }
Expand All @@ -40,6 +44,9 @@ class Buffer : public ngraph::op::Op {

private:
int32_t m_allocation_rank = -1;
ov::Shape m_static_shape;
ov::element::Type m_element_type;
bool m_is_single = false;
};

} // namespace op
Expand Down
3 changes: 2 additions & 1 deletion src/common/snippets/include/snippets/op/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ namespace op {
*/
class Load : public MemoryAccess {
public:
OPENVINO_OP("Load", "SnippetsOpset");
OPENVINO_OP("Load", "SnippetsOpset", MemoryAccess);

Load(const Output<Node>& x, const size_t count = 1lu, const size_t offset = 0lu);
Load() = default;

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};

Expand Down
50 changes: 39 additions & 11 deletions src/common/snippets/include/snippets/op/memory_access.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,59 @@ namespace ngraph {
namespace snippets {
namespace op {

class MemoryAccess;

/**
* @interface PortDescriptor
* @brief This class describes port of MemoryAccess operation
* @param m_count - count of elements to load/store
* @param m_offset - starting index of elements to load/store
* @param m_index - port index
* @ingroup snippets
*/

struct PortDescriptor {
PortDescriptor(size_t count, size_t offset) : m_count(count), m_offset(offset) {}
PortDescriptor() = default;

size_t m_count = 0lu;
size_t m_offset = 0lu;
size_t m_index = 0lu;

private:
PortDescriptor(size_t count, size_t offset, size_t index) : m_count(count), m_offset(offset), m_index(index) {}

friend class MemoryAccess;
};

/**
* @interface MemoryAccess
* @brief This is a base class for memory access operations (like Load and Store).
* It provides universal set/get interface to manipulate the number
* of elements accessed during one operation call ("count").
* Default "count" value is "1" - it means to load/store one element
* It provides universal interface to manipulate with memory: load/store.
* @param m_input_ports - vector of input descriptors: variables of PortDescriptor class
* @param m_output_ports - vector of output descriptors: variables of PortDescriptor class
* @ingroup snippets
*/

class MemoryAccess : public ngraph::op::Op {
public:
OPENVINO_OP("MemoryAccess", "SnippetsOpset");

size_t get_count() const;
size_t get_offset() const;
void set_count(const size_t count);
void set_offset(const size_t offset);
void set_input_port_descriptor(const PortDescriptor& desc, const size_t i);
void set_output_port_descriptor(const PortDescriptor& desc, const size_t i);
PortDescriptor get_input_port_descriptor(const size_t i) const;
PortDescriptor get_output_port_descriptor(const size_t i) const;
PortDescriptor& get_input_port_descriptor(const size_t i);
PortDescriptor& get_output_port_descriptor(const size_t i);

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;

protected:
explicit MemoryAccess(const Output<Node>& x, size_t count = 1lu, size_t offset = 0lu);
explicit MemoryAccess(const OutputVector& arguments);
MemoryAccess() = default;
size_t m_count = 0lu;
size_t m_offset = 0lu;

std::vector<PortDescriptor> m_input_ports;
std::vector<PortDescriptor> m_output_ports;
};

} // namespace op
Expand Down
3 changes: 2 additions & 1 deletion src/common/snippets/include/snippets/op/store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ namespace op {
*/
class Store : public MemoryAccess {
public:
OPENVINO_OP("Store", "SnippetsOpset");
OPENVINO_OP("Store", "SnippetsOpset", MemoryAccess);

Store(const Output<Node>& x, const size_t count = 1lu, const size_t offset = 0lu);
Store() = default;

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <ngraph/pass/pass.hpp>

#include "snippets/generator.hpp"

namespace ngraph {
namespace snippets {
namespace pass {
Expand All @@ -18,10 +20,13 @@ namespace pass {
*/
class AssignRegisters : public ngraph::pass::FunctionPass {
public:
explicit AssignRegisters() {
explicit AssignRegisters(const std::shared_ptr<const TargetMachine>& target_machine) : m_target_machine(target_machine) {
set_property(ngraph::pass::PassProperty::REQUIRE_STATIC_SHAPE, true);
}
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;

private:
std::shared_ptr<const TargetMachine> m_target_machine = nullptr;
};

} // namespace pass
Expand Down
17 changes: 17 additions & 0 deletions src/common/snippets/include/snippets/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,27 @@ ov::PartialShape get_port_planar_shape(const Output<Node>& out);
ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector<size_t>& layout);
std::vector<size_t> get_node_output_layout(const std::shared_ptr<Node>& node);
std::vector<size_t> get_node_output_layout(const Node* node);
void set_output_layout(const ov::Output<Node>& port, const std::shared_ptr<opset1::Transpose>& node);
void set_output_layout(const ov::Output<Node>& port, const std::vector<size_t>& layout);

inline ov::Dimension get_inner_dim(const ov::PartialShape &shape) { return *(shape.rbegin()); }
inline ov::Dimension get_outer_dim(const ov::PartialShape &shape) { return *(shape.rbegin() + 1); }

template <typename T, typename P>
constexpr bool one_of(T val, P item) { return val == item; }

template <typename T, typename P, typename... Args>
constexpr bool one_of(T val, P item, Args... item_others) {
return val == item || one_of(val, item_others...);
}

template <typename T, typename P>
constexpr bool everyone_is(T val, P item) { return val == item; }

template <typename T, typename P, typename... Args>
constexpr bool everyone_is(T val, P item, Args... item_others) {
return val == item && everyone_is(val, item_others...);
}
} // namespace utils
} // namespace snippets
} // namespace ngraph
37 changes: 35 additions & 2 deletions src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,30 @@
namespace ngraph {
namespace snippets {

TargetMachine::opRegType TargetMachine::get_op_reg_type(const std::shared_ptr<Node>& op) const {
if (std::dynamic_pointer_cast<opset1::Parameter>(op) ||
std::dynamic_pointer_cast<opset1::Result>(op) ||
std::dynamic_pointer_cast<op::LoopBegin>(op) ||
std::dynamic_pointer_cast<op::LoopEnd>(op) ||
std::dynamic_pointer_cast<op::Brgemm>(op) ||
std::dynamic_pointer_cast<op::Buffer>(op))
return gpr2gpr;
else if (std::dynamic_pointer_cast<snippets::op::Load>(op) ||
std::dynamic_pointer_cast<snippets::op::BroadcastLoad>(op))
return gpr2vec;
else if (std::dynamic_pointer_cast<snippets::op::Store>(op))
return vec2gpr;
else if (ov::op::util::is_unary_elementwise_arithmetic(op) ||
ov::op::util::is_binary_elementwise_arithmetic(op) ||
ov::op::util::is_binary_elementwise_comparison(op) ||
ov::op::util::is_binary_elementwise_logical(op) ||
std::dynamic_pointer_cast<opset1::Convert>(op) ||
std::dynamic_pointer_cast<opset1::Select>(op))
return vec2vec;
else
return get_specific_op_reg_type(op);
}

auto getRegisters(const std::shared_ptr<ngraph::Node> &n) -> RegInfo {
OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::getRegisters")

Expand Down Expand Up @@ -77,8 +101,17 @@ auto tail_transformations(NodeVector& tail, const size_t tail_size, const ngraph
}
}
} else if (const auto memory_access = std::dynamic_pointer_cast<ngraph::snippets::op::MemoryAccess>(op)) {
if (memory_access->get_count() != 1) {
memory_access->set_count(tail_size);
for (size_t i = 0; i < memory_access->get_input_size(); ++i) {
auto& desc = memory_access->get_input_port_descriptor(i);
if (desc.m_count != 1) {
desc.m_count = tail_size;
}
}
for (size_t i = 0; i < memory_access->get_output_size(); ++i) {
auto& desc = memory_access->get_output_port_descriptor(i);
if (desc.m_count != 1) {
desc.m_count = tail_size;
}
}
}
updated_tile.push_back(op);
Expand Down
Loading

0 comments on commit d9503fd

Please sign in to comment.