Skip to content

Commit

Permalink
[Snippets] TPP FP32 MHA support (#22210)
Browse files Browse the repository at this point in the history
### Details:
 - *Enable FP32 MHA lowering using TPP backend*
- *This PR is a productization of the [TPP integration
POC](#20956
### Prerequisites:
- #21303
- #21672

Branch to Branch PR in to review the changes before the Prerequisites
are merged: IvanNovoselov#18

---------

Co-authored-by: egeorgan <[email protected]>
  • Loading branch information
IvanNovoselov and egeor authored Apr 22, 2024
1 parent a19c34a commit e563109
Show file tree
Hide file tree
Showing 84 changed files with 2,590 additions and 211 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,6 @@
[submodule "thirdparty/telemetry"]
path = thirdparty/telemetry
url = https://github.com/openvinotoolkit/telemetry.git
[submodule "src/plugins/intel_cpu/thirdparty/libxsmm"]
path = src/plugins/intel_cpu/thirdparty/libxsmm
url = https://github.com/libxsmm/libxsmm.git
2 changes: 2 additions & 0 deletions cmake/features.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ ov_dependent_option (ENABLE_GPU_DEBUG_CAPS "enable GPU debug capabilities at run
ov_dependent_option (ENABLE_CPU_DEBUG_CAPS "enable CPU debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS;ENABLE_INTEL_CPU" OFF)
ov_dependent_option (ENABLE_SNIPPETS_DEBUG_CAPS "enable Snippets debug capabilities at runtime" ON "ENABLE_DEBUG_CAPS" OFF)

ov_dependent_option (ENABLE_SNIPPETS_LIBXSMM_TPP "allow Snippets to use LIBXSMM Tensor Processing Primitives" OFF "ENABLE_INTEL_CPU AND X86_64" OFF)

ov_option (ENABLE_PROFILING_ITT "Build with ITT tracing. Optionally configure pre-built ittnotify library though INTEL_VTUNE_DIR variable." OFF)

ov_option_enum(ENABLE_PROFILING_FILTER "Enable or disable ITT counter groups.\
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/include/snippets/emitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace snippets {
* @interface RegType
* @brief Register type of input and output operations
*/
enum class RegType { gpr, vec };
enum class RegType { gpr, vec, undefined };
/**
* @interface Reg
* @brief Register representation: type of register and index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class InitLoops : public Pass {
InitLoops() = default;
bool run(LinearIR& linear_ir) override;

static void init_loop_info(const LinearIR::LoopManager::LoopInfoPtr& loop_info, bool only_runtime_args = false);
static void init_loop_info(const LinearIR::LoopManager::LoopInfoPtr& loop_info, size_t loop_id, bool only_runtime_args = false);
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class InsertBuffers : public RangedPass {
const LinearIR::constExprIt& end_it,
const LinearIR::LoopManagerPtr& loop_manager,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_entries,
const std::vector<LinearIR::LoopManager::LoopPort>& loop_exits);
const std::vector<LinearIR::LoopManager::LoopPort>& loop_exits) const;

LinearIR::constExprIt insertion_position(const LinearIR& linear_ir,
const LinearIR::LoopManagerPtr& loop_manager,
const ExpressionPtr& expr,
const ExpressionPtr& down_expr);
static LinearIR::constExprIt insertion_position(const LinearIR& linear_ir,
const LinearIR::LoopManagerPtr& loop_manager,
const ExpressionPtr& expr,
const ExpressionPtr& down_expr);

int32_t m_buffer_allocation_rank;
};
Expand Down
11 changes: 11 additions & 0 deletions src/common/snippets/include/snippets/lowered/port_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ class PortDescriptorVectorAttribute : public ov::RuntimeAttribute {
std::vector<PortDescriptorPtr> outputs{};
};

template<typename T>
void set_port_desc(const T& port, std::vector<size_t> subtensor) {
const auto& shape = port.get_shape();
for (size_t i = 1; i <= std::min(subtensor.size(), shape.size()); i++) {
auto& dim = subtensor[subtensor.size() - i];
if (dim != PortDescriptor::ServiceDimensions::FULL_DIM)
dim = std::min(dim, shape[shape.size() - i]);
}
PortDescriptorUtils::set_port_descriptor_ptr(port, std::make_shared<PortDescriptor>(shape, subtensor));
}

} // namespace lowered
} // namespace snippets
} // namespace ov
26 changes: 22 additions & 4 deletions src/common/snippets/include/snippets/op/brgemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,51 @@ namespace op {
* @brief Brgemm is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows
* @ingroup snippets
*/
class Brgemm : public MemoryAccess {
class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op {
public:
OPENVINO_OP("Brgemm", "SnippetsOpset", MemoryAccess);
OPENVINO_OP("Brgemm", "SnippetsOpset");
Brgemm(const Output<Node>& A, const Output<Node>& B,
const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu,
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {});
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
size_t blk_size_m = 0, size_t blk_size_k = 0, size_t blk_size_n = 0);
Brgemm(const Output<Node>& A, const Output<Node>& B,
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c,
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {});
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
size_t blk_size_m = 0, size_t blk_size_k = 0, size_t blk_size_n = 0);
Brgemm() = default;

size_t get_offset_a() const { return get_input_offset(0); }
size_t get_offset_b() const { return get_input_offset(1); }
size_t get_offset_c() const { return get_output_offset(0); }

size_t get_m_block_size() const { return m_M_blk; }
size_t get_k_block_size() const { return m_K_blk; }
size_t get_n_block_size() const { return m_N_blk; }
float get_beta() const { return m_beta; }

void set_m_block_size(size_t block_size) { m_M_blk = block_size; }
void set_k_block_size(size_t block_size) { m_K_blk = block_size; }
void set_n_block_size(size_t block_size) { m_N_blk = block_size; }
void set_beta(float beta) { m_beta = beta; }

static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1);

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

bool has_evaluate() const override { return false; }
bool visit_attributes(AttributeVisitor& visitor) override;

protected:
ov::element::Type get_output_type() const;
std::vector<ov::PartialShape> get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const;
ov::PartialShape get_output_partial_shape(const std::vector<ov::PartialShape>& input_shapes) const;
ov::PartialShape get_planar_output_shape(const ov::PartialShape& output_shape) const;
void compute_block_size_values(size_t blk_size_m, size_t blk_size_k, size_t blk_size_n);
size_t m_M_blk = 0;
size_t m_K_blk = 0;
size_t m_N_blk = 0;
float m_beta = 0.f;

private:
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c);
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/include/snippets/op/broadcastload.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ namespace op {
* @brief Is generated for broadcasting by least varying dimension for non-blocked cases and the second varying dimension for blocked
* @ingroup snippets
*/
class BroadcastLoad : public MemoryAccess {
class BroadcastLoad : public modifier::MemoryAccess, public ov::op::Op {
public:
OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ov::snippets::op::MemoryAccess);
OPENVINO_OP("BroadcastLoad", "SnippetsOpset");

BroadcastLoad(const Output<Node>& x, ov::Dimension bcast_dimension, size_t offset = 0lu);
BroadcastLoad() = default;
Expand Down
5 changes: 3 additions & 2 deletions src/common/snippets/include/snippets/op/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ namespace op {
* and memory offset for loading is determined by "offset" (Default value is "0" - to load starting from the first element)
* @ingroup snippets
*/
class Load : public MemoryAccess {
class Load : public modifier::MemoryAccess, public ov::op::Op {
public:
OPENVINO_OP("Load", "SnippetsOpset", MemoryAccess);
OPENVINO_OP("Load", "SnippetsOpset");

Load(const Output<Node>& x, const size_t count = 1lu, const size_t offset = 0lu);
Load() = default;
Expand All @@ -34,6 +34,7 @@ class Load : public MemoryAccess {

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;

protected:
void validate_memory_access_params() const;
Expand Down
35 changes: 22 additions & 13 deletions src/common/snippets/include/snippets/op/memory_access.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace ov {
namespace snippets {
namespace op {
namespace modifier {

/**
* @interface MemoryAccess
Expand All @@ -19,10 +19,8 @@ namespace op {
* @ingroup snippets
*/

class MemoryAccess : public ov::op::Op {
class MemoryAccess {
public:
OPENVINO_OP("MemoryAccess", "SnippetsOpset");

/**
* @interface PortDescriptor
* @brief This class describes port of MemoryAccess operation
Expand All @@ -34,13 +32,16 @@ class MemoryAccess : public ov::op::Op {
struct PortDescriptor {
PortDescriptor(size_t count, size_t offset) : count(count), offset(offset) {}
PortDescriptor() = default;

// TODO: should we deprecate count in favor of subtensors, ticket: 130004
size_t count = 0lu;
size_t offset = 0lu;
// Note: stride is interpreted as leading dimension for 2D subtensor ops
size_t stride = 0lu;
size_t index = 0lu;

private:
PortDescriptor(size_t count, size_t offset, size_t index) : count(count), offset(offset), index(index) {}
PortDescriptor(size_t count, size_t offset, size_t stride, size_t index) :
count(count), offset(offset), stride(stride), index(index) {}

friend class MemoryAccess;
};
Expand All @@ -50,27 +51,33 @@ class MemoryAccess : public ov::op::Op {
void set_output_count(size_t count, size_t idx = 0);
void set_input_offset(size_t offset, size_t idx = 0);
void set_output_offset(size_t offset, size_t idx = 0);
void set_input_stride(size_t stride, size_t idx = 0);
void set_output_stride(size_t stride, size_t idx = 0);

size_t get_input_count(size_t idx = 0) const;
size_t get_output_count(size_t idx = 0) const;
size_t get_input_offset(size_t idx = 0) const;
size_t get_output_offset(size_t idx = 0) const;
size_t get_input_stride(size_t idx = 0) const;
size_t get_output_stride(size_t idx = 0) const;

PortMap get_memory_access_input_ports() const { return m_input_ports; }
PortMap get_memory_access_output_ports() const { return m_output_ports; }

bool is_memory_access_input_port(size_t idx) const;
bool is_memory_access_output_port(size_t idx) const;

// All input and output ports are MemoryAccess
bool is_full_memory_access_op() const;
/**
* @brief Checks if the provided operation memory access on all ports
*/
bool is_full_memory_access_op(const std::shared_ptr<ov::Node>& op) const;

bool visit_attributes(AttributeVisitor& visitor) override;
bool visit_attributes(AttributeVisitor& visitor);

protected:
explicit MemoryAccess(const OutputVector& arguments, size_t input_count = 0, size_t output_count = 0);
explicit MemoryAccess(const OutputVector& arguments, const std::set<size_t>& input_ports, const std::set<size_t>& output_ports);
explicit MemoryAccess(const OutputVector& arguments, const PortMap& input_ports, const PortMap& output_ports);
explicit MemoryAccess(size_t input_count, size_t output_count = 0);
explicit MemoryAccess(const std::set<size_t>& input_ports, const std::set<size_t>& output_ports);
explicit MemoryAccess(const PortMap& input_ports, const PortMap& output_ports);
MemoryAccess() = default;

// This method can be called only in ctors
Expand All @@ -80,12 +87,14 @@ class MemoryAccess : public ov::op::Op {
void set_output_port_descriptor(const PortDescriptor& desc, const size_t i);
const PortDescriptor& get_input_port_descriptor(const size_t i) const;
const PortDescriptor& get_output_port_descriptor(const size_t i) const;
PortDescriptor& get_input_port_descriptor(const size_t i);
PortDescriptor& get_output_port_descriptor(const size_t i);

// [port_num, port_desc]
PortMap m_input_ports;
PortMap m_output_ports;
};

} // namespace op
} // namespace modifier
} // namespace snippets
} // namespace ov
5 changes: 3 additions & 2 deletions src/common/snippets/include/snippets/op/store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ namespace op {
* and memory offset for storing is determined by "offset" (Default value is "0" - to store starting at start memory ptr)
* @ingroup snippets
*/
class Store : public MemoryAccess {
class Store : public modifier::MemoryAccess, public ov::op::Op {
public:
OPENVINO_OP("Store", "SnippetsOpset", MemoryAccess);
OPENVINO_OP("Store", "SnippetsOpset");

Store(const Output<Node>& x, const size_t count = 1lu, const size_t offset = 0lu);
Store() = default;
Expand All @@ -33,6 +33,7 @@ class Store : public MemoryAccess {

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
};

} // namespace op
Expand Down
7 changes: 6 additions & 1 deletion src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "snippets/lowered/pass/optimize_loop_single_evaluation.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/op/kernel.hpp"
#include "snippets/op/memory_access.hpp"

namespace ov {
namespace snippets {
Expand Down Expand Up @@ -72,6 +73,9 @@ std::shared_ptr<const TargetMachine> Generator::get_target_machine() const {
}

RegType Generator::get_op_out_reg_type(const ov::Output<Node>& out) const {
auto reg_type = get_specific_op_out_reg_type(out);
if (reg_type != RegType::undefined)
return reg_type;
const auto op = out.get_node_shared_ptr();
if (std::dynamic_pointer_cast<ov::op::v0::Parameter>(op) ||
std::dynamic_pointer_cast<ov::op::v0::Result>(op) ||
Expand Down Expand Up @@ -107,7 +111,8 @@ RegType Generator::get_op_out_reg_type(const ov::Output<Node>& out) const {
std::dynamic_pointer_cast<op::Fill>(op))
return RegType::vec;
else
return get_specific_op_out_reg_type(op);
OPENVINO_THROW("Register type of the operation " + std::string(op->get_type_name()) + " isn't determined!");
return reg_type;
}

RegType Generator::get_specific_op_out_reg_type(const ov::Output<Node>& out) const {
Expand Down
4 changes: 2 additions & 2 deletions src/common/snippets/src/lowered/pass/allocate_buffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const
const auto& parent_expr = parent_output.get_expr();
const auto port = parent_output.get_index();
const auto& parent_node = parent_expr->get_node();
auto memory_access = ov::as_type_ptr<ov::snippets::op::MemoryAccess>(parent_node);
auto memory_access = std::dynamic_pointer_cast<modifier::MemoryAccess>(parent_node);
if (memory_access && memory_access->is_memory_access_output_port(port)) {
memory_access->set_output_offset(offset, port);
} else {
Expand All @@ -54,7 +54,7 @@ void AllocateBuffers::set_buffer_offset(const ExpressionPtr& buffer_expr, const
const auto& child_expr = child_expr_input.get_expr();
const auto port = child_expr_input.get_index();
const auto& child_node = child_expr->get_node();
auto memory_access = ov::as_type_ptr<ov::snippets::op::MemoryAccess>(child_node);
auto memory_access = std::dynamic_pointer_cast<modifier::MemoryAccess>(child_node);
if (memory_access && memory_access->is_memory_access_input_port(port)) {
memory_access->set_input_offset(offset, port);
} else if (ov::is_type<op::LoopEnd>(child_node)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ bool DefineBufferClusters::are_buffer_neighbours(const ExpressionPtr& up, const
}

void DefineBufferClusters::parse_memory_access_op(const ExpressionPtr& expr) {
const auto ma = ov::as_type_ptr<op::MemoryAccess>(expr->get_node());
if (!ma->is_full_memory_access_op())
const auto ma = std::dynamic_pointer_cast<modifier::MemoryAccess>(expr->get_node());
if (!ma->is_full_memory_access_op(expr->get_node()))
return;
// TODO: Some full MemoryAccess ops can have inplace inputs and outputs in general.
// Need to add mechanism of inplace ports using MemoryAccess::PortDescriptor::inplace
Expand Down Expand Up @@ -331,7 +331,7 @@ bool DefineBufferClusters::run(lowered::LinearIR& linear_ir, lowered::LinearIR::
continue;
}

if (ov::is_type<op::MemoryAccess>(op)) {
if (std::dynamic_pointer_cast<modifier::MemoryAccess>(op)) {
parse_memory_access_op(expr);
continue;
}
Expand Down
Loading

0 comments on commit e563109

Please sign in to comment.