From d9503fd365ecdab534fe63a318737aa5553191a6 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Thu, 29 Dec 2022 15:55:33 +0400 Subject: [PATCH] [Snippets] Added support of I8/U8/BF16 for MatMul Rewrote MemoryAccess --- .../snippets/include/snippets/generator.hpp | 20 + .../snippets/include/snippets/op/brgemm.hpp | 28 +- .../include/snippets/op/broadcastload.hpp | 11 +- .../snippets/include/snippets/op/buffer.hpp | 7 + .../snippets/include/snippets/op/load.hpp | 3 +- .../include/snippets/op/memory_access.hpp | 50 +- .../snippets/include/snippets/op/store.hpp | 3 +- .../snippets/pass/assign_registers.hpp | 7 +- .../snippets/include/snippets/utils.hpp | 17 + src/common/snippets/src/generator.cpp | 37 +- src/common/snippets/src/op/brgemm.cpp | 74 ++- src/common/snippets/src/op/broadcastload.cpp | 9 +- src/common/snippets/src/op/buffer.cpp | 38 +- src/common/snippets/src/op/load.cpp | 13 +- src/common/snippets/src/op/memory_access.cpp | 60 +- src/common/snippets/src/op/store.cpp | 11 +- src/common/snippets/src/op/subgraph.cpp | 95 ++-- .../snippets/src/pass/align_element_type.cpp | 4 +- .../snippets/src/pass/assign_registers.cpp | 49 +- .../snippets/src/pass/collapse_subgraph.cpp | 11 +- .../src/pass/fuse_transpose_brgemm.cpp | 11 +- .../snippets/src/pass/insert_load_store.cpp | 4 +- src/common/snippets/src/pass/insert_loops.cpp | 6 + .../load_movebroadcast_to_broadcastload.cpp | 9 +- .../snippets/src/pass/matmul_to_brgemm.cpp | 3 +- .../snippets/src/pass/vector_to_scalar.cpp | 6 +- src/common/snippets/src/utils.cpp | 10 + .../snippets/tests/include/lowering_utils.hpp | 3 + .../set_scalar_count_for_load_and_store.cpp | 59 +- src/common/snippets/tests/src/registers.cpp | 7 +- .../intel_cpu/src/emitters/cpu_generator.cpp | 14 +- .../intel_cpu/src/emitters/cpu_generator.hpp | 3 + .../src/emitters/jit_snippets_emitters.cpp | 515 ++++++++++++++---- .../src/emitters/jit_snippets_emitters.hpp | 56 +- src/plugins/intel_cpu/src/extension.cpp | 4 + src/plugins/intel_cpu/src/nodes/subgraph.cpp | 3 + .../brgemm_to_brgemm_cpu.cpp | 98 ++++ .../brgemm_to_brgemm_cpu.hpp | 28 + .../fuse_load_store_and_convert.cpp | 12 +- .../op/brgemm_copy_b.cpp | 73 +++ .../op/brgemm_copy_b.hpp | 38 ++ .../op/brgemm_cpu.cpp | 123 +++++ .../op/brgemm_cpu.hpp | 40 ++ .../op/load_convert.cpp | 8 +- .../op/store_convert.cpp | 8 +- .../snippets/matmul.cpp | 33 +- .../snippets/transpose_matmul.cpp | 58 +- .../plugin/shared/include/snippets/matmul.hpp | 28 +- .../include/snippets/transpose_matmul.hpp | 7 +- .../plugin/shared/src/snippets/matmul.cpp | 98 ++-- .../shared/src/snippets/transpose_matmul.cpp | 35 +- .../include/subgraph_lowered.hpp | 2 +- .../include/subgraph_matmul.hpp | 46 +- .../src/subgraph_lowered.cpp | 4 +- .../src/subgraph_matmul.cpp | 115 +++- 55 files changed, 1612 insertions(+), 502 deletions(-) create mode 100644 src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp create mode 100644 src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp create mode 100644 src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp create mode 100644 src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp create mode 100644 src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp create mode 100644 src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp diff --git a/src/common/snippets/include/snippets/generator.hpp b/src/common/snippets/include/snippets/generator.hpp index e88c638c5205d7..3b44825a943b26 100644 --- a/src/common/snippets/include/snippets/generator.hpp +++ b/src/common/snippets/include/snippets/generator.hpp @@ -41,6 +41,20 @@ class TargetMachine { */ virtual size_t get_lanes() const = 0; + /** + * @interface opRegType + * @brief Register type of operations + * Note that currently there are 4 types of ops: + * gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd etc) + * gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc. + * vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc. + */ + enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec}; + /** + * @brief gets register type by op type + * @return register type + */ + opRegType get_op_reg_type(const std::shared_ptr& op) const; /** * @brief called by generator to all the emitter for a target machine @@ -64,6 +78,12 @@ class TargetMachine { virtual ~TargetMachine() = default; protected: + /** + * @brief gets register type by specific plugin op type + * @return register type + */ + virtual opRegType get_specific_op_reg_type(const std::shared_ptr& op) const = 0; + std::map(std::shared_ptr)>> jitters; }; diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp index 2746d974a06400..531d101cb23d14 100644 --- a/src/common/snippets/include/snippets/op/brgemm.hpp +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -5,7 +5,7 @@ #pragma once #include "ngraph/op/op.hpp" -#include "ngraph/op/matmul.hpp" +#include "memory_access.hpp" namespace ngraph { namespace snippets { @@ -16,30 +16,28 @@ namespace op { * @brief Brgemm is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows * @ingroup snippets */ -class Brgemm : public ngraph::op::v0::MatMul { +class Brgemm : public MemoryAccess { public: - OPENVINO_OP("Brgemm", "SnippetsOpset", ngraph::op::v0::MatMul); - Brgemm(const Output& A, const Output& B, const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu); + OPENVINO_OP("Brgemm", "SnippetsOpset", MemoryAccess); + Brgemm(const Output& A, const Output& B, bool transposed_a = false, bool transposed_b = false, + const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu); Brgemm() = default; + bool transposed_a() const { return m_transposed_a; } + bool transposed_b() const { return m_transposed_b; } + bool visit_attributes(AttributeVisitor& visitor) override; void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; bool has_evaluate() const override { return false; } - size_t get_offset_a() const { return m_offset_a; } - size_t get_offset_b() const { return m_offset_b; } - size_t get_offset_c() const { return m_offset_c; } - - void set_offset_a(const size_t offset) { m_offset_a = offset; } - void set_offset_b(const size_t offset) { m_offset_b = offset; } - void set_offset_c(const size_t offset) { m_offset_c = offset; } +protected: + ov::element::Type get_output_type() const; + ov::PartialShape get_output_partial_shape(const std::vector& input_shapes) const; -private: - size_t m_offset_a = 0lu; // offset for first input - size_t m_offset_b = 0lu; // offset for second input - size_t m_offset_c = 0lu; // offset for output + bool m_transposed_a; + bool m_transposed_b; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/broadcastload.hpp b/src/common/snippets/include/snippets/op/broadcastload.hpp index db42a716f657e2..6268fdd736a722 100644 --- a/src/common/snippets/include/snippets/op/broadcastload.hpp +++ b/src/common/snippets/include/snippets/op/broadcastload.hpp @@ -4,7 +4,7 @@ #pragma once -#include +#include #include "ngraph/op/op.hpp" @@ -17,22 +17,19 @@ namespace op { * @brief Is generated for broadcasting by least varying dimension for non-blocked cases and the second varying dimension for blocked * @ingroup snippets */ -class BroadcastLoad : public BroadcastMove { +class BroadcastLoad : public MemoryAccess { public: - OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ngraph::snippets::op::BroadcastMove); + OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ngraph::snippets::op::MemoryAccess); BroadcastLoad(const Output& x, ov::PartialShape output_shape, size_t offset = 0lu); BroadcastLoad() = default; - size_t get_offset() const { return m_offset; } - void set_offset(const size_t offset) { m_offset = offset; } - bool visit_attributes(AttributeVisitor& visitor) override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; void validate_and_infer_types() override; private: - size_t m_offset = 0lu; + ov::PartialShape output_shape; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/buffer.hpp b/src/common/snippets/include/snippets/op/buffer.hpp index f75fc95e742edb..35cd143ba4a32e 100644 --- a/src/common/snippets/include/snippets/op/buffer.hpp +++ b/src/common/snippets/include/snippets/op/buffer.hpp @@ -16,6 +16,9 @@ namespace op { * - m_allocation_rank - rank of shape for memory allocation: shape[shape_rank - normalize(m_allocation_rank) : shape_rank]. * It's needed to allocate needed memory size that depends on Tile rank, for example. * Default value is -1 (full shape) + * - m_static_shape - static shape that describes Buffer size in cases when Buffer doesn't have parent node. + * - m_element_type - element type in cases when Buffer doesn't have parent node. + * - m_single - True if Buffer doesn't have parent node else False * Notes: * - All buffers in a graph have the same memory pointer. So if we have a few buffers, * each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer @@ -27,6 +30,7 @@ class Buffer : public ngraph::op::Op { OPENVINO_OP("Buffer", "SnippetsOpset"); Buffer(const Output& x, const int32_t allocation_rank = -1); + Buffer(const ov::Shape shape, const ov::element::Type element_type, int32_t allocation_rank = -1); Buffer() = default; int32_t get_allocation_rank() const { return m_allocation_rank; } @@ -40,6 +44,9 @@ class Buffer : public ngraph::op::Op { private: int32_t m_allocation_rank = -1; + ov::Shape m_static_shape; + ov::element::Type m_element_type; + bool m_is_single = false; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/load.hpp b/src/common/snippets/include/snippets/op/load.hpp index 1b8a391ebfe740..2639918dd1cd7c 100644 --- a/src/common/snippets/include/snippets/op/load.hpp +++ b/src/common/snippets/include/snippets/op/load.hpp @@ -20,11 +20,12 @@ namespace op { */ class Load : public MemoryAccess { public: - OPENVINO_OP("Load", "SnippetsOpset"); + OPENVINO_OP("Load", "SnippetsOpset", MemoryAccess); Load(const Output& x, const size_t count = 1lu, const size_t offset = 0lu); Load() = default; + void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; }; diff --git a/src/common/snippets/include/snippets/op/memory_access.hpp b/src/common/snippets/include/snippets/op/memory_access.hpp index f1b2d8ebb2f00d..418af53a0cf1b7 100644 --- a/src/common/snippets/include/snippets/op/memory_access.hpp +++ b/src/common/snippets/include/snippets/op/memory_access.hpp @@ -10,12 +10,37 @@ namespace ngraph { namespace snippets { namespace op { +class MemoryAccess; + +/** +* @interface PortDescriptor +* @brief This class describes port of MemoryAccess operation +* @param m_count - count of elements to load/store +* @param m_offset - starting index of elements to load/store +* @param m_index - port index +* @ingroup snippets +*/ + +struct PortDescriptor { + PortDescriptor(size_t count, size_t offset) : m_count(count), m_offset(offset) {} + PortDescriptor() = default; + + size_t m_count = 0lu; + size_t m_offset = 0lu; + size_t m_index = 0lu; + +private: + PortDescriptor(size_t count, size_t offset, size_t index) : m_count(count), m_offset(offset), m_index(index) {} + + friend class MemoryAccess; +}; + /** * @interface MemoryAccess * @brief This is a base class for memory access operations (like Load and Store). - * It provides universal set/get interface to manipulate the number - * of elements accessed during one operation call ("count"). - * Default "count" value is "1" - it means to load/store one element + * It provides universal interface to manipulate with memory: load/store. + * @param m_input_ports - vector of input descriptors: variables of PortDescriptor class + * @param m_output_ports - vector of output descriptors: variables of PortDescriptor class * @ingroup snippets */ @@ -23,18 +48,21 @@ class MemoryAccess : public ngraph::op::Op { public: OPENVINO_OP("MemoryAccess", "SnippetsOpset"); - size_t get_count() const; - size_t get_offset() const; - void set_count(const size_t count); - void set_offset(const size_t offset); + void set_input_port_descriptor(const PortDescriptor& desc, const size_t i); + void set_output_port_descriptor(const PortDescriptor& desc, const size_t i); + PortDescriptor get_input_port_descriptor(const size_t i) const; + PortDescriptor get_output_port_descriptor(const size_t i) const; + PortDescriptor& get_input_port_descriptor(const size_t i); + PortDescriptor& get_output_port_descriptor(const size_t i); + bool visit_attributes(AttributeVisitor& visitor) override; - void validate_and_infer_types() override; protected: - explicit MemoryAccess(const Output& x, size_t count = 1lu, size_t offset = 0lu); + explicit MemoryAccess(const OutputVector& arguments); MemoryAccess() = default; - size_t m_count = 0lu; - size_t m_offset = 0lu; + + std::vector m_input_ports; + std::vector m_output_ports; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/store.hpp b/src/common/snippets/include/snippets/op/store.hpp index 48c7466b924cff..74245daf71d021 100644 --- a/src/common/snippets/include/snippets/op/store.hpp +++ b/src/common/snippets/include/snippets/op/store.hpp @@ -20,11 +20,12 @@ namespace op { */ class Store : public MemoryAccess { public: - OPENVINO_OP("Store", "SnippetsOpset"); + OPENVINO_OP("Store", "SnippetsOpset", MemoryAccess); Store(const Output& x, const size_t count = 1lu, const size_t offset = 0lu); Store() = default; + void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; }; diff --git a/src/common/snippets/include/snippets/pass/assign_registers.hpp b/src/common/snippets/include/snippets/pass/assign_registers.hpp index 0eff4bcc7d7033..caaecfc58c130c 100644 --- a/src/common/snippets/include/snippets/pass/assign_registers.hpp +++ b/src/common/snippets/include/snippets/pass/assign_registers.hpp @@ -6,6 +6,8 @@ #include +#include "snippets/generator.hpp" + namespace ngraph { namespace snippets { namespace pass { @@ -18,10 +20,13 @@ namespace pass { */ class AssignRegisters : public ngraph::pass::FunctionPass { public: - explicit AssignRegisters() { + explicit AssignRegisters(const std::shared_ptr& target_machine) : m_target_machine(target_machine) { set_property(ngraph::pass::PassProperty::REQUIRE_STATIC_SHAPE, true); } bool run_on_model(const std::shared_ptr& m) override; + +private: + std::shared_ptr m_target_machine = nullptr; }; } // namespace pass diff --git a/src/common/snippets/include/snippets/utils.hpp b/src/common/snippets/include/snippets/utils.hpp index 2c6ca823aeec8c..b6d4f59919d84d 100644 --- a/src/common/snippets/include/snippets/utils.hpp +++ b/src/common/snippets/include/snippets/utils.hpp @@ -29,10 +29,27 @@ ov::PartialShape get_port_planar_shape(const Output& out); ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector& layout); std::vector get_node_output_layout(const std::shared_ptr& node); std::vector get_node_output_layout(const Node* node); +void set_output_layout(const ov::Output& port, const std::shared_ptr& node); +void set_output_layout(const ov::Output& port, const std::vector& layout); inline ov::Dimension get_inner_dim(const ov::PartialShape &shape) { return *(shape.rbegin()); } inline ov::Dimension get_outer_dim(const ov::PartialShape &shape) { return *(shape.rbegin() + 1); } +template +constexpr bool one_of(T val, P item) { return val == item; } + +template +constexpr bool one_of(T val, P item, Args... item_others) { + return val == item || one_of(val, item_others...); +} + +template +constexpr bool everyone_is(T val, P item) { return val == item; } + +template +constexpr bool everyone_is(T val, P item, Args... item_others) { + return val == item && everyone_is(val, item_others...); +} } // namespace utils } // namespace snippets } // namespace ngraph \ No newline at end of file diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index 3859479e85c110..45e7cd2824f16b 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -17,6 +17,30 @@ namespace ngraph { namespace snippets { +TargetMachine::opRegType TargetMachine::get_op_reg_type(const std::shared_ptr& op) const { + if (std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return gpr2gpr; + else if (std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return gpr2vec; + else if (std::dynamic_pointer_cast(op)) + return vec2gpr; + else if (ov::op::util::is_unary_elementwise_arithmetic(op) || + ov::op::util::is_binary_elementwise_arithmetic(op) || + ov::op::util::is_binary_elementwise_comparison(op) || + ov::op::util::is_binary_elementwise_logical(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return vec2vec; + else + return get_specific_op_reg_type(op); +} + auto getRegisters(const std::shared_ptr &n) -> RegInfo { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::getRegisters") @@ -77,8 +101,17 @@ auto tail_transformations(NodeVector& tail, const size_t tail_size, const ngraph } } } else if (const auto memory_access = std::dynamic_pointer_cast(op)) { - if (memory_access->get_count() != 1) { - memory_access->set_count(tail_size); + for (size_t i = 0; i < memory_access->get_input_size(); ++i) { + auto& desc = memory_access->get_input_port_descriptor(i); + if (desc.m_count != 1) { + desc.m_count = tail_size; + } + } + for (size_t i = 0; i < memory_access->get_output_size(); ++i) { + auto& desc = memory_access->get_output_port_descriptor(i); + if (desc.m_count != 1) { + desc.m_count = tail_size; + } } } updated_tile.push_back(op); diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 7bf999cb15e423..09b55ee2224ca7 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -13,50 +13,78 @@ namespace ngraph { namespace snippets { namespace op { -Brgemm::Brgemm(const Output& A, const Output& B, const size_t offset_a, const size_t offset_b, const size_t offset_c) - : MatMul(), m_offset_a(offset_a), m_offset_b(offset_b), m_offset_c(offset_c) { - set_arguments({A, B}); +Brgemm::Brgemm(const Output& A, const Output& B, bool transposed_a, bool transposed_b, + const size_t offset_a, const size_t offset_b, const size_t offset_c) + : MemoryAccess({A, B}), m_transposed_a(transposed_a), m_transposed_b(transposed_b) { set_output_size(1); + set_input_port_descriptor({0, offset_a}, 0); + set_input_port_descriptor({0, offset_b}, 1); + set_output_port_descriptor({0, offset_c}, 0); constructor_validate_and_infer_types(); } bool Brgemm::visit_attributes(AttributeVisitor& visitor) { - MatMul::visit_attributes(visitor); - visitor.on_attribute("offset_a", m_offset_a); - visitor.on_attribute("offset_b", m_offset_b); - visitor.on_attribute("offset_c", m_offset_c); + MemoryAccess::visit_attributes(visitor); + visitor.on_attribute("transposed_a", m_transposed_a); + visitor.on_attribute("transposed_b", m_transposed_b); return true; } void Brgemm::validate_and_infer_types() { INTERNAL_OP_SCOPE(Brgemm_validate_and_infer_types); - element::Type result_et; - NODE_VALIDATION_CHECK(this, - element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)), - "Arguments do not have the same element type (arg0 element type: ", - get_input_element_type(0), - ", arg1 element type: ", - get_input_element_type(1), - ")."); // If no leading dimensions are provided, assume dense row-major inputs-outputs NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), "Brgemm currently supports only static shapes."); - std::vector planar_input_shapes; - for (const auto& in : input_values()) - planar_input_shapes.emplace_back(utils::get_port_planar_shape(in)); + std::vector planar_input_shapes = { + utils::get_port_planar_shape(input_value(0)), + utils::get_port_planar_shape(input_value(1)) + }; - std::vector output_shapes = {ov::PartialShape{}}; - ov::op::v0::shape_infer(this, planar_input_shapes, output_shapes); + auto output_shape = get_output_partial_shape(planar_input_shapes); const auto& output_layout = utils::get_node_output_layout(this); - output_shapes[0] = utils::get_reordered_planar_shape(output_shapes[0], output_layout); - set_output_type(0, result_et, output_shapes[0]); + set_output_type(0, + get_output_type(), + utils::get_reordered_planar_shape(output_shape, output_layout)); } std::shared_ptr Brgemm::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(Brgemm_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), new_args.at(1), m_offset_a, m_offset_b, m_offset_c); + return std::make_shared(new_args.at(0), new_args.at(1), + m_transposed_a, m_transposed_b, + get_input_port_descriptor(0).m_offset, + get_input_port_descriptor(1).m_offset, + get_output_port_descriptor(0).m_offset); +} + +ov::element::Type Brgemm::get_output_type() const { + const auto element_type_a = get_input_element_type(0); + const auto element_type_b = get_input_element_type(1); + const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b); + const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8; + const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b); + if (is_f32 || is_bf16) { + return element::f32; + } else if (is_int8) { + return element::i32; + } else { + throw ngraph_error("BrgemmCPU node has incompatible input element types: " + + element_type_a.get_type_name() + + " and " + + element_type_b.get_type_name()); + } +} + +ov::PartialShape Brgemm::get_output_partial_shape(const std::vector& input_shapes) const { + NGRAPH_CHECK(input_shapes.size() == 2, "BRGEMM expects 2 input shapes for shape inference"); + auto matmul_in0 = std::make_shared(ngraph::element::f32, input_shapes[0]); + auto matmul_in1 = std::make_shared(ngraph::element::f32, input_shapes[1]); + auto matmul = std::make_shared(matmul_in0, matmul_in1, m_transposed_a, m_transposed_b); + + std::vector output_shapes = {ov::PartialShape{}}; + ov::op::v0::shape_infer(matmul.get(), input_shapes, output_shapes); + return output_shapes.front(); } } // namespace op diff --git a/src/common/snippets/src/op/broadcastload.cpp b/src/common/snippets/src/op/broadcastload.cpp index 927b47f94498bc..801528d0dcaa27 100644 --- a/src/common/snippets/src/op/broadcastload.cpp +++ b/src/common/snippets/src/op/broadcastload.cpp @@ -11,21 +11,20 @@ using namespace std; using namespace ngraph; -snippets::op::BroadcastLoad::BroadcastLoad(const Output& x, ov::PartialShape shape, size_t offset) - : BroadcastMove(x, std::move(shape)), m_offset(offset) { +snippets::op::BroadcastLoad::BroadcastLoad(const Output& x, ov::PartialShape shape, size_t offset) : MemoryAccess({x}), output_shape(std::move(shape)) { constructor_validate_and_infer_types(); + set_input_port_descriptor({1, offset}, 0); } bool snippets::op::BroadcastLoad::visit_attributes(AttributeVisitor& visitor) { - BroadcastMove::visit_attributes(visitor); - visitor.on_attribute("offset", m_offset); + MemoryAccess::visit_attributes(visitor); return true; } std::shared_ptr snippets::op::BroadcastLoad::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(BroadcastLoad); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), output_shape, m_offset); + return std::make_shared(new_args.at(0), output_shape, get_input_port_descriptor(0).m_offset); } void snippets::op::BroadcastLoad::validate_and_infer_types() { diff --git a/src/common/snippets/src/op/buffer.cpp b/src/common/snippets/src/op/buffer.cpp index ad05ae2e046932..1255c13427f147 100644 --- a/src/common/snippets/src/op/buffer.cpp +++ b/src/common/snippets/src/op/buffer.cpp @@ -16,38 +16,62 @@ auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t return allocation_rank < 0 ? allocation_rank + static_cast(shape_rank) : allocation_rank; } -snippets::op::Buffer::Buffer(const Output& x, const int32_t allocation_rank) : Op({x}), m_allocation_rank(allocation_rank) { +snippets::op::Buffer::Buffer(const Output& x, const int32_t allocation_rank) + : Op({x}), m_allocation_rank(allocation_rank), m_is_single(false) { + constructor_validate_and_infer_types(); +} + +snippets::op::Buffer::Buffer(const ov::Shape shape, const ov::element::Type element_type, const int32_t allocation_rank) + : Op(), m_static_shape(shape), m_element_type(element_type), m_allocation_rank(allocation_rank), m_is_single(true) { constructor_validate_and_infer_types(); } bool snippets::op::Buffer::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(Buffer_visit_attributes); visitor.on_attribute("allocation_rank", m_allocation_rank); + if (m_is_single) { + visitor.on_attribute("shape", m_static_shape); + visitor.on_attribute("element_type", m_element_type); + } return true; } std::shared_ptr snippets::op::Buffer::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(Buffer_clone_with_new_inputs); check_new_args_count(this, new_args); - auto new_buffer = std::make_shared(new_args.at(0), m_allocation_rank); - return new_buffer; + if (m_is_single) { + return std::make_shared(m_static_shape, m_element_type, m_allocation_rank); + } + + return std::make_shared(new_args.at(0), m_allocation_rank); } void snippets::op::Buffer::validate_and_infer_types() { INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types); - const auto shape_rank = get_input_partial_shape(0).rank(); + ov::PartialShape output_shape; + ov::element::Type output_type; + if (m_is_single) { + output_shape = m_static_shape; + output_type = m_element_type; + } else { + output_shape = get_input_partial_shape(0); + output_type = get_input_element_type(0); + } + + const auto shape_rank = output_shape.rank(); if (shape_rank.is_static()) { const auto normalized_rank = normalize_rank(m_allocation_rank, shape_rank.get_length()); NGRAPH_CHECK(normalized_rank >= 0 && normalized_rank <= shape_rank.get_length(), "Buffer has incorrect allocation rank: " + std::to_string(m_allocation_rank)); } - set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); + + set_output_type(0, output_type, output_shape); } size_t ngraph::snippets::op::Buffer::get_byte_size() const { - const auto pshape = get_input_partial_shape(0); + const auto pshape = get_output_partial_shape(0); NGRAPH_CHECK(pshape.is_static(), "Buffer should have static shapes for memory allocation"); const auto shape = pshape.get_shape(); const auto normalized_rank = normalize_rank(m_allocation_rank, shape.size()); - return ngraph::shape_size(shape.rbegin(), shape.rbegin() + normalized_rank) * get_element_type().size(); + return ngraph::shape_size(shape.rbegin(), shape.rbegin() + normalized_rank + 1) * get_element_type().size(); } diff --git a/src/common/snippets/src/op/load.cpp b/src/common/snippets/src/op/load.cpp index f7637fbc7962a5..da531db5e9a76d 100644 --- a/src/common/snippets/src/op/load.cpp +++ b/src/common/snippets/src/op/load.cpp @@ -12,14 +12,20 @@ namespace ngraph { namespace snippets { namespace op { -Load::Load(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}, count, offset) { +Load::Load(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}) { constructor_validate_and_infer_types(); + set_input_port_descriptor({count, offset}, 0); +} + +void snippets::op::Load::validate_and_infer_types() { + set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); } std::shared_ptr Load::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(Load); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_count, m_offset); + return std::make_shared( + new_args.at(0), get_input_port_descriptor(0).m_count, get_input_port_descriptor(0).m_offset); } @@ -53,7 +59,8 @@ bool snippets::op::LoadReshape::visit_attributes(AttributeVisitor& visitor) { std::shared_ptr snippets::op::LoadReshape::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(LoadReshape); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_count, m_offset, m_order); + return std::make_shared( + new_args.at(0), get_input_port_descriptor(0).m_count, get_input_port_descriptor(0).m_offset, m_order); } }// namespace op diff --git a/src/common/snippets/src/op/memory_access.cpp b/src/common/snippets/src/op/memory_access.cpp index 2530ea77b6352b..059e0d74087419 100644 --- a/src/common/snippets/src/op/memory_access.cpp +++ b/src/common/snippets/src/op/memory_access.cpp @@ -12,32 +12,66 @@ namespace ngraph { namespace snippets { namespace op { -MemoryAccess::MemoryAccess(const Output& x, const size_t count, const size_t offset) : Op({x}), m_count(count), m_offset(offset) {} +MemoryAccess::MemoryAccess(const OutputVector& arguments) : Op(arguments) {} bool MemoryAccess::visit_attributes(AttributeVisitor& visitor) { - visitor.on_attribute("count", m_count); - visitor.on_attribute("offset", m_offset); + for (size_t i = 0; i < m_input_ports.size(); ++i) { + auto port = m_input_ports[i]; + visitor.on_attribute("count_in_" + std::to_string(i), port.m_count); + visitor.on_attribute("offset_in_" + std::to_string(i), port.m_offset); + } + for (size_t i = 0; i < m_output_ports.size(); ++i) { + auto port = m_output_ports[i]; + visitor.on_attribute("count_out_" + std::to_string(i), port.m_count); + visitor.on_attribute("offset_out_" + std::to_string(i), port.m_offset); + } return true; } -size_t MemoryAccess::get_count() const { - return m_count; +void MemoryAccess::set_input_port_descriptor(const PortDescriptor& desc, const size_t i) { + // Logic is as same as ov::Node::get_input_descriptor + while (m_input_ports.size() <= i) { + m_input_ports.emplace_back(PortDescriptor{0, 0, m_input_ports.size()}); + } + m_input_ports[i] = { desc.m_count, desc.m_offset, i}; } -size_t MemoryAccess::get_offset() const { - return m_offset; +PortDescriptor MemoryAccess::get_input_port_descriptor(const size_t i) const { + // We cannot use the same way as in ov::Node::get_input_descriptor because this method must be static + // to allow call const Derived::clone_with_new_inputs() method + NGRAPH_CHECK(i < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports"); + return m_input_ports[i]; } -void MemoryAccess::set_count(const size_t count) { - m_count = count; +PortDescriptor& MemoryAccess::get_input_port_descriptor(const size_t i) { + // Logic is as same as ov::Node::get_input_descriptor + while (m_input_ports.size() <= i) { + m_input_ports.emplace_back(PortDescriptor{0, 0, m_input_ports.size()}); + } + return m_input_ports[i]; } -void MemoryAccess::set_offset(const size_t offset) { - m_offset = offset; +void MemoryAccess::set_output_port_descriptor(const PortDescriptor& desc, const size_t i) { + // Logic is as same as ov::Node::get_output_descriptor + while (m_output_ports.size() <= i) { + m_output_ports.emplace_back(PortDescriptor{0, 0, m_output_ports.size()}); + } + m_output_ports[i] = { desc.m_count, desc.m_offset, i}; } -void MemoryAccess::validate_and_infer_types() { - set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); +PortDescriptor MemoryAccess::get_output_port_descriptor(const size_t i) const { + // We cannot use the same way as in ov::Node::get_input_descriptor because this method must be static + // to allow call const Derived::clone_with_new_inputs() method + NGRAPH_CHECK(i < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports"); + return m_output_ports[i]; +} + +PortDescriptor& MemoryAccess::get_output_port_descriptor(const size_t i) { + // Logic is as same as ov::Node::get_output_descriptor + while (m_output_ports.size() <= i) { + m_output_ports.emplace_back(PortDescriptor{0, 0, m_output_ports.size()}); + } + return m_output_ports[i]; } } // namespace op diff --git a/src/common/snippets/src/op/store.cpp b/src/common/snippets/src/op/store.cpp index 90750de6b65fec..d5cc35de8bc03d 100644 --- a/src/common/snippets/src/op/store.cpp +++ b/src/common/snippets/src/op/store.cpp @@ -12,13 +12,20 @@ namespace ngraph { namespace snippets { namespace op { -snippets::op::Store::Store(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}, count, offset) { +snippets::op::Store::Store(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}) { constructor_validate_and_infer_types(); + set_output_port_descriptor({count, offset}, 0); } + +void snippets::op::Store::validate_and_infer_types() { + set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); +} + std::shared_ptr snippets::op::Store::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(Store_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_count, m_offset); + return std::make_shared( + new_args.at(0), get_output_port_descriptor(0).m_count, get_output_port_descriptor(0).m_offset); } } // namespace op diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index c2fc57add6d811..5e30a6d597ea02 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -398,6 +398,23 @@ void snippets::op::Subgraph::align_element_types(const BlockedShapeVector& outpu manager.register_pass(); } manager.run_passes(body_ptr()); + + // TODO: NEED NORMAL ALIGN ELEMENT TYPE + const auto ops = body_ptr()->get_ops(); + for (const auto& op : ops) { + if (auto brgemm = ov::as_type_ptr(op)) { + if (brgemm->get_input_element_type(0).is_integral() && brgemm->get_input_element_type(1).is_integral()) { + auto target_input = brgemm->get_output_target_inputs(0).begin(); + if (ov::is_type(target_input->get_node())) { + target_input = target_input->get_node()->get_output_target_inputs(0).begin(); + } + auto convert = std::make_shared( + target_input->get_source_output(), + ov::element::f32); + target_input->get_node()->shared_from_this()->set_argument(target_input->get_index(), convert); + } + } + } } void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { @@ -431,22 +448,21 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { // Propagate to up: in Store. Buffer can have only one Store { - auto parent = buffer->get_input_node_shared_ptr(0); - auto idx = buffer->input(0).get_source_output().get_index(); - // There may be graph with several LoopBegin and LoopEnd between Store/Brgemm and Buffer, - // so we should iterate through LoopBase - while (ov::is_type(parent)) { - const auto source_output = parent->input_value(idx); - parent = source_output.get_node_shared_ptr(); - idx = source_output.get_index(); - } - if (auto store = ov::as_type_ptr(parent)) { - store->set_offset(offset); - } else if (const auto brgemm = ov::as_type_ptr(parent)) { - // Brgemm encapsulates work with loading and storing of data - brgemm->set_offset_c(offset); - } else { - throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding Store op for offset propagation"); + if (buffer->get_input_size() > 0) { + auto parent = buffer->get_input_node_shared_ptr(0); + auto idx = buffer->input(0).get_source_output().get_index(); + while (ov::is_type(parent)) { + const auto source_output = parent->input_value(idx); + parent = source_output.get_node_shared_ptr(); + idx = source_output.get_index(); + } + if (auto memory_access = ov::as_type_ptr(parent)) { + auto &out_desc = memory_access->get_output_port_descriptor(idx); + out_desc.m_offset = offset; + } else { + throw ngraph_error( + "Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation"); + } } } @@ -463,17 +479,11 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { for (const auto loop_target_output : child->output(index).get_target_inputs()) { propagate_down(loop_target_output); } - } else if (const auto load = ov::as_type_ptr(child)) { - load->set_offset(offset); - } else if (const auto brgemm = ov::as_type_ptr(child)) { - // Brgemm encapsulates work with loading and storing of data - if (target_input.get_index() == 0) { - brgemm->set_offset_a(offset); - } else if (target_input.get_index() == 1) { - brgemm->set_offset_b(offset); - } + } else if (auto memory_access = ov::as_type_ptr(child)) { + auto& in_desc = memory_access->get_input_port_descriptor(target_input.get_index()); + in_desc.m_offset = offset; } else { - throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding Load op for offset propagation"); + throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation"); } }; @@ -494,26 +504,24 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { continue; } - // Transpose and MatMul ops should have different memories on inputs and outputs to avoid data corruption, - // so after them, we should allocate new memory. Other operations (Eltwises, Convert) can be executed inplace. - const auto parent = buffer->get_input_node_shared_ptr(0); - if (ov::is_type(parent) || is_transpose_loop(parent)) { + if (buffer->get_input_size() > 0) { + // Transpose, MatMul and other non-decomposed ops should have different memories on inputs and outputs to avoid data corruption, + // so after them, we should allocate new memory. Other operations (Eltwises, Convert) can be executed inplace inside Loop. + const auto parent = buffer->get_input_node_shared_ptr(0); + if (!ov::is_type(parent) || is_transpose_loop(parent)) { + offset = m_buffer_scratchpad; + propagate_offset(buffer, offset); + m_buffer_scratchpad += buffer_size; + continue; + } + + propagate_offset(buffer, offset); + } else { + // Single Buffer without input should allocate new memory offset = m_buffer_scratchpad; propagate_offset(buffer, offset); m_buffer_scratchpad += buffer_size; - continue; - } - - // If Buffer op requires memory size more that has been already allocated, - // we increase current memory size to the needed size - // For example, it's possible when we have a sequence of Eltwise ops with broadcasting - const auto current_allocated_memory_size = m_buffer_scratchpad - offset; - if (buffer_size > current_allocated_memory_size) { - m_buffer_scratchpad += (buffer_size - current_allocated_memory_size); - // Note: we don't update offset because we just add memory to needed size } - - propagate_offset(buffer, offset); } } } @@ -620,13 +628,14 @@ snippets::Schedule snippets::op::Subgraph::generate(ngraph::pass::Manager& opt, convert_to_snippet_dialect(); opt.run_passes(body_ptr()); + opt.run_passes(body_ptr()); // After all passes, when all optimizations are completed and all MemoryAccess ops are inserted, // we can calculate common buffer scratchpad size and propagate offset from Buffer to the corresponding MemoryAccess ops if (config.m_has_domain_sensitive_ops) initialize_buffer_scratchpad_size(); - snippets::pass::AssignRegisters().run_on_model(body_ptr()); + snippets::pass::AssignRegisters(m_generator->get_target_machine()).run_on_model(body_ptr()); const auto ops = body_ptr()->get_ops(); ngraph::snippets::Generator::GeneratorConfig generatorConfig; diff --git a/src/common/snippets/src/pass/align_element_type.cpp b/src/common/snippets/src/pass/align_element_type.cpp index 73636b46d99d9e..f2d52288de6550 100644 --- a/src/common/snippets/src/pass/align_element_type.cpp +++ b/src/common/snippets/src/pass/align_element_type.cpp @@ -30,7 +30,9 @@ inline auto op_supports_only_exec_type(const std::shared_ptr& n) -> bo !ov::is_type(n) && !ov::is_type(n) && !ov::is_type(n) && - !ov::is_type(n); + !ov::is_type(n) && + !ov::is_type(n) && + !ov::is_type(n); } } // namespace diff --git a/src/common/snippets/src/pass/assign_registers.cpp b/src/common/snippets/src/pass/assign_registers.cpp index 04cbadf5a608cd..08c737c4523ae9 100644 --- a/src/common/snippets/src/pass/assign_registers.cpp +++ b/src/common/snippets/src/pass/assign_registers.cpp @@ -17,31 +17,10 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr using Reg = size_t; using tensor = std::shared_ptr; auto ops = f->get_ordered_ops(); - // Note that currently there are 3 types of ops: - // * gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd) will also be Buffer? - // * gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc. - // * vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc. - enum op_reg_type {gpr2gpr, gpr2vec, vec2gpr, vec2vec}; - auto get_op_reg_type = [](const std::shared_ptr& op) { - if (std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) - return gpr2gpr; - else if (std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) - return gpr2vec; - else if (std::dynamic_pointer_cast(op)) - return vec2gpr; - else - return vec2vec; - }; - std::vector>> typed_ops; + std::vector>> typed_ops; for (const auto& op : ops) - typed_ops.emplace_back(std::make_pair(get_op_reg_type(op), op)); + typed_ops.emplace_back(std::make_pair(m_target_machine->get_op_reg_type(op), op)); size_t counter_vec = 0; size_t counter_gpr = 0; std::map regs_vec, regs_gpr; @@ -109,12 +88,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr }; for (const auto& t_op : typed_ops) { switch (t_op.first) { - case vec2vec: - case gpr2vec: + case TargetMachine::opRegType::vec2vec: + case TargetMachine::opRegType::gpr2vec: enumerate_out_tensors(t_op.second, regs_vec, manually_assigned_vecs, counter_vec); break; - case gpr2gpr: - case vec2gpr: + case TargetMachine::opRegType::gpr2gpr: + case TargetMachine::opRegType::vec2gpr: enumerate_out_tensors(t_op.second, regs_gpr, manually_assigned_gprs, counter_gpr); break; } @@ -144,19 +123,19 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr for (const auto& out : t_op.second->outputs()) defined_tensors.push_back(out.get_tensor_ptr()); switch (t_op.first) { - case vec2vec: + case TargetMachine::opRegType::vec2vec: used_vec[i] = tensor2reg(used_tensors, regs_vec); defined_vec[i] = tensor2reg(defined_tensors, regs_vec); break; - case gpr2gpr: + case TargetMachine::opRegType::gpr2gpr: used_gpr[i] = tensor2reg(used_tensors, regs_gpr); defined_gpr[i] = tensor2reg(defined_tensors, regs_gpr); break; - case gpr2vec: + case TargetMachine::opRegType::gpr2vec: used_gpr[i] = tensor2reg(used_tensors, regs_gpr); defined_vec[i] = tensor2reg(defined_tensors, regs_vec); break; - case vec2gpr: + case TargetMachine::opRegType::vec2gpr: used_vec[i] = tensor2reg(used_tensors, regs_vec); defined_gpr[i] = tensor2reg(defined_tensors, regs_gpr); break; @@ -191,12 +170,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr if (k == ops.size()) throw ngraph_error("assign registers can't find target op in the body"); switch (typed_ops[k].first) { - case vec2vec: - case vec2gpr: + case TargetMachine::opRegType::vec2vec: + case TargetMachine::opRegType::vec2gpr: life_out_vec[n].insert(life_in_vec[k].begin(), life_in_vec[k].end()); break; - case gpr2gpr: - case gpr2vec: + case TargetMachine::opRegType::gpr2gpr: + case TargetMachine::opRegType::gpr2vec: life_out_gpr[n].insert(life_in_gpr[k].begin(), life_in_gpr[k].end()); break; } diff --git a/src/common/snippets/src/pass/collapse_subgraph.cpp b/src/common/snippets/src/pass/collapse_subgraph.cpp index 1c20b154d2d893..7fe4df58d415d0 100644 --- a/src/common/snippets/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/src/pass/collapse_subgraph.cpp @@ -49,9 +49,16 @@ auto outputs_are_not_broadcastable(const std::shared_ptr& node) -> b auto is_supported_op(const std::shared_ptr &n) -> bool { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::is_supported_op") auto is_supported_matmul = [](const std::shared_ptr& n) -> bool { - const auto& matmul = is_type(n); + const auto& matmul = ov::as_type_ptr(n); const auto& out_shape = n->get_output_partial_shape(0); - return matmul && out_shape.is_static() && out_shape.size() == 4; + if (!matmul || out_shape.is_dynamic() || out_shape.size() != 4) + return false; + const auto intype_0 = matmul->get_input_element_type(0); + const auto intype_1 = matmul->get_input_element_type(1); + const bool is_f32 = intype_0 == element::f32 && intype_1 == element::f32; + const bool is_int8 = (intype_0 == element::i8 || intype_0 == element::u8) && (intype_1 == element::i8); + const bool is_bf16 = intype_0 == element::bf16 && intype_1 == element::bf16; + return is_f32 || is_bf16 || is_int8; }; auto is_supported_transpose = [](const std::shared_ptr& n) -> bool { const auto& transpose = as_type_ptr(n); diff --git a/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp b/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp index 73347c6475bba0..fcde932d2ecca3 100644 --- a/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp +++ b/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp @@ -49,13 +49,8 @@ FuseTransposeBrgemm::FuseTransposeBrgemm() { auto callback = [=](pattern::Matcher& m) { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "FuseTransposeBrgemm") - auto set_layout_from_order = [](const std::shared_ptr& node, const ov::Output& port) { - const auto& const_order = as_type_ptr(node->get_input_node_shared_ptr(1)); - std::vector layout = const_order->cast_vector(); - auto& rt_info = port.get_node_shared_ptr()->get_rt_info(); - rt_info["Layout"] = layout; - }; auto brgemm = as_type_ptr(m.get_match_root()); + // Transpose on the Brgemm's output if (!brgemm) { brgemm = as_type_ptr(m.get_match_root()->get_input_node_shared_ptr(0)); @@ -63,13 +58,13 @@ FuseTransposeBrgemm::FuseTransposeBrgemm() { const auto& transpose_out = m.get_match_value(); for (const auto& in : transpose_out.get_target_inputs()) in.replace_source_output(brgemm->output(0)); - set_layout_from_order(as_type_ptr(transpose_out.get_node_shared_ptr()), brgemm_out); + utils::set_output_layout(brgemm_out, as_type_ptr(transpose_out.get_node_shared_ptr())); } for (int i = 0; i < brgemm->get_input_size(); i++) { const auto& in_value = brgemm->input_value(i); if (transpose_matcher->match(in_value)) { const auto& transpose = as_type_ptr(in_value.get_node_shared_ptr()); - set_layout_from_order(transpose, transpose->input_value(0)); + utils::set_output_layout(transpose->input_value(0), transpose); brgemm->set_argument(i, transpose->input_value(0)); } } diff --git a/src/common/snippets/src/pass/insert_load_store.cpp b/src/common/snippets/src/pass/insert_load_store.cpp index efa0d6396c63fd..da01bce98ca235 100644 --- a/src/common/snippets/src/pass/insert_load_store.cpp +++ b/src/common/snippets/src/pass/insert_load_store.cpp @@ -30,7 +30,7 @@ ngraph::snippets::pass::InsertLoad::InsertLoad(const size_t count) { const auto& consumer_node = consumer.get_node(); if (ov::is_type(consumer_node) || ov::is_type(consumer_node) || - ov::is_type(consumer_node) || + ov::is_type(consumer_node) || ov::is_type(consumer_node)) { return false; } @@ -67,7 +67,7 @@ ngraph::snippets::pass::InsertStore::InsertStore(const size_t count) { const auto& parent_node = input.get_source_output().get_node(); if (ov::is_type(parent_node) || ov::is_type(parent_node) || - ov::is_type(parent_node) || + ov::is_type(parent_node) || ov::is_type(parent_node)) { return false; } diff --git a/src/common/snippets/src/pass/insert_loops.cpp b/src/common/snippets/src/pass/insert_loops.cpp index f6d83bf6da733f..f26e48ec71ad07 100644 --- a/src/common/snippets/src/pass/insert_loops.cpp +++ b/src/common/snippets/src/pass/insert_loops.cpp @@ -217,6 +217,9 @@ bool InsertLoops::run_on_model(const std::shared_ptr &model) { if (m_master_shape.is_dynamic()) throw ngraph_error("InsertLoops doesn't support dynamic shapes yet"); + ov::pass::Serialize("/home/a-sidorova/projects/mha_matmul/openvino/graphs/loops.xml", + "/home/a-sidorova/projects/mha_matmul/openvino/graphs/loops.bin").run_on_model(model); + const auto inner_work_amount = utils::get_inner_dim(m_master_shape).get_length(); const auto outer_work_amount = m_loop_depth == 2 ? utils::get_outer_dim(m_master_shape).get_length() : 1; @@ -277,6 +280,9 @@ bool InsertLoops::run_on_model(const std::shared_ptr &model) { } } + ov::pass::Serialize("/home/a-sidorova/projects/mha_matmul/openvino/graphs/loops_after.xml", + "/home/a-sidorova/projects/mha_matmul/openvino/graphs/loops_after.bin").run_on_model(model); + return true; } diff --git a/src/common/snippets/src/pass/load_movebroadcast_to_broadcastload.cpp b/src/common/snippets/src/pass/load_movebroadcast_to_broadcastload.cpp index 9945724c83e88d..e4347364f6d01f 100644 --- a/src/common/snippets/src/pass/load_movebroadcast_to_broadcastload.cpp +++ b/src/common/snippets/src/pass/load_movebroadcast_to_broadcastload.cpp @@ -24,20 +24,21 @@ ngraph::snippets::pass::LoadMoveBroadcastToBroadcastLoad::LoadMoveBroadcastToBro auto root = m.get_match_root(); const auto &pm = m.get_pattern_value_map(); - const auto input = pm.at(load_pattern).get_node_shared_ptr(); + const auto load = ov::as_type_ptr(pm.at(load_pattern).get_node_shared_ptr()); const auto param = pm.at(param_pattern).get_node_shared_ptr(); // Cannot rewrite Broadcast + Load if load has more than 1 user // or more than one input, or if Broadcast has several inputs - if (input->output(0).get_target_inputs().size() != 1 || - root->inputs().size() != 1 || input->inputs().size() != 1) { + if (load->output(0).get_target_inputs().size() != 1 || + root->inputs().size() != 1 || load->inputs().size() != 1) { return false; } auto inshape = root->input(0).get_partial_shape(); auto outshape = root->output(0).get_partial_shape(); - auto broadcastload = std::make_shared(param, outshape, ov::as_type_ptr(input)->get_offset()); + const auto load_in_desc = load->get_input_port_descriptor(0); + auto broadcastload = std::make_shared(param, outshape, load_in_desc.m_offset); ngraph::copy_runtime_info(root, broadcastload); ngraph::replace_node(root, broadcastload); diff --git a/src/common/snippets/src/pass/matmul_to_brgemm.cpp b/src/common/snippets/src/pass/matmul_to_brgemm.cpp index b74fb3e68cc47e..df46030b4a34f1 100644 --- a/src/common/snippets/src/pass/matmul_to_brgemm.cpp +++ b/src/common/snippets/src/pass/matmul_to_brgemm.cpp @@ -29,7 +29,8 @@ MatMulToBrgemm::MatMulToBrgemm() { if (matmul->get_transpose_a() || matmul->get_transpose_b()) return false; - auto brgemm = std::make_shared(matmul->get_input_source_output(0), matmul->get_input_source_output(1)); + auto brgemm = std::make_shared(matmul->get_input_source_output(0), matmul->get_input_source_output(1), + matmul->get_transpose_a(), matmul->get_transpose_b()); brgemm->set_friendly_name(matmul->get_friendly_name()); ngraph::copy_runtime_info(matmul, brgemm); ngraph::replace_node(matmul, brgemm); diff --git a/src/common/snippets/src/pass/vector_to_scalar.cpp b/src/common/snippets/src/pass/vector_to_scalar.cpp index b8de68eafd8258..4b1928d64f3f16 100644 --- a/src/common/snippets/src/pass/vector_to_scalar.cpp +++ b/src/common/snippets/src/pass/vector_to_scalar.cpp @@ -24,7 +24,8 @@ ngraph::snippets::pass::SetScalarCountForLoad::SetScalarCountForLoad() { if (!load) return false; - load->set_count(1lu); + auto& desc = load->get_input_port_descriptor(0); + desc.m_count = 1lu; return true; }); } @@ -43,7 +44,8 @@ ngraph::snippets::pass::SetScalarCountForStore::SetScalarCountForStore() { if (!store) return false; - store->set_count(1lu); + auto& desc = store->get_output_port_descriptor(0); + desc.m_count = 1lu; return true; }); } diff --git a/src/common/snippets/src/utils.cpp b/src/common/snippets/src/utils.cpp index d904317d6029f7..1cf9a0f6f6b7f3 100644 --- a/src/common/snippets/src/utils.cpp +++ b/src/common/snippets/src/utils.cpp @@ -106,6 +106,16 @@ ov::PartialShape get_port_planar_shape(const Output& out) { return get_reordered_planar_shape(tensor_shape, layout); } +void set_output_layout(const ov::Output& port, const std::shared_ptr& node) { + const auto& const_order = as_type_ptr(node->get_input_node_shared_ptr(1)); + set_output_layout(port, const_order->cast_vector()); +} + +void set_output_layout(const ov::Output& port, const std::vector& layout) { + auto& rt_info = port.get_node_shared_ptr()->get_rt_info(); + rt_info["Layout"] = layout; +} + } // namespace utils } // namespace snippets } // namespace ngraph diff --git a/src/common/snippets/tests/include/lowering_utils.hpp b/src/common/snippets/tests/include/lowering_utils.hpp index c629b1c13f59f6..85366ebc021bec 100644 --- a/src/common/snippets/tests/include/lowering_utils.hpp +++ b/src/common/snippets/tests/include/lowering_utils.hpp @@ -30,6 +30,9 @@ class DummyTargetMachine : public ngraph::snippets::TargetMachine { bool is_supported() const override { return true; } ngraph::snippets::code get_snippet() const override { return nullptr; } size_t get_lanes() const override { return 10; } + +protected: + opRegType get_specific_op_reg_type(const std::shared_ptr& op) const override { return vec2vec; }; }; class DummyGenerator : public ngraph::snippets::Generator { diff --git a/src/common/snippets/tests/src/pass/set_scalar_count_for_load_and_store.cpp b/src/common/snippets/tests/src/pass/set_scalar_count_for_load_and_store.cpp index 9305faa50119be..334a350cbb2daa 100644 --- a/src/common/snippets/tests/src/pass/set_scalar_count_for_load_and_store.cpp +++ b/src/common/snippets/tests/src/pass/set_scalar_count_for_load_and_store.cpp @@ -19,18 +19,20 @@ using namespace ngraph; // todo: Rewrite this test using Snippets test infrastructure. See ./include/canonicalization.hpp for example -template -size_t get_count(const std::shared_ptr& f, const std::string& name) { - size_t load_count = std::numeric_limits::max(); +size_t get_count(const std::shared_ptr& f, const std::string& name, bool is_load = true) { + size_t count = std::numeric_limits::max(); for (auto op : f->get_ops()) { if (op->get_friendly_name() == name) { - load_count = ov::as_type_ptr(op)->get_count(); + if (const auto memory_access = std::dynamic_pointer_cast(op)) { + count = is_load ? memory_access->get_input_port_descriptor(0).m_offset + : memory_access->get_output_port_descriptor(0).m_offset; + } } } - return load_count; + return count; } -TEST(TransformationTests, SetScalarCountForLoad) { +TEST(TransformationTests, SetScalarCountForLoadStore) { std::shared_ptr f(nullptr), f_ref(nullptr); const auto count = 16; { @@ -39,11 +41,13 @@ TEST(TransformationTests, SetScalarCountForLoad) { load->set_friendly_name("load"); auto neg = std::make_shared(load); auto store = std::make_shared(neg, count); + store->set_friendly_name("store"); f = std::make_shared(NodeVector{store}, ParameterVector{data}); pass::Manager m; m.register_pass(); m.register_pass(); + m.register_pass(); m.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); } @@ -52,39 +56,6 @@ TEST(TransformationTests, SetScalarCountForLoad) { auto load = std::make_shared(data, 1lu); load->set_friendly_name("load_ref"); auto neg = std::make_shared(load); - auto store = std::make_shared(neg, count); - f_ref = std::make_shared(NodeVector{store}, ParameterVector{data}); - } - - auto res = compare_functions(f, f_ref); - ASSERT_TRUE(res.first) << res.second; - - auto load_count = get_count(f, "load"); - auto load_count_ref = get_count(f_ref, "load_ref"); - ASSERT_EQ(load_count, load_count_ref); -} - -TEST(TransformationTests, SetScalarCountForStore) { - std::shared_ptr f(nullptr), f_ref(nullptr); - const auto count = 16; - { - auto data = std::make_shared(element::f32, Shape{2, 2}); - auto load = std::make_shared(data, count); - auto neg = std::make_shared(load); - auto store = std::make_shared(neg, count); - store->set_friendly_name("store"); - f = std::make_shared(NodeVector{store}, ParameterVector{data}); - - pass::Manager m; - m.register_pass(); - m.register_pass(); - m.run_passes(f); - ASSERT_NO_THROW(check_rt_info(f)); - } - { - auto data = std::make_shared(element::f32, Shape{2, 2}); - auto load = std::make_shared(data, count); - auto neg = std::make_shared(load); auto store = std::make_shared(neg, 1lu); store->set_friendly_name("store_ref"); f_ref = std::make_shared(NodeVector{store}, ParameterVector{data}); @@ -93,7 +64,11 @@ TEST(TransformationTests, SetScalarCountForStore) { auto res = compare_functions(f, f_ref); ASSERT_TRUE(res.first) << res.second; - int64_t store_count = get_count(f, "store"); - int64_t store_count_ref = get_count(f_ref, "store_ref"); + auto load_count = get_count(f, "load"); + auto load_count_ref = get_count(f_ref, "load_ref"); + ASSERT_EQ(load_count, load_count_ref); + + auto store_count = get_count(f, "store", false); + auto store_count_ref = get_count(f_ref, "store_ref", false); ASSERT_EQ(store_count, store_count_ref); -} \ No newline at end of file +} diff --git a/src/common/snippets/tests/src/registers.cpp b/src/common/snippets/tests/src/registers.cpp index 4b53f0e8092f67..37d92bf434bbd2 100644 --- a/src/common/snippets/tests/src/registers.cpp +++ b/src/common/snippets/tests/src/registers.cpp @@ -14,6 +14,7 @@ #include #include "common_test_utils/ngraph_test_utils.hpp" +#include "lowering_utils.hpp" using namespace testing; using namespace ngraph; @@ -21,6 +22,7 @@ using namespace ngraph; // todo: Rewrite this test using Snippets test infrastructure. See ./include/canonicalization.hpp for example TEST(TransformationTests, AssignRegisters) { + const auto generator = std::make_shared(); std::shared_ptr f(nullptr); { auto p0 = std::make_shared(element::f32, Shape(1)); @@ -38,7 +40,7 @@ TEST(TransformationTests, AssignRegisters) { pass::Manager m; m.register_pass(); - m.register_pass(); + m.register_pass(generator->get_target_machine()); m.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); } @@ -74,6 +76,7 @@ TEST(TransformationTests, AssignRegisters) { } TEST(TransformationTests, AssignRegisters2) { + const auto generator = std::make_shared(); std::shared_ptr f(nullptr); { auto p0 = std::make_shared(ngraph::element::f32, Shape()); @@ -127,7 +130,7 @@ TEST(TransformationTests, AssignRegisters2) { pass::Manager m; m.register_pass(); - m.register_pass(); + m.register_pass(generator->get_target_machine()); m.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); } diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp index fb3f12a9761386..66fb1f56537bf3 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp @@ -17,7 +17,8 @@ #include "snippets_transformations/op/load_convert.hpp" #include "snippets_transformations/op/store_convert.hpp" -#include "snippets/op/brgemm.hpp" +#include "snippets_transformations/op/brgemm_copy_b.hpp" +#include "snippets_transformations/op/brgemm_cpu.hpp" #include "ngraph_transformations/op/swish_cpu.hpp" #include @@ -136,7 +137,8 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_ jitters[ngraph::snippets::op::Kernel::get_type_info_static()] = CREATE_EMITTER(KernelEmitter); jitters[ngraph::snippets::op::LoopBegin::get_type_info_static()] = CREATE_EMITTER(LoopBeginEmitter); jitters[ngraph::snippets::op::LoopEnd::get_type_info_static()] = CREATE_EMITTER(LoopEndEmitter); - jitters[ngraph::snippets::op::Brgemm::get_type_info_static()] = CREATE_EMITTER(BrgemmEmitter); + jitters[ov::intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmEmitter); + jitters[ov::intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_EMITTER(BrgemmCopyBEmitter); } size_t ov::intel_cpu::CPUTargetMachine::get_lanes() const { @@ -159,5 +161,13 @@ code ov::intel_cpu::CPUTargetMachine::get_snippet() const { return h->jit_ker(); } +ngraph::snippets::TargetMachine::opRegType ov::intel_cpu::CPUTargetMachine::get_specific_op_reg_type(const std::shared_ptr& op) const { + if (std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return gpr2gpr; + else + return vec2vec; +} + ov::intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_) : Generator(std::make_shared(isa_)) { } diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp index 7301fcb177b93f..93d062ae41d595 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp @@ -20,6 +20,9 @@ class CPUTargetMachine : public ngraph::snippets::TargetMachine { ngraph::snippets::code get_snippet() const override; size_t get_lanes() const override; +protected: + opRegType get_specific_op_reg_type(const std::shared_ptr& op) const override; + private: std::unique_ptr h; dnnl::impl::cpu::x64::cpu_isa_t isa; diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp index 728c5de139be7c..60b99c23f98094 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp @@ -7,9 +7,10 @@ #include #include "jit_snippets_emitters.hpp" -#include "snippets/op/brgemm.hpp" #include "snippets/op/subgraph.hpp" #include "snippets/utils.hpp" +#include "snippets_transformations/op/brgemm_copy_b.hpp" +#include "snippets_transformations/op/brgemm_cpu.hpp" using namespace Xbyak; using ngraph::snippets::op::Subgraph; @@ -65,7 +66,8 @@ void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, // where all utility emitters align with conventional Op emitters if (std::dynamic_pointer_cast(emitter) || std::dynamic_pointer_cast(emitter) || - std::dynamic_pointer_cast(emitter)) + std::dynamic_pointer_cast(emitter) || + std::dynamic_pointer_cast(emitter)) in_physical_regs = std::move(map_regs(in_abstract_regs, gpr_map_pool)); else in_physical_regs = std::move(in_abstract_regs); @@ -182,7 +184,8 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: // todo: how this will be handled if Brgemm in & out are op::Buffer // Brgemm is a special case since it incorporates input and output (we use onednn kernel) // Just like Load & Store it requires offsets calculation - const auto is_brgemm = std::dynamic_pointer_cast(emitter) != nullptr; + const auto is_brgemm = std::dynamic_pointer_cast(emitter) || + std::dynamic_pointer_cast(emitter); return emitter_type == gpr_to_vec || emitter_type == vec_to_gpr || is_brgemm; }); // Note that we can't use reg_indexes_idx or reg_const_params_idx to store data pointers because these two @@ -545,8 +548,9 @@ StoreEmitter::StoreEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::c IE_THROW() << "StoreEmitter supports only equal input and output types but gets: " << src_prc.name() << " and " << dst_prc.name(); const auto store = ov::as_type_ptr(n); - count = store->get_count(); - byte_offset = store->get_offset(); + const auto desc = store->get_output_port_descriptor(0); + count = desc.m_count; + byte_offset = desc.m_offset; in_out_type_ = emitter_in_out_map::vec_to_gpr; store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count)); } @@ -586,11 +590,9 @@ LoadEmitter::LoadEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu IE_THROW() << "LoadEmitter supports only equal input and output types but gets: " << src_prc.name() << " and " << dst_prc.name(); const auto load = std::dynamic_pointer_cast(n); - if (!load) - IE_THROW() << "LoadEmitter expects Load snippets op"; - - count = load->get_count(); - byte_offset = load->get_offset(); + const auto desc = load->get_input_port_descriptor(0); + count = desc.m_count; + byte_offset = desc.m_offset; in_out_type_ = emitter_in_out_map::gpr_to_vec; load_emitter.reset(new jit_load_emitter(h, isa, src_prc, dst_prc, count)); } @@ -630,10 +632,8 @@ BroadcastLoadEmitter::BroadcastLoadEmitter(dnnl::impl::cpu::x64::jit_generator* IE_THROW() << "BroadcastEmitters support only equal input and output types but gets: " << src_prc.name() << " and " << dst_prc.name(); const auto broadcast_load = std::dynamic_pointer_cast(n); - if (!broadcast_load) - IE_THROW() << "BroadcastLoadEmitter expects BroadcastLoad snippets op"; - - byte_offset = broadcast_load->get_offset(); + const auto desc = broadcast_load->get_input_port_descriptor(0); + byte_offset = desc.m_offset; in_out_type_ = emitter_in_out_map::gpr_to_vec; } @@ -673,8 +673,9 @@ void BroadcastLoadEmitter::emit_isa(const std::vector &in, const std::ve LoadConvertEmitter::LoadConvertEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) : MemoryEmitter(h, isa, n) { const auto load = ov::as_type_ptr(n); - count = load->get_count(); - byte_offset = load->get_offset(); + const auto desc = load->get_input_port_descriptor(0); + count = desc.m_count; + byte_offset = desc.m_offset; in_out_type_ = emitter_in_out_map::gpr_to_vec; load_emitter.reset(new jit_load_emitter(h, isa, src_prc, dst_prc, count)); } @@ -709,8 +710,9 @@ void LoadConvertEmitter::emit_data() const { StoreConvertEmitter::StoreConvertEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) : MemoryEmitter(h, isa, n) { const auto store = ov::as_type_ptr(n); - count = store->get_count(); - byte_offset = store->get_offset(); + const auto desc = store->get_output_port_descriptor(0); + count = desc.m_count; + byte_offset = desc.m_offset; in_out_type_ = emitter_in_out_map::vec_to_gpr; if (ov::is_type(n)) { @@ -750,12 +752,15 @@ size_t BrgemmEmitter::getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const { return mIdx * 4 + kIdx * 2 + nIdx; } BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, - const std::shared_ptr& node) : jit_emitter(h, isa, node) { + const std::shared_ptr& node) : jit_emitter(h, isa, node) { in_out_type_ = emitter_in_out_map::gpr_to_gpr; - const auto& brgemm_node = as_type_ptr(node); + const auto& brgemm_node = as_type_ptr(node); if (brgemm_node->is_dynamic()) IE_THROW() << "Snippets don't support code generation for dynamic Brgemm"; - const OutputVector io_values {brgemm_node->input_value(0), brgemm_node->input_value(1), brgemm_node->output(0)}; + const auto brgemm_copy = brgemm_node->get_brgemm_copy(); + const OutputVector io_values {brgemm_node->input_value(0), + brgemm_copy ? brgemm_copy->input_value(0) : brgemm_node->input_value(1), + brgemm_node->output(0)}; std::vector leading_dimensions; std::vector> io_layouts; for (const auto& val : io_values) { @@ -799,13 +804,16 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: auto brg1Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(1)); io_data_size = {brg0Prc.size(), brg1Prc.size(), brgemm_node->get_output_element_type(0).size()}; brg0VnniFactor = 4 / brg0Prc.size(); - bool brg0WithAMX = isAMXSupported && brg0Prc != Precision::FP32 && (K % brg0VnniFactor == 0) && (N % brg0VnniFactor == 0); + bool brgWithAMX = isAMXSupported && brg0Prc != Precision::FP32 && (K % brg0VnniFactor == 0) && (N % brg0VnniFactor == 0); - N_blk = brg0Prc == Precision::FP32 ? N : - brg0Prc == Precision::BF16 ? 32 : 64; + with_scratch = brgemm_node->get_input_size() == 3; + with_comp = !brgWithAMX && brg0Prc == Precision::I8; + + N_blk = brg1Prc == Precision::FP32 ? N : + brg1Prc == Precision::BF16 ? 32 : 64; N_tail = N % N_blk; - K_blk = brg0WithAMX ? brg0Prc == Precision::BF16 ? 32 : 64 - : K; + K_blk = brgWithAMX ? brg0Prc == Precision::BF16 ? 32 : 64 + : K; K_tail = K % K_blk; size_t brg0BaseIdx = -1; @@ -824,7 +832,7 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: brgemmCtx.N = N_; brgemmCtx.K = K_; brgemmCtx.LDA = leading_dimensions[0]; - brgemmCtx.LDB = leading_dimensions[1]; + brgemmCtx.LDB = brg1Prc == Precision::FP32 ? leading_dimensions[1] : rnd_up(N, N_blk); brgemmCtx.LDC = leading_dimensions[2]; brgemmCtx.dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc)); brgemmCtx.dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg1Prc)); @@ -834,22 +842,25 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: if (M_ != 0 && K_ != 0 && N_ != 0) { if (brg0BaseIdx == -1) brg0BaseIdx = getBrgIdx(m, k, n); - initBrgemm(brgemmCtx, brgKernels0[getBrgIdx(m, k, n)], brg0WithAMX); + initBrgemm(brgemmCtx, brgKernels0[getBrgIdx(m, k, n)], brgWithAMX); } } } } - load_offset_a = brgemm_node->get_offset_a(); - load_offset_b = brgemm_node->get_offset_b(); - store_offset_c = brgemm_node->get_offset_c(); + load_offset_a = brgemm_node->get_input_port_descriptor(0).m_offset; + load_offset_b = brgemm_node->get_input_port_descriptor(1).m_offset; + load_offset_scratch = brgemm_node->get_input_port_descriptor(2).m_offset; + store_offset_c = brgemm_node->get_output_port_descriptor(0).m_offset; } void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) const { brgemm_t brgDesc; brgemm_strides_t strides {static_cast(ctx.M * ctx.K), static_cast(ctx.K * ctx.N)}; - // When implementing int8 support, note that isa logics is more complicated in the MHA node - auto status = brgemm_desc_init(&brgDesc, host_isa_, brgemm_strd, ctx.dt_in0, ctx.dt_in1, + const bool is_int8 = utils::one_of(ctx.dt_in0, data_type::u8, data_type::s8) && utils::one_of(ctx.dt_in1, data_type::u8, data_type::s8); + auto isa = use_amx ? isa_any + : ctx.dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : (is_int8 ? avx512_core_vnni : avx512_core); + auto status = brgemm_desc_init(&brgDesc, isa, brgemm_strd, ctx.dt_in0, ctx.dt_in1, false, false, brgemm_row_major, 1.f, ctx.beta, ctx.LDA, ctx.LDB, ctx.LDC, ctx.M, ctx.N, ctx.K, &strides); if (status != dnnl_success) IE_THROW() << "BrgemmEmitter cannot initialize brgemm descriptor due to invalid params"; @@ -873,20 +884,63 @@ void BrgemmEmitter::emit_impl(const std::vector& in, const std::vector& pool, const std::vector& gpr, const ov::intel_cpu::emitter_context *emit_context) const { - if (host_isa_ == cpu::x64::sse41 || host_isa_ == cpu::x64::avx2) { - IE_THROW() << "BrgemmEmitter requires at least avx512_core instruction set"; - } else if (host_isa_ == cpu::x64::avx512_core) { - emit_isa(in, out); + if (host_isa_ == cpu::x64::avx512_core) { + Reg64 input_0(static_cast(in[0])); + Reg64 input_1(static_cast(in[1])); + Reg64 input_2(static_cast(in[0])); // scratch. + if (with_scratch) { + if (in.size() != 3) { + IE_THROW() << "BRGEMM Emitter expects 3 inputs if there is compensations/wsp"; + } + input_2 = Reg64(static_cast(in[2])); + } + Reg64 output_0(static_cast(out[0])); + + for (size_t mb = 0; mb < div_up(M, M_blk); mb++) { + const bool is_M_tail = (M - mb * M_blk < M_blk); + + size_t brgIdx0 = getBrgIdx(0, 0, 0); + size_t K0_step0 = brgCtxs0[brgIdx0].K; + size_t K0_step1 = brgCtxs0[brgIdx0].K * brgCtxs0[brgIdx0].LDB; + size_t N0_step0 = brgCtxs0[brgIdx0].N * brg0VnniFactor; + size_t N0_step1 = brgCtxs0[brgIdx0].N; + for (size_t n = 0; n < 2; n++) { + for (size_t k = 0; k < 2; k++) { + size_t mIdx = is_M_tail ? 1 : 0; + auto& brgemmCtx = brgCtxs0[getBrgIdx(mIdx, k, n)]; + + if (brgemmCtx.K != 0 && brgemmCtx.N != 0) { + const size_t in0_offset = load_offset_a + (k * K0_step0 + mb * M_blk * brgemmCtx.LDA) * io_data_size[0]; + const size_t in1_offset = load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1]; + const size_t in2_offset = load_offset_scratch + (with_comp ? n * N0_step1 * sizeof(int32_t) : 0); + const size_t out0_offset = store_offset_c + (n * N0_step1 + mb * M_blk * brgemmCtx.LDC) * io_data_size[2]; + + emit_brgemm_kernel_call(brgKernels0[getBrgIdx(mIdx, k, n)].get(), + brgemmCtx, + input_0, + input_1, + input_2, + output_0, + in0_offset, + in1_offset, + in2_offset, + out0_offset); + } + } + } + } } else { - assert(!"unsupported isa"); + IE_THROW() << "BrgemmEmitter requires at least avx512_core instruction set"; } } -template -void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, int bs, - Reg64 addr_A, Reg64 addr_B, - const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch, - const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t out0_kernel_offset) const { - using Vmm = typename dnnl::impl::utils::conditional3::type; + +void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, const brgemmCtx& ctx, + Reg64 addr_A, Reg64 addr_B, Reg64 scratch, Reg64 addr_C, + const size_t in0_kernel_offset, const size_t in1_kernel_offset, + const size_t in2_kernel_offset, const size_t out0_kernel_offset) const { + if (ctx.is_with_amx) + amx_tile_configure(ctx.palette); + size_t gpr_size = 8; Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; @@ -898,14 +952,12 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in // caller obligation to save k-regs as callee may use them size_t n_k_regs_to_save = 8; - if (isa == cpu::x64::avx512_core) { - h->sub(h->rsp, n_k_regs_to_save * k_mask_size); - for (size_t i = 0; i < n_k_regs_to_save; ++i) { - if (mayiuse(avx512_core)) - h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); - else - h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); - } + h->sub(h->rsp, n_k_regs_to_save * k_mask_size); + for (size_t i = 0; i < n_k_regs_to_save; ++i) { + if (mayiuse(avx512_core)) + h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); + else + h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); } // 1. Caller obligation to save vector registers as callee may use them. @@ -915,13 +967,16 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in // `host_isa::vecs_count`. h->sub(h->rsp, get_max_vecs_count() * get_vec_length()); for (size_t i = 0; i < get_max_vecs_count(); ++i) - h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Vmm(i)); + h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Zmm(i)); + size_t num_args_passed_on_stack = 0; // save function address in gpr to pass in call instruction const auto& brgemm_kernel_overload = static_cast(kernel_execute); + void*, + void*, + int)>(kernel_execute); h->mov(h->rbp, reinterpret_cast(brgemm_kernel_overload)); // todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted // if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption. @@ -929,16 +984,32 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in h->uni_vmovq(Xmm(0), addr_A); h->uni_vmovq(Xmm(1), addr_B); h->uni_vmovq(Xmm(2), addr_C); - + h->uni_vmovq(Xmm(3), scratch); + // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. const auto data_ptr_reg = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) { h->uni_vmovq(reg, xmm); if (bytes_offset) h->add(reg, bytes_offset); }; - h->mov(abi_param1, reinterpret_cast(brgKernel)); + h->mov(abi_param1, reinterpret_cast(brg_kernel)); data_ptr_reg(Xmm(0), abi_param2, in0_kernel_offset); data_ptr_reg(Xmm(1), abi_param3, in1_kernel_offset); data_ptr_reg(Xmm(2), abi_param4, out0_kernel_offset); +#ifdef _WIN32 + const auto data_ptr_stack = [&](Xmm xmm, size_t idx, size_t bytes_offset) { + h->uni_vmovq(h->qword[h->rsp + idx * gpr_size], xmm); + if (bytes_offset) h->add(h->qword[h->rsp + idx * gpr_size], bytes_offset); + }; + + num_args_passed_on_stack = 4; + h->sub(h->rsp, num_args_passed_on_stack * gpr_size); + data_ptr_stack(Xmm(3), 0, in2_kernel_offset); + h->mov(h->ptr[h->rsp + 1 * gpr_size], static_cast(with_comp)); +#else + data_ptr_reg(Xmm(3), abi_param5, in2_kernel_offset); + h->mov(abi_param6, static_cast(with_comp)); +#endif + // align stack on 16-byte as ABI requires // note that RBX must not be changed by the callee h->mov(h->rbx, h->rsp); @@ -948,22 +1019,21 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in h->call(h->rbp); h->add(h->rsp, h->rbx); + h->add(h->rsp, num_args_passed_on_stack * gpr_size); // restore vector registers for (int i = static_cast(get_max_vecs_count()) - 1; i >= 0; --i) { - h->uni_vmovups(Vmm(i), h->ptr[h->rsp + i * get_vec_length()]); + h->uni_vmovups(Zmm(i), h->ptr[h->rsp + i * get_vec_length()]); } h->add(h->rsp, (get_max_vecs_count()) * get_vec_length()); // restore k registers - if (isa == cpu::x64::avx512_core) { - for (int i = n_k_regs_to_save - 1; i >= 0; --i) { - if (mayiuse(avx512_core)) - h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); - else - h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); - } - h->add(h->rsp, n_k_regs_to_save * k_mask_size); + for (int i = n_k_regs_to_save - 1; i >= 0; --i) { + if (mayiuse(avx512_core)) + h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + else + h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); } + h->add(h->rsp, n_k_regs_to_save * k_mask_size); // restore gpr registers for (int i = n_gprs_to_save - 1; i >= 0; --i) @@ -971,7 +1041,8 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in h->add(h->rsp, n_gprs_to_save * gpr_size); } -void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C) { +void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel, + const void *A, const void *B, void *C, void *scratch, int with_comp) { // TODO: There are 4 available abi_params on Windows so we have the copy of brgemm_kernel_execute() function // with 4 runtime parameters (kernel and I/O) and 4 default parameter values (batch, bs and scratch) brgemm_kernel_params_t brgemm_p; @@ -981,57 +1052,291 @@ void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel, const void brgemm_p.ptr_B = B; brgemm_p.ptr_C = C; brgemm_p.ptr_D = C; - brgemm_p.ptr_buf = nullptr; // default value + brgemm_p.ptr_buf = scratch; brgemm_p.ptr_bias = nullptr; - brgemm_p.do_post_ops = 0; - brgemm_p.do_apply_comp = 0; + brgemm_p.do_post_ops = with_comp; + brgemm_p.do_apply_comp = with_comp; brgemm_p.skip_accm = 0; brgemm_p.BS = 1; // default value assert(brg_kernel); (*brg_kernel)(&brgemm_p); } -template -void BrgemmEmitter::emit_isa(const std::vector &in, const std::vector &out) const { - using Vmm = typename dnnl::impl::utils::conditional3::type; - Reg64 input_0(static_cast(in[0])); - Reg64 input_1(static_cast(in[1])); - Reg64 output_0(static_cast(out[0])); - - for (size_t mb = 0; mb < div_up(M, M_blk); mb++) { - const bool is_M_tail = (M - mb * M_blk < M_blk); - - size_t brgIdx0 = getBrgIdx(0, 0, 0); - size_t K0_step0 = brgCtxs0[brgIdx0].K; - size_t K0_step1 = brgCtxs0[brgIdx0].K * brgCtxs0[brgIdx0].LDB; - size_t N0_step0 = brgCtxs0[brgIdx0].N * brg0VnniFactor; - size_t N0_step1 = brgCtxs0[brgIdx0].N; - for (size_t n = 0; n < 2; n++) { - for (size_t k = 0; k < 2; k++) { - size_t mIdx = is_M_tail ? 1 : 0; - auto& brgemmCtx = brgCtxs0[getBrgIdx(mIdx, k, n)]; - - if (brgemmCtx.K != 0 && brgemmCtx.N != 0) { - const size_t in0_offset = load_offset_a + (k * K0_step0 + mb * M_blk * brgemmCtx.LDA) * io_data_size[0]; - const size_t in1_offset = load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1]; - const size_t out0_offset = store_offset_c + (n * N0_step1 + mb * M_blk * brgemmCtx.LDC) * io_data_size[2]; - - emit_brgemm_kernel_call(brgKernels0[getBrgIdx(mIdx, k, n)].get(), - 1, - input_0, - input_1, - nullptr, - output_0, - nullptr, - in0_offset, - in1_offset, - out0_offset); - } +BrgemmCopyBEmitter::BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) + : jit_emitter(h, isa, n) { + in_out_type_ = emitter_in_out_map::gpr_to_gpr; + const auto brgemm_repack = ov::as_type_ptr(n); + if (!brgemm_repack) + IE_THROW() << "BrgemmCopyBEmitters expects BrgemmCopyB node"; + + brgemm_prc_in0 = brgemm_repack->get_src_element_type(); + brgemm_prc_in1 = brgemm_repack->get_input_element_type(0); + brgemmVNNIFactor = 4 / brgemm_prc_in0.size(); + with_comp = brgemm_repack->is_with_comp(); + in_offset = brgemm_repack->get_input_port_descriptor(0).m_offset; + out_offset = brgemm_repack->get_output_port_descriptor(0).m_offset; + if (with_comp) + comp_offset = brgemm_repack->get_output_port_descriptor(1).m_offset; + + auto layout = ngraph::snippets::utils::get_node_output_layout(brgemm_repack->get_input_node_shared_ptr(0)); + const auto& original_shape = brgemm_repack->get_input_shape(0); + auto transposed_shape = original_shape; + size_t leading_dimension = *(original_shape.rbegin()); + if (!layout.empty()) { + transposed_shape.resize(layout.size(), 1); + for (size_t i = 0; i < layout.size(); ++i) { + transposed_shape[i] = original_shape[layout[i]]; + } + // The idea here is to find "2" (for 4D shapes) in the layout and multiply dimensions that are to the right + // This implies that "3" is the last layout value, otherwise this layout is not supported. + // counting from the end since shape could be prepended with ones + const int64_t num_last_dims = layout.end() - std::find(layout.begin(), layout.end(), layout.size() - 2) - 1; + if (layout.back() != layout.size() - 1 || num_last_dims < 1) + IE_THROW() << "BrgemmRepackEmitter detected invalid layout values: " << + "check that this shape + layout combination is schedulable"; + leading_dimension = std::accumulate(original_shape.end() - num_last_dims, original_shape.end(), 1, std::multiplies()); + } + + N = *(transposed_shape.rbegin()); + K = *(transposed_shape.rbegin() + 1); + + const bool isAMXSupported = mayiuse(avx512_core_bf16_amx_int8) || mayiuse(avx512_core_bf16_amx_bf16); + const auto use_amx = isAMXSupported && brgemm_prc_in0 != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0); + + N_blk = brgemm_prc_in1 == ov::element::f32 ? N : + brgemm_prc_in1 == ov::element::bf16 ? 32 : 64; + K_blk = use_amx ? brgemm_prc_in0 == ov::element::bf16 ? 32 : 64 + : K; + N_tail = N % N_blk; + K_tail = K % K_blk; + LDB = brgemm_prc_in1 == ov::element::f32 ? leading_dimension : rnd_up(N, N_blk); + + const auto dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(brgemm_prc_in0))); + const auto dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(brgemm_prc_in1))); + init_brgemm_copy(kernel, leading_dimension, N_blk, N_tail, LDB, (K_tail == 0 ? K : K_tail), use_amx, dt_in0, dt_in1); +} + +void BrgemmCopyBEmitter::init_brgemm_copy(std::unique_ptr& kernel, + size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, + bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const { + matmul::brgemm_matmul_conf_t brgCopyKernelConf; + brgCopyKernelConf.src_dt = dt_in0; + brgCopyKernelConf.wei_dt = dt_in1; + brgCopyKernelConf.wei_n_blk = N_blk; + brgCopyKernelConf.wei_tag = dnnl_abcd; // What's about other ranks? + brgCopyKernelConf.copy_B_wei_stride = 0; + brgCopyKernelConf.LDB = LDB; + brgCopyKernelConf.N = N; + brgCopyKernelConf.N_tail = N_tail; + brgCopyKernelConf.N_blk = N_blk; + brgCopyKernelConf.K = K; + brgCopyKernelConf.K_blk = K; + brgCopyKernelConf.N_chunk_elems = brgCopyKernelConf.N_blk; + brgCopyKernelConf.b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); + brgCopyKernelConf.tr_b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); + brgCopyKernelConf.req_wei_vnni_downconvert = false; + + if (is_with_amx) { + brgCopyKernelConf.isa = dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16_amx_bf16 : avx512_core_bf16_amx_int8; + brgCopyKernelConf.s8s8_compensation_required = false; + } else { + brgCopyKernelConf.isa = dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : avx512_core_vnni; + brgCopyKernelConf.s8s8_compensation_required = dt_in0 == dnnl_data_type_t::dnnl_s8; + } + + brgCopyKernelConf.has_zero_point_a = false; + brgCopyKernelConf.has_zero_point_b = false; + brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::x64::none; + + auto status = matmul::create_brgemm_matmul_copy_b(kernel, &brgCopyKernelConf); + if (status != dnnl_success) + IE_THROW() << "BrgemmRepackEmitter cannot create kernel due to invalid params"; +} + +void BrgemmCopyBEmitter::emit_impl(const std::vector& in, + const std::vector& out, + const std::vector& pool, + const std::vector& gpr, + const ov::intel_cpu::emitter_context *emit_context) const { + if (host_isa_ == cpu::x64::avx512_core) { + Reg64 src(static_cast(in[0])); + Reg64 dst(static_cast(out[0])); + Reg64 comp(static_cast(0)); + if (with_comp) { + if (out.size() != 2) { + IE_THROW() << "BrgemmCopyBEmitter with compensations requires separate register for them"; } + comp = Reg64(static_cast(out[1])); + } + + const size_t data_size = brgemm_prc_in1.size(); + for (size_t nb = 0; nb < div_up(N, N_blk); nb++) { + const size_t offset_in = in_offset + nb * N_blk * data_size; + const size_t offset_out = out_offset + nb * N_blk * brgemmVNNIFactor * data_size; + const size_t offset_comp = with_comp ? comp_offset + nb * N_blk * sizeof(int32_t) : 0; + + const bool is_N_tail = (N - nb * N_blk < N_blk); + const auto current_N_blk = is_N_tail ? N_tail : N_blk; + + emit_kernel_call(kernel.get(), src, dst, comp, current_N_blk, K, offset_in, offset_out, offset_comp); } + } else { + IE_THROW() << "BrgemmCopyBEmitter requires at least avx512_core instruction set"; } } +void BrgemmCopyBEmitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, Reg64 src, Reg64 dst, Reg64 comp, + size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const { + size_t gpr_size = 8; + Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax, + h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; + size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); + + h->sub(h->rsp, n_gprs_to_save * gpr_size); + for (size_t i = 0; i < n_gprs_to_save; ++i) + h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]); + + // caller obligation to save k-regs as callee may use them + size_t n_k_regs_to_save = 8; + h->sub(h->rsp, n_k_regs_to_save * k_mask_size); + for (size_t i = 0; i < n_k_regs_to_save; ++i) { + if (mayiuse(avx512_core)) + h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); + else + h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); + } + + // 1. Caller obligation to save vector registers as callee may use them. + // 2. There is an implicit assumption that the host code uses the same + // `isa` as the injector. Once the assumption is wrong, `vecs_count` and + // `vlen` should be replaced with `host_isa::vlen` and + // `host_isa::vecs_count`. + h->sub(h->rsp, get_max_vecs_count() * get_vec_length()); + for (size_t i = 0; i < get_max_vecs_count(); ++i) + h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Zmm(i)); + + size_t num_args_passed_on_stack = 0; + if (with_comp) { + // save function address in gpr to pass in call instruction + const auto &kernel_overload = static_cast(execute_with_comp); + h->mov(h->rbp, reinterpret_cast(kernel_overload)); + // todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted + // if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption. + // It's likely that a more efficient solution exists. + h->uni_vmovq(Xmm(0), src); + h->uni_vmovq(Xmm(1), dst); + h->uni_vmovq(Xmm(2), comp); + // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. + h->mov(abi_param1, reinterpret_cast(kernel)); + + const auto data_ptr = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) { + h->uni_vmovq(reg, xmm); + if (bytes_offset) h->add(reg, bytes_offset); + }; + data_ptr(Xmm(0), abi_param2, offset_in); + data_ptr(Xmm(1), abi_param3, offset_out); + data_ptr(Xmm(2), abi_param4, offset_comp); + +#ifdef _WIN32 + num_args_passed_on_stack = 2; + h->sub(h->rsp, gpr_size * num_args_passed_on_stack); + h->mov(h->qword[h->rsp], reinterpret_cast(K)); + h->mov(h->qword[h->rsp + gpr_size], reinterpret_cast(N)); +#else + h->mov(abi_param5, N); + h->mov(abi_param6, K); +#endif + } else { + // save function address in gpr to pass in call instruction + const auto &kernel_overload = static_cast(execute); + h->mov(h->rbp, reinterpret_cast(kernel_overload)); + // todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted + // if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption. + // It's likely that a more efficient solution exists. + h->uni_vmovq(Xmm(0), src); + h->uni_vmovq(Xmm(1), dst); + // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. + h->mov(abi_param1, reinterpret_cast(kernel)); + + const auto data_ptr = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) { + h->uni_vmovq(reg, xmm); + if (bytes_offset) h->add(reg, bytes_offset); + }; + data_ptr(Xmm(0), abi_param2, offset_in); + data_ptr(Xmm(1), abi_param3, offset_out); + + h->mov(abi_param4, N); + +#ifdef _WIN32 + num_args_passed_on_stack = 1; + h->sub(h->rsp, gpr_size * num_args_passed_on_stack); + h->mov(h->qword[h->rsp], reinterpret_cast(K)); +#else + h->mov(abi_param5, K); +#endif + } + // align stack on 16-byte as ABI requires + // note that RBX must not be changed by the callee + h->mov(h->rbx, h->rsp); + h->and_(h->rbx, 0xf); + h->sub(h->rsp, h->rbx); + + h->call(h->rbp); + + h->add(h->rsp, h->rbx); + h->add(h->rsp, gpr_size * num_args_passed_on_stack); + // restore vector registers + for (int i = static_cast(get_max_vecs_count()) - 1; i >= 0; --i) { + h->uni_vmovups(Zmm(i), h->ptr[h->rsp + i * get_vec_length()]); + } + h->add(h->rsp, (get_max_vecs_count()) * get_vec_length()); + + // restore k registers + for (int i = n_k_regs_to_save - 1; i >= 0; --i) { + if (mayiuse(avx512_core)) + h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + else + h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + } + h->add(h->rsp, n_k_regs_to_save * k_mask_size); + + // restore gpr registers + for (int i = n_gprs_to_save - 1; i >= 0; --i) + h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); + h->add(h->rsp, n_gprs_to_save * gpr_size); +} + +void BrgemmCopyBEmitter::execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, size_t N, size_t K) { + execute_with_comp(kernel, src, dst, nullptr, N, K); +} + +void BrgemmCopyBEmitter::execute_with_comp(matmul::jit_brgemm_matmul_copy_b_t *kernel, const void *src, + const void *dst, const void *comp, size_t N, size_t K) { + if (!kernel) + IE_THROW() << "Kernel for `brgemm_copy_b` hasn't been created"; + + auto ctx = dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); + ctx.current_N_blk = N; + ctx.src = src; + ctx.tr_src = dst; + ctx.compensation_ptr = comp; + ctx.zp_a_compensation_ptr = nullptr; + ctx.zp_a_neg_value_ptr = nullptr; + ctx.current_K_start = 0; + ctx.current_K_iters = K; + + (*kernel)(&ctx); +} + HorizonMaxEmitter::HorizonMaxEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) : jit_emitter(h, isa, n, Precision::FP32, emitter_in_out_map::vec_to_vec) {} diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp index 51b2d2d7840cfb..fc8ae80d8b2379 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp @@ -375,8 +375,6 @@ class BrgemmEmitter : public jit_emitter { const std::vector& gpr, const ov::intel_cpu::emitter_context *emit_context) const override; - template - void emit_isa(const std::vector &in, const std::vector &out) const; std::vector io_data_size {}; struct brgemmCtx { size_t M, N, K, LDA, LDB, LDC; @@ -387,15 +385,15 @@ class BrgemmEmitter : public jit_emitter { float beta; }; void initBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) const; - template void callBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, const void* pin0, const void* pin1, void* pout, void* wsp) const; size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const; - template - void emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, int bs, - Reg64 addr_A, Reg64 addr_B, - const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch, - const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t out0_kernel_offset) const; - static void kernel_execute(const brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C); + + void emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, const brgemmCtx& ctx, + Reg64 addr_A, Reg64 addr_B, Reg64 scratch, Reg64 addr_C, + const size_t in0_kernel_offset, const size_t in1_kernel_offset, + const size_t in2_kernel_offset, const size_t out0_kernel_offset) const; + static void kernel_execute(const brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C, void *scratch, int with_comp); + static constexpr size_t BRGEMM_KERNELS_NUM = 8; static constexpr size_t matmulOptimalM = 32; brgemmCtx brgCtxs0[BRGEMM_KERNELS_NUM]; @@ -406,11 +404,51 @@ class BrgemmEmitter : public jit_emitter { size_t N, N_blk, N_tail; size_t brg0VnniFactor; + bool with_scratch = false; + bool with_comp = false; + size_t load_offset_a = 0lu; size_t load_offset_b = 0lu; + size_t load_offset_scratch = 0lu; size_t store_offset_c = 0lu; }; +class BrgemmCopyBEmitter : public jit_emitter { +public: + BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); + + size_t get_inputs_num() const override {return 2;} + +private: + void emit_impl(const std::vector& in, + const std::vector& out, + const std::vector& pool, + const std::vector& gpr, + const ov::intel_cpu::emitter_context *emit_context) const override; + + void init_brgemm_copy(std::unique_ptr& kernel, + size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, + bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const; + void emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, Reg64 src, Reg64 dst, Reg64 comp, + size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const; + + static void execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, size_t N, size_t K); + static void execute_with_comp(matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, const void* comp, size_t N, size_t K); + + std::unique_ptr kernel; + + ov::element::Type brgemm_prc_in0, brgemm_prc_in1; + size_t N, N_blk, N_tail; + size_t K, K_blk, K_tail; + size_t LDB; + size_t brgemmVNNIFactor; + bool with_comp = false; + + size_t in_offset = 0lu; + size_t out_offset = 0lu; + size_t comp_offset = 0lu; +}; + class HorizonMaxEmitter : public jit_emitter { public: HorizonMaxEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index 5b09112b041ec1..f1a0d68c8dbaaa 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -11,6 +11,8 @@ #include "ngraph_transformations/op/mha.hpp" #include "snippets_transformations/op/load_convert.hpp" #include "snippets_transformations/op/store_convert.hpp" +#include "snippets_transformations/op/brgemm_cpu.hpp" +#include "snippets_transformations/op/brgemm_copy_b.hpp" #include #include @@ -54,6 +56,8 @@ std::map Extension::getOpSets() { NGRAPH_OP(LoadConvertTruncation, ov::intel_cpu) NGRAPH_OP(StoreConvertSaturation, ov::intel_cpu) NGRAPH_OP(StoreConvertTruncation, ov::intel_cpu) + NGRAPH_OP(BrgemmCPU, ov::intel_cpu) + NGRAPH_OP(BrgemmCopyB, ov::intel_cpu) #undef NGRAPH_OP return opset; diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 6d4eabfb64df83..98e2462ac22c33 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -24,6 +24,7 @@ #include "emitters/cpu_generator.hpp" #include "utils/cpu_utils.hpp" #include "snippets_transformations/fuse_load_store_and_convert.hpp" +#include "snippets_transformations/brgemm_to_brgemm_cpu.hpp" #include "ngraph_transformations/convert_to_swish_cpu.hpp" using namespace InferenceEngine; @@ -502,6 +503,7 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) { ov::pass::Manager optManager; optManager.register_pass(); optManager.register_pass(); + optManager.register_pass(); optManager.register_pass(); // LoadConvert uses Load emitter that support conversion from any type to only f32 @@ -519,6 +521,7 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) { return convert->get_input_element_type(0) != ov::element::f32; return true; }); + schedule = snippet->generate(optManager, reinterpret_cast(jcp)); } diff --git a/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp new file mode 100644 index 00000000000000..f7dc571fce82e8 --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp @@ -0,0 +1,98 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" + +#include "brgemm_to_brgemm_cpu.hpp" +#include "snippets/snippets_isa.hpp" +#include "snippets/utils.hpp" +#include "op/brgemm_copy_b.hpp" +#include "op/brgemm_cpu.hpp" + +#include "ngraph/rt_info.hpp" +#include "ngraph/pattern/op/wrap_type.hpp" + +#include + +#include "cpu_shape.h" +#include "utils/general_utils.h" + + +namespace ov { +namespace intel_cpu { + +pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { + MATCHER_SCOPE(BrgemmToBrgemmCPU); + + auto m_brgemm = ngraph::pattern::wrap_type(); + + auto callback = [=](ngraph::pattern::Matcher& m) { + OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::BrgemmToBrgemmCPU") + auto& pm = m.get_pattern_value_map(); + const auto brgemm = ov::as_type_ptr(pm.at(m_brgemm).get_node_shared_ptr()); + if (!brgemm) + return false; + + if (brgemm->get_input_partial_shape(0).is_dynamic() || brgemm->get_input_partial_shape(1).is_dynamic()) { + return false; + } + + const auto dimsMatMulIn0 = ngraph::snippets::utils::get_port_planar_shape(brgemm->input_value(0)).get_shape(); + const auto dimsMatMulIn1 = ngraph::snippets::utils::get_port_planar_shape(brgemm->input_value(1)).get_shape(); + + const auto K = *dimsMatMulIn0.rbegin(); + const auto N = *dimsMatMulIn1.rbegin(); + + const auto element_type_a = brgemm->get_input_element_type(0); + const auto element_type_b = brgemm->get_input_element_type(1); + const auto brgemmVNNIFactor = 4 / element_type_a.size(); + const bool isAMXSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16_amx_int8) || + dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16_amx_bf16); + const bool with_amx = isAMXSupported && element_type_a != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0); + // TODO: Only i8? What's about u8? + const bool with_comp = element_type_a == ov::element::i8 && !with_amx; + + const auto offset_a = brgemm->get_input_port_descriptor(0).m_offset; + const auto offset_b = brgemm->get_input_port_descriptor(1).m_offset; + const auto offset_c = brgemm->get_output_port_descriptor(0).m_offset; + + std::shared_ptr brgemm_cpu = nullptr; + if (one_of(element_type_a, ov::element::f32)) { + brgemm_cpu = std::make_shared(brgemm->input_value(0), brgemm->input_value(1), + brgemm->transposed_a(), brgemm->transposed_b(), with_comp, + offset_a, offset_b, offset_c); + } else { + const auto layoutIn1 = ngraph::snippets::utils::get_node_output_layout(brgemm->input_value(1).get_node_shared_ptr()); + const auto brgemmRepackIn1 = std::make_shared(brgemm->input_value(1), element_type_a, with_comp, offset_b); + const auto buffer = std::make_shared(brgemmRepackIn1->output(0)); + const auto scratch = with_amx ? std::make_shared(ov::Shape{4 * 1024}, ov::element::i32) : + with_comp ? std::make_shared(brgemmRepackIn1->output(1)) : + nullptr; + + if (with_amx || with_comp) { + brgemm_cpu = std::make_shared(brgemm->input_value(0), buffer, scratch, + brgemm->transposed_a(), brgemm->transposed_b(), with_comp, + offset_a, offset_b, offset_c); + } else if (one_of(element_type_a, ov::element::u8, ov::element::bf16)) { + brgemm_cpu = std::make_shared(brgemm->input_value(0), buffer, + brgemm->transposed_a(), brgemm->transposed_b(), with_comp, + offset_a, offset_b, offset_c); + } else { + IE_THROW() << "Invalid configuration for BRGEMM CPU"; + } + } + + brgemm_cpu->set_friendly_name(brgemm->get_friendly_name()); + ngraph::snippets::utils::set_output_layout(brgemm_cpu->output(0), ngraph::snippets::utils::get_node_output_layout(brgemm)); + ngraph::copy_runtime_info(brgemm, brgemm_cpu); + ngraph::replace_node(brgemm, brgemm_cpu); + + return true; + }; + + auto m = std::make_shared(m_brgemm, matcher_name); + register_matcher(m, callback); +} +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp new file mode 100644 index 00000000000000..bafaeca58fdbac --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/pass/graph_rewrite.hpp" +#include "ngraph/pattern/matcher.hpp" + +namespace ov { +namespace intel_cpu { +namespace pass { + +/** + * @interface BrgemmToBrgemmCPU + * @brief TODO + * @ingroup snippets + */ +class BrgemmToBrgemmCPU: public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("BrgemmToBrgemmCPU", "0"); + BrgemmToBrgemmCPU(); +}; + + +} // namespace pass +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp b/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp index 021b3f6c1293ec..7246eb6f0f5c3b 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp @@ -10,7 +10,6 @@ #include "snippets_transformations/op/load_convert.hpp" #include "snippets_transformations/op/store_convert.hpp" -#include "ngraph/opsets/opset1.hpp" #include "ngraph/rt_info.hpp" #include "ngraph/pattern/op/wrap_type.hpp" @@ -32,6 +31,7 @@ ov::intel_cpu::pass::FuseLoadConvert::FuseLoadConvert() { const auto load = std::dynamic_pointer_cast(load_shared); if (!load) return false; + const auto desc = load->get_input_port_descriptor(0); const auto convert = pm.at(convert_pattern).get_node_shared_ptr(); if (transformation_callback(convert)) @@ -42,12 +42,12 @@ ov::intel_cpu::pass::FuseLoadConvert::FuseLoadConvert() { std::dynamic_pointer_cast(convert)) { load_convert = std::make_shared(param, convert_saturation->get_destination_type(), - load->get_count(), load->get_offset()); + desc.m_count, desc.m_offset); } else if (const auto convert_truncation = std::dynamic_pointer_cast(convert)) { load_convert = std::make_shared(param, convert_truncation->get_destination_type(), - load->get_count(), load->get_offset()); + desc.m_count, desc.m_offset); } else { throw ngraph::ngraph_error( "Type of Convert op is undefined. Supports only fusing Load and ConvertTruncation or ConvertSaturation ops"); @@ -81,6 +81,7 @@ ov::intel_cpu::pass::FuseStoreConvert::FuseStoreConvert() { const auto store = std::dynamic_pointer_cast(pm.at(store_pattern).get_node_shared_ptr()); if (!store) return false; + const auto desc = store->get_output_port_descriptor(0); const auto convert = pm.at(convert_pattern).get_node_shared_ptr(); if (convert->output(0).get_target_inputs().size() != 1 || transformation_callback(convert)) @@ -91,18 +92,17 @@ ov::intel_cpu::pass::FuseStoreConvert::FuseStoreConvert() { std::dynamic_pointer_cast(convert)) { store_convert = std::make_shared(input, convert_saturation->get_destination_type(), - store->get_count(), store->get_offset()); + desc.m_count, desc.m_offset); } else if (const auto convert_truncation = std::dynamic_pointer_cast(convert)) { store_convert = std::make_shared(input, convert_truncation->get_destination_type(), - store->get_count(), store->get_offset()); + desc.m_count, desc.m_offset); } else { throw ngraph::ngraph_error( "Type of Convert op is undefined. Supports only fusing Store and ConvertTruncation or ConvertSaturation ops"); } - if (!store_convert) return false; diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp new file mode 100644 index 00000000000000..b300b7d47eda21 --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp @@ -0,0 +1,73 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" +#include "snippets/utils.hpp" + +#include "brgemm_copy_b.hpp" + +#include "utils/general_utils.h" + +using namespace std; +using namespace ov; + +intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output& x, const element::Type src_type, const bool with_comp, + const size_t offset_in, const size_t offset_out0, const size_t offset_out1) + : ngraph::snippets::op::MemoryAccess({x}), m_with_comp(with_comp), m_src_type(src_type) { + set_input_port_descriptor({0, offset_in}, 0); + set_output_port_descriptor({0, offset_out0}, 0); + if (with_comp) { + set_output_port_descriptor({0, offset_out1}, 1); + set_output_size(2); + } else { + set_output_size(1); + } + constructor_validate_and_infer_types(); +} + +bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) { + INTERNAL_OP_SCOPE(BrgemmRepack_visit_attributes); + MemoryAccess::visit_attributes(visitor); + visitor.on_attribute("with_comp", m_with_comp); + visitor.on_attribute("src_type", m_src_type); + return true; +} + +void intel_cpu::BrgemmCopyB::validate_and_infer_types() { + INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types); + + const auto element_type = get_input_element_type(0); + NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8), + "BrgemmCopyB doesn't support element type" + element_type.get_type_name()); + + const auto pshape = ngraph::snippets::utils::get_port_planar_shape(input_value(0)); + if (pshape.is_dynamic()) { + set_output_type(0, element_type, ov::PartialShape{ov::Dimension::dynamic()}); + if (m_with_comp) { + set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension::dynamic()}); + } + + return; + } + + const auto shape = pshape.get_shape(); + const auto N = *shape.rbegin(); + const auto K = *(shape.rbegin() + 1); + const auto N_blk = element_type == element::bf16 ? 32 : 64; + const auto brgemmVNNIFactor = 4 / m_src_type.size(); + + set_output_type(0, element_type, ov::PartialShape{rnd_up(K, brgemmVNNIFactor), rnd_up(N, N_blk)}); + if (m_with_comp) { + set_output_type(1, ov::element::f32, ov::PartialShape{rnd_up(N, N_blk)}); + } +} + +std::shared_ptr intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const { + INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs); + check_new_args_count(this, new_args); + return std::make_shared(new_args.at(0), m_src_type, m_with_comp, + get_input_port_descriptor(0).m_offset, + get_output_port_descriptor(0).m_offset, + m_with_comp ? get_output_port_descriptor(1).m_offset : 0); +} diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp new file mode 100644 index 00000000000000..30b592552de096 --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "snippets/op/memory_access.hpp" + +namespace ov { +namespace intel_cpu { + +/** +* @interface BrgemmCopyB +* @brief TODO +* @ingroup snippets +*/ +class BrgemmCopyB : public ngraph::snippets::op::MemoryAccess { +public: + OPENVINO_OP("BrgemmCopyB", "SnippetsOpset", MemoryAccess); + BrgemmCopyB(const Output& x, const element::Type src_type, const bool with_comp = false, + const size_t offset_in = 0lu, const size_t offset_out0 = 0lu, const size_t offset_out1 = 0lu); + BrgemmCopyB() = default; + + element::Type get_src_element_type() const { return m_src_type; } + bool is_with_comp() const { return m_with_comp; } + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + bool has_evaluate() const override { return false; } + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + +private: + bool m_with_comp = false; + element::Type m_src_type; // src element type of the corresponding BRGEMM +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp new file mode 100644 index 00000000000000..918a66f58c7bbe --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp @@ -0,0 +1,123 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" +#include "brgemm_cpu.hpp" +#include "ngraph/runtime/host_tensor.hpp" +#include "openvino/core/rt_info.hpp" +#include "snippets/utils.hpp" +#include "matmul_shape_inference.hpp" +#include "utils/general_utils.h" + + +namespace ov { +namespace intel_cpu { + +BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, bool transposed_a, bool transposed_b, const bool with_comp, + const size_t offset_a, const size_t offset_b, const size_t offset_c) + : Brgemm(), m_with_comp(with_comp) { + // We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call + set_arguments({A, B}); + set_output_size(1); + set_input_port_descriptor({0, offset_a}, 0); + set_input_port_descriptor({0, offset_b}, 1); + set_output_port_descriptor({0, offset_c}, 0); + m_transposed_a = transposed_a; + m_transposed_b = transposed_b; + constructor_validate_and_infer_types(); +} + +BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output& scratch, + bool transposed_a, bool transposed_b, const bool with_comp, + const size_t offset_a, const size_t offset_b, const size_t offset_scratch, const size_t offset_c) + : Brgemm(), m_with_comp(with_comp) { + set_arguments({A, B, scratch}); + set_output_size(1); + set_input_port_descriptor({0, offset_a}, 0); + set_input_port_descriptor({0, offset_b}, 1); + set_output_port_descriptor({0, offset_c}, 0); + set_input_port_descriptor({0, offset_scratch}, 2); + m_transposed_a = transposed_a; + m_transposed_b = transposed_b; + constructor_validate_and_infer_types(); +} + +bool BrgemmCPU::visit_attributes(AttributeVisitor& visitor) { + MemoryAccess::visit_attributes(visitor); + visitor.on_attribute("transposed_a", m_transposed_a); + visitor.on_attribute("transposed_b", m_transposed_b); + visitor.on_attribute("with_comp", m_with_comp); + return true; +} + +void BrgemmCPU::validate_and_infer_types() { + INTERNAL_OP_SCOPE(BrgemmCPU_validate_and_infer_types); + // If no leading dimensions are provided, assume dense row-major inputs-outputs + NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), + "BrgemmCPU currently supports only static shapes."); + + const auto brgemm_copy = get_brgemm_copy(); + std::vector planar_input_shapes = { + ngraph::snippets::utils::get_port_planar_shape(input_value(0)), + ngraph::snippets::utils::get_port_planar_shape(brgemm_copy ? brgemm_copy->input_value(0) : input_value(1)) + }; + + auto output_shape = get_output_partial_shape(planar_input_shapes); + const auto& output_layout = ngraph::snippets::utils::get_node_output_layout(this); + set_output_type(0, + get_output_type(), + ngraph::snippets::utils::get_reordered_planar_shape(output_shape, output_layout)); + + // Verify Scratch input + if (get_input_size() == 3) { + const auto shape = get_input_partial_shape(2); + NGRAPH_CHECK(shape.is_static(), "BRGEMM Scratch must have static shape"); + const auto type = get_input_element_type(2); + if (m_with_comp) { + const auto element_type_b = get_input_element_type(0); + const auto shape_b = planar_input_shapes[1].get_shape(); + const auto N = *shape_b.rbegin(); + const auto N_blk = element_type_b == element::f32 ? N : + element_type_b == element::bf16 ? 32 : 64; + const auto expected_shape = ov::Shape{rnd_up(N, N_blk)}; + const auto expected_type = ov::element::f32; + NGRAPH_CHECK(expected_shape == shape.get_shape() && expected_type == type, + "BRGEMM Scratch with compensations must have shape {rnd_up(N, N_blk)} and FP32 element type"); + } else { + NGRAPH_CHECK(ngraph::shape_size(shape.get_shape()) == 4 * 1024 && type == element::i32, + "BRGEMM Scratch for space workplace must be static, have F32 element type and 1024 shape size"); + } + } +} + +std::shared_ptr BrgemmCPU::clone_with_new_inputs(const OutputVector& new_args) const { + INTERNAL_OP_SCOPE(BrgemmCPU_clone_with_new_inputs); + check_new_args_count(this, new_args); + std::shared_ptr new_node = nullptr; + if (new_args.size() == 2) { + new_node = std::make_shared(new_args.at(0), new_args.at(1), + m_transposed_a, m_transposed_b, m_with_comp, + get_input_port_descriptor(0).m_offset, + get_input_port_descriptor(1).m_offset, + get_output_port_descriptor(0).m_offset); + } else { + new_node = std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), + m_transposed_a, m_transposed_b, m_with_comp, + get_input_port_descriptor(0).m_offset, + get_input_port_descriptor(1).m_offset, + get_input_port_descriptor(2).m_offset, + get_output_port_descriptor(0).m_offset); + } + return new_node; +} + +std::shared_ptr BrgemmCPU::get_brgemm_copy() const { + if (const auto buffer = ov::as_type_ptr(get_input_node_shared_ptr(1))) { + return ov::as_type_ptr(buffer->get_input_node_shared_ptr(0)); + } + return nullptr; +} + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp new file mode 100644 index 00000000000000..183392d321a2dc --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "snippets/op/brgemm.hpp" +#include "brgemm_copy_b.hpp" + +namespace ov { +namespace intel_cpu { + +/** + * @interface BrgemmCPU + * @brief BrgemmCPU is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows + * with support of several precisions on plugin level + * @ingroup snippets + */ +class BrgemmCPU : public ngraph::snippets::op::Brgemm { +public: + OPENVINO_OP("BrgemmCPU", "SnippetsOpset", ngraph::snippets::op::Brgemm); + BrgemmCPU(const Output& A, const Output& B, bool transposed_a = false, bool transposed_b = false, const bool with_comp = false, + const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_c = 0); + BrgemmCPU(const Output& A, const Output& B, const Output& scratch, + bool transposed_a = false, bool transposed_b = false, const bool with_comp = false, + const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_scratch = 0, const size_t offset_c = 0); + BrgemmCPU() = default; + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + + std::shared_ptr get_brgemm_copy() const; + +private: + bool m_with_comp = false; // compensations +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/load_convert.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/load_convert.cpp index 675c214ed7ae2b..65f7bd67c8f4bf 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/load_convert.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/load_convert.cpp @@ -19,6 +19,7 @@ intel_cpu::LoadConvertSaturation::LoadConvertSaturation(const Output& x, c bool intel_cpu::LoadConvertSaturation::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(LoadConvert_visit_attributes); + MemoryAccess::visit_attributes(visitor); visitor.on_attribute("destination_type", m_destination_type); return true; } @@ -31,7 +32,8 @@ void intel_cpu::LoadConvertSaturation::validate_and_infer_types() { std::shared_ptr intel_cpu::LoadConvertSaturation::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(LoadConvert_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_destination_type, m_count, m_offset); + return std::make_shared( + new_args.at(0), m_destination_type, get_input_port_descriptor(0).m_count, get_input_port_descriptor(0).m_offset); } intel_cpu::LoadConvertTruncation::LoadConvertTruncation(const Output& x, const ov::element::Type& destination_type, @@ -42,6 +44,7 @@ intel_cpu::LoadConvertTruncation::LoadConvertTruncation(const Output& x, c bool intel_cpu::LoadConvertTruncation::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(LoadConvert_visit_attributes); + MemoryAccess::visit_attributes(visitor); visitor.on_attribute("destination_type", m_destination_type); return true; } @@ -54,5 +57,6 @@ void intel_cpu::LoadConvertTruncation::validate_and_infer_types() { std::shared_ptr intel_cpu::LoadConvertTruncation::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(LoadConvert_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_destination_type, m_count, m_offset); + return std::make_shared( + new_args.at(0), m_destination_type, get_input_port_descriptor(0).m_count, get_input_port_descriptor(0).m_offset); } diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/store_convert.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/store_convert.cpp index 6a4180c54299c5..ade88450aff218 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/store_convert.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/store_convert.cpp @@ -19,6 +19,7 @@ intel_cpu::StoreConvertSaturation::StoreConvertSaturation(const Output& x, bool intel_cpu::StoreConvertSaturation::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(StoreConvert_visit_attributes); + MemoryAccess::visit_attributes(visitor); visitor.on_attribute("destination_type", m_destination_type); return true; } @@ -31,7 +32,8 @@ void intel_cpu::StoreConvertSaturation::validate_and_infer_types() { std::shared_ptr intel_cpu::StoreConvertSaturation::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(StoreConvert_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_destination_type, m_count, m_offset); + return std::make_shared( + new_args.at(0), m_destination_type, get_output_port_descriptor(0).m_count, get_output_port_descriptor(0).m_offset); } intel_cpu::StoreConvertTruncation::StoreConvertTruncation(const Output& x, const ov::element::Type& destination_type, @@ -42,6 +44,7 @@ intel_cpu::StoreConvertTruncation::StoreConvertTruncation(const Output& x, bool intel_cpu::StoreConvertTruncation::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(StoreConvert_visit_attributes); + MemoryAccess::visit_attributes(visitor); visitor.on_attribute("destination_type", m_destination_type); return true; } @@ -54,5 +57,6 @@ void intel_cpu::StoreConvertTruncation::validate_and_infer_types() { std::shared_ptr intel_cpu::StoreConvertTruncation::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(StoreConvert_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_destination_type, m_count, m_offset); + return std::make_shared( + new_args.at(0), m_destination_type, get_output_port_descriptor(0).m_count, get_output_port_descriptor(0).m_offset); } diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp index 9ab22c79d2eb39..b4a2ed8b51c86e 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -16,22 +16,38 @@ std::vector> input_shapes{ {{3, 1, 32, 14}, {1, 2, 14, 32}}, {{1, 2, 37, 23}, {2, 1, 23, 37}}, {{1, 1, 37, 23}, {1, 2, 23, 33}}, - {{2, 1, 69, 43}, {1, 1, 43, 49}} + {{1, 16, 384, 64}, {1, 16, 64, 384}} +}; +std::vector> precisions = { + {element::f32, element::f32} +}; +std::vector> all_precisions = { + {element::f32, element::f32}, + {element::i8, element::i8}, + {element::u8, element::i8} }; -std::vector precisions{element::f32}; INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul, ::testing::Combine( ::testing::ValuesIn(input_shapes), - ::testing::ValuesIn(precisions), + ::testing::ValuesIn(all_precisions), ::testing::Values(1), // MatMu; ::testing::Values(1), // Tokenized MatMul ::testing::Values(CommonTestUtils::DEVICE_CPU)), MatMul::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulFQ, MatMulFQ, + ::testing::Combine( + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(precisions), + ::testing::Values(1), // MatMul; + ::testing::Values(1), // Tokenized MatMul + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MatMul::getTestCaseName); + INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBias, MatMulBias, ::testing::Combine( ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 1, 69, 49}}), - ::testing::ValuesIn(precisions), + ::testing::ValuesIn(all_precisions), ::testing::Values(1), // Subgraph; ::testing::Values(1), // Tokenized MatMul+Bias ::testing::Values(CommonTestUtils::DEVICE_CPU)), @@ -55,15 +71,6 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulBias, ExplicitTransposeMa ::testing::Values(CommonTestUtils::DEVICE_CPU)), MatMul::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMulMatMulBias, ExplicitTransposeMulMatMulBias, - ::testing::Combine( - ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 2, 1, 1}, {1, 1, 69, 49}}), - ::testing::ValuesIn(precisions), - ::testing::Values(1), // Subgraph; - ::testing::Values(1), // Tokenized MatMul+Bias - ::testing::Values(CommonTestUtils::DEVICE_CPU)), - MatMul::getTestCaseName); - } // namespace } // namespace snippets } // namespace test diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp index 8e3af45fd52da2..8d9b0945f512ad 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp @@ -4,6 +4,7 @@ #include "snippets/transpose_matmul.hpp" #include "common_test_utils/test_constants.hpp" +#include "ie_system_conf.h" namespace ov { namespace test { @@ -11,7 +12,17 @@ namespace snippets { namespace { -std::vector precisions{element::f32}; +static inline std::vector> precisions() { + std::vector> prc = { + {element::f32, element::f32}, + {element::i8, element::i8}, + {element::u8, element::i8} + }; + if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) { + prc.emplace_back(std::vector{element::bf16, element::bf16}); + } + return prc; +} namespace transpose_zero_input { std::vector> transpose_input_shapes{ {{1, 49, 2, 23}, {2, 2, 23, 39}} @@ -20,11 +31,22 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, ::testing::Combine( ::testing::ValuesIn(transpose_input_shapes), ::testing::Values(0), // Transpose on 0th Matmul input - ::testing::ValuesIn(precisions), - ::testing::Values(1), // MatMul; + ::testing::ValuesIn(precisions()), + ::testing::Values(1), // MatMul ::testing::Values(1), // Tokenized MatMul + FusedTranspose ::testing::Values(CommonTestUtils::DEVICE_CPU)), TransposeMatMul::getTestCaseName); + +// TODO: FuseTransposeToBrgemm supports fusing only if Transpose is before Parameter in cases when Transpose is on input +// INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ, +// ::testing::Combine( +// ::testing::ValuesIn(transpose_input_shapes), +// ::testing::Values(0), // Transpose on 0th Matmul input +// ::testing::Values(ov::element::i8), +// ::testing::Values(1), // MatMul +// ::testing::Values(1), // Tokenized MatMul + FusedTranspose +// ::testing::Values(CommonTestUtils::DEVICE_CPU)), +// TransposeMatMulFQ::getTestCaseName); } // namespace transpose_zero_input namespace transpose_first_input { @@ -35,26 +57,48 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, ::testing::Combine( ::testing::ValuesIn(transpose_input_shapes), ::testing::Values(1), // Transpose on 1st Matmul input - ::testing::ValuesIn(precisions), - ::testing::Values(1), // MatMu; + ::testing::ValuesIn(precisions()), + ::testing::Values(1), // MatMul ::testing::Values(1), // Tokenized MatMul + FusedTranspose ::testing::Values(CommonTestUtils::DEVICE_CPU)), TransposeMatMul::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ, + ::testing::Combine( + ::testing::ValuesIn(transpose_input_shapes), + ::testing::Values(1), // Transpose on 1st Matmul input + ::testing::Values(std::vector{ov::element::f32}), + ::testing::Values(1), // MatMul + ::testing::Values(1), // Tokenized MatMul + FusedTranspose + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + TransposeMatMulFQ::getTestCaseName); } // namespace transpose_first_input namespace transpose_output { std::vector> transpose_input_shapes{ {{2, 1, 49, 13}, {1, 2, 13, 39}} }; +// TODO: Propagate shape through Brgemm with Transpose down INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, ::testing::Combine( ::testing::ValuesIn(transpose_input_shapes), ::testing::Values(2), // Transpose on Matmul output - ::testing::ValuesIn(precisions), - ::testing::Values(1), // MatMu; + ::testing::Values(std::vector{ov::element::f32, ov::element::f32}), + ::testing::Values(1), // MatMul ::testing::Values(1), // Tokenized MatMul + FusedTranspose ::testing::Values(CommonTestUtils::DEVICE_CPU)), TransposeMatMul::getTestCaseName); + +// TODO: Propagate shape through Brgemm with Transpose down +// INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ, +// ::testing::Combine( +// ::testing::ValuesIn(transpose_input_shapes), +// ::testing::Values(2), // Transpose on Matmul output +// ::testing::Values(ov::element::i8), +// ::testing::Values(1), // MatMul +// ::testing::Values(1), // Tokenized MatMul + FusedTranspose +// ::testing::Values(CommonTestUtils::DEVICE_CPU)), +// TransposeMatMulFQ::getTestCaseName); } // namespace transpose_output } // namespace diff --git a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp index bfa2a82921f416..9ce8cf8a6258b9 100644 --- a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp @@ -12,21 +12,12 @@ namespace snippets { typedef std::tuple< std::vector, // Input Shapes - ov::element::Type, // Element type + std::vector,// Input Element types size_t, // Expected num nodes size_t, // Expected num subgraphs std::string // Target Device > MatMulParams; -typedef std::tuple< - std::vector, // Input Shapes - size_t , // Transpose position - ov::element::Type, // Element type - size_t, // Expected num nodes - size_t, // Expected num subgraphs - std::string // Target Device -> TransposeMatMulParams; - class MatMul : public testing::WithParamInterface, virtual public ov::test::SnippetsTestsCommon { public: @@ -36,31 +27,22 @@ class MatMul : public testing::WithParamInterface, - virtual public ov::test::SnippetsTestsCommon { -public: - static std::string getTestCaseName(testing::TestParamInfo obj); - +class ExplicitTransposeMatMulBias : public MatMul { protected: void SetUp() override; }; diff --git a/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp index f949e9df9d5c3b..6be2324e938fad 100644 --- a/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp @@ -13,7 +13,7 @@ namespace snippets { typedef std::tuple< std::vector, // Input Shapes size_t , // Transpose position - ov::element::Type, // Element type + std::vector,// Input Element types size_t, // Expected num nodes size_t, // Expected num subgraphs std::string // Target Device @@ -28,6 +28,11 @@ class TransposeMatMul : public testing::WithParamInterface obj) { std::vector input_shapes; - ov::element::Type elem_type; + std::vector elem_types; std::string targetDevice; size_t num_nodes, num_subgraphs; - std::tie(input_shapes, elem_type, num_nodes, num_subgraphs, targetDevice) = obj.param; + std::tie(input_shapes, elem_types, num_nodes, num_subgraphs, targetDevice) = obj.param; std::ostringstream result; for (size_t i = 0; i < input_shapes.size(); i++) result << "IS[" << i <<"]=" << CommonTestUtils::partialShape2str({input_shapes[i]}) << "_"; - result << "T=" << elem_type << "_"; + for (size_t i = 0; i < elem_types.size(); i++) + result << "T[" << i <<"]=" << elem_types[i] << "_"; result << "#N=" << num_nodes << "_"; result << "#S=" << num_subgraphs << "_"; result << "targetDevice=" << targetDevice; @@ -30,11 +31,11 @@ std::string MatMul::getTestCaseName(testing::TestParamInfo input_shapes; - ov::element::Type elem_type; - std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - auto f = ov::test::snippets::MatMulFunction(input_shapes); + auto f = ov::test::snippets::MatMulFunction(input_shapes, elem_types); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, @@ -42,27 +43,13 @@ void MatMul::SetUp() { } } -void MatMulBias::SetUp() { - std::vector input_shapes; - ov::element::Type elem_type; - std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - - auto f = ov::test::snippets::MatMulBiasFunction(input_shapes); - function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } -} - -void ExplicitTransposeMatMul::SetUp() { +void MatMulFQ::SetUp() { std::vector input_shapes; - ov::element::Type elem_type; - std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - auto f = ov::test::snippets::TransposeMatMulFunction(input_shapes); + auto f = ov::test::snippets::FQMatMulFunction(input_shapes); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, @@ -70,13 +57,13 @@ void ExplicitTransposeMatMul::SetUp() { } } -void ExplicitTransposeMatMulBias::SetUp() { +void MatMulBias::SetUp() { std::vector input_shapes; - ov::element::Type elem_type; - std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - auto f = ov::test::snippets::TransposeMatMulBiasFunction(input_shapes); + auto f = ov::test::snippets::MatMulBiasFunction(input_shapes, elem_types); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, @@ -84,13 +71,13 @@ void ExplicitTransposeMatMulBias::SetUp() { } } -void ExplicitTransposeMulMatMulBias::SetUp() { +void ExplicitTransposeMatMul::SetUp() { std::vector input_shapes; - ov::element::Type elem_type; - std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - auto f = ov::test::snippets::TransposeMulMatMulBiasFunction(input_shapes); + auto f = ov::test::snippets::TransposeMatMulFunction(input_shapes); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, @@ -98,34 +85,13 @@ void ExplicitTransposeMulMatMulBias::SetUp() { } } -std::string TransposeMatMul::getTestCaseName(testing::TestParamInfo obj) { - std::vector input_shapes; - size_t transpose_position; - ov::element::Type elem_type; - std::string targetDevice; - size_t num_nodes, num_subgraphs; - std::tie(input_shapes, transpose_position, elem_type, num_nodes, num_subgraphs, targetDevice) = obj.param; - if (input_shapes.size() != 2) - IE_THROW() << "Invalid input shapes vector size"; - std::ostringstream result; - result << "IS[0]=" << CommonTestUtils::partialShape2str({input_shapes[0]}) << "_"; - result << "IS[1]=" << CommonTestUtils::partialShape2str({input_shapes[1]}) << "_"; - result << "Pos=" << transpose_position << "_"; - result << "T=" << elem_type << "_"; - result << "#N=" << num_nodes << "_"; - result << "#S=" << num_subgraphs << "_"; - result << "targetDevice=" << targetDevice; - return result.str(); -} - -void TransposeMatMul::SetUp() { +void ExplicitTransposeMatMulBias::SetUp() { std::vector input_shapes; - size_t transpose_position; - ov::element::Type elem_type; - std::tie(input_shapes, transpose_position, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - auto f = ov::test::snippets::Transpose0213MatMulFunction(input_shapes, transpose_position); + auto f = ov::test::snippets::TransposeMatMulBiasFunction(input_shapes); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, @@ -134,31 +100,31 @@ void TransposeMatMul::SetUp() { } TEST_P(MatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } -TEST_P(MatMulBias, CompareWithRefImpl) { - run(); - validateNumSubgraphs(); -} - -TEST_P(ExplicitTransposeMatMul, CompareWithRefImpl) { +TEST_P(MatMulFQ, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } -TEST_P(ExplicitTransposeMatMulBias, CompareWithRefImpl) { +TEST_P(MatMulBias, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } -TEST_P(ExplicitTransposeMulMatMulBias, CompareWithRefImpl) { +TEST_P(ExplicitTransposeMatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } -TEST_P(TransposeMatMul, CompareWithRefImpl) { +TEST_P(ExplicitTransposeMatMulBias, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } diff --git a/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp index 68a2140339f5e5..60fb17c5788a37 100644 --- a/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp @@ -15,17 +15,18 @@ namespace snippets { std::string TransposeMatMul::getTestCaseName(testing::TestParamInfo obj) { std::vector input_shapes; size_t transpose_position; - ov::element::Type elem_type; + std::vector elem_types; std::string targetDevice; size_t num_nodes, num_subgraphs; - std::tie(input_shapes, transpose_position, elem_type, num_nodes, num_subgraphs, targetDevice) = obj.param; + std::tie(input_shapes, transpose_position, elem_types, num_nodes, num_subgraphs, targetDevice) = obj.param; if (input_shapes.size() != 2) IE_THROW() << "Invalid input shapes vector size"; std::ostringstream result; result << "IS[0]=" << CommonTestUtils::partialShape2str({input_shapes[0]}) << "_"; result << "IS[1]=" << CommonTestUtils::partialShape2str({input_shapes[1]}) << "_"; result << "Pos=" << transpose_position << "_"; - result << "T=" << elem_type << "_"; + for (size_t i = 0; i < elem_types.size(); i++) + result << "T[" << i <<"]=" << elem_types[i] << "_"; result << "#N=" << num_nodes << "_"; result << "#S=" << num_subgraphs << "_"; result << "targetDevice=" << targetDevice; @@ -35,11 +36,26 @@ std::string TransposeMatMul::getTestCaseName(testing::TestParamInfo input_shapes; size_t transpose_position; - ov::element::Type elem_type; - std::tie(input_shapes, transpose_position, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::vector elem_types; + std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - auto f = ov::test::snippets::Transpose0213MatMulFunction(input_shapes, transpose_position); + auto f = ov::test::snippets::Transpose0213MatMulFunction(input_shapes, elem_types, transpose_position); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +void TransposeMatMulFQ::SetUp() { + std::vector input_shapes; + size_t transpose_position; + std::vector elem_types; + std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::FQMatMulFunction(input_shapes, transpose_position); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, @@ -48,6 +64,13 @@ void TransposeMatMul::SetUp() { } TEST_P(TransposeMatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + +TEST_P(TransposeMatMulFQ, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp index c583b5882ab710..40f8c20c9f3a65 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp @@ -56,7 +56,7 @@ class EltwiseThreeInputsLoweredFunction : public EltwiseThreeInputsFunction { class Transpose0213MatMulLoweredFunction : public Transpose0213MatMulFunction { public: explicit Transpose0213MatMulLoweredFunction(const std::vector& inputShapes, size_t position = 0) : - Transpose0213MatMulFunction(inputShapes, position) { + Transpose0213MatMulFunction(inputShapes, std::vector{ov::element::f32, ov::element::f32}, position) { } protected: std::shared_ptr initLowered() const override; diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp index ea533334e80a88..15954605e69fdd 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp @@ -6,6 +6,7 @@ #include "ngraph/ngraph.hpp" #include "./snippets_helpers.hpp" +#include "snippets/utils.hpp" /* This file contains definitions of relatively simple functions (models) that will be used * to test snippets-specific behavior. All the functions are expected to be direct descendants of @@ -20,48 +21,77 @@ namespace snippets { // in1 in2 // Matmul // Result -// todo: remove once "no subgraph after input" limitation is relaxed class MatMulFunction : public SnippetsFunctionBase { public: - explicit MatMulFunction(const std::vector& inputShapes) - : SnippetsFunctionBase(inputShapes) { + explicit MatMulFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes"); + verify_precisions(precisions); + } + static void verify_precisions(const std::vector& precisions) { + NGRAPH_CHECK(precisions.size() == 2, "Got invalid number of input element types"); + const bool is_f32 = ngraph::snippets::utils::everyone_is(element::f32, precisions[0], precisions[1]); + const bool is_int8 = ngraph::snippets::utils::one_of(precisions[0], element::i8, element::u8) && precisions[1] == element::i8; + const bool is_bf16 = ngraph::snippets::utils::everyone_is(element::bf16, precisions[0], precisions[1]); + NGRAPH_CHECK(is_f32 || is_bf16 || is_int8, "Invalid precisions"); } protected: std::shared_ptr initOriginal() const override; std::shared_ptr initReference() const override; + + std::vector precisions; +}; + +class FQMatMulFunction : public SnippetsFunctionBase { +public: + explicit FQMatMulFunction(const std::vector& inputShapes, int pos = -1) : SnippetsFunctionBase({inputShapes[0]}), pos(pos) { + NGRAPH_CHECK(inputShapes.size() == 2, "Got invalid number of input shapes"); + NGRAPH_CHECK(pos >=-1 && pos <= 2, "Got invalid transpose position"); + const_shape = inputShapes[1]; + } +protected: + std::shared_ptr initOriginal() const override; + + ov::PartialShape const_shape; + int pos = -1; }; // As same as MatMulFunction but with biases class MatMulBiasFunction : public SnippetsFunctionBase { public: - explicit MatMulBiasFunction(const std::vector& inputShapes) - : SnippetsFunctionBase(inputShapes) { + explicit MatMulBiasFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + MatMulFunction::verify_precisions(precisions); } protected: std::shared_ptr initOriginal() const override; + + std::vector precisions; }; /// Minimal graph to test MatMul+Transpose combinations. Transpose location is specified via the position argument: /// 0 - before the first MatMul input; 1 - before the second MatMul input; 2 - after the MatMul output. /// Tokenized simply by starting subgraph, // in1 in2 -// Transpose / +// Transpose / // Matmul // Result class Transpose0213MatMulFunction : public SnippetsFunctionBase { public: - explicit Transpose0213MatMulFunction(const std::vector& inputShapes, size_t position = 0) - : SnippetsFunctionBase(inputShapes), transpose_position(position) { + explicit Transpose0213MatMulFunction(const std::vector& inputShapes, const std::vector& precisions, + size_t position = 0) + : SnippetsFunctionBase(inputShapes), transpose_position(position), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes"); NGRAPH_CHECK(input_shapes[0].rank().get_length() == 4 && input_shapes[1].rank().get_length() == 4, "Only rank 4 input shapes are supported by this test"); NGRAPH_CHECK(transpose_position >=0 && transpose_position <= 2, "Got invalid transpose position"); + MatMulFunction::verify_precisions(precisions); } protected: std::shared_ptr initOriginal() const override; size_t transpose_position; + std::vector precisions; }; class TransposeMatMulFunction : public SnippetsFunctionBase { diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp index 22b86982e9e0e1..31e66c97534164 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp @@ -107,8 +107,8 @@ std::shared_ptr EltwiseThreeInputsLoweredFunction::initLowered() cons } std::shared_ptr Transpose0213MatMulLoweredFunction::initLowered() const { - ParameterVector data{std::make_shared(precision, input_shapes[0]), - std::make_shared(precision, input_shapes[1])}; + ParameterVector data{std::make_shared(precisions[0], input_shapes[0]), + std::make_shared(precisions[1], input_shapes[1])}; std::vector layout{0, 2, 1, 3}; // Note: validity of transpose_position values is checked in Transpose0213MatMulSinhFunction constructor if (transpose_position <= 1) { diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp index af312a2ee2d812..d8e49abf573aae 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp @@ -5,50 +5,133 @@ #include "subgraph_matmul.hpp" #include "common_test_utils/data_utils.hpp" #include +#include "ngraph_functions/builders.hpp" +#include "ov_ops/type_relaxed.hpp" + namespace ov { namespace test { namespace snippets { std::shared_ptr MatMulFunction::initOriginal() const { - auto data0 = std::make_shared(precision, input_shapes[0]); - auto data1 = std::make_shared(precision, input_shapes[1]); - auto matmul = std::make_shared(data0, data1); + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + std::shared_ptr matmul; + if (precisions[1] == ov::element::i8) { + matmul = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ngraph::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ngraph::op::TemporaryReplaceOutputType(data1, element::f32).get()); + } else { + matmul = std::make_shared(data0, data1); + } return std::make_shared(NodeVector{matmul}, ParameterVector{data0, data1}); } std::shared_ptr MatMulFunction::initReference() const { + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + auto indata0 = std::make_shared(precisions[0], data0->get_output_partial_shape(0)); + auto indata1 = std::make_shared(precisions[1], data1->get_output_partial_shape(0)); + std::shared_ptr matmul; + if (precisions[1] == ov::element::i8) { + matmul = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ngraph::op::TemporaryReplaceOutputType(indata0, element::f32).get(), + ngraph::op::TemporaryReplaceOutputType(indata1, element::f32).get()); + } else { + matmul = std::make_shared(indata0, indata1); + } + const auto subgraph = std::make_shared(NodeVector{data0, data1}, + std::make_shared(NodeVector{matmul}, + ParameterVector{indata0, indata1})); + return std::make_shared(NodeVector{subgraph}, ParameterVector{data0, data1}); +} +std::shared_ptr FQMatMulFunction::initOriginal() const { + auto const_order = std::make_shared(ov::element::i32, Shape {4}, std::vector{0, 2, 1, 3}); auto data0 = std::make_shared(precision, input_shapes[0]); - auto data1 = std::make_shared(precision, input_shapes[1]); - auto indata0 = std::make_shared(precision, data0->get_output_partial_shape(0)); - auto indata1 = std::make_shared(precision, data1->get_output_partial_shape(0)); - auto matmul = std::make_shared(NodeVector{data0, data1}, - std::make_shared(NodeVector{std::make_shared(indata0, indata1)}, - ParameterVector{indata0, indata1})); - return std::make_shared(NodeVector{matmul}, ParameterVector{data0, data1}); + auto ih = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{34.7436294}); + auto il = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{-35.0172004}); + auto oh = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{34.7436294}); + auto ol = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{-35.0172004}); + auto fq = std::make_shared(data0, il, ih, ol, oh, 256); + std::shared_ptr in0 = fq; + if (pos == 0) { + in0 = std::make_shared(in0, const_order); + } + auto constant = ngraph::builder::makeConstant(ov::element::i8, const_shape.get_shape(), std::vector{}, true); + auto convert = std::make_shared(constant, ov::element::f32); + auto deq_mul = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{0.00499185826}); + auto mul = std::make_shared(convert, deq_mul); + std::shared_ptr in1 = mul; + if (pos == 1) { + in1 = std::make_shared(in1, const_order); + } + auto matmul = std::make_shared(in0, in1); + std::shared_ptr out = matmul; + if (pos == 2) { + out = std::make_shared(out, const_order); + } + return std::make_shared(NodeVector{out}, ParameterVector{data0}); } std::shared_ptr MatMulBiasFunction::initOriginal() const { auto data0 = std::make_shared(precision, input_shapes[0]); auto data1 = std::make_shared(precision, input_shapes[1]); - auto matmul = std::make_shared(data0, data1); auto data2 = std::make_shared(precision, input_shapes[2]); + std::shared_ptr matmul; + if (precisions[1] == ov::element::i8) { + matmul = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ngraph::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ngraph::op::TemporaryReplaceOutputType(data1, element::f32).get()); + } else { + matmul = std::make_shared(data0, data1); + } auto bias = std::make_shared(matmul, data2); return std::make_shared(NodeVector{bias}, ParameterVector{data0, data1, data2}); } std::shared_ptr Transpose0213MatMulFunction::initOriginal() const { - auto data0 = std::make_shared(precision, input_shapes[0]); - auto data1 = std::make_shared(precision, input_shapes[1]); + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); auto const_order = std::make_shared(ov::element::i32, Shape {4}, std::vector{0, 2, 1, 3}); std::shared_ptr result; switch (transpose_position) { case 0: { auto transpose = std::make_shared(data0, const_order); - result = std::make_shared(transpose, data1); + if (precisions[1] == ov::element::i8) { + result = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ngraph::op::TemporaryReplaceOutputType(transpose, element::f32).get(), + ngraph::op::TemporaryReplaceOutputType(data1, element::f32).get()); + } else { + result = std::make_shared(transpose, data1); + } break; } case 1: { auto transpose = std::make_shared(data1, const_order); - result = std::make_shared(data0, transpose); + if (precisions[1] == ov::element::i8) { + result = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ngraph::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ngraph::op::TemporaryReplaceOutputType(transpose, element::f32).get()); + } else { + result = std::make_shared(data0, transpose); + } break; } case 2: { - auto matmul = std::make_shared(data0, data1); + std::shared_ptr matmul; + if (precisions[1] == ov::element::i8) { + matmul = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ngraph::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ngraph::op::TemporaryReplaceOutputType(data1, element::f32).get()); + } else { + matmul = std::make_shared(data0, data1); + } result = std::make_shared(matmul, const_order); break; }