Skip to content

Commit

Permalink
[Snippets] Add support of MHA Tokenization for different precisions (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova authored Jun 8, 2023
1 parent bdfa970 commit eb3e6a6
Show file tree
Hide file tree
Showing 30 changed files with 1,105 additions and 277 deletions.
2 changes: 2 additions & 0 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class Brgemm : public MemoryAccess {
size_t get_offset_b() const { return get_input_offset(1); }
size_t get_offset_c() const { return get_output_offset(0); }

static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1);

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

Expand Down
6 changes: 4 additions & 2 deletions src/common/snippets/include/snippets/op/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Buffer : public ov::op::Op {
public:
OPENVINO_OP("Buffer", "SnippetsOpset");
Buffer() = default;
Buffer(const ov::Shape& shape, size_t id = 0);
Buffer(const ov::Shape& shape, ov::element::Type element_type = ov::element::u8, size_t id = 0);
Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id = 0);
Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1, size_t id = 0);

Expand All @@ -48,9 +48,10 @@ class Buffer : public ov::op::Op {
int64_t get_offset() const { return m_offset; }
void set_id(size_t id) { m_id = id; }
void set_offset(int64_t offset) { m_offset = offset; }

size_t get_byte_size() const;

void set_element_type(ov::element::Type element_type);

bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; }
bool is_new_memory() const { return m_type == Type::NewMemory; }

Expand All @@ -59,6 +60,7 @@ class Buffer : public ov::op::Op {
ov::Shape m_shape = {};
int64_t m_offset = 0;
size_t m_id = 0; // Default ID - 0. All Buffers are from the same set
ov::element::Type m_element_type = ov::element::u8; // u8 - default 1 byte
};

} // namespace op
Expand Down
1 change: 0 additions & 1 deletion src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ class Subgraph : public ov::op::util::SubGraphOp {
// should have explicit Constants even if they're non-scalar (Reshape, Transpose, Broadcast)
// This check returns True if Constant op which is input of this op should be inside Subgraph body
static auto constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool;

static bool check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept;
// Return estimated unique buffer count (upper bound). It's needed for tokenization
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;
Expand Down
26 changes: 24 additions & 2 deletions src/common/snippets/include/snippets/pass/mha_tokenization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "snippets/pass/tokenization.hpp"

namespace ov {
namespace snippets {
Expand All @@ -14,13 +15,34 @@ namespace pass {
/**
* @interface TokenizeMHASnippets
* @brief The pass tokenizes MHA-pattern into Subgraph
* TODO: Write pattern
* Pattern: Transpose1
* |
* Transpose0 [Eltwise, Select]
* \ /
* MatMul0
* |
* [Eltwise, Select, Reshape]
* |
* Softmax
* |
* [Eltwise, Select, Reshape] Transpose2
* \ /
* MatMul1
* |
* [Eltwise, Select, Transpose3]
* Notes:
* - Transposes can be missed
* - Transpose0, Transpose2 and Transpose3 may have only [0,2,1,3] order
* - Transpose1 may have only [0,2,3,1] order
* - [...] means any count of different nodes from list. But:
* * Reshapes can be only explicitly around Softmax (Reshape -> Softmax -> Reshape)
* * After MatMul1 may be only Transpose3 or any count of Eltwise, Select ops.
* @ingroup snippets
*/
class TokenizeMHASnippets: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TokenizeMHASnippets", "0");
TokenizeMHASnippets();
TokenizeMHASnippets(const SnippetsTokenization::Config& config = {});
};

} // namespace pass
Expand Down
29 changes: 27 additions & 2 deletions src/common/snippets/include/snippets/pass/tokenization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"

#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/op/subgraph.hpp"

namespace ov {
namespace snippets {
Expand All @@ -19,8 +18,16 @@ namespace pass {
SkippedByPlugin - indicate that snippets can't include this node in subgraph. Can be set by Plugin via SetSnippetsNodeType(...).
*/
enum class SnippetsNodeType : int64_t {NotSet, SkippedByPlugin};
/*
NotSet - default value returned if the subgraph wasn't marked and snippets can include nodes in this subgraph
Completed - indicate that snippets can't include any nodes in this subgraph.
It's used in separate tokenization pass, for example, tokenization by matcher (MHA Tokenization).
*/
enum class SnippetsSubgraphType : int64_t {NotSet, Completed};
void SetSnippetsNodeType(const std::shared_ptr<Node>&, SnippetsNodeType);
void SetSnippetsSubgraphType(const std::shared_ptr<op::Subgraph>&, SnippetsSubgraphType);
SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr<const Node>&);
SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr<const op::Subgraph>&);
void SetTopologicalOrder(const std::shared_ptr<Node>&, int64_t);
int64_t GetTopologicalOrder(const std::shared_ptr<const Node>&);

Expand Down Expand Up @@ -48,8 +55,26 @@ class EnumerateNodes : public ov::pass::ModelPass {
*/
class SnippetsTokenization : public ov::pass::ModelPass {
public:
/**
* @interface Config
* @brief Allow to adjust tokenization passes
* @ingroup snippets
*/
struct Config {
Config(bool enable_transpose = true) : mha_token_enable_transpose(enable_transpose) {}

// False if all Transposes aren't tokenized in MHA Tokenization.
// Otherwise, they may be fused into Subgraph if possible
// TODO [106921]: Remove please when the ticket 106921 is implemented
bool mha_token_enable_transpose = true;
};

OPENVINO_RTTI("SnippetsTokenization", "0");
SnippetsTokenization(const Config& config) : m_config(config) {}
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;

private:
Config m_config{};
};


Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/lowered/pass/allocate_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ bool AllocateBuffers::run(LinearIR& linear_ir) {

const auto current_allocated_memory_size = m_buffer_scratchpad_size - offset;
if (buffer_size > current_allocated_memory_size) {
m_buffer_scratchpad_size += (buffer_size - current_allocated_memory_size);
// Note: we don't update offset because we just add memory to needed size
allocate(buffer, expr, buffer_size);
continue;
}
propagate_offset(linear_ir, *expr_it, offset);
allocated_buffers.insert(expr);
Expand Down
6 changes: 3 additions & 3 deletions src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
// Otherwise WIN build fails with "IS_MANUALLY_ALLOCATED_REG cannot be implicitly captured because no default capture mode has been specified"
// the same problem with all the other lambdas in this file
auto enumerate_out_tensors = [=] (const ExpressionPtr& expr,
decltype(regs_vec)& reg_map,
const std::map<tensor, Reg>& manually_assigned_regs,
size_t& counter) {
decltype(regs_vec)& reg_map,
const std::map<tensor, Reg>& manually_assigned_regs,
size_t& counter) {
for (const auto& out_tensor : expr->get_output_port_connectors()) {
// Note that some ops might have identical input&output tensors (Result and Tile* for ex.)
// so we have to check that the tensor has not been enumerated already
Expand Down
27 changes: 17 additions & 10 deletions src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,29 @@ std::shared_ptr<Node> Brgemm::clone_with_new_inputs(const OutputVector& new_args
lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout());
}

ov::element::Type Brgemm::get_output_type() const {
const auto element_type_a = get_input_element_type(0);
const auto element_type_b = get_input_element_type(1);
const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b);
const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b);
ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1) {
const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1);
const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1);
if (is_f32 || is_bf16) {
return element::f32;
return element::f32;
} else if (is_int8) {
return element::i32;
} else {
return element::undefined;
}
}

ov::element::Type Brgemm::get_output_type() const {
auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1));
if (output_type == element::undefined) {
OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " +
element_type_a.get_type_name() +
" and " +
element_type_b.get_type_name());
get_input_element_type(0).get_type_name() +
" and " +
get_input_element_type(1).get_type_name());
}

return output_type;
}

std::vector<ov::PartialShape> Brgemm::get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const {
Expand Down
20 changes: 13 additions & 7 deletions src/common/snippets/src/op/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ namespace snippets {
namespace op {


Buffer::Buffer(const ov::Shape& shape, size_t id)
: Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id) {
Buffer::Buffer(const ov::Shape& shape, ov::element::Type element_type, size_t id)
: Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id), m_element_type(std::move(element_type)) {
constructor_validate_and_infer_types();
}

Expand All @@ -40,34 +40,33 @@ bool Buffer::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("allocation_shape", m_shape);
visitor.on_attribute("offset", m_offset);
visitor.on_attribute("id", m_id);
visitor.on_attribute("element_type", m_element_type);
return true;
}

void Buffer::validate_and_infer_types() {
INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types);
ov::element::Type output_type;
ov::Shape output_shape;
if (m_type == Type::NewMemory) {
OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory must to not have arguments!");
output_shape = m_shape;
output_type = ov::element::u8; // 1Byte
} else if (m_type == Type::IntermediateMemory) {
const auto& input_shape = get_input_partial_shape(0);
OPENVINO_ASSERT(input_shape.is_static(), "Buffer supports only static input shape");
output_type = get_input_element_type(0);
m_element_type = get_input_element_type(0);
output_shape = input_shape.get_shape();
} else {
OPENVINO_THROW("Buffer supports only the following types: NewMemory and IntermediateMemory");
}
set_output_type(0, output_type, output_shape);
set_output_type(0, m_element_type, output_shape);
}

std::shared_ptr<Node> Buffer::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(Buffer_clone_with_new_inputs);
check_new_args_count(this, new_args);
std::shared_ptr<op::Buffer> new_buffer = nullptr;
if (m_type == Type::NewMemory) {
new_buffer = std::make_shared<Buffer>(m_shape, m_id);
new_buffer = std::make_shared<Buffer>(m_shape, m_element_type, m_id);
} else if (m_type == Type::IntermediateMemory) {
new_buffer = std::make_shared<Buffer>(new_args.at(0), m_shape, m_id);
} else {
Expand All @@ -82,6 +81,13 @@ size_t Buffer::get_byte_size() const {
return ov::shape_size(shape) * get_element_type().size();
}

void Buffer::set_element_type(ov::element::Type element_type) {
OPENVINO_ASSERT(is_new_memory(), "Only Buffer with NewMemory can change his output precision!");
m_element_type = std::move(element_type);
// Apply the change
validate_and_infer_types();
}

} // namespace op
} // namespace snippets
} // namespace ov
38 changes: 16 additions & 22 deletions src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "snippets/pass/convert_constants.hpp"
#include "snippets/pass/convert_power_to_powerstatic.hpp"
#include "snippets/pass/transpose_decomposition.hpp"
#include "snippets/pass/transform_convert.hpp"
#include "snippets/pass/matmul_to_brgemm.hpp"
#include "snippets/pass/fuse_transpose_brgemm.hpp"
#include "snippets/pass/set_softmax_ports.hpp"
Expand Down Expand Up @@ -75,12 +74,11 @@ auto snippets::op::Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::No
}

void snippets::op::Subgraph::init_config() {
auto update = [](bool& flag, bool status) { flag = flag || status; };
const auto ops = body_ptr()->get_ops();
for (const auto& op : ops) {
config.m_is_quantized = config.m_is_quantized ||
ov::is_type<ov::op::v0::FakeQuantize>(op);
config.m_has_domain_sensitive_ops = config.m_has_domain_sensitive_ops ||
is_domain_sensitive_op(op);
update(config.m_is_quantized, ov::is_type<ov::op::v0::FakeQuantize>(op));
update(config.m_has_domain_sensitive_ops, is_domain_sensitive_op(op));
}
}

Expand All @@ -93,6 +91,13 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
// and where will be Loops - we can just predict.
// Note: The ops that create Buffers: MatMul, Transpose and Softmax (always FP32)
std::vector<size_t> used_precision_size;

auto push_prc_size = [&used_precision_size](size_t precision_size) {
if (used_precision_size.empty() || used_precision_size.back() != precision_size) {
used_precision_size.push_back(precision_size);
}
};

for (const auto& op : ops) {
if (const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(op)) {
// At the moment Transposes are supported only on Results and Parameters but
Expand All @@ -106,34 +111,23 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
}) ||
!ov::is_type<ov::op::v0::Parameter>(transpose->get_input_node_shared_ptr(0));
if (are_prev_or_next_ops) {
const auto prc_size = transpose->get_element_type().size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
push_prc_size(transpose->get_element_type().size());
}
} else if (ov::is_type<ov::op::v1::Softmax>(op) || ov::is_type<ov::op::v8::Softmax>(op)) {
// Softmax always uses 2 FP32 Buffers
const auto prc_size = ov::element::f32.size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
// Softmax always uses 2 FP32 Buffers after decomposition.
// They are inplace and the same so we can push precision size only once
push_prc_size(ov::element::f32.size());
} else if (const auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(op)) {
// First input check is enough because MatMul requires the same prc size on inputs
if (!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(0)) ||
!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(1))) {
const auto prc_size = matmul->get_input_element_type(0).size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
push_prc_size(matmul->get_input_element_type(0).size());
}

const auto consumers = matmul->get_output_target_inputs(0);
if (std::none_of(consumers.begin(), consumers.end(),
[](const ov::Input<ov::Node>& in) { return ov::is_type<ov::op::v0::Result>(in.get_node()); })) {
const auto prc_size = matmul->get_element_type().size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
push_prc_size(matmul->get_element_type().size());
}
}
}
Expand Down
Loading

0 comments on commit eb3e6a6

Please sign in to comment.