From 5c39a4e732812d895e885112deabba49ead746bd Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Thu, 16 Mar 2023 18:10:46 +0400 Subject: [PATCH] Applied Ivan comments --- .../snippets/include/snippets/generator.hpp | 43 ++-- .../snippets/include/snippets/op/brgemm.hpp | 15 +- .../include/snippets/op/broadcastload.hpp | 2 +- .../snippets/include/snippets/op/buffer.hpp | 71 ++---- .../snippets/include/snippets/op/load.hpp | 10 +- .../include/snippets/op/memory_access.hpp | 65 +++--- .../snippets/include/snippets/op/store.hpp | 4 +- .../snippets/pass/assign_registers.hpp | 4 +- .../snippets/include/snippets/utils.hpp | 4 +- src/common/snippets/src/generator.cpp | 68 +++--- src/common/snippets/src/op/brgemm.cpp | 25 +-- src/common/snippets/src/op/broadcastload.cpp | 1 + src/common/snippets/src/op/buffer.cpp | 127 ++++------- src/common/snippets/src/op/load.cpp | 2 + src/common/snippets/src/op/memory_access.cpp | 89 +++++--- src/common/snippets/src/op/store.cpp | 1 + src/common/snippets/src/op/subgraph.cpp | 19 +- .../snippets/src/pass/assign_registers.cpp | 49 ++--- .../src/pass/fuse_transpose_brgemm.cpp | 4 +- .../snippets/src/pass/insert_buffer.cpp | 8 +- src/common/snippets/src/pass/insert_loops.cpp | 7 +- src/common/snippets/src/pass/loop_fusion.cpp | 10 +- .../snippets/src/pass/matmul_to_brgemm.cpp | 3 +- src/common/snippets/src/pass/reset_buffer.cpp | 4 +- .../src/pass/softmax_decomposition.cpp | 4 +- .../snippets/src/pass/vector_to_scalar.cpp | 6 +- src/common/snippets/src/utils.cpp | 7 +- .../snippets/tests/include/lowering_utils.hpp | 6 +- .../snippets/tests/src/lowering_utils.cpp | 3 +- .../snippets/tests/src/pass/merge_loops.cpp | 4 +- src/common/snippets/tests/src/registers.cpp | 12 +- .../intel_cpu/src/emitters/cpu_generator.cpp | 19 +- .../intel_cpu/src/emitters/cpu_generator.hpp | 6 +- .../src/emitters/jit_snippets_emitters.cpp | 204 +++++++++--------- .../src/emitters/jit_snippets_emitters.hpp | 59 ++--- src/plugins/intel_cpu/src/extension.cpp | 3 +- .../brgemm_to_brgemm_cpu.cpp | 46 ++-- .../brgemm_to_brgemm_cpu.hpp | 19 +- .../fuse_load_store_and_convert.cpp | 13 +- .../op/brgemm_copy_b.cpp | 30 +-- .../op/brgemm_copy_b.hpp | 20 +- .../op/brgemm_cpu.cpp | 61 +++--- .../op/brgemm_cpu.hpp | 26 ++- .../snippets/matmul.cpp | 18 -- .../snippets/transpose_matmul.cpp | 46 +++- .../plugin/shared/include/snippets/matmul.hpp | 10 - .../include/snippets/transpose_matmul.hpp | 10 + .../plugin/shared/src/snippets/matmul.cpp | 40 ---- .../shared/src/snippets/transpose_matmul.cpp | 49 ++++- .../src/subgraph_lowered.cpp | 6 +- 50 files changed, 677 insertions(+), 685 deletions(-) diff --git a/src/common/snippets/include/snippets/generator.hpp b/src/common/snippets/include/snippets/generator.hpp index c5c516ec35c80f..b02f995c6e5c19 100644 --- a/src/common/snippets/include/snippets/generator.hpp +++ b/src/common/snippets/include/snippets/generator.hpp @@ -41,21 +41,6 @@ class TargetMachine { */ virtual size_t get_lanes() const = 0; - /** - * @interface opRegType - * @brief Register type of operations - * Note that currently there are 4 types of ops: - * gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd etc) - * gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc. - * vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc. - */ - enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec}; - /** - * @brief gets register type by op type - * @return register type - */ - opRegType get_op_reg_type(const std::shared_ptr& op) const; - /** * @brief called by generator to all the emitter for a target machine * @return a map by node's type info with callbacks to create an instance of emitter for corresponding operation type @@ -78,12 +63,6 @@ class TargetMachine { virtual ~TargetMachine() = default; protected: - /** - * @brief gets register type by specific plugin op type - * @return register type - */ - virtual opRegType get_specific_op_reg_type(const std::shared_ptr& op) const = 0; - std::map(std::shared_ptr)>> jitters; }; @@ -164,7 +143,29 @@ class Generator { */ std::shared_ptr get_target_machine() const; + /** + * @interface opRegType + * @brief Register type of operations + * Note that currently there are 4 types of ops: + * gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd etc) + * gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc. + * vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc. + */ + enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec}; + /** + * @brief gets register type by op type + * TODO: Should be static attribute of emitters + * @return register type + */ + opRegType get_op_reg_type(const std::shared_ptr& op) const; + protected: + /** + * @brief gets register type by specific plugin op type + * @return register type + */ + virtual opRegType get_specific_op_reg_type(const std::shared_ptr& op) const; + std::shared_ptr target; // todo: we need to save lowered code to access compiled brgemm kernels on execution time (normally lowered is destructed by then). // This is temporary solution, remove this when kernel caching is implemented. Don't forget to make generate const method. diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp index 7b13f5e3053b0b..58c70f164799a6 100644 --- a/src/common/snippets/include/snippets/op/brgemm.hpp +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -19,18 +19,14 @@ namespace op { class Brgemm : public MemoryAccess { public: OPENVINO_OP("Brgemm", "SnippetsOpset", MemoryAccess); - Brgemm(const Output& A, const Output& B, bool transposed_a = false, bool transposed_b = false, + Brgemm(const Output& A, const Output& B, const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu); Brgemm() = default; - bool transposed_a() const { return m_transposed_a; } - bool transposed_b() const { return m_transposed_b; } + size_t get_offset_a() const { return get_input_offset(0); } + size_t get_offset_b() const { return get_input_offset(1); } + size_t get_offset_c() const { return get_output_offset(0); } - size_t get_offset_a() const { return get_input_port_descriptor(0).m_offset; } - size_t get_offset_b() const { return get_input_port_descriptor(1).m_offset; } - size_t get_offset_c() const { return get_output_port_descriptor(0).m_offset; } - - bool visit_attributes(AttributeVisitor& visitor) override; void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; @@ -39,9 +35,6 @@ class Brgemm : public MemoryAccess { protected: ov::element::Type get_output_type() const; ov::PartialShape get_output_partial_shape(const std::vector& input_shapes) const; - - bool m_transposed_a; - bool m_transposed_b; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/broadcastload.hpp b/src/common/snippets/include/snippets/op/broadcastload.hpp index 0c9489f8917d36..edcbe170a371f6 100644 --- a/src/common/snippets/include/snippets/op/broadcastload.hpp +++ b/src/common/snippets/include/snippets/op/broadcastload.hpp @@ -24,7 +24,7 @@ class BroadcastLoad : public MemoryAccess { BroadcastLoad(const Output& x, ov::PartialShape output_shape, size_t offset = 0lu); BroadcastLoad() = default; - size_t get_offset() const { return get_input_port_descriptor(0).m_offset; } + size_t get_offset() const { return get_input_offset(0); } bool visit_attributes(AttributeVisitor& visitor) override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; diff --git a/src/common/snippets/include/snippets/op/buffer.hpp b/src/common/snippets/include/snippets/op/buffer.hpp index 70da0c5dab9f22..a1975b42f5c2d9 100644 --- a/src/common/snippets/include/snippets/op/buffer.hpp +++ b/src/common/snippets/include/snippets/op/buffer.hpp @@ -13,6 +13,8 @@ namespace op { /** * @interface Buffer * @brief This is a base class for memory storage. + * If Buffer has a parent, the operation is for intermediate data storage - Intermediate type. + * Otherwise, the operation is for allocation of new empty memory with shape `m_shape` - Empty type * Notes: * - All buffers in a graph have the same memory pointer. So if we have a few buffers, * each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer @@ -22,67 +24,30 @@ namespace op { class Buffer : public ngraph::op::Op { public: OPENVINO_OP("Buffer", "SnippetsOpset"); - - size_t get_byte_size() const; - virtual ov::PartialShape get_allocation_shape() const = 0; - -protected: Buffer() = default; -}; - -/** - * @interface AllocationBuffer - * @brief The operation is for allocation of new empty memory. The operation has one parent that is equal to allocation shape - * - m_element_type - element type of memory - * @ingroup snippets - */ -class AllocationBuffer : public Buffer { -public: - OPENVINO_OP("AllocationBuffer", "SnippetsOpset", Buffer); - - AllocationBuffer() = default; - AllocationBuffer(const ov::Output& shape, const ov::element::Type element_type); - - ov::PartialShape get_allocation_shape() const override; + Buffer(const ov::Shape& shape); + Buffer(const ov::Output& arg, const ov::Shape& shape); + Buffer(const ov::Output& arg, int32_t allocation_rank = -1); bool visit_attributes(AttributeVisitor& visitor) override; - std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; -protected: - ov::element::Type m_element_type; -}; - -/** - * @interface IntermediateBuffer - * @brief The operation is for intermediate data storage. - * If Buffer has only one parent, the Buffer will allocate a full memory with input shape of Buffer. - * If Buffer has second parent as well, the Buffer will allocate memory with shape that is equal to values from second input but - * saves the input shape for shape inference and input element type. - * For example, - * Parameter [5, 3, 128] Constant [2] (with values {3, 128}) - * \ / - * Buffer with allocated memory 3x128 size - * | - * Result [5, 3, 128] - * @ingroup snippets - */ -class IntermediateBuffer : public Buffer { -public: - OPENVINO_OP("IntermediateBuffer", "SnippetsOpset", Buffer); - - IntermediateBuffer() = default; - IntermediateBuffer(const ov::Output& x); - IntermediateBuffer(const ov::Output& x, const ov::Output& shape); + enum Type { + NewMemory, + IntermediateMemory + }; - ov::PartialShape get_allocation_shape() const override; + Type get_type() const { return m_type; } + ov::Shape get_allocation_shape() const { return m_shape; } + size_t get_byte_size() const; - bool visit_attributes(AttributeVisitor& visitor) override { return true; } - std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; - void validate_and_infer_types() override; + bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; } + bool is_new_memory() const { return m_type == Type::NewMemory; } - static std::shared_ptr create_shape_constant(const ov::PartialShape& shape, size_t allocation_rank); - static std::shared_ptr create_shape_constant(const ov::PartialShape& shape); +private: + Type m_type = Type::IntermediateMemory; + ov::Shape m_shape = {}; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/load.hpp b/src/common/snippets/include/snippets/op/load.hpp index 32166247808aa6..38acd0e8a10255 100644 --- a/src/common/snippets/include/snippets/op/load.hpp +++ b/src/common/snippets/include/snippets/op/load.hpp @@ -25,8 +25,11 @@ class Load : public MemoryAccess { Load(const Output& x, const size_t count = 1lu, const size_t offset = 0lu); Load() = default; - size_t get_offset() const { return get_input_port_descriptor(0).m_offset; } - size_t get_count() const { return get_input_port_descriptor(0).m_count; } + size_t get_offset() const { return get_input_offset(0); } + size_t get_count() const { return get_input_count(0); } + + void set_offset(size_t offset) { set_input_offset(offset, 0); } + void set_count(size_t count) { set_input_count(count, 0); } void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; @@ -45,6 +48,9 @@ class LoadReshape : public Load { LoadReshape(const Output& x, size_t count = 1lu, const size_t offset = 0lu, std::vector order = {}); LoadReshape() = default; + void set_offset(size_t offset) { set_output_offset(offset, 0); } + void set_count(size_t count) { set_output_count(count, 0); } + bool visit_attributes(AttributeVisitor& visitor) override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; void validate_and_infer_types() override; diff --git a/src/common/snippets/include/snippets/op/memory_access.hpp b/src/common/snippets/include/snippets/op/memory_access.hpp index 418af53a0cf1b7..7ac0daea2096b6 100644 --- a/src/common/snippets/include/snippets/op/memory_access.hpp +++ b/src/common/snippets/include/snippets/op/memory_access.hpp @@ -10,31 +10,6 @@ namespace ngraph { namespace snippets { namespace op { -class MemoryAccess; - -/** -* @interface PortDescriptor -* @brief This class describes port of MemoryAccess operation -* @param m_count - count of elements to load/store -* @param m_offset - starting index of elements to load/store -* @param m_index - port index -* @ingroup snippets -*/ - -struct PortDescriptor { - PortDescriptor(size_t count, size_t offset) : m_count(count), m_offset(offset) {} - PortDescriptor() = default; - - size_t m_count = 0lu; - size_t m_offset = 0lu; - size_t m_index = 0lu; - -private: - PortDescriptor(size_t count, size_t offset, size_t index) : m_count(count), m_offset(offset), m_index(index) {} - - friend class MemoryAccess; -}; - /** * @interface MemoryAccess * @brief This is a base class for memory access operations (like Load and Store). @@ -48,14 +23,46 @@ class MemoryAccess : public ngraph::op::Op { public: OPENVINO_OP("MemoryAccess", "SnippetsOpset"); + /** + * @interface PortDescriptor + * @brief This class describes port of MemoryAccess operation + * @param m_count - count of elements to load/store + * @param m_offset - starting index of elements to load/store + * @param m_index - port index + * @ingroup snippets + */ + struct PortDescriptor { + PortDescriptor(size_t count, size_t offset) : m_count(count), m_offset(offset) {} + PortDescriptor() = default; + + size_t m_count = 0lu; + size_t m_offset = 0lu; + size_t m_index = 0lu; + + private: + PortDescriptor(size_t count, size_t offset, size_t index) : m_count(count), m_offset(offset), m_index(index) {} + + friend class MemoryAccess; + }; + void set_input_port_descriptor(const PortDescriptor& desc, const size_t i); void set_output_port_descriptor(const PortDescriptor& desc, const size_t i); - PortDescriptor get_input_port_descriptor(const size_t i) const; - PortDescriptor get_output_port_descriptor(const size_t i) const; - PortDescriptor& get_input_port_descriptor(const size_t i); - PortDescriptor& get_output_port_descriptor(const size_t i); + const PortDescriptor& get_input_port_descriptor(const size_t i) const; + const PortDescriptor& get_output_port_descriptor(const size_t i) const; + + void set_input_count(size_t count, size_t idx); + void set_output_count(size_t count, size_t idx); + void set_input_offset(size_t offset, size_t idx); + void set_output_offset(size_t offset, size_t idx); + + size_t get_input_count(size_t idx) const; + size_t get_output_count(size_t idx) const; + size_t get_input_offset(size_t idx) const; + size_t get_output_offset(size_t idx) const; + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; protected: explicit MemoryAccess(const OutputVector& arguments); diff --git a/src/common/snippets/include/snippets/op/store.hpp b/src/common/snippets/include/snippets/op/store.hpp index 5aca6d3f3f3d2c..e804f7f917d5cd 100644 --- a/src/common/snippets/include/snippets/op/store.hpp +++ b/src/common/snippets/include/snippets/op/store.hpp @@ -25,8 +25,8 @@ class Store : public MemoryAccess { Store(const Output& x, const size_t count = 1lu, const size_t offset = 0lu); Store() = default; - size_t get_offset() const { return get_output_port_descriptor(0).m_offset; } - size_t get_count() const { return get_output_port_descriptor(0).m_count; } + size_t get_offset() const { return get_output_offset(0); } + size_t get_count() const { return get_output_count(0); } void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; diff --git a/src/common/snippets/include/snippets/pass/assign_registers.hpp b/src/common/snippets/include/snippets/pass/assign_registers.hpp index e13aaa7596af8e..81a5e3b2b29d62 100644 --- a/src/common/snippets/include/snippets/pass/assign_registers.hpp +++ b/src/common/snippets/include/snippets/pass/assign_registers.hpp @@ -20,13 +20,13 @@ namespace pass { */ class AssignRegisters : public ngraph::pass::FunctionPass { public: - explicit AssignRegisters(const std::shared_ptr& target_machine) : m_target_machine(target_machine) { + explicit AssignRegisters(const std::function& op)>& mapper) : m_reg_type_mapper(mapper) { set_property(ngraph::pass::PassProperty::REQUIRE_STATIC_SHAPE, true); } bool run_on_model(const std::shared_ptr& m) override; private: - std::shared_ptr m_target_machine = nullptr; + std::function& op)> m_reg_type_mapper; }; } // namespace pass diff --git a/src/common/snippets/include/snippets/utils.hpp b/src/common/snippets/include/snippets/utils.hpp index c641b452571a04..bcbe2860882c1e 100644 --- a/src/common/snippets/include/snippets/utils.hpp +++ b/src/common/snippets/include/snippets/utils.hpp @@ -29,8 +29,8 @@ ov::PartialShape get_port_planar_shape(const Output& out); ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector& layout); std::vector get_node_output_layout(const std::shared_ptr& node); std::vector get_node_output_layout(const Node* node); -void set_output_layout(const ov::Output& port, const std::shared_ptr& node); -void set_output_layout(const ov::Output& port, const std::vector& layout); +void set_transpose_output_layout(const ov::Output& port, const std::shared_ptr& node); +void set_transpose_output_layout(const ov::Output& port, const std::vector& layout); inline ov::Dimension get_inner_dim(const ov::PartialShape &shape) { return *(shape.rbegin()); } inline ov::Dimension get_outer_dim(const ov::PartialShape &shape) { return *(shape.rbegin() + 1); } diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index 68b5e026aafaf4..1b1e8dabfaba29 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -17,30 +17,6 @@ namespace ngraph { namespace snippets { -TargetMachine::opRegType TargetMachine::get_op_reg_type(const std::shared_ptr& op) const { - if (std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) - return gpr2gpr; - else if (std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) - return gpr2vec; - else if (std::dynamic_pointer_cast(op)) - return vec2gpr; - else if (ov::op::util::is_unary_elementwise_arithmetic(op) || - ov::op::util::is_binary_elementwise_arithmetic(op) || - ov::op::util::is_binary_elementwise_comparison(op) || - ov::op::util::is_binary_elementwise_logical(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) - return vec2vec; - else - return get_specific_op_reg_type(op); -} - auto getRegisters(const std::shared_ptr &n) -> RegInfo { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::getRegisters") @@ -102,15 +78,13 @@ auto tail_transformations(NodeVector& tail, const size_t tail_size, const ngraph } } else if (const auto memory_access = std::dynamic_pointer_cast(op)) { for (size_t i = 0; i < memory_access->get_input_size(); ++i) { - auto& desc = memory_access->get_input_port_descriptor(i); - if (desc.m_count != 1) { - desc.m_count = tail_size; + if (memory_access->get_input_count(i) != 1) { + memory_access->set_input_count(tail_size, i); } } for (size_t i = 0; i < memory_access->get_output_size(); ++i) { - auto& desc = memory_access->get_output_port_descriptor(i); - if (desc.m_count != 1) { - desc.m_count = tail_size; + if (memory_access->get_output_count(i) != 1) { + memory_access->set_output_count(tail_size, i); } } } @@ -253,5 +227,39 @@ std::shared_ptr Generator::get_target_machine() const { return target; } +Generator::opRegType Generator::get_op_reg_type(const std::shared_ptr& op) const { + if (std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return gpr2gpr; + else if (std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return gpr2vec; + else if (std::dynamic_pointer_cast(op)) + return vec2gpr; + else if (ov::op::util::is_unary_elementwise_arithmetic(op) || + ov::op::util::is_binary_elementwise_arithmetic(op) || + ov::op::util::is_binary_elementwise_comparison(op) || + ov::op::util::is_binary_elementwise_logical(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return vec2vec; + else + return get_specific_op_reg_type(op); +} + +Generator::opRegType Generator::get_specific_op_reg_type(const std::shared_ptr& op) const { + throw ov::Exception("Register type of the operation " + std::string(op->get_type_name()) + " isn't determined!"); +} + + }// namespace snippets }// namespace ngraph diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 1ffb295bc08b64..f984348f235550 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -13,25 +13,18 @@ namespace ngraph { namespace snippets { namespace op { -Brgemm::Brgemm(const Output& A, const Output& B, bool transposed_a, bool transposed_b, - const size_t offset_a, const size_t offset_b, const size_t offset_c) - : MemoryAccess({A, B}), m_transposed_a(transposed_a), m_transposed_b(transposed_b) { +Brgemm::Brgemm(const Output& A, const Output& B, + const size_t offset_a, const size_t offset_b, const size_t offset_c) : MemoryAccess({A, B}) { set_output_size(1); - set_input_port_descriptor({0, offset_a}, 0); - set_input_port_descriptor({0, offset_b}, 1); - set_output_port_descriptor({0, offset_c}, 0); constructor_validate_and_infer_types(); -} - -bool Brgemm::visit_attributes(AttributeVisitor& visitor) { - MemoryAccess::visit_attributes(visitor); - visitor.on_attribute("transposed_a", m_transposed_a); - visitor.on_attribute("transposed_b", m_transposed_b); - return true; + set_input_offset(offset_a, 0); + set_input_offset(offset_b, 1); + set_output_offset(offset_a, 0); } void Brgemm::validate_and_infer_types() { INTERNAL_OP_SCOPE(Brgemm_validate_and_infer_types); + MemoryAccess::validate_and_infer_types(); // If no leading dimensions are provided, assume dense row-major inputs-outputs NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), "Brgemm currently supports only static shapes."); @@ -51,9 +44,7 @@ void Brgemm::validate_and_infer_types() { std::shared_ptr Brgemm::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(Brgemm_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), new_args.at(1), - m_transposed_a, m_transposed_b, - get_offset_a(), get_offset_b(), get_offset_c()); + return std::make_shared(new_args.at(0), new_args.at(1), get_offset_a(), get_offset_b(), get_offset_c()); } ov::element::Type Brgemm::get_output_type() const { @@ -78,7 +69,7 @@ ov::PartialShape Brgemm::get_output_partial_shape(const std::vector(ngraph::element::f32, input_shapes[0]); auto matmul_in1 = std::make_shared(ngraph::element::f32, input_shapes[1]); - auto matmul = std::make_shared(matmul_in0, matmul_in1, m_transposed_a, m_transposed_b); + auto matmul = std::make_shared(matmul_in0, matmul_in1); std::vector output_shapes = {ov::PartialShape{}}; ov::op::v0::shape_infer(matmul.get(), input_shapes, output_shapes); diff --git a/src/common/snippets/src/op/broadcastload.cpp b/src/common/snippets/src/op/broadcastload.cpp index c767870aaea1fe..f24ff3fc46a000 100644 --- a/src/common/snippets/src/op/broadcastload.cpp +++ b/src/common/snippets/src/op/broadcastload.cpp @@ -28,5 +28,6 @@ std::shared_ptr snippets::op::BroadcastLoad::clone_with_new_inputs(const O } void snippets::op::BroadcastLoad::validate_and_infer_types() { + MemoryAccess::validate_and_infer_types(); set_output_type(0, get_input_element_type(0), output_shape); } diff --git a/src/common/snippets/src/op/buffer.cpp b/src/common/snippets/src/op/buffer.cpp index 35b419398a0330..8a3963119b832b 100644 --- a/src/common/snippets/src/op/buffer.cpp +++ b/src/common/snippets/src/op/buffer.cpp @@ -16,103 +16,64 @@ auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t return allocation_rank < 0 ? allocation_rank + static_cast(shape_rank) : allocation_rank; } -size_t ngraph::snippets::op::Buffer::get_byte_size() const { - const auto pshape = get_allocation_shape(); - // TODO: Add support of dynamism - NGRAPH_CHECK(pshape.is_static(), "Buffer should have static shapes for memory allocation"); - const auto shape = pshape.get_shape(); - return ngraph::shape_size(shape) * get_element_type().size(); -} - -snippets::op::AllocationBuffer::AllocationBuffer(const Output& shape, const ov::element::Type element_type) - : Buffer(), m_element_type(element_type) { - set_arguments({shape}); +snippets::op::Buffer::Buffer(const ov::Shape& shape) + : Op(), m_type(Type::NewMemory), m_shape(shape) { constructor_validate_and_infer_types(); } -bool snippets::op::AllocationBuffer::visit_attributes(AttributeVisitor& visitor) { - INTERNAL_OP_SCOPE(AllocationBuffer_visit_attributes); - visitor.on_attribute("element_type", m_element_type); - return true; -} - -std::shared_ptr snippets::op::AllocationBuffer::clone_with_new_inputs(const OutputVector& new_args) const { - INTERNAL_OP_SCOPE(AllocationBuffer_clone_with_new_inputs); - check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_element_type); -} - -void snippets::op::AllocationBuffer::validate_and_infer_types() { - INTERNAL_OP_SCOPE(AllocationBuffer_validate_and_infer_types); - set_output_type(0, m_element_type, get_allocation_shape()); -} - -ov::PartialShape ngraph::snippets::op::AllocationBuffer::get_allocation_shape() const { - ov::PartialShape shape = ov::PartialShape::dynamic(); - const auto shape_constant = ov::as_type_ptr(get_input_node_shared_ptr(0)); - if (shape_constant) { - NGRAPH_CHECK(shape_constant->get_element_type() == ov::element::i32, - "The AllocationBuffer expects Constant with shape of I32 element type"); - const auto dims = shape_constant->cast_vector(); - NGRAPH_CHECK(!dims.empty(), "The AllocationBuffer got invalid shape Constant"); - shape = ov::PartialShape(ov::Shape(std::vector(dims.begin(), dims.end()))); - } - return shape; -} - -snippets::op::IntermediateBuffer::IntermediateBuffer(const ov::Output& x) : Buffer() { - set_arguments({x}); +snippets::op::Buffer::Buffer(const ov::Output& arg, const ov::Shape& shape) + : Op({arg}), m_type(Type::IntermediateMemory), m_shape(shape) { constructor_validate_and_infer_types(); } -snippets::op::IntermediateBuffer::IntermediateBuffer(const ov::Output& x, const ov::Output& shape) : Buffer() { - set_arguments({x, shape}); +snippets::op::Buffer::Buffer(const ov::Output& arg, int32_t allocation_rank) + : Op({arg}), m_type(Type::IntermediateMemory) { + const auto pshape = arg.get_partial_shape(); + OPENVINO_ASSERT(pshape.is_static(), "Buffer supports only static input shape"); + const auto shape = pshape.get_shape(); + const auto normalize_rank = utils::normalize_rank(static_cast(allocation_rank), shape.size()); + const auto offset = static_cast(shape.size()) - normalize_rank; + m_shape = {shape.begin() + offset, shape.end()}; constructor_validate_and_infer_types(); } -std::shared_ptr snippets::op::IntermediateBuffer::clone_with_new_inputs(const OutputVector& new_args) const { - INTERNAL_OP_SCOPE(IntermediateBuffer_clone_with_new_inputs); - check_new_args_count(this, new_args); - if (new_args.size() == 2) { - return std::make_shared(new_args.at(0), new_args.at(1)); - } else if (new_args.size() == 1) { - return std::make_shared(new_args.at(0)); - } - - throw ngraph_error("The IntermediateBuffer op got invalid input count"); -} - -void snippets::op::IntermediateBuffer::validate_and_infer_types() { - INTERNAL_OP_SCOPE(IntermediateBuffer_validate_and_infer_types); - set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); +bool snippets::op::Buffer::visit_attributes(AttributeVisitor& visitor) { + INTERNAL_OP_SCOPE(Buffer_visit_attributes); + visitor.on_attribute("allocation_shape", m_shape); + return true; } -ov::PartialShape ngraph::snippets::op::IntermediateBuffer::get_allocation_shape() const { - if (get_input_size() == 1) { - return get_input_partial_shape(0); +void snippets::op::Buffer::validate_and_infer_types() { + INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types); + ov::element::Type output_type; + ov::Shape output_shape; + if (m_type == Type::NewMemory) { + OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory must to not have arguments!"); + output_shape = m_shape; + output_type = ov::element::u8; // 1Byte + } else if (m_type == Type::IntermediateMemory) { + const auto input_shape = get_input_partial_shape(0); + OPENVINO_ASSERT(input_shape.is_static(), "Buffer supports only static input shape"); + output_type = get_input_element_type(0); + output_shape = input_shape.get_shape(); + } else { + throw ov::Exception("Buffer supports only the following types: NewMemory and IntermediateMemory"); } + set_output_type(0, output_type, output_shape); +} - const auto shape_constant = ov::as_type_ptr(get_input_node_shared_ptr(1)); - if (shape_constant) { - NGRAPH_CHECK(shape_constant->get_element_type() == ov::element::i32, - "The AllocationBuffer expects Constant with shape of I32 element type"); - const auto dims = shape_constant->cast_vector(); - NGRAPH_CHECK(!dims.empty(), "The AllocationBuffer got invalid shape Constant"); - return ov::PartialShape(ov::Shape(std::vector(dims.begin(), dims.end()))); +std::shared_ptr snippets::op::Buffer::clone_with_new_inputs(const OutputVector& new_args) const { + INTERNAL_OP_SCOPE(Buffer_clone_with_new_inputs); + check_new_args_count(this, new_args); + if (m_type == Type::NewMemory) { + return std::make_shared(m_shape); + } else if (m_type == Type::IntermediateMemory) { + return std::make_shared(new_args.at(0), m_shape); } - return ov::PartialShape::dynamic(); + throw ov::Exception("Buffer supports only the following types: NewMemory and IntermediateMemory"); } -std::shared_ptr ngraph::snippets::op::IntermediateBuffer::create_shape_constant(const ov::PartialShape& shape, size_t allocation_rank) { - if (shape.rank().is_dynamic()) - return nullptr; - const auto normalize_rank = utils::normalize_rank(static_cast(allocation_rank), shape.size()); - const auto offset = static_cast(shape.size()) - normalize_rank; - return create_shape_constant(ov::PartialShape(std::vector{shape.begin() + offset, shape.end()})); +size_t ngraph::snippets::op::Buffer::get_byte_size() const { + const auto shape = get_allocation_shape(); + return ngraph::shape_size(shape) * get_element_type().size(); } - -std::shared_ptr ngraph::snippets::op::IntermediateBuffer::create_shape_constant(const ov::PartialShape& shape) { - if (shape.is_dynamic()) - return nullptr; - return std::make_shared(ov::element::i32, ov::Shape{shape.size()}, shape.get_shape()); -} \ No newline at end of file diff --git a/src/common/snippets/src/op/load.cpp b/src/common/snippets/src/op/load.cpp index 20f0ac390bbb90..3de338aa899eef 100644 --- a/src/common/snippets/src/op/load.cpp +++ b/src/common/snippets/src/op/load.cpp @@ -18,6 +18,7 @@ Load::Load(const Output& x, const size_t count, const size_t offset) : Mem } void snippets::op::Load::validate_and_infer_types() { + MemoryAccess::validate_and_infer_types(); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); } @@ -42,6 +43,7 @@ LoadReshape::LoadReshape(const Output& x, const size_t count, const si } void snippets::op::LoadReshape::validate_and_infer_types() { + MemoryAccess::validate_and_infer_types(); const auto& old_shape = get_input_partial_shape(0); ov::PartialShape new_shape; for (const auto idx : m_order) diff --git a/src/common/snippets/src/op/memory_access.cpp b/src/common/snippets/src/op/memory_access.cpp index 9352884ca740f3..e723ecfae46dd9 100644 --- a/src/common/snippets/src/op/memory_access.cpp +++ b/src/common/snippets/src/op/memory_access.cpp @@ -3,17 +3,28 @@ // #include - #include "snippets/op/memory_access.hpp" -#include - namespace ngraph { namespace snippets { namespace op { MemoryAccess::MemoryAccess(const OutputVector& arguments) : Op(arguments) {} +void MemoryAccess::validate_and_infer_types() { + // We create descriptors in validate_and_infer_types() (instead of in ctor) + const auto input_count = get_input_size(); + const auto output_count = get_output_size(); + while (m_input_ports.size() < input_count) { + m_input_ports.push_back({0, 0, m_input_ports.size()}); + } + while (m_output_ports.size() < output_count) { + m_output_ports.push_back({0, 0, m_output_ports.size()}); + } + OPENVINO_ASSERT(m_input_ports.size() == input_count, "The count of input ports must be equal to input count"); + OPENVINO_ASSERT(m_output_ports.size() == output_count, "The count of output ports must be equal to output count"); +} + bool MemoryAccess::visit_attributes(AttributeVisitor& visitor) { for (size_t i = 0; i < m_input_ports.size(); ++i) { auto port = m_input_ports[i]; @@ -29,51 +40,59 @@ bool MemoryAccess::visit_attributes(AttributeVisitor& visitor) { } void MemoryAccess::set_input_port_descriptor(const PortDescriptor& desc, const size_t i) { - // Logic is as same as ov::Node::get_input_descriptor - while (m_input_ports.size() <= i) { - m_input_ports.emplace_back(PortDescriptor{0, 0, m_input_ports.size()}); - } - m_input_ports[i] = { desc.m_count, desc.m_offset, i}; -} - -PortDescriptor MemoryAccess::get_input_port_descriptor(const size_t i) const { - // We cannot use the same way as in ov::Node::get_input_descriptor because this method must be const - // to allow call const Derived::clone_with_new_inputs() method NGRAPH_CHECK(i < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports"); - return m_input_ports[i]; -} - -PortDescriptor& MemoryAccess::get_input_port_descriptor(const size_t i) { - // Logic is as same as ov::Node::get_input_descriptor - while (m_input_ports.size() <= i) { - m_input_ports.emplace_back(PortDescriptor{0, 0, m_input_ports.size()}); - } - return m_input_ports[i]; + m_input_ports[i] = { desc.m_count, desc.m_offset, i}; } void MemoryAccess::set_output_port_descriptor(const PortDescriptor& desc, const size_t i) { // Logic is as same as ov::Node::get_output_descriptor - while (m_output_ports.size() <= i) { - m_output_ports.emplace_back(PortDescriptor{0, 0, m_output_ports.size()}); - } + NGRAPH_CHECK(i < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports"); m_output_ports[i] = { desc.m_count, desc.m_offset, i}; } -PortDescriptor MemoryAccess::get_output_port_descriptor(const size_t i) const { - // We cannot use the same way as in ov::Node::get_input_descriptor because this method must be const - // to allow call const Derived::clone_with_new_inputs() method +const MemoryAccess::PortDescriptor& MemoryAccess::get_input_port_descriptor(const size_t i) const { + NGRAPH_CHECK(i < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports"); + return m_input_ports[i]; +} + +const MemoryAccess::PortDescriptor& MemoryAccess::get_output_port_descriptor(const size_t i) const { NGRAPH_CHECK(i < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports"); return m_output_ports[i]; } -PortDescriptor& MemoryAccess::get_output_port_descriptor(const size_t i) { - // Logic is as same as ov::Node::get_output_descriptor - while (m_output_ports.size() <= i) { - m_output_ports.emplace_back(PortDescriptor{0, 0, m_output_ports.size()}); - } - return m_output_ports[i]; +void MemoryAccess::set_input_count(size_t count, size_t idx) { + NGRAPH_CHECK(idx < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports"); + m_input_ports[idx].m_count = count; +} +void MemoryAccess::set_output_count(size_t count, size_t idx) { + NGRAPH_CHECK(idx < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports"); + m_output_ports[idx].m_count = count; +} +void MemoryAccess::set_input_offset(size_t offset, size_t idx) { + NGRAPH_CHECK(idx < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports"); + m_input_ports[idx].m_offset = offset; +} +void MemoryAccess::set_output_offset(size_t offset, size_t idx) { + NGRAPH_CHECK(idx < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports"); + m_output_ports[idx].m_offset = offset; +} +size_t MemoryAccess::get_input_count(size_t idx) const { + NGRAPH_CHECK(idx < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports"); + return m_input_ports[idx].m_count; +} +size_t MemoryAccess::get_output_count(size_t idx) const { + NGRAPH_CHECK(idx < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports"); + return m_output_ports[idx].m_count; +} +size_t MemoryAccess::get_input_offset(size_t idx) const { + NGRAPH_CHECK(idx < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports"); + return m_input_ports[idx].m_offset; +} +size_t MemoryAccess::get_output_offset(size_t idx) const { + NGRAPH_CHECK(idx < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports"); + return m_output_ports[idx].m_offset; } } // namespace op } // namespace snippets -} // namespace ngraph \ No newline at end of file +} // namespace ngraph diff --git a/src/common/snippets/src/op/store.cpp b/src/common/snippets/src/op/store.cpp index a97b5500afecfe..214871b89d667e 100644 --- a/src/common/snippets/src/op/store.cpp +++ b/src/common/snippets/src/op/store.cpp @@ -18,6 +18,7 @@ snippets::op::Store::Store(const Output& x, const size_t count, const size } void snippets::op::Store::validate_and_infer_types() { + MemoryAccess::validate_and_infer_types(); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); } diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 0cebc7ec092940..93e3b47881524a 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -431,8 +431,9 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { // Propagate to up: in Store. Buffer can have only one Store { - auto parent = buffer->get_input_node_shared_ptr(0); - if (!ov::is_type(parent)) { + if (buffer->is_intermediate_memory()) { + OPENVINO_ASSERT(buffer->get_input_size() == 1, "Buffer with intermediate memory must have one parent"); + auto parent = buffer->get_input_node_shared_ptr(0); auto idx = buffer->input(0).get_source_output().get_index(); while (ov::is_type(parent)) { const auto source_output = parent->input_value(idx); @@ -440,8 +441,7 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { idx = source_output.get_index(); } if (auto memory_access = ov::as_type_ptr(parent)) { - auto &out_desc = memory_access->get_output_port_descriptor(idx); - out_desc.m_offset = offset; + memory_access->set_output_offset(offset, idx); } else { throw ngraph_error( "Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation"); @@ -463,8 +463,7 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { propagate_down(loop_target_output); } } else if (auto memory_access = ov::as_type_ptr(child)) { - auto& in_desc = memory_access->get_input_port_descriptor(target_input.get_index()); - in_desc.m_offset = offset; + memory_access->set_input_offset(offset, target_input.get_index()); } else { throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation"); } @@ -487,9 +486,10 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { continue; } - if (buffer->get_input_size() > 0) { + if (buffer->is_intermediate_memory()) { // Transpose, MatMul and other non-decomposed ops should have different memories on inputs and outputs to avoid data corruption, // so after them, we should allocate new memory. Other operations (Eltwises, Convert) can be executed inplace inside Loop. + OPENVINO_ASSERT(buffer->get_input_size() == 1, "Buffer with intermediate memory must have one parent"); const auto parent = buffer->get_input_node_shared_ptr(0); if (!ov::is_type(parent) || is_transpose_loop(parent)) { offset = m_buffer_scratchpad; @@ -617,7 +617,10 @@ snippets::Schedule snippets::op::Subgraph::generate(ngraph::pass::Manager& opt, if (config.m_has_domain_sensitive_ops) initialize_buffer_scratchpad_size(); - snippets::pass::AssignRegisters(m_generator->get_target_machine()).run_on_model(body_ptr()); + std::function& op)> reg_type_mapper = [=](const std::shared_ptr& op) -> Generator::opRegType { + return m_generator->get_op_reg_type(op); + }; + snippets::pass::AssignRegisters(reg_type_mapper).run_on_model(body_ptr()); const auto ops = body_ptr()->get_ops(); ngraph::snippets::Generator::GeneratorConfig generatorConfig; diff --git a/src/common/snippets/src/pass/assign_registers.cpp b/src/common/snippets/src/pass/assign_registers.cpp index 240ed318d31cdc..c9af20443b8938 100644 --- a/src/common/snippets/src/pass/assign_registers.cpp +++ b/src/common/snippets/src/pass/assign_registers.cpp @@ -14,14 +14,7 @@ namespace { constexpr size_t reg_count = 16lu; - -auto filter_ops(const std::shared_ptr& op) -> bool { - if (ov::is_type(op) && - ov::is_type(op->get_output_target_inputs(0).begin()->get_node())) - return false; - return true; -} - +using opRegType = ngraph::snippets::Generator::opRegType; } // namespace bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr& f) { @@ -29,14 +22,11 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::op::AssignRegisters") using Reg = size_t; using tensor = std::shared_ptr; - auto original_ops = f->get_ordered_ops(); - ov::NodeVector ops; - ops.reserve(original_ops.size()); - std::copy_if(original_ops.cbegin(), original_ops.cend(), std::back_inserter(ops), filter_ops); + auto ops = f->get_ordered_ops(); - std::vector>> typed_ops; + std::vector>> typed_ops; for (const auto& op : ops) { - typed_ops.emplace_back(std::make_pair(m_target_machine->get_op_reg_type(op), op)); + typed_ops.emplace_back(std::make_pair(m_reg_type_mapper(op), op)); } size_t counter_vec = 0; @@ -56,9 +46,9 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr // here we use the fact that Result input & output tensors are identical by construction manually_assigned_gprs[op->output(0).get_tensor_ptr()] = static_cast(f->get_result_index(result) + num_parameters); - } else if (ov::is_type(op)) { + } else if (const auto buffer = ov::as_type_ptr(op)) { // All buffers have one common data pointer - if (ov::is_type(op)) { + if (buffer->is_intermediate_memory()) { manually_assigned_gprs[op->input(0).get_tensor_ptr()] = static_cast(num_results + num_parameters); } @@ -108,12 +98,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr }; for (const auto& t_op : typed_ops) { switch (t_op.first) { - case TargetMachine::opRegType::vec2vec: - case TargetMachine::opRegType::gpr2vec: + case opRegType::vec2vec: + case opRegType::gpr2vec: enumerate_out_tensors(t_op.second, regs_vec, manually_assigned_vecs, counter_vec); break; - case TargetMachine::opRegType::gpr2gpr: - case TargetMachine::opRegType::vec2gpr: + case opRegType::gpr2gpr: + case opRegType::vec2gpr: enumerate_out_tensors(t_op.second, regs_gpr, manually_assigned_gprs, counter_gpr); break; } @@ -139,27 +129,24 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr const auto& t_op = typed_ops[i]; std::vector used_tensors, defined_tensors; for (const auto& in : t_op.second->inputs()) { - if (ov::is_type(t_op.second) && - ov::is_type(t_op.second->get_input_node_shared_ptr(in.get_index()))) - continue; used_tensors.push_back(in.get_tensor_ptr()); } for (const auto& out : t_op.second->outputs()) defined_tensors.push_back(out.get_tensor_ptr()); switch (t_op.first) { - case TargetMachine::opRegType::vec2vec: + case opRegType::vec2vec: used_vec[i] = tensor2reg(used_tensors, regs_vec); defined_vec[i] = tensor2reg(defined_tensors, regs_vec); break; - case TargetMachine::opRegType::gpr2gpr: + case opRegType::gpr2gpr: used_gpr[i] = tensor2reg(used_tensors, regs_gpr); defined_gpr[i] = tensor2reg(defined_tensors, regs_gpr); break; - case TargetMachine::opRegType::gpr2vec: + case opRegType::gpr2vec: used_gpr[i] = tensor2reg(used_tensors, regs_gpr); defined_vec[i] = tensor2reg(defined_tensors, regs_vec); break; - case TargetMachine::opRegType::vec2gpr: + case opRegType::vec2gpr: used_vec[i] = tensor2reg(used_tensors, regs_vec); defined_gpr[i] = tensor2reg(defined_tensors, regs_gpr); break; @@ -194,12 +181,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr if (k == ops.size()) throw ngraph_error("assign registers can't find target op in the body"); switch (typed_ops[k].first) { - case TargetMachine::opRegType::vec2vec: - case TargetMachine::opRegType::vec2gpr: + case opRegType::vec2vec: + case opRegType::vec2gpr: life_out_vec[n].insert(life_in_vec[k].begin(), life_in_vec[k].end()); break; - case TargetMachine::opRegType::gpr2gpr: - case TargetMachine::opRegType::gpr2vec: + case opRegType::gpr2gpr: + case opRegType::gpr2vec: life_out_gpr[n].insert(life_in_gpr[k].begin(), life_in_gpr[k].end()); break; } diff --git a/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp b/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp index 6b08f27ca33893..62dd1292b3ffce 100644 --- a/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp +++ b/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp @@ -58,13 +58,13 @@ FuseTransposeBrgemm::FuseTransposeBrgemm() { const auto& transpose_out = m.get_match_value(); for (const auto& in : transpose_out.get_target_inputs()) in.replace_source_output(brgemm->output(0)); - utils::set_output_layout(brgemm_out, as_type_ptr(transpose_out.get_node_shared_ptr())); + utils::set_transpose_output_layout(brgemm_out, as_type_ptr(transpose_out.get_node_shared_ptr())); } for (size_t i = 0; i < brgemm->get_input_size(); i++) { const auto& in_value = brgemm->input_value(i); if (transpose_matcher->match(in_value)) { const auto& transpose = as_type_ptr(in_value.get_node_shared_ptr()); - utils::set_output_layout(transpose->input_value(0), transpose); + utils::set_transpose_output_layout(transpose->input_value(0), transpose); brgemm->set_argument(i, transpose->input_value(0)); } } diff --git a/src/common/snippets/src/pass/insert_buffer.cpp b/src/common/snippets/src/pass/insert_buffer.cpp index 3c2ae74858503d..e7f4c90ae028ed 100644 --- a/src/common/snippets/src/pass/insert_buffer.cpp +++ b/src/common/snippets/src/pass/insert_buffer.cpp @@ -31,9 +31,7 @@ ngraph::snippets::pass::InsertBuffer::InsertBuffer(const int32_t allocation_rank if (!ov::is_type(input_node) && !ov::is_type(input_node) && !ov::is_type(input_node)) { - const auto constant_shape = op::IntermediateBuffer::create_shape_constant(input.get_partial_shape(), allocation_rank); - const auto buffer = constant_shape ? std::make_shared(input_node, constant_shape) : - std::make_shared(input_node); + const auto buffer = std::make_shared(input_node, allocation_rank); root->set_argument(input.get_index(), buffer); rewritten |= true; } @@ -70,9 +68,7 @@ ngraph::snippets::pass::InsertBuffer::InsertBuffer(const int32_t allocation_rank } } - const auto constant_shape = op::IntermediateBuffer::create_shape_constant(output.get_partial_shape(), allocation_rank); - const auto buffer = constant_shape ? std::make_shared(output, constant_shape) : - std::make_shared(output); + const auto buffer = std::make_shared(output, allocation_rank); for (const auto& consumer : output.get_target_inputs()) { const auto output_node = consumer.get_node()->shared_from_this(); if (output_node != buffer && diff --git a/src/common/snippets/src/pass/insert_loops.cpp b/src/common/snippets/src/pass/insert_loops.cpp index 6f4f8726c75a28..e88ba92770dcaf 100644 --- a/src/common/snippets/src/pass/insert_loops.cpp +++ b/src/common/snippets/src/pass/insert_loops.cpp @@ -137,10 +137,7 @@ void insert_loops_explicitly(const ov::NodeVector& ops, const size_t vector_size // on LoopBegin to guarantee that the constants are executed inside the Loop. for (const auto& n : body) { if (auto c = std::dynamic_pointer_cast(n)) { - // Except Constant Shape for Buffers - if (!ov::is_type(n->get_output_target_inputs(0).begin()->get_node())) { - c->add_control_dependency(inner_loop_begin); - } + c->add_control_dependency(inner_loop_begin); } } @@ -158,8 +155,6 @@ void insert_loops_explicitly(const ov::NodeVector& ops, const size_t vector_size ov::is_type(op) || ov::is_type(op)) return true; - if (ov::is_type(op) && ov::is_type(op->get_output_target_inputs(0).begin()->get_node())) - return true; auto& rt = op->get_rt_info(); auto outside_rt = rt.find("outside_loop"); bool is_outside = false; diff --git a/src/common/snippets/src/pass/loop_fusion.cpp b/src/common/snippets/src/pass/loop_fusion.cpp index 4751b4403dc583..2291e0746075d9 100644 --- a/src/common/snippets/src/pass/loop_fusion.cpp +++ b/src/common/snippets/src/pass/loop_fusion.cpp @@ -56,7 +56,7 @@ auto can_be_merged(const std::shared_ptr& loop_en auto get_buffer_and_loop_end(const std::shared_ptr& loop_begin_down, std::shared_ptr& loop_end_up, - std::shared_ptr& buffer) -> bool { + std::shared_ptr& buffer) -> bool { size_t fusion_input_num = 0; for (const auto& parent : loop_begin_down->input_values()) { const auto parent_shared = parent.get_node_shared_ptr(); @@ -70,7 +70,7 @@ auto get_buffer_and_loop_end(const std::shared_ptr(parent_shared); - buffer = ov::as_type_ptr(parent_shared); + buffer = ov::as_type_ptr(parent_shared); if (buffer) { if (buffer->output(0).get_target_inputs().size() == 0 || buffer->get_input_source_output(0).get_target_inputs().size() != 1) @@ -86,7 +86,7 @@ auto get_buffer_and_loop_end(const std::shared_ptr& loop_begin, - const std::shared_ptr& buffer, + const std::shared_ptr& buffer, std::vector& new_loop_inputs, std::vector& new_ptr_increments, std::vector& new_finalization_offsets) -> void { @@ -109,7 +109,7 @@ auto collect_loop_inputs(const std::shared_ptr& } auto collect_loop_outputs(const std::shared_ptr& loop_end, - const std::shared_ptr& buffer, + const std::shared_ptr& buffer, std::vector& new_loop_outputs, std::vector& new_ptr_increments, std::vector& new_finalization_offsets, @@ -162,7 +162,7 @@ bool ngraph::snippets::pass::LoopFusion::Merge(const std::shared_ptr loop_end_up = nullptr; - std::shared_ptr buffer = nullptr; + std::shared_ptr buffer = nullptr; // Initialize the corresponding upper LoopEnd and Buffer if (!get_buffer_and_loop_end(loop_begin_down, loop_end_up, buffer)) { return false; diff --git a/src/common/snippets/src/pass/matmul_to_brgemm.cpp b/src/common/snippets/src/pass/matmul_to_brgemm.cpp index 9a380bb304969c..add672b0fef3ea 100644 --- a/src/common/snippets/src/pass/matmul_to_brgemm.cpp +++ b/src/common/snippets/src/pass/matmul_to_brgemm.cpp @@ -29,8 +29,7 @@ MatMulToBrgemm::MatMulToBrgemm() { if (matmul->get_transpose_a() || matmul->get_transpose_b()) return false; - auto brgemm = std::make_shared(matmul->get_input_source_output(0), matmul->get_input_source_output(1), - matmul->get_transpose_a(), matmul->get_transpose_b()); + auto brgemm = std::make_shared(matmul->get_input_source_output(0), matmul->get_input_source_output(1)); ov::NodeVector nodes = { brgemm }; if (brgemm->get_output_element_type(0) != matmul->get_output_element_type(0)) { nodes.emplace_back(std::make_shared(brgemm, matmul->get_output_element_type(0))); diff --git a/src/common/snippets/src/pass/reset_buffer.cpp b/src/common/snippets/src/pass/reset_buffer.cpp index 55bc5aad88deae..54bdfef03f7f13 100644 --- a/src/common/snippets/src/pass/reset_buffer.cpp +++ b/src/common/snippets/src/pass/reset_buffer.cpp @@ -92,9 +92,7 @@ ngraph::snippets::pass::ResetBufferState::ResetBufferState() { port_idx = source_output.get_index(); loop_index++; } - const auto pshape = buffer->get_allocation_shape(); - NGRAPH_CHECK(pshape.is_static(), "Buffer must have static allocation shape to calculate finalization offsets"); - const auto result_shape = pshape.get_shape(); + const auto result_shape = buffer->get_allocation_shape(); NGRAPH_CHECK(loop_index < result_shape.size(), "Buffer has invalid Loop index and allocation shape rank"); const auto work_amount = std::accumulate(result_shape.rbegin(), result_shape.rbegin() + loop_index + 1, size_t(1), std::multiplies()); finalization_offsets[i_size + i] = diff --git a/src/common/snippets/src/pass/softmax_decomposition.cpp b/src/common/snippets/src/pass/softmax_decomposition.cpp index 8a20cc06c07d0e..a0259a4061b41e 100644 --- a/src/common/snippets/src/pass/softmax_decomposition.cpp +++ b/src/common/snippets/src/pass/softmax_decomposition.cpp @@ -126,9 +126,7 @@ ngraph::snippets::pass::SoftmaxDecomposition::SoftmaxDecomposition(const size_t apply_increments_sum, finalization_offsets_sum); const auto horizon_sum = std::make_shared(sum); - const auto constant_shape_exp = op::IntermediateBuffer::create_shape_constant(loop_sum_end->output(0).get_partial_shape(), buffer_allocation_rank); - const auto buffer_exp = constant_shape_exp ? std::make_shared(loop_sum_end->output(0), constant_shape_exp) : - std::make_shared(loop_sum_end->output(0)); + const auto buffer_exp = std::make_shared(loop_sum_end->output(0), buffer_allocation_rank); /* =========================================== */ diff --git a/src/common/snippets/src/pass/vector_to_scalar.cpp b/src/common/snippets/src/pass/vector_to_scalar.cpp index e5b88ad2b7dcae..4f98a49de4eedd 100644 --- a/src/common/snippets/src/pass/vector_to_scalar.cpp +++ b/src/common/snippets/src/pass/vector_to_scalar.cpp @@ -24,8 +24,7 @@ ngraph::snippets::pass::SetScalarCountForLoad::SetScalarCountForLoad() { if (!load) return false; - auto& desc = load->get_input_port_descriptor(0); - desc.m_count = 1lu; + load->set_input_count(1lu, 0); return true; }); } @@ -44,8 +43,7 @@ ngraph::snippets::pass::SetScalarCountForStore::SetScalarCountForStore() { if (!store) return false; - auto& desc = store->get_output_port_descriptor(0); - desc.m_count = 1lu; + store->set_output_count(1lu, 0); return true; }); } diff --git a/src/common/snippets/src/utils.cpp b/src/common/snippets/src/utils.cpp index 2adb4b0a3ade4f..f2ce3dbf8ad1a0 100644 --- a/src/common/snippets/src/utils.cpp +++ b/src/common/snippets/src/utils.cpp @@ -115,12 +115,13 @@ ov::PartialShape get_port_planar_shape(const Output& out) { return get_reordered_planar_shape(tensor_shape, layout); } -void set_output_layout(const ov::Output& port, const std::shared_ptr& node) { +void set_transpose_output_layout(const ov::Output& port, const std::shared_ptr& node) { const auto& const_order = as_type_ptr(node->get_input_node_shared_ptr(1)); - set_output_layout(port, const_order->cast_vector()); + OPENVINO_ASSERT(const_order != nullptr, "Transpose order must be Constant to set layout!"); + set_transpose_output_layout(port, const_order->cast_vector()); } -void set_output_layout(const ov::Output& port, const std::vector& layout) { +void set_transpose_output_layout(const ov::Output& port, const std::vector& layout) { auto& rt_info = port.get_node_shared_ptr()->get_rt_info(); rt_info["Layout"] = layout; } diff --git a/src/common/snippets/tests/include/lowering_utils.hpp b/src/common/snippets/tests/include/lowering_utils.hpp index f543870a652040..4fd06f760f3207 100644 --- a/src/common/snippets/tests/include/lowering_utils.hpp +++ b/src/common/snippets/tests/include/lowering_utils.hpp @@ -30,15 +30,15 @@ class DummyTargetMachine : public ngraph::snippets::TargetMachine { bool is_supported() const override { return true; } ngraph::snippets::code get_snippet() const override { return nullptr; } size_t get_lanes() const override { return 10; } - -protected: - opRegType get_specific_op_reg_type(const std::shared_ptr& op) const override { return vec2vec; }; }; class DummyGenerator : public ngraph::snippets::Generator { public: DummyGenerator() : ngraph::snippets::Generator(std::make_shared()) {} DummyGenerator(const std::shared_ptr& t) : ngraph::snippets::Generator(t) {} + +protected: + opRegType get_specific_op_reg_type(const std::shared_ptr& op) const override { return vec2vec; }; }; class LoweringTests : public TransformationTestsF { diff --git a/src/common/snippets/tests/src/lowering_utils.cpp b/src/common/snippets/tests/src/lowering_utils.cpp index c62c63eb9922cb..a536a0317eae12 100644 --- a/src/common/snippets/tests/src/lowering_utils.cpp +++ b/src/common/snippets/tests/src/lowering_utils.cpp @@ -38,8 +38,7 @@ DummyTargetMachine::DummyTargetMachine(const std::vector& jitters[ngraph::snippets::op::LoopBegin::get_type_info_static()] = dummy_functor; jitters[ngraph::snippets::op::LoopEnd::get_type_info_static()] = dummy_functor; jitters[ngraph::snippets::op::Brgemm::get_type_info_static()] = dummy_functor; - jitters[ngraph::snippets::op::IntermediateBuffer::get_type_info_static()] = dummy_functor; - jitters[ngraph::snippets::op::AllocationBuffer::get_type_info_static()] = dummy_functor; + jitters[ngraph::snippets::op::Buffer::get_type_info_static()] = dummy_functor; jitters[ngraph::snippets::op::VectorBuffer::get_type_info_static()] = dummy_functor; jitters[ngraph::snippets::op::Fill::get_type_info_static()] = dummy_functor; diff --git a/src/common/snippets/tests/src/pass/merge_loops.cpp b/src/common/snippets/tests/src/pass/merge_loops.cpp index a73148dc15850c..048b3e52a76b1b 100644 --- a/src/common/snippets/tests/src/pass/merge_loops.cpp +++ b/src/common/snippets/tests/src/pass/merge_loops.cpp @@ -38,7 +38,7 @@ TEST(TransformationTests, UnaryEltwisesLoops) { OutputVector{inner_loop_end_up->output(0), outer_loop_begin_up->output(1)}, shape[shape.size() - 2], 1, std::vector{0, 0}, std::vector{0, 0}); - auto buffer = std::make_shared(outer_loop_end_up); + auto buffer = std::make_shared(outer_loop_end_up); auto outer_loop_begin_down = std::make_shared(OutputVector{buffer}); auto inner_loop_begin_down = std::make_shared(OutputVector{outer_loop_begin_down}); @@ -108,7 +108,7 @@ TEST(TransformationTests, BinaryEltwisesLoops) { OutputVector{inner_loop_end_up->output(0), outer_loop_begin_up->output(2)}, shape[shape.size() - 2], 1, std::vector{0, 0, 0}, std::vector{0, 0, 0}); - auto buffer = std::make_shared(outer_loop_end_up); + auto buffer = std::make_shared(outer_loop_end_up); auto data2 = std::make_shared(element::f32, shape); diff --git a/src/common/snippets/tests/src/registers.cpp b/src/common/snippets/tests/src/registers.cpp index 004ec45f71b35a..e9d7c503802142 100644 --- a/src/common/snippets/tests/src/registers.cpp +++ b/src/common/snippets/tests/src/registers.cpp @@ -39,7 +39,11 @@ TEST(TransformationTests, AssignRegisters) { pass::Manager m; m.register_pass(); - m.register_pass(generator->get_target_machine()); + std::function& op)> reg_type_mapper = + [=](const std::shared_ptr& op) -> snippets::Generator::opRegType { + return generator->get_op_reg_type(op); + }; + m.register_pass(reg_type_mapper); m.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); @@ -130,7 +134,11 @@ TEST(TransformationTests, AssignRegisters2) { pass::Manager m; m.register_pass(); - m.register_pass(generator->get_target_machine()); + std::function& op)> reg_type_mapper = + [=](const std::shared_ptr& op) -> snippets::Generator::opRegType { + return generator->get_op_reg_type(op); + }; + m.register_pass(reg_type_mapper); m.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); } diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp index a647ead007e220..025cf8e60811b9 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp @@ -48,10 +48,9 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_ // data movement jitters[ngraph::opset1::Parameter::get_type_info_static()] = CREATE_EMITTER(NopEmitter); jitters[ngraph::opset1::Result::get_type_info_static()] = CREATE_EMITTER(NopEmitter); - jitters[ngraph::snippets::op::AllocationBuffer::get_type_info_static()] = CREATE_EMITTER(NopEmitter); - jitters[ngraph::snippets::op::IntermediateBuffer::get_type_info_static()] = CREATE_EMITTER(NopEmitter); + jitters[ngraph::snippets::op::Buffer::get_type_info_static()] = CREATE_EMITTER(NopEmitter); jitters[ngraph::snippets::op::VectorBuffer::get_type_info_static()] = CREATE_EMITTER(VectorBufferEmitter); - jitters[ngraph::opset1::Constant::get_type_info_static()] = CREATE_EMITTER(NopEmitter); // Not supported + // jitters[ngraph::opset1::Constant::get_type_info_static()] = CREATE_EMITTER(); // Not supported jitters[ngraph::snippets::op::Load::get_type_info_static()] = CREATE_EMITTER(LoadEmitter); jitters[ngraph::snippets::op::LoadReshape::get_type_info_static()] = CREATE_EMITTER(LoadEmitter); @@ -164,13 +163,17 @@ code ov::intel_cpu::CPUTargetMachine::get_snippet() const { return h->jit_ker(); } -ngraph::snippets::TargetMachine::opRegType ov::intel_cpu::CPUTargetMachine::get_specific_op_reg_type(const std::shared_ptr& op) const { +ov::intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_) : Generator(std::make_shared(isa_)) { +} + +ngraph::snippets::Generator::opRegType ov::intel_cpu::CPUGenerator::get_specific_op_reg_type(const std::shared_ptr& op) const { if (std::dynamic_pointer_cast(op) || std::dynamic_pointer_cast(op)) return gpr2gpr; - else + else if ( + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) return vec2vec; -} - -ov::intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_) : Generator(std::make_shared(isa_)) { + else + throw ov::Exception("Register type of the operation " + std::string(op->get_type_name()) + " isn't determined!"); } diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp index 93d062ae41d595..b624d2c0b093bf 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp @@ -20,9 +20,6 @@ class CPUTargetMachine : public ngraph::snippets::TargetMachine { ngraph::snippets::code get_snippet() const override; size_t get_lanes() const override; -protected: - opRegType get_specific_op_reg_type(const std::shared_ptr& op) const override; - private: std::unique_ptr h; dnnl::impl::cpu::x64::cpu_isa_t isa; @@ -31,6 +28,9 @@ class CPUTargetMachine : public ngraph::snippets::TargetMachine { class CPUGenerator : public ngraph::snippets::Generator { public: CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa); + +protected: + opRegType get_specific_op_reg_type(const std::shared_ptr& op) const override; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp index 61061b8bf153ba..c59c4fa61752c2 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp @@ -21,6 +21,10 @@ using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { +namespace { +constexpr size_t gpr_size = 8; +} // namespace + inline static void transform_idxs_to_regs(const std::vector& idxs, std::vector& regs) { regs.resize(idxs.size()); std::transform(idxs.begin(), idxs.end(), regs.begin(), [](size_t idx){return Reg64(static_cast(idx));}); @@ -706,7 +710,7 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: const auto& brgemm_node = as_type_ptr(node); if (brgemm_node->is_dynamic()) IE_THROW() << "Snippets don't support code generation for dynamic Brgemm"; - const auto brgemm_copy = brgemm_node->get_brgemm_copy(); + const auto brgemm_copy = brgemm_node->is_with_data_repacking() ? brgemm_node->get_brgemm_copy() : nullptr; const OutputVector io_values {brgemm_node->input_value(0), brgemm_copy ? brgemm_copy->input_value(0) : brgemm_node->input_value(1), brgemm_node->output(0)}; @@ -734,54 +738,53 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: io_layouts.push_back(layout); } } - // todo: leave AMX and VNNI related code for now, it'll help to enable int8 and bf16 support - bool isAMXSupported = mayiuse(avx512_core_bf16_amx_int8) || mayiuse(avx512_core_bf16_amx_bf16); const auto& A_shape = io_values[0].get_shape(); const auto& A_layout = io_layouts[0]; const auto& C_shape = io_values[2].get_shape(); const auto& C_layout = io_layouts[2]; - M = C_shape[C_layout[2]]; - K = A_shape[A_layout[3]]; - M_blk = matmulOptimalM; - M_tail = M % M_blk; + m_M = C_shape[C_layout[2]]; + m_K = A_shape[A_layout[3]]; + m_M_blk = matmulOptimalM; + m_M_tail = m_M % m_M_blk; // B_shape[B_layout[3]] - N = C_shape[C_layout[3]]; + m_N = C_shape[C_layout[3]]; auto brg0Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(0)); auto brg1Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(1)); io_data_size = {brg0Prc.size(), brg1Prc.size(), brgemm_node->get_output_element_type(0).size()}; - brg0VnniFactor = 4 / brg0Prc.size(); - bool brgWithAMX = isAMXSupported && brg0Prc != Precision::FP32 && (K % brg0VnniFactor == 0) && (N % brg0VnniFactor == 0); + m_brg0VnniFactor = 4 / brg0Prc.size(); + bool brgWithAMX = brgemm_node->is_amx(); - with_scratch = brgemm_node->get_input_size() == 3; - with_comp = !brgWithAMX && brg0Prc == Precision::I8; + m_with_comp = brgemm_node->is_with_compensations(); + m_with_scratch = brgemm_node->is_with_scratchpad(); + OPENVINO_ASSERT((m_with_scratch && brgemm_node->get_input_size() == 3) || !m_with_scratch, "Brgemm with scratchpad expect 3 inputs"); - N_blk = brg1Prc == Precision::FP32 ? N : - brg1Prc == Precision::BF16 ? 32 : 64; - N_tail = N % N_blk; - K_blk = brgWithAMX ? brg0Prc == Precision::BF16 ? 32 : 64 - : K; - K_tail = K % K_blk; + m_N_blk = brg1Prc == Precision::FP32 ? m_N : + brg1Prc == Precision::BF16 ? 32 : 64; + m_N_tail = m_N % m_N_blk; + m_K_blk = brgWithAMX ? brg0Prc == Precision::BF16 ? 32 : 64 + : m_K; + m_K_tail = m_K % m_K_blk; size_t brg0BaseIdx = -1; for (size_t m = 0; m < 2; m++) { for (size_t k = 0; k < 2; k++) { for (size_t n = 0; n < 2; n++) { - auto& brgemmCtx = brgCtxs0[getBrgIdx(m, k, n)]; + auto& brgemmCtx = m_brgCtxs0[getBrgIdx(m, k, n)]; - auto M_ = m ? M_tail - : M < M_blk ? 0 : M_blk; - auto N_ = n ? N_tail : N - N_tail; - auto K_ = k ? K_tail : K - K_tail; - auto beta = k && brgCtxs0[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f; + auto M_ = m ? m_M_tail + : m_M < m_M_blk ? 0 : m_M_blk; + auto N_ = n ? m_N_tail : m_N - m_N_tail; + auto K_ = k ? m_K_tail : m_K - m_K_tail; + auto beta = k && m_brgCtxs0[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f; brgemmCtx.M = M_; brgemmCtx.N = N_; brgemmCtx.K = K_; brgemmCtx.LDA = leading_dimensions[0]; - brgemmCtx.LDB = brg1Prc == Precision::FP32 ? leading_dimensions[1] : rnd_up(N, N_blk); + brgemmCtx.LDB = brg1Prc == Precision::FP32 ? leading_dimensions[1] : rnd_up(m_N, m_N_blk); brgemmCtx.LDC = leading_dimensions[2]; brgemmCtx.dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc)); brgemmCtx.dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg1Prc)); @@ -791,17 +794,17 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: if (M_ != 0 && K_ != 0 && N_ != 0) { if (brg0BaseIdx == -1) brg0BaseIdx = getBrgIdx(m, k, n); - initBrgemm(brgemmCtx, brgKernels0[getBrgIdx(m, k, n)], brgWithAMX); + initBrgemm(brgemmCtx, m_brgKernels0[getBrgIdx(m, k, n)], brgWithAMX); } } } } - load_offset_a = brgemm_node->get_offset_a(); - load_offset_b = brgemm_node->get_offset_b(); - store_offset_c = brgemm_node->get_offset_c(); - if (with_scratch) - load_offset_scratch = brgemm_node->get_offset_scratch(); + m_load_offset_a = brgemm_node->get_offset_a(); + m_load_offset_b = brgemm_node->get_offset_b(); + m_store_offset_c = brgemm_node->get_offset_c(); + if (m_with_scratch) + m_load_offset_scratch = brgemm_node->get_offset_scratch(); } void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) const { @@ -832,37 +835,37 @@ void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr& void BrgemmEmitter::emit_impl(const std::vector& in, const std::vector& out) const { if (host_isa_ == cpu::x64::avx512_core) { - Reg64 input_0(static_cast(in[0])); - Reg64 input_1(static_cast(in[1])); - Reg64 input_2(static_cast(0)); // scratch. Default reg index is 0 if there isn't scratch - if (with_scratch) { + Xbyak::Reg64 input_0(static_cast(in[0])); + Xbyak::Reg64 input_1(static_cast(in[1])); + Xbyak::Reg64 input_2(static_cast(0)); // scratch. Default reg index is 0 if there isn't scratch + if (m_with_scratch) { if (in.size() != 3) { IE_THROW() << "BRGEMM Emitter expects 3 inputs if there are compensations/wsp"; } - input_2 = Reg64(static_cast(in[2])); + input_2 = Xbyak::Reg64(static_cast(in[2])); } - Reg64 output_0(static_cast(out[0])); + Xbyak::Reg64 output_0(static_cast(out[0])); - for (size_t mb = 0; mb < div_up(M, M_blk); mb++) { - const bool is_M_tail = (M - mb * M_blk < M_blk); + for (size_t mb = 0; mb < div_up(m_M, m_M_blk); mb++) { + const bool is_M_tail = (m_M - mb * m_M_blk < m_M_blk); size_t brgIdx0 = getBrgIdx(0, 0, 0); - size_t K0_step0 = brgCtxs0[brgIdx0].K; - size_t K0_step1 = brgCtxs0[brgIdx0].K * brgCtxs0[brgIdx0].LDB; - size_t N0_step0 = brgCtxs0[brgIdx0].N * brg0VnniFactor; - size_t N0_step1 = brgCtxs0[brgIdx0].N; + size_t K0_step0 = m_brgCtxs0[brgIdx0].K; + size_t K0_step1 = m_brgCtxs0[brgIdx0].K * m_brgCtxs0[brgIdx0].LDB; + size_t N0_step0 = m_brgCtxs0[brgIdx0].N * m_brg0VnniFactor; + size_t N0_step1 = m_brgCtxs0[brgIdx0].N; for (size_t n = 0; n < 2; n++) { for (size_t k = 0; k < 2; k++) { size_t mIdx = is_M_tail ? 1 : 0; - auto& brgemmCtx = brgCtxs0[getBrgIdx(mIdx, k, n)]; + auto& brgemmCtx = m_brgCtxs0[getBrgIdx(mIdx, k, n)]; if (brgemmCtx.K != 0 && brgemmCtx.N != 0) { - const size_t in0_offset = load_offset_a + (k * K0_step0 + mb * M_blk * brgemmCtx.LDA) * io_data_size[0]; - const size_t in1_offset = load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1]; - const size_t in2_offset = load_offset_scratch + (with_comp ? n * N0_step1 * sizeof(int32_t) : 0); - const size_t out0_offset = store_offset_c + (n * N0_step1 + mb * M_blk * brgemmCtx.LDC) * io_data_size[2]; + const size_t in0_offset = m_load_offset_a + (k * K0_step0 + mb * m_M_blk * brgemmCtx.LDA) * io_data_size[0]; + const size_t in1_offset = m_load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1]; + const size_t in2_offset = m_load_offset_scratch + (m_with_comp ? n * N0_step1 * sizeof(int32_t) : 0); + const size_t out0_offset = m_store_offset_c + (n * N0_step1 + mb * m_M_blk * brgemmCtx.LDC) * io_data_size[2]; - emit_brgemm_kernel_call(brgKernels0[getBrgIdx(mIdx, k, n)].get(), + emit_brgemm_kernel_call(m_brgKernels0[getBrgIdx(mIdx, k, n)].get(), brgemmCtx, input_0, input_1, @@ -886,7 +889,6 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, c const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t in2_kernel_offset, const size_t out0_kernel_offset) const { if (ctx.is_with_amx) { - size_t gpr_size = 8; Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); @@ -915,7 +917,6 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, c h->add(h->rsp, n_gprs_to_save * gpr_size); } - size_t gpr_size = 8; Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); @@ -958,7 +959,7 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, c h->uni_vmovq(Xmm(0), addr_A); h->uni_vmovq(Xmm(1), addr_B); h->uni_vmovq(Xmm(2), addr_C); - if (with_scratch) + if (m_with_scratch) h->uni_vmovq(Xmm(3), scratch); // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. const auto data_ptr_reg = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) { @@ -988,12 +989,12 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, c h->mov(abi_not_param1, static_cast(with_comp)); h->mov(h->qword[h->rsp + (abi_param_count + 1) * gpr_size], abi_not_param1); #else - if (with_scratch) { + if (m_with_scratch) { data_ptr_reg(Xmm(3), abi_param5, in2_kernel_offset); } else { h->mov(abi_param5, reinterpret_cast(nullptr)); } - h->mov(abi_param6, static_cast(with_comp)); + h->mov(abi_param6, static_cast(m_with_comp)); #endif // align stack on 16-byte as ABI requires @@ -1039,8 +1040,8 @@ void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel, brgemm_p.ptr_D = C; brgemm_p.ptr_buf = scratch; brgemm_p.ptr_bias = nullptr; - brgemm_p.do_post_ops = with_comp; - brgemm_p.do_apply_comp = with_comp; + brgemm_p.do_post_ops = static_cast(with_comp); + brgemm_p.do_apply_comp = static_cast(with_comp); brgemm_p.skip_accm = 0; brgemm_p.BS = 1; // default value assert(brg_kernel); @@ -1054,14 +1055,14 @@ BrgemmCopyBEmitter::BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, d if (!brgemm_repack) IE_THROW() << "BrgemmCopyBEmitters expects BrgemmCopyB node"; - brgemm_prc_in0 = brgemm_repack->get_src_element_type(); - brgemm_prc_in1 = brgemm_repack->get_input_element_type(0); - brgemmVNNIFactor = 4 / brgemm_prc_in0.size(); - with_comp = brgemm_repack->is_with_comp(); - in_offset = brgemm_repack->get_offset_in(); - out_offset = brgemm_repack->get_offset_out(); - if (with_comp) - comp_offset = brgemm_repack->get_offset_comp(); + m_brgemm_prc_in0 = brgemm_repack->get_src_element_type(); + m_brgemm_prc_in1 = brgemm_repack->get_input_element_type(0); + m_brgemmVNNIFactor = 4 / m_brgemm_prc_in0.size(); + m_with_comp = brgemm_repack->is_with_compensations(); + m_in_offset = brgemm_repack->get_offset_in(); + m_out_offset = brgemm_repack->get_offset_out(); + if (m_with_comp) + m_comp_offset = brgemm_repack->get_offset_compensations(); auto layout = ngraph::snippets::utils::get_node_output_layout(brgemm_repack->get_input_node_shared_ptr(0)); const auto& original_shape = brgemm_repack->get_input_shape(0); @@ -1082,40 +1083,40 @@ BrgemmCopyBEmitter::BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, d leading_dimension = std::accumulate(original_shape.end() - num_last_dims, original_shape.end(), 1, std::multiplies()); } - N = *(transposed_shape.rbegin()); - K = *(transposed_shape.rbegin() + 1); + m_N = *(transposed_shape.rbegin()); + m_K = *(transposed_shape.rbegin() + 1); const bool isAMXSupported = mayiuse(avx512_core_amx); - const auto use_amx = isAMXSupported && brgemm_prc_in0 != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0); + const auto use_amx = isAMXSupported && m_brgemm_prc_in0 != ov::element::f32 && (m_K % m_brgemmVNNIFactor == 0) && (m_N % m_brgemmVNNIFactor == 0); - N_blk = brgemm_prc_in1 == ov::element::f32 ? N : - brgemm_prc_in1 == ov::element::bf16 ? 32 : 64; - K_blk = use_amx ? brgemm_prc_in0 == ov::element::bf16 ? 32 : 64 - : K; - N_tail = N % N_blk; - K_tail = K % K_blk; - LDB = brgemm_prc_in1 == ov::element::f32 ? leading_dimension : rnd_up(N, N_blk); + m_N_blk = m_brgemm_prc_in1 == ov::element::f32 ? m_N : + m_brgemm_prc_in1 == ov::element::bf16 ? 32 : 64; + m_K_blk = use_amx ? m_brgemm_prc_in0 == ov::element::bf16 ? 32 : 64 + : m_K; + m_N_tail = m_N % m_N_blk; + m_K_tail = m_K % m_K_blk; + m_LDB = m_brgemm_prc_in1 == ov::element::f32 ? leading_dimension : rnd_up(m_N, m_N_blk); - const auto dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(brgemm_prc_in0))); - const auto dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(brgemm_prc_in1))); - init_brgemm_copy(kernel, leading_dimension, N_blk, N_tail, LDB, K - K_tail, use_amx, dt_in0, dt_in1); + const auto dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(m_brgemm_prc_in0))); + const auto dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(m_brgemm_prc_in1))); + init_brgemm_copy(m_kernel, leading_dimension, m_N_blk, m_N_tail, m_LDB, m_K - m_K_tail, use_amx, dt_in0, dt_in1); } void BrgemmCopyBEmitter::init_brgemm_copy(std::unique_ptr& kernel, - size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, - bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const { + size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, + bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const { matmul::brgemm_matmul_conf_t brgCopyKernelConf; brgCopyKernelConf.src_dt = dt_in0; brgCopyKernelConf.wei_dt = dt_in1; - brgCopyKernelConf.wei_n_blk = N_blk; + brgCopyKernelConf.wei_n_blk = static_cast(N_blk); brgCopyKernelConf.wei_tag = dnnl_abcd; // What's about other ranks? brgCopyKernelConf.copy_B_wei_stride = 0; - brgCopyKernelConf.LDB = LDB; - brgCopyKernelConf.N = N; - brgCopyKernelConf.N_tail = N_tail; - brgCopyKernelConf.N_blk = N_blk; - brgCopyKernelConf.K = K; - brgCopyKernelConf.K_blk = K; + brgCopyKernelConf.LDB = static_cast(LDB); + brgCopyKernelConf.N = static_cast(N); + brgCopyKernelConf.N_tail = static_cast(N_tail); + brgCopyKernelConf.N_blk = static_cast(N_blk); + brgCopyKernelConf.K = static_cast(K); + brgCopyKernelConf.K_blk = static_cast(K); brgCopyKernelConf.N_chunk_elems = brgCopyKernelConf.N_blk; brgCopyKernelConf.b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); brgCopyKernelConf.tr_b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); @@ -1141,26 +1142,26 @@ void BrgemmCopyBEmitter::init_brgemm_copy(std::unique_ptr& in, const std::vector& out) const { if (host_isa_ == cpu::x64::avx512_core) { - Reg64 src(static_cast(in[0])); - Reg64 dst(static_cast(out[0])); - Reg64 comp(static_cast(0)); // Compensations. Default reg idx is 0 if there aren't the compensations - if (with_comp) { + Xbyak::Reg64 src(static_cast(in[0])); + Xbyak::Reg64 dst(static_cast(out[0])); + Xbyak::Reg64 comp(static_cast(0)); // Compensations. Default reg idx is 0 if there aren't the compensations + if (m_with_comp) { if (out.size() != 2) { IE_THROW() << "BrgemmCopyBEmitter with compensations requires separate register for them"; } - comp = Reg64(static_cast(out[1])); + comp = Xbyak::Reg64(static_cast(out[1])); } - const size_t data_size = brgemm_prc_in1.size(); - for (size_t nb = 0; nb < div_up(N, N_blk); nb++) { - const size_t offset_in = in_offset + nb * N_blk * data_size; - const size_t offset_out = out_offset + nb * N_blk * brgemmVNNIFactor * data_size; - const size_t offset_comp = with_comp ? comp_offset + nb * N_blk * sizeof(int32_t) : 0; + const size_t data_size = m_brgemm_prc_in1.size(); + for (size_t nb = 0; nb < div_up(m_N, m_N_blk); nb++) { + const size_t offset_in = m_in_offset + nb * m_N_blk * data_size; + const size_t offset_out = m_out_offset + nb * m_N_blk * m_brgemmVNNIFactor * data_size; + const size_t offset_comp = m_with_comp ? m_comp_offset + nb * m_N_blk * sizeof(int32_t) : 0; - const bool is_N_tail = (N - nb * N_blk < N_blk); - const auto current_N_blk = is_N_tail ? N_tail : N_blk; + const bool is_N_tail = (m_N - nb * m_N_blk < m_N_blk); + const auto current_N_blk = is_N_tail ? m_N_tail : m_N_blk; - emit_kernel_call(kernel.get(), src, dst, comp, current_N_blk, K, offset_in, offset_out, offset_comp); + emit_kernel_call(m_kernel.get(), src, dst, comp, current_N_blk, m_K, offset_in, offset_out, offset_comp); } } else { IE_THROW() << "BrgemmCopyBEmitter requires at least avx512_core instruction set"; @@ -1169,7 +1170,6 @@ void BrgemmCopyBEmitter::emit_impl(const std::vector& in, void BrgemmCopyBEmitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, Reg64 src, Reg64 dst, Reg64 comp, size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const { - size_t gpr_size = 8; Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); @@ -1223,14 +1223,14 @@ void BrgemmCopyBEmitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b // It's likely that a more efficient solution exists. h->uni_vmovq(Xmm(0), src); h->uni_vmovq(Xmm(1), dst); - if (with_comp) + if (m_with_comp) h->uni_vmovq(Xmm(2), comp); // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. h->mov(abi_param1, reinterpret_cast(kernel)); data_ptr(Xmm(0), abi_param2, offset_in); data_ptr(Xmm(1), abi_param3, offset_out); - if (with_comp) { + if (m_with_comp) { data_ptr(Xmm(2), abi_param4, offset_comp); } else { h->mov(abi_param4, reinterpret_cast(nullptr)); diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp index b052763192b6f8..d98ccf31dbf629 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp @@ -336,33 +336,34 @@ class BrgemmEmitter : public jit_emitter { bool is_with_comp; float beta; }; - void initBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) const; - void callBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, const void* pin0, const void* pin1, void* pout, void* wsp) const; + void initBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) const; + void callBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, + const void* pin0, const void* pin1, void* pout, void* wsp) const; size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const; - void emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, const brgemmCtx& ctx, - Reg64 addr_A, Reg64 addr_B, Reg64 scratch, Reg64 addr_C, + void emit_brgemm_kernel_call(const dnnl::impl::cpu::x64::brgemm_kernel_t* brg_kernel, const brgemmCtx& ctx, + Xbyak::Reg64 addr_A, Xbyak::Reg64 addr_B, Xbyak::Reg64 scratch, Xbyak::Reg64 addr_C, const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t in2_kernel_offset, const size_t out0_kernel_offset) const; - static void kernel_execute(const brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C, void *scratch, int with_comp); + static void kernel_execute(const dnnl::impl::cpu::x64::brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C, void *scratch, int with_comp); static constexpr size_t BRGEMM_KERNELS_NUM = 8; static constexpr size_t matmulOptimalM = 32; - brgemmCtx brgCtxs0[BRGEMM_KERNELS_NUM]; - std::unique_ptr brgKernels0[BRGEMM_KERNELS_NUM]; + brgemmCtx m_brgCtxs0[BRGEMM_KERNELS_NUM]; + std::unique_ptr m_brgKernels0[BRGEMM_KERNELS_NUM]; - size_t M, M_blk, M_tail; - size_t K, K_blk, K_tail; - size_t N, N_blk, N_tail; - size_t brg0VnniFactor; + size_t m_M, m_M_blk, m_M_tail; + size_t m_K, m_K_blk, m_K_tail; + size_t m_N, m_N_blk, m_N_tail; + size_t m_brg0VnniFactor; - bool with_scratch = false; - bool with_comp = false; + bool m_with_scratch = false; + bool m_with_comp = false; - size_t load_offset_a = 0lu; - size_t load_offset_b = 0lu; - size_t load_offset_scratch = 0lu; - size_t store_offset_c = 0lu; + size_t m_load_offset_a = 0lu; + size_t m_load_offset_b = 0lu; + size_t m_load_offset_scratch = 0lu; + size_t m_store_offset_c = 0lu; }; class BrgemmCopyBEmitter : public jit_emitter { @@ -378,25 +379,25 @@ class BrgemmCopyBEmitter : public jit_emitter { void init_brgemm_copy(std::unique_ptr& kernel, size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const; - void emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, - Reg64 src, Reg64 dst, Reg64 comp, size_t N, size_t K, + void emit_kernel_call(const dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel, + Xbyak::Reg64 src, Xbyak::Reg64 dst, Xbyak::Reg64 comp, size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const; static void execute(dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, const void* comp, size_t N, size_t K); - std::unique_ptr kernel; + std::unique_ptr m_kernel; - ov::element::Type brgemm_prc_in0, brgemm_prc_in1; - size_t N, N_blk, N_tail; - size_t K, K_blk, K_tail; - size_t LDB; - size_t brgemmVNNIFactor; - bool with_comp = false; + ov::element::Type m_brgemm_prc_in0, m_brgemm_prc_in1; + size_t m_N, m_N_blk, m_N_tail; + size_t m_K, m_K_blk, m_K_tail; + size_t m_LDB; + size_t m_brgemmVNNIFactor; + bool m_with_comp = false; - size_t in_offset = 0lu; - size_t out_offset = 0lu; - size_t comp_offset = 0lu; + size_t m_in_offset = 0lu; + size_t m_out_offset = 0lu; + size_t m_comp_offset = 0lu; }; class HorizonMaxEmitter : public jit_emitter { diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index 0913da05184b5b..a9d6e08377c971 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -135,8 +135,8 @@ std::map Extension::getOpSets() { ngraph::OpSet opset; #define NGRAPH_OP(NAME, NAMESPACE) opset.insert(); - NGRAPH_OP(AllocationBuffer, ngraph::snippets::op) NGRAPH_OP(Brgemm, ngraph::snippets::op) + NGRAPH_OP(Buffer, ngraph::snippets::op) NGRAPH_OP(BroadcastLoad, ngraph::snippets::op) NGRAPH_OP(BroadcastMove, ngraph::snippets::op) NGRAPH_OP(ConvertSaturation, ngraph::snippets::op) @@ -144,7 +144,6 @@ std::map Extension::getOpSets() { NGRAPH_OP(Fill, ngraph::snippets::op) NGRAPH_OP(HorizonMax, ngraph::snippets::op) NGRAPH_OP(HorizonSum, ngraph::snippets::op) - NGRAPH_OP(IntermediateBuffer, ngraph::snippets::op) NGRAPH_OP(Kernel, ngraph::snippets::op) NGRAPH_OP(Load, ngraph::snippets::op) NGRAPH_OP(LoadReshape, ngraph::snippets::op) diff --git a/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp index a7e0d4b78601d3..a8b2efcdc9fb3c 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp @@ -29,13 +29,13 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { auto callback = [=](ngraph::pattern::Matcher& m) { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::BrgemmToBrgemmCPU") - auto& pm = m.get_pattern_value_map(); - const auto brgemm = ov::as_type_ptr(pm.at(m_brgemm).get_node_shared_ptr()); - const auto brgemm_plugin = ov::as_type_ptr(pm.at(m_brgemm).get_node_shared_ptr()); + const auto node = m.get_match_root(); + const auto brgemm = ov::as_type_ptr(node); + const auto brgemm_plugin = ov::as_type_ptr(node); if (!brgemm || brgemm_plugin) - return false; + throw ov::Exception("BrgemmCPU cannot be in body before BrgemmToBrgemmCPU pass"); - if (brgemm->get_input_partial_shape(0).is_dynamic() || brgemm->get_input_partial_shape(1).is_dynamic()) { + if (brgemm->is_dynamic()) { return false; } @@ -46,7 +46,6 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { const auto N = *dimsMatMulIn1.rbegin(); const auto element_type_a = brgemm->get_input_element_type(0); - const auto element_type_b = brgemm->get_input_element_type(1); const auto brgemmVNNIFactor = 4 / element_type_a.size(); const bool isAMXSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx); const bool with_amx = isAMXSupported && element_type_a != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0); @@ -57,30 +56,25 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { const auto offset_c = brgemm->get_offset_c(); std::shared_ptr brgemm_cpu = nullptr; - if (one_of(element_type_a, ov::element::f32)) { - brgemm_cpu = std::make_shared(brgemm->input_value(0), brgemm->input_value(1), - brgemm->transposed_a(), brgemm->transposed_b(), with_comp, + if (element_type_a == ov::element::f32) { + brgemm_cpu = std::make_shared(brgemm->input_value(0), brgemm->input_value(1), BrgemmCPU::Type::Floating, offset_a, offset_b, offset_c); } else { const auto layoutIn1 = ngraph::snippets::utils::get_node_output_layout(brgemm->input_value(1).get_node_shared_ptr()); - const auto brgemmRepackIn1 = std::make_shared(brgemm->input_value(1), element_type_a, with_comp, offset_b); - const auto buffer = std::make_shared(brgemmRepackIn1->output(0)); - - if (with_amx || with_comp) { - std::shared_ptr scratch = nullptr; - if (with_amx) { - const auto scratch_size = std::make_shared(ov::element::i32, ov::Shape{1}, std::vector{8 * 1024}); - scratch = std::make_shared(scratch_size, ov::element::f32); - } else if (with_comp) { - scratch = std::make_shared(brgemmRepackIn1->output(1)); - } - - brgemm_cpu = std::make_shared(brgemm->input_value(0), buffer, scratch, - brgemm->transposed_a(), brgemm->transposed_b(), with_comp, + const auto copy_b_type = with_comp ? BrgemmCopyB::WithCompensations : BrgemmCopyB::OnlyRepacking; + const auto brgemmRepackIn1 = std::make_shared(brgemm->input_value(1), element_type_a, copy_b_type, offset_b); + const auto buffer = std::make_shared(brgemmRepackIn1->output(0)); + + if (with_amx) { + const auto scratch = std::make_shared(ov::Shape{BrgemmCPU::SCRATCH_BYTE_SIZE}); + brgemm_cpu = std::make_shared(brgemm->input_value(0), buffer, scratch, BrgemmCPU::Type::AMX, + offset_a, offset_b, offset_c); + } else if (with_comp) { + const auto scratch = std::make_shared(brgemmRepackIn1->output(1)); + brgemm_cpu = std::make_shared(brgemm->input_value(0), buffer, scratch, BrgemmCPU::Type::WithCompensations, offset_a, offset_b, offset_c); } else if (one_of(element_type_a, ov::element::u8, ov::element::bf16)) { - brgemm_cpu = std::make_shared(brgemm->input_value(0), buffer, - brgemm->transposed_a(), brgemm->transposed_b(), with_comp, + brgemm_cpu = std::make_shared(brgemm->input_value(0), buffer, BrgemmCPU::Type::WithDataRepacking, offset_a, offset_b, offset_c); } else { IE_THROW() << "Invalid configuration for BRGEMM CPU"; @@ -88,7 +82,7 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { } brgemm_cpu->set_friendly_name(brgemm->get_friendly_name()); - ngraph::snippets::utils::set_output_layout(brgemm_cpu->output(0), ngraph::snippets::utils::get_node_output_layout(brgemm)); + ngraph::snippets::utils::set_transpose_output_layout(brgemm_cpu->output(0), ngraph::snippets::utils::get_node_output_layout(brgemm)); ngraph::copy_runtime_info(brgemm, brgemm_cpu); ngraph::replace_node(brgemm, brgemm_cpu); diff --git a/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp index bafaeca58fdbac..ddab96604dbb81 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp @@ -13,7 +13,24 @@ namespace pass { /** * @interface BrgemmToBrgemmCPU - * @brief TODO + * @brief The pass decompose Snippets Brgemm to specific subgraph that depends on ISA and input precisions: + * - f32|f32: + * BrgemmCPU + * - u8|i8 or bf16|bf16 (non-AMX system): + * \ BrgemmCopyB (the operation for data repacking) + * \ Buffer + * BrgemmCPU + * - i8|i8 (non-AMX system) - needs compensations: + * \ BrgemmCopyB + * \ / \ + * \ Buffer (with repacked data) Buffer (with compensations) + * \ | / + * BrgemmCPU + * - i8|i8 or bf16|bf16 on AMX system: + * \ BrgemmCopyB + * \ Buffer (with repacked data) Buffer (with new memory) + * \ | / + * BrgemmCPU * @ingroup snippets */ class BrgemmToBrgemmCPU: public ngraph::pass::MatcherPass { diff --git a/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp b/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp index d3c14212694ac3..0c64a20b655ed9 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp @@ -15,14 +15,12 @@ ov::intel_cpu::pass::FuseLoadConvert::FuseLoadConvert() { MATCHER_SCOPE(FuseLoadConvert); - auto param_pattern = ngraph::pattern::wrap_type(); - auto load_pattern = ngraph::pattern::wrap_type({param_pattern}); + auto load_pattern = ngraph::pattern::wrap_type(); auto convert_pattern = ngraph::pattern::wrap_type({load_pattern}); auto callback = [=](ngraph::pattern::Matcher& m) { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::FuseLoadConvert") auto& pm = m.get_pattern_value_map(); - const auto param = pm.at(param_pattern).get_node_shared_ptr(); const auto load_shared = pm.at(load_pattern).get_node_shared_ptr(); if (!load_shared || load_shared->output(0).get_target_inputs().size() != 1) { return false; @@ -39,12 +37,12 @@ ov::intel_cpu::pass::FuseLoadConvert::FuseLoadConvert() { std::shared_ptr load_convert = nullptr; if (const auto convert_saturation = std::dynamic_pointer_cast(convert)) { - load_convert = std::make_shared(param, + load_convert = std::make_shared(load->input_value(0), convert_saturation->get_destination_type(), load->get_count(), load->get_offset()); } else if (const auto convert_truncation = std::dynamic_pointer_cast(convert)) { - load_convert = std::make_shared(param, + load_convert = std::make_shared(load->input_value(0), convert_truncation->get_destination_type(), load->get_count(), load->get_offset()); } else { @@ -80,7 +78,6 @@ ov::intel_cpu::pass::FuseStoreConvert::FuseStoreConvert() { const auto store = std::dynamic_pointer_cast(pm.at(store_pattern).get_node_shared_ptr()); if (!store) return false; - const auto desc = store->get_output_port_descriptor(0); const auto convert = pm.at(convert_pattern).get_node_shared_ptr(); if (convert->output(0).get_target_inputs().size() != 1 || transformation_callback(convert)) @@ -91,12 +88,12 @@ ov::intel_cpu::pass::FuseStoreConvert::FuseStoreConvert() { std::dynamic_pointer_cast(convert)) { store_convert = std::make_shared(input, convert_saturation->get_destination_type(), - desc.m_count, desc.m_offset); + store->get_count(), store->get_offset()); } else if (const auto convert_truncation = std::dynamic_pointer_cast(convert)) { store_convert = std::make_shared(input, convert_truncation->get_destination_type(), - desc.m_count, desc.m_offset); + store->get_count(), store->get_offset()); } else { throw ngraph::ngraph_error( "Type of Convert op is undefined. Supports only fusing Store and ConvertTruncation or ConvertSaturation ops"); diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp index 5f4b54b110fbae..0d19f1a6999e9f 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp @@ -12,30 +12,28 @@ using namespace std; using namespace ov; -intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output& x, const element::Type src_type, const bool with_comp, - const size_t offset_in, const size_t offset_out0, const size_t offset_out1) - : ngraph::snippets::op::MemoryAccess({x}), m_with_comp(with_comp), m_src_type(src_type) { +intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output& x, const element::Type src_type, const Type type, + const size_t offset_in, const size_t offset_out0, const size_t offset_out1) + : ngraph::snippets::op::MemoryAccess({x}), m_type(type), m_src_type(src_type) { + set_output_size(is_with_compensations() ? 2 : 1); + constructor_validate_and_infer_types(); set_input_port_descriptor({0, offset_in}, 0); set_output_port_descriptor({0, offset_out0}, 0); - if (with_comp) { + if (is_with_compensations()) { set_output_port_descriptor({0, offset_out1}, 1); - set_output_size(2); - } else { - set_output_size(1); } - constructor_validate_and_infer_types(); } bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(BrgemmRepack_visit_attributes); MemoryAccess::visit_attributes(visitor); - visitor.on_attribute("with_comp", m_with_comp); visitor.on_attribute("src_type", m_src_type); return true; } void intel_cpu::BrgemmCopyB::validate_and_infer_types() { INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types); + MemoryAccess::validate_and_infer_types(); const auto element_type = get_input_element_type(0); NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8), @@ -44,7 +42,7 @@ void intel_cpu::BrgemmCopyB::validate_and_infer_types() { const auto pshape = ngraph::snippets::utils::get_port_planar_shape(input_value(0)); if (pshape.is_dynamic()) { set_output_type(0, element_type, ov::PartialShape{ov::Dimension::dynamic()}); - if (m_with_comp) { + if (is_with_compensations()) { set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension::dynamic()}); } return; @@ -58,7 +56,7 @@ void intel_cpu::BrgemmCopyB::validate_and_infer_types() { set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, brgemmVNNIFactor)), ov::Dimension(rnd_up(N, N_blk))}); - if (m_with_comp) { + if (is_with_compensations()) { set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, N_blk))}); } } @@ -66,8 +64,14 @@ void intel_cpu::BrgemmCopyB::validate_and_infer_types() { std::shared_ptr intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_src_type, m_with_comp, + return std::make_shared(new_args.at(0), m_src_type, m_type, get_offset_in(), get_offset_out(), - m_with_comp ? get_offset_comp() : 0); + is_with_compensations() ? get_offset_compensations() : 0); +} + +size_t intel_cpu::BrgemmCopyB::get_offset_compensations() const { + OPENVINO_ASSERT(is_with_compensations() && get_output_size() == 2, + "The offset for compensations must be in BrgemmCopyB only with compensations and 2 outputs!"); + return get_output_offset(1); } diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp index da118d5cd35cc9..9d33f098e00cc3 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp @@ -17,16 +17,22 @@ namespace intel_cpu { class BrgemmCopyB : public ngraph::snippets::op::MemoryAccess { public: OPENVINO_OP("BrgemmCopyB", "SnippetsOpset", MemoryAccess); - BrgemmCopyB(const Output& x, const element::Type src_type, const bool with_comp = false, - const size_t offset_in = 0lu, const size_t offset_out0 = 0lu, const size_t offset_out1 = 0lu); + + enum Type { + OnlyRepacking, // Just data repacking - one output + WithCompensations, // Repack data and caclulate compensations - 2 outputs (is needed for BrgemmCPU with compensations) + }; + + BrgemmCopyB(const Output& x, const element::Type src_type, const Type type = Type::OnlyRepacking, + const size_t offset_in = 0lu, const size_t offset_out0 = 0lu, const size_t offset_out1 = 0lu); BrgemmCopyB() = default; - size_t get_offset_in() const { return get_input_port_descriptor(0).m_offset; } - size_t get_offset_out() const { return get_output_port_descriptor(0).m_offset; } - size_t get_offset_comp() const { return get_output_port_descriptor(1).m_offset; } + size_t get_offset_in() const { return get_input_offset(0); } + size_t get_offset_out() const { return get_output_offset(0); } + size_t get_offset_compensations() const; element::Type get_src_element_type() const { return m_src_type; } - bool is_with_comp() const { return m_with_comp; } + bool is_with_compensations() const { return m_type == Type::WithCompensations; } bool visit_attributes(AttributeVisitor& visitor) override; void validate_and_infer_types() override; @@ -34,7 +40,7 @@ class BrgemmCopyB : public ngraph::snippets::op::MemoryAccess { std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; private: - bool m_with_comp = false; + Type m_type; element::Type m_src_type; // src element type of the corresponding BRGEMM }; diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp index d889a623914b7d..ceddc612d831b5 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp @@ -13,50 +13,43 @@ namespace ov { namespace intel_cpu { -BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, bool transposed_a, bool transposed_b, const bool with_comp, +BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Type type, const size_t offset_a, const size_t offset_b, const size_t offset_c) - : Brgemm(), m_with_comp(with_comp) { + : Brgemm(), m_type(type) { // We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call set_arguments({A, B}); set_output_size(1); + constructor_validate_and_infer_types(); set_input_port_descriptor({0, offset_a}, 0); set_input_port_descriptor({0, offset_b}, 1); set_output_port_descriptor({0, offset_c}, 0); - m_transposed_a = transposed_a; - m_transposed_b = transposed_b; - constructor_validate_and_infer_types(); } -BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output& scratch, - bool transposed_a, bool transposed_b, const bool with_comp, +BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output& scratch, const Type type, const size_t offset_a, const size_t offset_b, const size_t offset_scratch, const size_t offset_c) - : Brgemm(), m_with_comp(with_comp) { + : Brgemm(), m_type(type) { set_arguments({A, B, scratch}); set_output_size(1); + constructor_validate_and_infer_types(); set_input_port_descriptor({0, offset_a}, 0); set_input_port_descriptor({0, offset_b}, 1); set_output_port_descriptor({0, offset_c}, 0); set_input_port_descriptor({0, offset_scratch}, 2); - m_transposed_a = transposed_a; - m_transposed_b = transposed_b; - constructor_validate_and_infer_types(); -} - -bool BrgemmCPU::visit_attributes(AttributeVisitor& visitor) { - MemoryAccess::visit_attributes(visitor); - visitor.on_attribute("transposed_a", m_transposed_a); - visitor.on_attribute("transposed_b", m_transposed_b); - visitor.on_attribute("with_comp", m_with_comp); - return true; } void BrgemmCPU::validate_and_infer_types() { INTERNAL_OP_SCOPE(BrgemmCPU_validate_and_infer_types); + MemoryAccess::validate_and_infer_types(); // If no leading dimensions are provided, assume dense row-major inputs-outputs NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), "BrgemmCPU currently supports only static shapes."); - const auto brgemm_copy = get_brgemm_copy(); + OPENVINO_ASSERT(implication(one_of(m_type, Type::Floating, Type::WithDataRepacking), get_input_size() == 2), + "BrgemmCPU expects 2 inputs in cases, when input precisions are f32|f32, u8|i8 or bf16|bf16 (non-AMX system)"); + OPENVINO_ASSERT(implication(one_of(m_type, Type::WithCompensations, Type::AMX), get_input_size() == 3), + "BrgemmCPU expects 3 inputs with input precisions i8|i8 and bf16|bf16 on AMX system"); + + const auto brgemm_copy = is_with_data_repacking() ? get_brgemm_copy() : nullptr; std::vector planar_input_shapes = { ngraph::snippets::utils::get_port_planar_shape(input_value(0)), ngraph::snippets::utils::get_port_planar_shape(brgemm_copy ? brgemm_copy->input_value(0) : input_value(1)) @@ -68,12 +61,12 @@ void BrgemmCPU::validate_and_infer_types() { get_output_type(), ngraph::snippets::utils::get_reordered_planar_shape(output_shape, output_layout)); - // Verify Scratch input - if (get_input_size() == 3) { + //Additional check for 3rd input + if (one_of(m_type, Type::WithCompensations, Type::AMX)) { const auto shape = get_input_partial_shape(2); NGRAPH_CHECK(shape.is_static(), "BRGEMM Scratch must have static shape"); const auto type = get_input_element_type(2); - if (m_with_comp) { + if (is_with_compensations()) { const auto element_type_b = get_input_element_type(0); const auto shape_b = planar_input_shapes[1].get_shape(); const auto N = *shape_b.rbegin(); @@ -84,8 +77,8 @@ void BrgemmCPU::validate_and_infer_types() { NGRAPH_CHECK(expected_shape == shape.get_shape() && expected_type == type, "BRGEMM Scratch with compensations must have shape {rnd_up(N, N_blk)} and FP32 element type"); } else { - NGRAPH_CHECK(ngraph::shape_size(shape.get_shape()) == 8 * 1024 && type == ov::element::f32, - "BRGEMM Scratch for space workplace must be static, have FP32 element type and 8x1024 shape size"); + NGRAPH_CHECK(ngraph::shape_size(shape.get_shape()) == SCRATCH_BYTE_SIZE && type == ov::element::u8, + "BRGEMM Scratch for space workplace must be static, have U8 element type and size is equal to " + std::to_string(SCRATCH_BYTE_SIZE)); } } } @@ -94,23 +87,27 @@ std::shared_ptr BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a INTERNAL_OP_SCOPE(BrgemmCPU_clone_with_new_inputs); check_new_args_count(this, new_args); std::shared_ptr new_node = nullptr; - if (new_args.size() == 2) { - new_node = std::make_shared(new_args.at(0), new_args.at(1), - m_transposed_a, m_transposed_b, m_with_comp, + if (!is_with_scratchpad()) { + new_node = std::make_shared(new_args.at(0), new_args.at(1), m_type, get_offset_a(), get_offset_b(), get_offset_c()); } else { - new_node = std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), - m_transposed_a, m_transposed_b, m_with_comp, + new_node = std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_type, get_offset_a(), get_offset_b(), get_offset_scratch(), get_offset_c()); } return new_node; } std::shared_ptr BrgemmCPU::get_brgemm_copy() const { - if (const auto buffer = ov::as_type_ptr(get_input_node_shared_ptr(1))) { + OPENVINO_ASSERT(one_of(m_type, Type::WithDataRepacking, Type::WithCompensations, Type::AMX), "Brgemm doesn't need BrgemmCopyB"); + if (const auto buffer = ov::as_type_ptr(get_input_node_shared_ptr(1))) { return ov::as_type_ptr(buffer->get_input_node_shared_ptr(0)); } - return nullptr; + throw ov::Exception("BrgemmCopyB hasn't been found!"); +} + +size_t BrgemmCPU::get_offset_scratch() const { + OPENVINO_ASSERT(is_with_scratchpad() && get_input_size() == 3, "Offset of scratchpad must be only in Brgemm with scratchpad on 3rd input"); + return get_input_offset(2); } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp index c17b034868a63b..551fe9b7151405 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp @@ -19,22 +19,36 @@ namespace intel_cpu { class BrgemmCPU : public ngraph::snippets::op::Brgemm { public: OPENVINO_OP("BrgemmCPU", "SnippetsOpset", ngraph::snippets::op::Brgemm); - BrgemmCPU(const Output& A, const Output& B, bool transposed_a = false, bool transposed_b = false, const bool with_comp = false, + + enum Type { + Floating, // f32|f32 + WithDataRepacking, // u8|i8 or bf16|bf16 (non-AMX system) - needs BrgemmCopyB on second input for data repacking + WithCompensations, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations + AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad + }; + + BrgemmCPU(const Output& A, const Output& B, const Type type, const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_c = 0); - BrgemmCPU(const Output& A, const Output& B, const Output& scratch, - bool transposed_a = false, bool transposed_b = false, const bool with_comp = false, + BrgemmCPU(const Output& A, const Output& B, const Output& scratch, const Type type, const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_scratch = 0, const size_t offset_c = 0); BrgemmCPU() = default; - bool visit_attributes(AttributeVisitor& visitor) override; void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; - size_t get_offset_scratch() const { return get_input_port_descriptor(2).m_offset; } + Type get_type() const { return m_type; } + bool is_with_compensations() const { return m_type == Type::WithCompensations; } + bool is_with_data_repacking() const { return m_type != Type::Floating; } + bool is_amx() const { return m_type == Type::AMX; } + bool is_with_scratchpad() const { return is_with_compensations() || is_amx(); } + + size_t get_offset_scratch() const; std::shared_ptr get_brgemm_copy() const; + constexpr static size_t SCRATCH_BYTE_SIZE = 32 * 1024; + private: - bool m_with_comp = false; // compensations + Type m_type; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp index f003581c1e3eb0..9d792f35264066 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -63,24 +63,6 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBias, MatMulBias, ::testing::Values(CommonTestUtils::DEVICE_CPU)), MatMul::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ExplicitTransposeMatMul, ExplicitTransposeMatMul, - ::testing::Combine( - ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}}), - ::testing::ValuesIn(precisions()), - ::testing::Values(1), // Subgraph; - ::testing::Values(1), // Tokenized MatMul+Bias - ::testing::Values(CommonTestUtils::DEVICE_CPU)), - ExplicitTransposeMatMul::getTestCaseName); - -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulBias, ExplicitTransposeMatMulBias, - ::testing::Combine( - ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 1, 69, 49}}), - ::testing::ValuesIn(precisions()), - ::testing::Values(1), // Subgraph; - ::testing::Values(1), // Tokenized MatMul+Bias - ::testing::Values(CommonTestUtils::DEVICE_CPU)), - MatMul::getTestCaseName); - } // namespace } // namespace snippets } // namespace test diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp index bcebd4af2eebc1..6423f5a3db418f 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp @@ -43,7 +43,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, ::testing::Values(CommonTestUtils::DEVICE_CPU)), TransposeMatMul::getTestCaseName); -// TODO: FuseTransposeToBrgemm supports fusing only if Transpose is before Parameter in cases when Transpose is on input +// TODO: FuseTransposeToBrgemm supports fusing only if Transpose is before Parameter in cases when Transpose is on input at the moment +// When we support the branch Parameter->FQ->Transpose->MatMul[0th input], uncomment this test case please // INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ, // ::testing::Combine( // ::testing::ValuesIn(transpose_input_shapes), @@ -84,7 +85,6 @@ namespace transpose_output { std::vector> transpose_input_shapes{ {{2, 1, 49, 13}, {1, 2, 13, 39}} }; -// TODO: Propagate shape through Brgemm with Transpose down INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, ::testing::Combine( ::testing::ValuesIn(transpose_input_shapes), @@ -95,7 +95,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, ::testing::Values(CommonTestUtils::DEVICE_CPU)), TransposeMatMul::getTestCaseName); -// TODO: Propagate shape through Brgemm with Transpose down +// TODO: At the moment we doesn't support the branch MatMul[output]->Transpose->FQ. +// When we add support, uncomment this test case please // INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ, // ::testing::Combine( // ::testing::ValuesIn(transpose_input_shapes), @@ -107,6 +108,45 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, // TransposeMatMulFQ::getTestCaseName); } // namespace transpose_output +namespace explicit_transpose { +static inline std::vector> precisions(bool only_fp32 = true) { + std::vector> prc = { + {element::f32, element::f32}, + }; + if (!only_fp32) { + // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms + if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { + prc.emplace_back(std::vector{element::i8, element::i8}); + prc.emplace_back(std::vector{element::u8, element::i8}); + } + // In Snippets MatMul BF16 is supported only on bf16/AMX platforms + if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) { + prc.emplace_back(std::vector{element::bf16, element::bf16}); + } + } + return prc; +} +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ExplicitTransposeMatMul, ExplicitTransposeMatMul, + ::testing::Combine( + ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}}), + ::testing::Values(1), // Transpose on second input + ::testing::ValuesIn(precisions()), + ::testing::Values(1), // Subgraph; + ::testing::Values(1), // Tokenized MatMul+Bias + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ExplicitTransposeMatMul::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulBias, ExplicitTransposeMatMulBias, + ::testing::Combine( + ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 1, 69, 49}}), + ::testing::Values(1), // Transpose on second input + ::testing::ValuesIn(precisions()), + ::testing::Values(1), // Subgraph; + ::testing::Values(1), // Tokenized MatMul+Bias + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ExplicitTransposeMatMulBias::getTestCaseName); +} // namespace explicit_transpose + } // namespace } // namespace snippets } // namespace test diff --git a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp index 9ce8cf8a6258b9..3e2a0ab015e988 100644 --- a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp @@ -37,16 +37,6 @@ class MatMulBias : public MatMul { void SetUp() override; }; -class ExplicitTransposeMatMul : public MatMul { -protected: - void SetUp() override; -}; - -class ExplicitTransposeMatMulBias : public MatMul { -protected: - void SetUp() override; -}; - } // namespace snippets } // namespace test } // namespace ov \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp index 6be2324e938fad..6eadc733042151 100644 --- a/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp @@ -33,6 +33,16 @@ class TransposeMatMulFQ : public TransposeMatMul { void SetUp() override; }; +class ExplicitTransposeMatMul : public TransposeMatMul { +protected: + void SetUp() override; +}; + +class ExplicitTransposeMatMulBias : public TransposeMatMul { +protected: + void SetUp() override; +}; + } // namespace snippets } // namespace test } // namespace ov \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp index 5c2a85bf92730e..06a37e2fd1ffed 100644 --- a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp @@ -71,34 +71,6 @@ void MatMulBias::SetUp() { } } -void ExplicitTransposeMatMul::SetUp() { - std::vector input_shapes; - std::vector elem_types; - std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - - auto f = ov::test::snippets::TransposeMatMulFunction(input_shapes); - function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } -} - -void ExplicitTransposeMatMulBias::SetUp() { - std::vector input_shapes; - std::vector elem_types; - std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - - auto f = ov::test::snippets::TransposeMatMulBiasFunction(input_shapes); - function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } -} - TEST_P(MatMul, CompareWithRefImpl) { SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); @@ -117,18 +89,6 @@ TEST_P(MatMulBias, CompareWithRefImpl) { validateNumSubgraphs(); } -TEST_P(ExplicitTransposeMatMul, CompareWithRefImpl) { - SKIP_IF_CURRENT_TEST_IS_DISABLED() - run(); - validateNumSubgraphs(); -} - -TEST_P(ExplicitTransposeMatMulBias, CompareWithRefImpl) { - SKIP_IF_CURRENT_TEST_IS_DISABLED() - run(); - validateNumSubgraphs(); -} - } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp index 60fb17c5788a37..f3fc23c2ce4714 100644 --- a/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp @@ -19,11 +19,10 @@ std::string TransposeMatMul::getTestCaseName(testing::TestParamInfo input_shapes; + size_t transpose_position; + std::vector elem_types; + std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::TransposeMatMulFunction(input_shapes); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +void ExplicitTransposeMatMulBias::SetUp() { + std::vector input_shapes; + size_t transpose_position; + std::vector elem_types; + std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::TransposeMatMulBiasFunction(input_shapes); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + TEST_P(TransposeMatMul, CompareWithRefImpl) { SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); @@ -75,6 +104,18 @@ TEST_P(TransposeMatMulFQ, CompareWithRefImpl) { validateNumSubgraphs(); } +TEST_P(ExplicitTransposeMatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + +TEST_P(ExplicitTransposeMatMulBias, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + } // namespace snippets } // namespace test } // namespace ov diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp index 985a38e9750682..6c818b6078cdc6 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp @@ -195,7 +195,7 @@ std::shared_ptr SoftmaxLoweredFunction::initLowered() const { horizon_sum->add_control_dependency(loop_sum_end); const auto size_exp = std::make_shared(ov::element::i32, ov::Shape{2}); - const auto buffer_exp = std::make_shared(loop_sum_end->output(0), size_exp); + const auto buffer_exp = std::make_shared(loop_sum_end->output(0)); loop_sum_begin->add_control_dependency(vector_buffer_sum); loop_sum_begin->add_control_dependency(horizon_max); @@ -305,7 +305,7 @@ std::shared_ptr AddSoftmaxLoweredFunction::initLowered() const { /* =========================================== */ const auto size_add = std::make_shared(ov::element::i32, ov::Shape{2}); - const auto buffer_add = std::make_shared(loop_max_end->output(0), size_add); + const auto buffer_add = std::make_shared(loop_max_end->output(0)); /* === Sub + Exp + ReduceSum decomposition === */ @@ -334,7 +334,7 @@ std::shared_ptr AddSoftmaxLoweredFunction::initLowered() const { horizon_sum->add_control_dependency(loop_sum_end); const auto size_exp = std::make_shared(ov::element::i32, ov::Shape{2}); - const auto buffer_exp = std::make_shared(loop_sum_end->output(0), size_exp); + const auto buffer_exp = std::make_shared(loop_sum_end->output(0)); loop_sum_begin->add_control_dependency(vector_buffer_sum); loop_sum_begin->add_control_dependency(horizon_max);