Skip to content

Commit

Permalink
[Snippets] Added MHA I8 tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Feb 10, 2023
1 parent 80e945e commit 663a8aa
Show file tree
Hide file tree
Showing 34 changed files with 1,523 additions and 220 deletions.
2 changes: 2 additions & 0 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class Brgemm : public MemoryAccess {
size_t get_offset_b() const { return get_input_port_descriptor(1).m_offset; }
size_t get_offset_c() const { return get_output_port_descriptor(0).m_offset; }

static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1);

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
Expand Down
16 changes: 12 additions & 4 deletions src/common/snippets/include/snippets/op/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace op {
/**
* @interface Buffer
* @brief This is a base class for memory storage.
* - m_id - identifiers in the common Buffer system - id ~ register
* Notes:
* - All buffers in a graph have the same memory pointer. So if we have a few buffers,
* each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer
Expand All @@ -23,11 +24,19 @@ class Buffer : public ngraph::op::Op {
public:
OPENVINO_OP("Buffer", "SnippetsOpset");

void set_id(size_t id) { m_id = id; }

size_t get_id() const { return m_id; }
size_t get_byte_size() const;
virtual ov::PartialShape get_allocation_shape() const = 0;

bool visit_attributes(AttributeVisitor& visitor) override;

protected:
Buffer() = default;
Buffer(size_t id) : m_id(id) {}

size_t m_id = 0;
};

/**
Expand All @@ -41,7 +50,7 @@ class AllocationBuffer : public Buffer {
OPENVINO_OP("AllocationBuffer", "SnippetsOpset", Buffer);

AllocationBuffer() = default;
AllocationBuffer(const ov::Output<ov::Node>& shape, const ov::element::Type element_type);
AllocationBuffer(const ov::Output<ov::Node>& shape, const ov::element::Type element_type, size_t id = 0);

ov::PartialShape get_allocation_shape() const override;

Expand Down Expand Up @@ -72,12 +81,11 @@ class IntermediateBuffer : public Buffer {
OPENVINO_OP("IntermediateBuffer", "SnippetsOpset", Buffer);

IntermediateBuffer() = default;
IntermediateBuffer(const ov::Output<ov::Node>& x);
IntermediateBuffer(const ov::Output<ov::Node>& x, const ov::Output<ov::Node>& shape);
IntermediateBuffer(const ov::Output<ov::Node>& x, size_t id = 0);
IntermediateBuffer(const ov::Output<ov::Node>& x, const ov::Output<ov::Node>& shape, size_t id = 0);

ov::PartialShape get_allocation_shape() const override;

bool visit_attributes(AttributeVisitor& visitor) override { return true; }
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;

Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/op/loop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class LoopEnd : public LoopBase {
std::vector<int64_t> ptr_increments, std::vector<int64_t> finalization_offsets);
LoopEnd() = default;
std::shared_ptr<LoopBegin> get_loop_begin();
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
const std::vector<int64_t>& get_finalization_offsets() const;
Expand Down
9 changes: 4 additions & 5 deletions src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class Subgraph : public ov::op::util::SubGraphOp {

size_t get_buffer_scratchpad_size() const { return m_buffer_scratchpad; }
size_t get_virtual_port_count() const { return m_virtual_port_count; }
bool is_buffer_needed() const { return m_buffer_needed; }
bool is_quantized() const { return config.m_is_quantized; }
bool has_type_relaxed_ops() const { return config.m_has_type_relaxed_ops; }
bool has_domain_sensitive_ops() const { return config.m_has_domain_sensitive_ops; }
Expand All @@ -116,7 +115,6 @@ class Subgraph : public ov::op::util::SubGraphOp {
void set_generator(std::shared_ptr<ngraph::snippets::Generator> generator);
void set_tile_rank(size_t newRank) {tileRank = newRank;}
void set_virtual_port_count(const size_t count);
void set_buffer_needed(const bool need);

void print() const;
void print_statistics(bool verbose);
Expand All @@ -132,19 +130,20 @@ class Subgraph : public ov::op::util::SubGraphOp {
// This check returns True if Constant op which is input of this op should be inside Subgraph body
static auto constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool;

// Return estimated unqiue buffer count (rating from above). It's needed for tokenization
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;
static auto is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool;

private:
void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes);
void convert_to_snippet_dialect();
void init_config();
void initialize_buffer_scratchpad_size();
// Count of Subgraph virtual ports:
// - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition)
// Need Buffer op or not
// - Buffers. All Buffers are considered as one common additional virtual port. So we cannot summarize them as potential non-scalar Constants
// NOTE: To avoid overheads in each calculation of this count (for example, in validate_and_type_infer()),
// we should MANUALLY calculate it where it needed.
size_t m_virtual_port_count = 0;
bool m_buffer_needed = false;
size_t m_buffer_scratchpad = 0lu;
Shape exec_domain = {};
std::shared_ptr<ngraph::snippets::Generator> m_generator = nullptr;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/pattern/matcher.hpp>

namespace ngraph {
namespace snippets {
namespace pass {

/**
* @interface BufferIdentification
* @brief The pass set identifiers for Buffers in common Buffer system
* Note: should be called before ResetBuffer() pass to have correct offsets
* @ingroup snippets
*/
class BufferIdentification: public ngraph::pass::FunctionPass {
public:
OPENVINO_RTTI("InsertLoops", "0");
BufferIdentification() = default;

bool run_on_model(const std::shared_ptr<ngraph::Function>& m) override;
};

} // namespace pass
} // namespace snippets
} // namespace ngraph
9 changes: 9 additions & 0 deletions src/common/snippets/include/snippets/pass/tokenization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/op/subgraph.hpp"

namespace ngraph {
namespace snippets {
Expand All @@ -19,8 +20,16 @@ namespace pass {
SkippedByPlugin - indicate that snippets can't include this node in subgraph. Can be set by Plugin via SetSnippetsNodeType(...).
*/
enum class SnippetsNodeType : int64_t {NotSet, SkippedByPlugin};
/*
NotSet - default value returned if the subgraph wasn't marked and snippets can include nodes in this subgraph
Completed - indicate that snippets can't include any nodes in this subgraph.
It's used in separate tokenization pass, for example, tokenization by matcher (MHA Tokenization).
*/
enum class SnippetsSubgraphType : int64_t {NotSet, Completed};
void SetSnippetsNodeType(const std::shared_ptr<Node>&, SnippetsNodeType);
void SetSnippetsSubgraphType(const std::shared_ptr<op::Subgraph>&, SnippetsSubgraphType);
SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr<const Node>&);
SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr<const op::Subgraph>&);
void SetTopologicalOrder(const std::shared_ptr<Node>&, int64_t);
int64_t GetTopologicalOrder(const std::shared_ptr<const Node>&);

Expand Down
27 changes: 17 additions & 10 deletions src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,29 @@ std::shared_ptr<Node> Brgemm::clone_with_new_inputs(const OutputVector& new_args
get_offset_a(), get_offset_b(), get_offset_c());
}

ov::element::Type Brgemm::get_output_type() const {
const auto element_type_a = get_input_element_type(0);
const auto element_type_b = get_input_element_type(1);
const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b);
const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b);
ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1) {
const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1);
const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1);
if (is_f32 || is_bf16) {
return element::f32;
return element::f32;
} else if (is_int8) {
return element::i32;
} else {
return element::undefined;
}
}

ov::element::Type Brgemm::get_output_type() const {
auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1));
if (output_type == element::undefined) {
throw ngraph_error("BrgemmCPU node has incompatible input element types: " +
element_type_a.get_type_name() +
" and " +
element_type_b.get_type_name());
get_input_element_type(0).get_type_name() +
" and " +
get_input_element_type(1).get_type_name());
}

return output_type;
}

ov::PartialShape Brgemm::get_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const {
Expand Down
21 changes: 14 additions & 7 deletions src/common/snippets/src/op/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,29 @@ size_t ngraph::snippets::op::Buffer::get_byte_size() const {
return ngraph::shape_size(shape) * get_element_type().size();
}

snippets::op::AllocationBuffer::AllocationBuffer(const Output<Node>& shape, const ov::element::Type element_type)
: Buffer(), m_element_type(element_type) {
bool snippets::op::Buffer::visit_attributes(AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(Buffer_visit_attributes);
visitor.on_attribute("id", m_id);
return true;
}

snippets::op::AllocationBuffer::AllocationBuffer(const Output<Node>& shape, const ov::element::Type element_type, size_t id)
: Buffer(id), m_element_type(element_type) {
set_arguments({shape});
constructor_validate_and_infer_types();
}

bool snippets::op::AllocationBuffer::visit_attributes(AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(AllocationBuffer_visit_attributes);
Buffer::visit_attributes(visitor);
visitor.on_attribute("element_type", m_element_type);
return true;
}

std::shared_ptr<Node> snippets::op::AllocationBuffer::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(AllocationBuffer_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<AllocationBuffer>(new_args.at(0), m_element_type);
return std::make_shared<AllocationBuffer>(new_args.at(0), m_element_type, m_id);
}

void snippets::op::AllocationBuffer::validate_and_infer_types() {
Expand All @@ -60,12 +67,12 @@ ov::PartialShape ngraph::snippets::op::AllocationBuffer::get_allocation_shape()
return shape;
}

snippets::op::IntermediateBuffer::IntermediateBuffer(const ov::Output<ov::Node>& x) : Buffer() {
snippets::op::IntermediateBuffer::IntermediateBuffer(const ov::Output<ov::Node>& x, size_t id) : Buffer(id) {
set_arguments({x});
constructor_validate_and_infer_types();
}

snippets::op::IntermediateBuffer::IntermediateBuffer(const ov::Output<ov::Node>& x, const ov::Output<ov::Node>& shape) : Buffer() {
snippets::op::IntermediateBuffer::IntermediateBuffer(const ov::Output<ov::Node>& x, const ov::Output<ov::Node>& shape, size_t id) : Buffer(id) {
set_arguments({x, shape});
constructor_validate_and_infer_types();
}
Expand All @@ -74,9 +81,9 @@ std::shared_ptr<Node> snippets::op::IntermediateBuffer::clone_with_new_inputs(co
INTERNAL_OP_SCOPE(IntermediateBuffer_clone_with_new_inputs);
check_new_args_count(this, new_args);
if (new_args.size() == 2) {
return std::make_shared<IntermediateBuffer>(new_args.at(0), new_args.at(1));
return std::make_shared<IntermediateBuffer>(new_args.at(0), new_args.at(1), m_id);
} else if (new_args.size() == 1) {
return std::make_shared<IntermediateBuffer>(new_args.at(0));
return std::make_shared<IntermediateBuffer>(new_args.at(0), m_id);
}

throw ngraph_error("The IntermediateBuffer op got invalid input count");
Expand Down
9 changes: 9 additions & 0 deletions src/common/snippets/src/op/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ void LoopEnd::validate_and_infer_types() {
get_output_descriptor(i).set_tensor_ptr(get_input_descriptor(i).get_output().get_tensor_ptr());
}

bool LoopEnd::visit_attributes(AttributeVisitor& visitor) {
LoopBase::visit_attributes(visitor);
for (size_t i = 0; i < ptr_increments.size(); ++i) {
visitor.on_attribute("ptr_increment_" + std::to_string(i), ptr_increments[i]);
visitor.on_attribute("finalization_offsets_" + std::to_string(i), finalization_offsets[i]);
}
return true;
}

} // namespace op
} // namespace snippets
} // namespace ngraph
Loading

0 comments on commit 663a8aa

Please sign in to comment.