Skip to content

Commit

Permalink
[Snippets] Added MHA I8 tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Mar 28, 2023
1 parent 38c924a commit c0875bb
Show file tree
Hide file tree
Showing 34 changed files with 1,536 additions and 227 deletions.
2 changes: 2 additions & 0 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Brgemm : public MemoryAccess {
size_t get_offset_b() const { return get_input_offset(1); }
size_t get_offset_c() const { return get_output_offset(0); }

static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1);

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

Expand Down
15 changes: 11 additions & 4 deletions src/common/snippets/include/snippets/op/buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@ namespace op {
* - All buffers in a graph have the same memory pointer. So if we have a few buffers,
* each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer
* - Buffer should be a single consumer for operation output port
* @param m_type - type of Buffer: IntermediateMemory/NewMemory
* @param m_shape - output allocation shape for Buffer with type NewMemory
* @param m_id - Buffer ID in common Buffer system
* @ingroup snippets
*/
class Buffer : public ngraph::op::Op {
public:
OPENVINO_OP("Buffer", "SnippetsOpset");
Buffer() = default;
Buffer(const ov::Shape& shape);
Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape);
Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1);
Buffer(const ov::Shape& shape, size_t id = 0);
Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id = 0);
Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1, size_t id = 0);

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
Expand All @@ -38,16 +41,20 @@ class Buffer : public ngraph::op::Op {
IntermediateMemory
};

void set_id(size_t id) { m_id = id; }

size_t get_id() const { return m_id; }
size_t get_byte_size() const;
Type get_type() const { return m_type; }
ov::Shape get_allocation_shape() const { return m_shape; }
size_t get_byte_size() const;

bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; }
bool is_new_memory() const { return m_type == Type::NewMemory; }

private:
Type m_type = Type::IntermediateMemory;
ov::Shape m_shape = {};
size_t m_id = 0;
};

} // namespace op
Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/op/loop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class LoopEnd : public LoopBase {
std::vector<int64_t> ptr_increments, std::vector<int64_t> finalization_offsets);
LoopEnd() = default;
std::shared_ptr<LoopBegin> get_loop_begin();
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
const std::vector<int64_t>& get_finalization_offsets() const;
Expand Down
9 changes: 3 additions & 6 deletions src/common/snippets/include/snippets/op/subgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class Subgraph : public ov::op::util::SubGraphOp {

size_t get_buffer_scratchpad_size() const { return m_buffer_scratchpad; }
size_t get_virtual_port_count() const { return m_virtual_port_count; }
bool is_buffer_needed() const { return m_buffer_needed; }
bool is_quantized() const { return config.m_is_quantized; }
bool has_type_relaxed_ops() const { return config.m_has_type_relaxed_ops; }
bool has_domain_sensitive_ops() const { return config.m_has_domain_sensitive_ops; }
Expand All @@ -122,7 +121,6 @@ class Subgraph : public ov::op::util::SubGraphOp {
void set_generator(std::shared_ptr<ngraph::snippets::Generator> generator);
void set_tile_rank(size_t newRank) {tileRank = newRank;}
void set_virtual_port_count(const size_t count);
void set_buffer_needed(const bool need);

void print() const;
void print_statistics(bool verbose);
Expand All @@ -137,8 +135,10 @@ class Subgraph : public ov::op::util::SubGraphOp {
// should have explicit Constants even if they're non-scalar (Reshape, Transpose, Broadcast)
// This check returns True if Constant op which is input of this op should be inside Subgraph body
static auto constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool;

static bool check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept;
// Return estimated unqiue buffer count (rating from above). It's needed for tokenization
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;
static auto is_domain_sensitive_op(const std::shared_ptr<ov::Node>& op) -> bool;

private:
void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes);
Expand All @@ -147,12 +147,9 @@ class Subgraph : public ov::op::util::SubGraphOp {
void initialize_buffer_scratchpad_size();
// Count of Subgraph virtual ports:
// - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition)
// Need Buffer op or not
// - Buffers. All Buffers are considered as one common additional virtual port. So we cannot summarize them as potential non-scalar Constants
// NOTE: To avoid overheads in each calculation of this count (for example, in validate_and_type_infer()),
// we should MANUALLY calculate it where it needed.
size_t m_virtual_port_count = 0;
bool m_buffer_needed = false;
size_t m_buffer_scratchpad = 0lu;
Shape exec_domain = {};
std::shared_ptr<ngraph::snippets::Generator> m_generator = nullptr;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/pattern/matcher.hpp>

namespace ngraph {
namespace snippets {
namespace pass {

/**
* @interface BufferIdentification
* @brief The pass set identifiers for Buffers in common Buffer system
* Note: should be called before ResetBuffer() pass to have correct offsets
* @ingroup snippets
*/
class BufferIdentification: public ngraph::pass::FunctionPass {
public:
OPENVINO_RTTI("InsertLoops", "0");
BufferIdentification() = default;

bool run_on_model(const std::shared_ptr<ngraph::Function>& m) override;
};

} // namespace pass
} // namespace snippets
} // namespace ngraph
9 changes: 9 additions & 0 deletions src/common/snippets/include/snippets/pass/tokenization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/op/subgraph.hpp"

namespace ngraph {
namespace snippets {
Expand All @@ -19,8 +20,16 @@ namespace pass {
SkippedByPlugin - indicate that snippets can't include this node in subgraph. Can be set by Plugin via SetSnippetsNodeType(...).
*/
enum class SnippetsNodeType : int64_t {NotSet, SkippedByPlugin};
/*
NotSet - default value returned if the subgraph wasn't marked and snippets can include nodes in this subgraph
Completed - indicate that snippets can't include any nodes in this subgraph.
It's used in separate tokenization pass, for example, tokenization by matcher (MHA Tokenization).
*/
enum class SnippetsSubgraphType : int64_t {NotSet, Completed};
void SetSnippetsNodeType(const std::shared_ptr<Node>&, SnippetsNodeType);
void SetSnippetsSubgraphType(const std::shared_ptr<op::Subgraph>&, SnippetsSubgraphType);
SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr<const Node>&);
SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr<const op::Subgraph>&);
void SetTopologicalOrder(const std::shared_ptr<Node>&, int64_t);
int64_t GetTopologicalOrder(const std::shared_ptr<const Node>&);

Expand Down
1 change: 1 addition & 0 deletions src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "snippets/op/subgraph.hpp"
#include "snippets/op/kernel.hpp"
#include <snippets/itt.hpp>
#include <snippets/snippets_isa.hpp>

#include <ngraph/pass/manager.hpp>
#include <openvino/core/type.hpp>
Expand Down
29 changes: 18 additions & 11 deletions src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Brgemm::Brgemm(const Output<Node>& A, const Output<Node>& B,
set_output_size(1);
set_input_offset(offset_a, 0);
set_input_offset(offset_b, 1);
set_output_offset(offset_a, 0);
set_output_offset(offset_c, 0);
constructor_validate_and_infer_types();
}

Expand Down Expand Up @@ -45,22 +45,29 @@ std::shared_ptr<Node> Brgemm::clone_with_new_inputs(const OutputVector& new_args
return std::make_shared<Brgemm>(new_args.at(0), new_args.at(1), get_offset_a(), get_offset_b(), get_offset_c());
}

ov::element::Type Brgemm::get_output_type() const {
const auto element_type_a = get_input_element_type(0);
const auto element_type_b = get_input_element_type(1);
const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b);
const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b);
ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1) {
const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1);
const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1);
if (is_f32 || is_bf16) {
return element::f32;
return element::f32;
} else if (is_int8) {
return element::i32;
} else {
return element::undefined;
}
}

ov::element::Type Brgemm::get_output_type() const {
auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1));
if (output_type == element::undefined) {
throw ngraph_error("BrgemmCPU node has incompatible input element types: " +
element_type_a.get_type_name() +
" and " +
element_type_b.get_type_name());
get_input_element_type(0).get_type_name() +
" and " +
get_input_element_type(1).get_type_name());
}

return output_type;
}

ov::PartialShape Brgemm::get_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const {
Expand Down
19 changes: 8 additions & 11 deletions src/common/snippets/src/op/buffer.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -12,22 +12,18 @@
using namespace std;
using namespace ngraph;

auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t {
return allocation_rank < 0 ? allocation_rank + static_cast<int32_t>(shape_rank) : allocation_rank;
}

snippets::op::Buffer::Buffer(const ov::Shape& shape)
: Op(), m_type(Type::NewMemory), m_shape(shape) {
snippets::op::Buffer::Buffer(const ov::Shape& shape, size_t id)
: Op(), m_type(Type::NewMemory), m_shape(shape), m_id(id) {
constructor_validate_and_infer_types();
}

snippets::op::Buffer::Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape)
: Op({arg}), m_type(Type::IntermediateMemory), m_shape(shape) {
snippets::op::Buffer::Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id)
: Op({arg}), m_type(Type::IntermediateMemory), m_shape(shape), m_id(id) {
constructor_validate_and_infer_types();
}

snippets::op::Buffer::Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank)
: Op({arg}), m_type(Type::IntermediateMemory) {
snippets::op::Buffer::Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank, size_t id)
: Op({arg}), m_type(Type::IntermediateMemory), m_id(id) {
const auto pshape = arg.get_partial_shape();
OPENVINO_ASSERT(pshape.is_static(), "Buffer supports only static input shape");
const auto shape = pshape.get_shape();
Expand All @@ -40,6 +36,7 @@ snippets::op::Buffer::Buffer(const ov::Output<ov::Node>& arg, int32_t allocation
bool snippets::op::Buffer::visit_attributes(AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(Buffer_visit_attributes);
visitor.on_attribute("allocation_shape", m_shape);
visitor.on_attribute("id", m_id);
return true;
}

Expand Down
9 changes: 9 additions & 0 deletions src/common/snippets/src/op/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ void LoopEnd::validate_and_infer_types() {
get_output_descriptor(i).set_tensor_ptr(get_input_descriptor(i).get_output().get_tensor_ptr());
}

bool LoopEnd::visit_attributes(AttributeVisitor& visitor) {
LoopBase::visit_attributes(visitor);
for (size_t i = 0; i < ptr_increments.size(); ++i) {
visitor.on_attribute("ptr_increment_" + std::to_string(i), ptr_increments[i]);
visitor.on_attribute("finalization_offsets_" + std::to_string(i), finalization_offsets[i]);
}
return true;
}

} // namespace op
} // namespace snippets
} // namespace ngraph
Loading

0 comments on commit c0875bb

Please sign in to comment.