Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets] Add support of MHA Tokenization for different precisions #15647

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class Brgemm : public MemoryAccess {
size_t get_offset_b() const { return get_input_offset(1); }
size_t get_offset_c() const { return get_output_offset(0); }

static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1);

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

Expand Down
6 changes: 4 additions & 2 deletions src/common/snippets/include/snippets/op/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Buffer : public ov::op::Op {
public:
OPENVINO_OP("Buffer", "SnippetsOpset");
Buffer() = default;
Buffer(const ov::Shape& shape, size_t id = 0);
Buffer(const ov::Shape& shape, ov::element::Type element_type = ov::element::u8, size_t id = 0);
IvanNovoselov marked this conversation as resolved.
Show resolved Hide resolved
dmitry-gorokhov marked this conversation as resolved.
Show resolved Hide resolved
Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id = 0);
Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1, size_t id = 0);

Expand All @@ -48,9 +48,10 @@ class Buffer : public ov::op::Op {
int64_t get_offset() const { return m_offset; }
void set_id(size_t id) { m_id = id; }
void set_offset(int64_t offset) { m_offset = offset; }

size_t get_byte_size() const;

void set_element_type(ov::element::Type element_type);

bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; }
bool is_new_memory() const { return m_type == Type::NewMemory; }

Expand All @@ -59,6 +60,7 @@ class Buffer : public ov::op::Op {
ov::Shape m_shape = {};
int64_t m_offset = 0;
size_t m_id = 0; // Default ID - 0. All Buffers are from the same set
ov::element::Type m_element_type = ov::element::u8; // u8 - default 1 byte
};

} // namespace op
Expand Down
1 change: 0 additions & 1 deletion src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ class Subgraph : public ov::op::util::SubGraphOp {
// should have explicit Constants even if they're non-scalar (Reshape, Transpose, Broadcast)
// This check returns True if Constant op which is input of this op should be inside Subgraph body
static auto constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool;

static bool check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept;
// Return estimated unique buffer count (upper bound). It's needed for tokenization
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;
Expand Down
26 changes: 24 additions & 2 deletions src/common/snippets/include/snippets/pass/mha_tokenization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "snippets/pass/tokenization.hpp"

namespace ov {
namespace snippets {
Expand All @@ -14,13 +15,34 @@ namespace pass {
/**
* @interface TokenizeMHASnippets
* @brief The pass tokenizes MHA-pattern into Subgraph
* TODO: Write pattern
* Pattern: Transpose1
* |
* Transpose0 [Eltwise, Select]
* \ /
* MatMul0
* |
* [Eltwise, Select, Reshape]
* |
* Softmax
* |
* [Eltwise, Select, Reshape] Transpose2
* \ /
* MatMul1
* |
* [Eltwise, Select, Transpose3]
* Notes:
* - Transposes can be missed
dmitry-gorokhov marked this conversation as resolved.
Show resolved Hide resolved
* - Transpose0, Transpose2 and Transpose3 may have only [0,2,1,3] order
* - Transpose1 may have only [0,2,3,1] order
* - [...] means any count of different nodes from list. But:
* * Reshapes can be only explicitly around Softmax (Reshape -> Softmax -> Reshape)
* * After MatMul1 may be only Transpose3 or any count of Eltwise, Select ops.
* @ingroup snippets
*/
class TokenizeMHASnippets: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TokenizeMHASnippets", "0");
TokenizeMHASnippets();
TokenizeMHASnippets(const SnippetsTokenization::Config& config = {});
};

} // namespace pass
Expand Down
29 changes: 27 additions & 2 deletions src/common/snippets/include/snippets/pass/tokenization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"

#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/op/subgraph.hpp"

namespace ov {
namespace snippets {
Expand All @@ -19,8 +18,16 @@ namespace pass {
SkippedByPlugin - indicate that snippets can't include this node in subgraph. Can be set by Plugin via SetSnippetsNodeType(...).
*/
enum class SnippetsNodeType : int64_t {NotSet, SkippedByPlugin};
/*
NotSet - default value returned if the subgraph wasn't marked and snippets can include nodes in this subgraph
Completed - indicate that snippets can't include any nodes in this subgraph.
It's used in separate tokenization pass, for example, tokenization by matcher (MHA Tokenization).
*/
enum class SnippetsSubgraphType : int64_t {NotSet, Completed};
void SetSnippetsNodeType(const std::shared_ptr<Node>&, SnippetsNodeType);
void SetSnippetsSubgraphType(const std::shared_ptr<op::Subgraph>&, SnippetsSubgraphType);
SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr<const Node>&);
SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr<const op::Subgraph>&);
void SetTopologicalOrder(const std::shared_ptr<Node>&, int64_t);
int64_t GetTopologicalOrder(const std::shared_ptr<const Node>&);

Expand Down Expand Up @@ -48,8 +55,26 @@ class EnumerateNodes : public ov::pass::ModelPass {
*/
class SnippetsTokenization : public ov::pass::ModelPass {
public:
/**
* @interface Config
* @brief Allow to adjust tokenization passes
* @ingroup snippets
*/
struct Config {
IvanNovoselov marked this conversation as resolved.
Show resolved Hide resolved
Config(bool enable_transpose = true) : mha_token_enable_transpose(enable_transpose) {}

// False if all Transposes aren't tokenized in MHA Tokenization.
// Otherwise, they may be fused into Subgraph if possible
// TODO [106921]: Remove please when the ticket 106921 is implemented
bool mha_token_enable_transpose = true;
};

OPENVINO_RTTI("SnippetsTokenization", "0");
SnippetsTokenization(const Config& config) : m_config(config) {}
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;

private:
Config m_config{};
};


Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/lowered/pass/allocate_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ bool AllocateBuffers::run(LinearIR& linear_ir) {

const auto current_allocated_memory_size = m_buffer_scratchpad_size - offset;
if (buffer_size > current_allocated_memory_size) {
m_buffer_scratchpad_size += (buffer_size - current_allocated_memory_size);
// Note: we don't update offset because we just add memory to needed size
allocate(buffer, expr, buffer_size);
continue;
}
propagate_offset(linear_ir, *expr_it, offset);
allocated_buffers.insert(expr);
Expand Down
6 changes: 3 additions & 3 deletions src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
// Otherwise WIN build fails with "IS_MANUALLY_ALLOCATED_REG cannot be implicitly captured because no default capture mode has been specified"
// the same problem with all the other lambdas in this file
auto enumerate_out_tensors = [=] (const ExpressionPtr& expr,
decltype(regs_vec)& reg_map,
const std::map<tensor, Reg>& manually_assigned_regs,
size_t& counter) {
decltype(regs_vec)& reg_map,
const std::map<tensor, Reg>& manually_assigned_regs,
size_t& counter) {
for (const auto& out_tensor : expr->get_output_port_connectors()) {
// Note that some ops might have identical input&output tensors (Result and Tile* for ex.)
// so we have to check that the tensor has not been enumerated already
Expand Down
27 changes: 17 additions & 10 deletions src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,29 @@ std::shared_ptr<Node> Brgemm::clone_with_new_inputs(const OutputVector& new_args
lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout());
}

ov::element::Type Brgemm::get_output_type() const {
const auto element_type_a = get_input_element_type(0);
const auto element_type_b = get_input_element_type(1);
const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b);
const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b);
ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1) {
const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1);
const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1);
if (is_f32 || is_bf16) {
return element::f32;
return element::f32;
} else if (is_int8) {
return element::i32;
} else {
return element::undefined;
}
}

ov::element::Type Brgemm::get_output_type() const {
auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1));
if (output_type == element::undefined) {
OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " +
element_type_a.get_type_name() +
" and " +
element_type_b.get_type_name());
get_input_element_type(0).get_type_name() +
" and " +
get_input_element_type(1).get_type_name());
}

return output_type;
}

std::vector<ov::PartialShape> Brgemm::get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const {
Expand Down
20 changes: 13 additions & 7 deletions src/common/snippets/src/op/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ namespace snippets {
namespace op {


Buffer::Buffer(const ov::Shape& shape, size_t id)
: Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id) {
Buffer::Buffer(const ov::Shape& shape, ov::element::Type element_type, size_t id)
: Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id), m_element_type(std::move(element_type)) {
constructor_validate_and_infer_types();
}

Expand All @@ -40,34 +40,33 @@ bool Buffer::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("allocation_shape", m_shape);
visitor.on_attribute("offset", m_offset);
visitor.on_attribute("id", m_id);
visitor.on_attribute("element_type", m_element_type);
return true;
}

void Buffer::validate_and_infer_types() {
INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types);
ov::element::Type output_type;
ov::Shape output_shape;
if (m_type == Type::NewMemory) {
OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory must to not have arguments!");
output_shape = m_shape;
output_type = ov::element::u8; // 1Byte
} else if (m_type == Type::IntermediateMemory) {
const auto& input_shape = get_input_partial_shape(0);
OPENVINO_ASSERT(input_shape.is_static(), "Buffer supports only static input shape");
output_type = get_input_element_type(0);
m_element_type = get_input_element_type(0);
output_shape = input_shape.get_shape();
} else {
OPENVINO_THROW("Buffer supports only the following types: NewMemory and IntermediateMemory");
}
set_output_type(0, output_type, output_shape);
set_output_type(0, m_element_type, output_shape);
}

std::shared_ptr<Node> Buffer::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(Buffer_clone_with_new_inputs);
check_new_args_count(this, new_args);
std::shared_ptr<op::Buffer> new_buffer = nullptr;
if (m_type == Type::NewMemory) {
new_buffer = std::make_shared<Buffer>(m_shape, m_id);
new_buffer = std::make_shared<Buffer>(m_shape, m_element_type, m_id);
} else if (m_type == Type::IntermediateMemory) {
new_buffer = std::make_shared<Buffer>(new_args.at(0), m_shape, m_id);
} else {
Expand All @@ -82,6 +81,13 @@ size_t Buffer::get_byte_size() const {
return ov::shape_size(shape) * get_element_type().size();
}

void Buffer::set_element_type(ov::element::Type element_type) {
OPENVINO_ASSERT(is_new_memory(), "Only Buffer with NewMemory can change his output precision!");
m_element_type = std::move(element_type);
IvanNovoselov marked this conversation as resolved.
Show resolved Hide resolved
// Apply the change
validate_and_infer_types();
}

} // namespace op
} // namespace snippets
} // namespace ov
38 changes: 16 additions & 22 deletions src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "snippets/pass/convert_constants.hpp"
#include "snippets/pass/convert_power_to_powerstatic.hpp"
#include "snippets/pass/transpose_decomposition.hpp"
#include "snippets/pass/transform_convert.hpp"
#include "snippets/pass/matmul_to_brgemm.hpp"
#include "snippets/pass/fuse_transpose_brgemm.hpp"
#include "snippets/pass/set_softmax_ports.hpp"
Expand Down Expand Up @@ -75,12 +74,11 @@ auto snippets::op::Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::No
}

void snippets::op::Subgraph::init_config() {
auto update = [](bool& flag, bool status) { flag = flag || status; };
const auto ops = body_ptr()->get_ops();
for (const auto& op : ops) {
config.m_is_quantized = config.m_is_quantized ||
ov::is_type<ov::op::v0::FakeQuantize>(op);
config.m_has_domain_sensitive_ops = config.m_has_domain_sensitive_ops ||
is_domain_sensitive_op(op);
update(config.m_is_quantized, ov::is_type<ov::op::v0::FakeQuantize>(op));
update(config.m_has_domain_sensitive_ops, is_domain_sensitive_op(op));
}
}

Expand All @@ -93,6 +91,13 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
// and where will be Loops - we can just predict.
// Note: The ops that create Buffers: MatMul, Transpose and Softmax (always FP32)
std::vector<size_t> used_precision_size;

auto push_prc_size = [&used_precision_size](size_t precision_size) {
if (used_precision_size.empty() || used_precision_size.back() != precision_size) {
used_precision_size.push_back(precision_size);
}
};

for (const auto& op : ops) {
if (const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(op)) {
// At the moment Transposes are supported only on Results and Parameters but
Expand All @@ -106,34 +111,23 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
}) ||
!ov::is_type<ov::op::v0::Parameter>(transpose->get_input_node_shared_ptr(0));
if (are_prev_or_next_ops) {
const auto prc_size = transpose->get_element_type().size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
push_prc_size(transpose->get_element_type().size());
}
} else if (ov::is_type<ov::op::v1::Softmax>(op) || ov::is_type<ov::op::v8::Softmax>(op)) {
// Softmax always uses 2 FP32 Buffers
const auto prc_size = ov::element::f32.size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
// Softmax always uses 2 FP32 Buffers after decomposition.
// They are inplace and the same so we can push precision size only once
push_prc_size(ov::element::f32.size());
} else if (const auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(op)) {
// First input check is enough because MatMul requires the same prc size on inputs
if (!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(0)) ||
!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(1))) {
const auto prc_size = matmul->get_input_element_type(0).size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
push_prc_size(matmul->get_input_element_type(0).size());
}

const auto consumers = matmul->get_output_target_inputs(0);
if (std::none_of(consumers.begin(), consumers.end(),
[](const ov::Input<ov::Node>& in) { return ov::is_type<ov::op::v0::Result>(in.get_node()); })) {
const auto prc_size = matmul->get_element_type().size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
push_prc_size(matmul->get_element_type().size());
}
}
}
Expand Down
Loading