From b660da8ee272268dcce88cfb81802d1eb4c8424d Mon Sep 17 00:00:00 2001 From: Ivan Novoselov Date: Fri, 21 Jun 2024 11:04:31 +0100 Subject: [PATCH] Integrate recompilation infrastructure into RuntimeConfigurator (#24955) ### Details: - *Integrate dynamic executors recompilation infrastructure into RuntimeConfigurator* - *Allow RuntimeConfigurator to recompile dynamic kernel executors in runtime* - *Employ this approach to enable dynamic MatMul tests (fp32)* ### Tickets: - *143257* --- .../snippets/kernel_executor_table.hpp | 114 ++++++++++-- .../snippets/lowered/linear_ir_builder.hpp | 7 +- .../snippets/include/snippets/op/brgemm.hpp | 5 +- .../include/snippets/runtime_configurator.hpp | 6 +- .../include/snippets/target_machine.hpp | 2 - .../snippets/include/snippets/utils.hpp | 38 ++++ .../snippets/src/kernel_executor_table.cpp | 39 ++++ .../src/lowered/linear_ir_builder.cpp | 3 +- src/common/snippets/src/op/brgemm.cpp | 28 +-- src/common/snippets/src/op/subgraph.cpp | 11 +- .../snippets/src/pass/collapse_subgraph.cpp | 4 +- .../snippets/src/pass/matmul_to_brgemm.cpp | 11 +- .../snippets/src/runtime_configurator.cpp | 3 +- src/common/snippets/src/utils.cpp | 18 ++ .../snippets/cpu_kernel_executor_table.hpp | 18 +- .../snippets/cpu_runtime_configurator.cpp | 7 +- .../emitters/snippets/x64/cpu_generator.cpp | 3 +- .../x64/jit_brgemm_copy_b_emitter.cpp | 2 +- .../snippets/x64/jit_brgemm_emitter.cpp | 91 ++-------- .../snippets/x64/jit_brgemm_emitter.hpp | 3 - .../snippets/x64/kernel_executors/brgemm.cpp | 167 ++++++++++-------- .../snippets/x64/kernel_executors/brgemm.hpp | 67 ++++--- .../src/emitters/snippets/x64/verbose.cpp | 2 +- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 8 +- .../snippets/x64/op/brgemm_cpu.cpp | 15 +- .../x64/pass/brgemm_to_brgemm_cpu.cpp | 12 +- .../pass/set_brgemm_cpu_blocking_params.cpp | 28 +-- .../snippets/matmul.cpp | 69 ++++++-- .../x64/lowered/brgemm_blocking.cpp | 7 +- .../plugin/shared/include/snippets/matmul.hpp | 14 +- .../plugin/shared/src/snippets/matmul.cpp | 35 ++-- 31 files changed, 517 insertions(+), 320 deletions(-) create mode 100644 src/common/snippets/src/kernel_executor_table.cpp diff --git a/src/common/snippets/include/snippets/kernel_executor_table.hpp b/src/common/snippets/include/snippets/kernel_executor_table.hpp index 757e51f015a8f3..bfff0d9d4f778d 100644 --- a/src/common/snippets/include/snippets/kernel_executor_table.hpp +++ b/src/common/snippets/include/snippets/kernel_executor_table.hpp @@ -4,8 +4,10 @@ #pragma once -#include "snippets/lowered/expression.hpp" - +#include "snippets/lowered/linear_ir.hpp" +#if defined(SNIPPETS_DEBUG_CAPS) && !defined(_WIN32) +#include +#endif namespace ov { namespace snippets { @@ -23,8 +25,38 @@ class KernelExecutorBase { * while dynamic kernels will be completed only in runtime, when all the shapes are known. */ virtual bool is_completed() const = 0; + + /*** Return deep copy of the config */ + virtual std::shared_ptr clone() const = 0; + + /*** Compute hash for fast comparison operations or caching support */ + virtual size_t hash() const = 0; + + bool operator==(const GenericConfig& rhs) const { return hash() == rhs.hash(); } + bool operator!=(const GenericConfig& rhs) const { return hash() != rhs.hash(); } + virtual ~GenericConfig() = default; + /** serialize config for debug purposes */ +#ifdef SNIPPETS_DEBUG_CAPS + virtual std::string to_string() const = 0; +#endif }; + /** + * @brief Update current kernel config in accordance with the passed expression. Corresponding kernel is recompiled if necessary. + * This method should be called to update KernelExecutor based on runtime info (e.g. shapes) available through expression ptr + */ + virtual void update_by_expression(const ov::snippets::lowered::ExpressionPtr& expr) = 0; + /** + * @brief Replace current kernel config with the provided value. Corresponding kernel is recompiled if necessary. + * This method should be called to restore a saved state of the executor, that was configured using update_by_expression(). + */ + virtual void update_by_config(const std::shared_ptr& new_config) = 0; + + virtual std::shared_ptr get_config() const = 0; + /** serialize for debug purposes */ +#ifdef SNIPPETS_DEBUG_CAPS + virtual std::string to_string() const = 0; +#endif virtual ~KernelExecutorBase() = default; private: @@ -38,17 +70,47 @@ template c) : KernelExecutorBase(), m_config{std::move(c)} {} - /** - * @brief check current config and recompile kernel if necessary. Use kernel caching to avoid redundant recompilations. - * This method must be called only for complete configs. It's the user responsibility to check is_completed() before calling. - */ - virtual void update_kernel() = 0; + + // Note: override when final is redundant, but needed to avoid warnings on some compilers + void update_by_expression(const ov::snippets::lowered::ExpressionPtr& expr) override final { // NOLINT + m_config = std::static_pointer_cast(m_config->clone()); + update_config(expr, m_config); + OPENVINO_ASSERT(m_config && m_config->is_completed(), "Failed to update kernel config in update_by_expression"); + update_kernel(m_config, m_kernel); + OPENVINO_ASSERT(m_kernel, "Failed to compile kernel executor"); + } + void update_by_config(const std::shared_ptr& new_config) override final { // NOLINT + if (*m_config == *new_config) + return; + m_config = std::static_pointer_cast(std::const_pointer_cast(new_config)); + OPENVINO_ASSERT(m_config && m_config->is_completed(), "Failed to update kernel config in get_config"); + update_kernel(m_config, m_kernel); + OPENVINO_ASSERT(m_kernel, "Failed to compile kernel executor"); + } + std::shared_ptr get_config() const override { return m_config; } + std::shared_ptr get_kernel() const { return m_kernel; } +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const override { + std::string type_name = typeid(KernelType).name(); +#ifndef _WIN32 + int status; + std::unique_ptr demangled_name( + abi::__cxa_demangle(type_name.c_str(), nullptr, nullptr, &status), + std::free); + type_name = demangled_name.get(); +#endif + return "KernelExecutorType: " + std::string(type_name) + " KernelConfig: " + m_config->to_string(); + } +#endif + protected: - /** - * @brief Takes shared_ptr to compilation config, returns shared_ptr to compiled kernel. - * Should be called only if actual compilation is required. Kernel caching must be implemented in update_kernel(). - */ - virtual std::shared_ptr compile_kernel(const std::shared_ptr& c) const = 0; + /*** Updates stored kernel config based on runtime info from expression (e.g. new input shapes). */ + virtual void update_config(const ov::snippets::lowered::ExpressionPtr& expr, std::shared_ptr& config) const = 0; + /*** Updates stored kernel in accordance with the passed config. Recompilation of the kernel is + * performed only if necessary, otherwise an appropriate kernel is retrieved from cache. */ + virtual void update_kernel(const std::shared_ptr& c, std::shared_ptr& kernel) const = 0; + +private: /** Contains all the necessary information to compile a desired kernel*/ std::shared_ptr m_config = nullptr; /** Stores pointer to compiled kernel since the last update_kernel() call */ @@ -57,6 +119,7 @@ class KernelExecutor : public snippets::KernelExecutorBase { class KernelExecutorTable { public: + /*** Register KernelExecutor in the KernelExecutorTable so it can be later updated in runtime. */ template::value, bool>::type = true> std::shared_ptr register_kernel(const snippets::lowered::ExpressionPtr& expr, C... args) { @@ -69,10 +132,37 @@ class KernelExecutorTable { OPENVINO_ASSERT(m_table.count(expr), "This expression doesn't have a registered kernel executor"); return m_table.at(expr); } + /*** Updates every registered KernelExecutor in accordance with the corresponding expression */ + void update_state() const { + for (const auto& record : m_table) + record.second->update_by_expression(record.first); + } + + /*** Returns lambda function that contains current state of the table, and restores this state when called */ + std::function get_state_reset() { + auto current_state = get_state(); + return [=]() { reset_state(current_state); }; + } + + /** + * @brief Replace originally registered ExpressionPtr with a new value. + * Note that code emission is performed on a copy of LIR, so all expression pointers visible from emitters won't + * be accessible from RuntimeConfigurator. In order to replace these cloned ExpressionPtrs with the original ones, + * we need to call this method. + */ + void replace_key_expression(const snippets::lowered::ExpressionPtr& from, const snippets::lowered::ExpressionPtr& to); + virtual ~KernelExecutorTable() = default; protected: std::unordered_map> m_table{}; + typedef std::vector>> ExecTableState; + + /*** Restore the table state previously obtained by get_state() */ + void reset_state(const ExecTableState& state); + + /*** Return cumulative state of all the executors in the table. The returned ExecTableState object can be passed to reset_state */ + ExecTableState get_state() const; }; using KernelExecutorTablePtr = std::shared_ptr; diff --git a/src/common/snippets/include/snippets/lowered/linear_ir_builder.hpp b/src/common/snippets/include/snippets/lowered/linear_ir_builder.hpp index 969bf21cd27480..afd778047c9279 100644 --- a/src/common/snippets/include/snippets/lowered/linear_ir_builder.hpp +++ b/src/common/snippets/include/snippets/lowered/linear_ir_builder.hpp @@ -29,9 +29,14 @@ class LinearIRBuilder { /** * @brief Make a full copy of LinearIR by rules described in `m_config` * @param linear_ir Linear IR + * @param expression_map expression map * @return clone of `linear_ir` */ - std::shared_ptr clone(const std::shared_ptr& linear_ir) const; + std::shared_ptr clone(const std::shared_ptr& linear_ir, ExpressionMap& expression_map) const; + inline std::shared_ptr clone(const std::shared_ptr& linear_ir) const { + ExpressionMap expression_map; + return clone(linear_ir, expression_map); + } /** * @brief Make a copy of LinearIR range by rules described in `m_config` * @param begin begin iterator of the target range of LinearIR diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp index a170b02b15346e..d7b179246366e9 100644 --- a/src/common/snippets/include/snippets/op/brgemm.hpp +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -55,9 +55,9 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op { protected: ov::element::Type get_output_type() const; std::vector get_planar_input_shapes(const std::vector>& inputs) const; - ov::PartialShape get_output_partial_shape(const std::vector& input_shapes) const; + ov::PartialShape infer_output_partial_shape(const std::vector& input_shapes) const; ov::PartialShape get_planar_output_shape(const ov::PartialShape& output_shape) const; - void compute_block_size_values(size_t blk_size_m, size_t blk_size_k, size_t blk_size_n); + void set_block_size_values(size_t blk_size_m, size_t blk_size_k, size_t blk_size_n); size_t m_M_blk = 0; size_t m_K_blk = 0; size_t m_N_blk = 0; @@ -65,7 +65,6 @@ class Brgemm : virtual public modifier::MemoryAccess, public ov::op::Op { private: void custom_constructor_validate_and_infer_types(std::vector layout_a, std::vector layout_b, std::vector layout_c); - void validate_inputs() const; }; } // namespace op diff --git a/src/common/snippets/include/snippets/runtime_configurator.hpp b/src/common/snippets/include/snippets/runtime_configurator.hpp index 10a2c26b1d843a..059771d961df82 100644 --- a/src/common/snippets/include/snippets/runtime_configurator.hpp +++ b/src/common/snippets/include/snippets/runtime_configurator.hpp @@ -5,6 +5,7 @@ #pragma once #include "snippets/lowered/linear_ir.hpp" +#include "snippets/kernel_executor_table.hpp" #include "snippets/lowered/pass/pass.hpp" namespace ov { @@ -42,7 +43,8 @@ class RuntimeConfig { ov::snippets::VectorDims master_shape = {}; size_t buffer_scratchpad_size = 0; - std::vector buffer_cluster_offsets; + std::vector buffer_cluster_offsets {}; + KernelExecutorTablePtr kernel_executor_table = std::make_shared(); }; /** @@ -60,6 +62,8 @@ class RuntimeConfigurator { * @return updated config */ const std::shared_ptr& get_updated_config(const std::shared_ptr& linear_ir); + /*** Returns pointer to KernelExecutorTable owned by the config */ + const std::shared_ptr& get_kernel_executor_table() const { return m_config->kernel_executor_table; } protected: /** diff --git a/src/common/snippets/include/snippets/target_machine.hpp b/src/common/snippets/include/snippets/target_machine.hpp index f514e1a944ca3e..d9d89264fe1926 100644 --- a/src/common/snippets/include/snippets/target_machine.hpp +++ b/src/common/snippets/include/snippets/target_machine.hpp @@ -10,7 +10,6 @@ #include "emitter.hpp" #include "snippets/lowered/expression.hpp" -#include "kernel_executor_table.hpp" namespace ov { namespace snippets { @@ -94,7 +93,6 @@ class TargetMachine { protected: std::map jitters; - std::shared_ptr kernel_executor_table; std::shared_ptr configurator; }; diff --git a/src/common/snippets/include/snippets/utils.hpp b/src/common/snippets/include/snippets/utils.hpp index 8485b87bb63066..99fb9a3a4196ff 100644 --- a/src/common/snippets/include/snippets/utils.hpp +++ b/src/common/snippets/include/snippets/utils.hpp @@ -243,6 +243,44 @@ std::shared_ptr get_leaf_node_of_first_child_shape_infer_seq(const std */ std::shared_ptr get_leaf_node_of_first_parent_shape_infer_seq(const std::shared_ptr& start_node); +/** + * @brief Calculate leading dimension of the shape that should be read according to the layout + * @param shape original (not reordered) input shape + * @param layout specifies the order in what dimensions of in the input shape should be read + * @return stride of the dimension idx = layout[layout.size() - 2] in the original shape + Example: + Original shape (shape) = [1, 49, 2, 23] + Layout (transpose order) = [2, 0, 1, 3] + + dim_idx = layout.size() - 2 = 2 + // Since layout specifies the order of dimensions in which the shape should be read + dim = layout[dim_idx] = 1 + stride(shape[1]) = shape[2] * shape[3] = 2 * 23 + */ +size_t get_in_leading_dim(const VectorDims& shape, const std::vector& layout); +inline size_t get_in_leading_dim(const lowered::PortDescriptorPtr& pd) { + return get_in_leading_dim(pd->get_shape(), pd->get_layout()); +} +/** + * + * @param shape reordered input shape that is stored according to the layout + * @param layout specifies the order in what the dimensions of the input shape are stored + * @return + Output shape is already transposed, we need to correctly write the data with original shape by the order + Example: + Original transposed shape (shape) = [49, 2, 7, 39] + Layout (transpose order) = [2, 0, 1, 3] + + dim_idx = layout.size() - 2 = 2 + // Since the shape dimensions are already reordered according to the layout + dim = /find dim_idx index in layout/ = 0 + stride(shape[0]) = shape[1] * shape[2] * shape[3] = 2 * 7 * 39 + */ +size_t get_out_leading_dim(const VectorDims& shape, const std::vector& layout); +inline size_t get_out_leading_dim(const lowered::PortDescriptorPtr& pd) { + return get_out_leading_dim(pd->get_shape(), pd->get_layout()); +} + } // namespace utils } // namespace snippets } // namespace ov diff --git a/src/common/snippets/src/kernel_executor_table.cpp b/src/common/snippets/src/kernel_executor_table.cpp new file mode 100644 index 00000000000000..97c30af8e3c56d --- /dev/null +++ b/src/common/snippets/src/kernel_executor_table.cpp @@ -0,0 +1,39 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/kernel_executor_table.hpp" + +namespace ov { +namespace snippets { + +void KernelExecutorTable::replace_key_expression(const snippets::lowered::ExpressionPtr& from, const snippets::lowered::ExpressionPtr& to) { + const auto& found = m_table.find(from); + if (found != m_table.end()) { + OPENVINO_ASSERT(m_table.count(to) == 0, "Attempt to replace a value that is already in the KernelExecutorTable"); + m_table.insert({to, found->second}); + m_table.erase(found); + } +} + +void KernelExecutorTable::reset_state(const ExecTableState& state) { + OPENVINO_ASSERT(state.size() == m_table.size(), "Invalid state in restore_state: size mismatch"); + auto state_it = state.begin(); + for (const auto& table_record : m_table) { + const auto& state_record = *state_it++; + OPENVINO_ASSERT(table_record.first == state_record.first, "Invalid state in restore_state: expressions mismatch"); + table_record.second->update_by_config(state_record.second); + } +} + +KernelExecutorTable::ExecTableState KernelExecutorTable::get_state() const { + ExecTableState result; + // Note: we need to clone configs when saving the state, since the configs still stored in the table can + // be modified e.g. by calling update_by_expression(); + for (const auto& record : m_table) + result.emplace_back(std::make_pair(record.first, record.second->get_config()->clone())); + return result; +} + +}// namespace snippets +}// namespace ov diff --git a/src/common/snippets/src/lowered/linear_ir_builder.cpp b/src/common/snippets/src/lowered/linear_ir_builder.cpp index 56b87eb2856d14..e5d28375db05f6 100644 --- a/src/common/snippets/src/lowered/linear_ir_builder.cpp +++ b/src/common/snippets/src/lowered/linear_ir_builder.cpp @@ -65,11 +65,10 @@ std::vector> clone_nodes(const std::vector LinearIRBuilder::clone(const std::shared_ptr& linear_ir) const { +std::shared_ptr LinearIRBuilder::clone(const std::shared_ptr& linear_ir, ExpressionMap& expression_map) const { auto cloned = std::make_shared(); cloned->m_config = linear_ir->m_config; - ExpressionMap expression_map; cloned->m_expressions = clone_range(linear_ir->m_expressions.cbegin(), linear_ir->m_expressions.cend(), expression_map); for (const auto& expr : cloned->m_expressions) { cloned->register_expression(expr, true); diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 43af058e87d82d..c69d4193a6d943 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -39,7 +39,7 @@ Brgemm::Brgemm(const Output& A, const Output& B, set_input_offset(offset_a, 0); set_input_offset(offset_b, 1); set_output_offset(offset_c, 0); - compute_block_size_values(blk_size_m, blk_size_k, blk_size_n); + set_block_size_values(blk_size_m, blk_size_k, blk_size_n); custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); } @@ -49,43 +49,33 @@ Brgemm::Brgemm(const Output& A, const Output& B, const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) : MemoryAccess(PortMap{{0, desc_a}, {1, desc_b}}, PortMap{{0, desc_c}}), Op({A, B}) { set_output_size(1); - compute_block_size_values(blk_size_m, blk_size_k, blk_size_n); + set_block_size_values(blk_size_m, blk_size_k, blk_size_n); custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); } -void Brgemm::compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) { - const auto input_shape_0 = snippets::utils::get_planar_pshape(input(0)).get_shape(); - const auto input_shape_1 = snippets::utils::get_planar_pshape(input(1)).get_shape(); - m_M_blk = blk_size_m != 0 ? blk_size_m : *++input_shape_0.rbegin(); - m_K_blk = blk_size_k != 0 ? blk_size_k : *input_shape_0.rbegin(); - m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape_1.rbegin(); +void Brgemm::set_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) { + m_M_blk = blk_size_m; + m_K_blk = blk_size_k; + m_N_blk = blk_size_n; } void Brgemm::custom_constructor_validate_and_infer_types(std::vector layout_a, std::vector layout_b, std::vector layout_c) { INTERNAL_OP_SCOPE(BrgemmCPU_constructor_validate_and_infer_types); - validate_inputs(); // During ctor call, Brgemm doesn't know his port descriptors. // So we use explicit layouts from parameters const auto planar_input_shapes = std::vector{ ov::snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_a), ov::snippets::utils::get_planar_pshape(get_input_partial_shape(1), layout_b) }; - auto output_shape = get_output_partial_shape(planar_input_shapes); + auto output_shape = infer_output_partial_shape(planar_input_shapes); set_output_type(0, get_output_type(), ov::snippets::utils::get_planar_pshape(output_shape, layout_c)); } -void Brgemm::validate_inputs() const { - // If no leading dimensions are provided, assume dense row-major inputs-outputs - NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), - "Brgemm currently supports only static shapes."); -} - void Brgemm::validate_and_infer_types() { INTERNAL_OP_SCOPE(Brgemm_validate_and_infer_types); - validate_inputs(); const auto planar_input_shapes = get_planar_input_shapes(inputs()); - auto output_shape = get_output_partial_shape(planar_input_shapes); + auto output_shape = infer_output_partial_shape(planar_input_shapes); set_output_type(0, get_output_type(), get_planar_output_shape(output_shape)); } @@ -146,7 +136,7 @@ ov::PartialShape Brgemm::get_planar_output_shape(const ov::PartialShape& output_ return output_shape; } -ov::PartialShape Brgemm::get_output_partial_shape(const std::vector& input_shapes) const { +ov::PartialShape Brgemm::infer_output_partial_shape(const std::vector& input_shapes) const { OPENVINO_ASSERT(input_shapes.size() == 2, "BRGEMM expects 2 input shapes for shape inference"); // Note: All majors checks are missed because Brgemm is transformed from MatMul with whole shape infer support diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index ab793c722d1e3e..51e0e3b904ff50 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -530,7 +530,8 @@ snippets::Schedule Subgraph::generate(const void* compile_params) const { // actual code emission // Note: to not corrupt the lowered linear IR for the shape-dependent passes, we have to make a copy OPENVINO_ASSERT(m_linear_ir, "Attempt to call generate, when linear IR was not initialized"); - auto linear_ir = *lowered::LinearIRBuilder().clone(m_linear_ir); + ov::snippets::lowered::ExpressionMap expression_map; + auto linear_ir = *lowered::LinearIRBuilder().clone(m_linear_ir, expression_map); if (is_dynamic()) { ov::snippets::lowered::pass::PassPipeline shape_dependent_pipeline; @@ -542,6 +543,14 @@ snippets::Schedule Subgraph::generate(const void* compile_params) const { auto lowering_result = m_generator->generate(linear_ir, compile_params); + // Note: Since the code emission is performed on a copy of LIR, but RuntimeConfigurator works with the initial instance, + // we need to replace cloned expression pointers to original ones in the KernelExecutorTable + const auto& exec_table = m_generator->get_target_machine()->get_runtime_configurator()->get_kernel_executor_table(); + for (const auto& expr : *m_linear_ir) + exec_table->replace_key_expression(expression_map.at(expr.get()), expr); + // Some kernel executors might've been registered during code emission. + // We need to update them, so appropriate kernels will be compiled. + exec_table->update_state(); return {std::move(lowering_result)}; } diff --git a/src/common/snippets/src/pass/collapse_subgraph.cpp b/src/common/snippets/src/pass/collapse_subgraph.cpp index 9afdb340afb5a8..82675ba946e6a5 100644 --- a/src/common/snippets/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/src/pass/collapse_subgraph.cpp @@ -47,8 +47,8 @@ auto is_supported_op(const std::shared_ptr &n) -> bool { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::is_supported_op") auto is_supported_matmul = [](const std::shared_ptr& n) -> bool { const auto& matmul = ov::as_type_ptr(n); - const auto& out_shape = n->get_output_partial_shape(0); - if (!matmul || out_shape.is_dynamic() || out_shape.size() != 4) + const auto& out_rank = n->get_output_partial_shape(0).rank(); + if (!matmul || out_rank.is_dynamic() || out_rank.get_length() != 4) return false; const auto intype_0 = matmul->get_input_element_type(0); const auto intype_1 = matmul->get_input_element_type(1); diff --git a/src/common/snippets/src/pass/matmul_to_brgemm.cpp b/src/common/snippets/src/pass/matmul_to_brgemm.cpp index 22be569125ecf9..37fda8306ed04f 100644 --- a/src/common/snippets/src/pass/matmul_to_brgemm.cpp +++ b/src/common/snippets/src/pass/matmul_to_brgemm.cpp @@ -6,6 +6,7 @@ #include "snippets/itt.hpp" #include "snippets/snippets_isa.hpp" +#include "snippets/utils.hpp" #include "snippets/lowered/port_descriptor.hpp" #include "openvino/core/rt_info.hpp" @@ -17,16 +18,16 @@ namespace snippets { namespace pass { void MatMulToBrgemm::init_ports(const std::shared_ptr& brgemm) const { - auto get_subtensor = [](const ov::Shape& shape) { + auto get_subtensor = []() { return std::vector{ lowered::PortDescriptor::ServiceDimensions::FULL_DIM, lowered::PortDescriptor::ServiceDimensions::FULL_DIM }; }; for (const auto& input : brgemm->inputs()) { - const auto tensor = input.get_shape(); - const auto subtensor = get_subtensor(tensor); + const auto& tensor = utils::pshape_to_vdims(input.get_partial_shape()); + const auto& subtensor = get_subtensor(); lowered::PortDescriptorUtils::set_port_descriptor_ptr(input, std::make_shared(tensor, subtensor)); } - const auto tensor = brgemm->get_output_shape(0); - const auto subtensor = get_subtensor(tensor); + const auto& tensor = utils::pshape_to_vdims(brgemm->get_output_partial_shape(0)); + const auto& subtensor = get_subtensor(); lowered::PortDescriptorUtils::set_port_descriptor_ptr(brgemm->output(0), std::make_shared(tensor, subtensor)); } diff --git a/src/common/snippets/src/runtime_configurator.cpp b/src/common/snippets/src/runtime_configurator.cpp index 7681e78f1ebb49..f790ec747cd7bf 100644 --- a/src/common/snippets/src/runtime_configurator.cpp +++ b/src/common/snippets/src/runtime_configurator.cpp @@ -30,7 +30,8 @@ void init_data_ptr_shifts(const lowered::UnifiedLoopInfoPtr& unified_loop_info, } } // namespace -RuntimeConfigurator::RuntimeConfigurator(std::shared_ptr c) : m_config(std::move(c)) { +RuntimeConfigurator::RuntimeConfigurator(std::shared_ptr c) : + m_config(std::move(c)) { OPENVINO_ASSERT(m_config, "Runtime config is nullptr!"); } diff --git a/src/common/snippets/src/utils.cpp b/src/common/snippets/src/utils.cpp index a7f00bbfebcb9c..c2a49b8e0a4d41 100644 --- a/src/common/snippets/src/utils.cpp +++ b/src/common/snippets/src/utils.cpp @@ -279,6 +279,24 @@ std::shared_ptr get_leaf_node_of_first_parent_shape_infer_seq(const st return leaf_node; } +size_t get_in_leading_dim(const VectorDims& shape, const std::vector& layout) { + if (layout.empty()) + return shape.back(); + OPENVINO_ASSERT(layout.back() == layout.size() - 1 && layout.size() == shape.size(), + "detected invalid layout values: check that this shape + layout combination is schedulable"); + const auto idx = static_cast(layout[layout.size() - 2]); + return std::accumulate(shape.cbegin() + idx + 1, shape.end(), 1ull, std::multiplies()); +} +size_t get_out_leading_dim(const VectorDims& shape, const std::vector& layout) { + if (layout.empty()) + return shape.back(); + OPENVINO_ASSERT(layout.back() == layout.size() - 1 && layout.size() == shape.size(), + "detected invalid layout values: check that this shape + layout combination is schedulable"); + const auto idx = layout.size() - 2; + const auto dim = std::distance(layout.cbegin(), std::find(layout.cbegin(), layout.cend(), idx)); + return std::accumulate(shape.cbegin() + dim + 1, shape.cend(), 1ull, std::multiplies()); +} + } // namespace utils } // namespace snippets } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_kernel_executor_table.hpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_kernel_executor_table.hpp index 14154a79fe67f0..09781e014df946 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_kernel_executor_table.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_kernel_executor_table.hpp @@ -16,28 +16,24 @@ class CPUKernelExecutor : public snippets::KernelExecutor { CPUKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, std::shared_ptr c) : snippets::KernelExecutor(c), m_kernel_cache(std::move(kernel_cache)) {} struct Key { - explicit Key(std::shared_ptr c) : config{std::move(c)} {} - const std::shared_ptr config; + explicit Key(const std::shared_ptr& c) : config{c} {} + const std::shared_ptr config; size_t hash() const { return config->hash(); } bool operator==(const Key& rhs) const { return *config == *rhs.config; } }; - void update_kernel() override { - OPENVINO_ASSERT(m_config && m_config->is_completed(), "Update kernel was called with invalid config"); + void update_kernel(const std::shared_ptr& config, std::shared_ptr& kernel) const override final { // NOLINT const auto& cache = m_kernel_cache.lock(); OPENVINO_ASSERT(cache, "Invalid kernel cache pointer in CPUKernelExecutor::update_kernel()"); - const auto& lookup_result = cache->getOrCreate(Key(m_config), + const auto& lookup_result = cache->getOrCreate(Key(config), [this](const Key& k) { return compile_kernel(k.config); }); - m_kernel = lookup_result.first; - OPENVINO_ASSERT(m_kernel, "Failed to compile kernel executor"); + kernel = lookup_result.first; } protected: - // Note: this usings are needed because non-dependent names are not looked up in dependent base classes - using snippets::KernelExecutor::m_config; - using snippets::KernelExecutor::m_kernel; - using snippets::KernelExecutor::compile_kernel; + /** Compile kernel managed by KernelExecutor instance. Will be called only if Kernel is not found in the cache */ + virtual std::shared_ptr compile_kernel(const std::shared_ptr& c) const = 0; /** CPU plugin cache implementation is used to avoid redundant recompilations */ ov::intel_cpu::MultiCacheWeakPtr m_kernel_cache; }; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp index 920e2f1720f3a4..b92d70136ab4d5 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp @@ -11,13 +11,16 @@ namespace ov { namespace intel_cpu { -CPURuntimeConfigurator::CPURuntimeConfigurator() : ov::snippets::RuntimeConfigurator(std::make_shared()) {} +CPURuntimeConfigurator::CPURuntimeConfigurator() : ov::snippets::RuntimeConfigurator(std::make_shared()) { +} void CPURuntimeConfigurator::update(const std::shared_ptr& linear_ir) { RuntimeConfigurator::update(linear_ir); - if (linear_ir->is_dynamic()) + if (linear_ir->is_dynamic()) { + get_kernel_executor_table()->update_state(); update_loop_args(linear_ir); + } } void CPURuntimeConfigurator::init_tensor_rank(const std::shared_ptr& linear_ir) const { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp index e444d65bab774b..4b000ee1521d43 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp @@ -156,7 +156,6 @@ class jit_snippet : public dnnl::impl::cpu::x64::jit_generator { intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::intel_cpu::MultiCacheWeakPtr cache) : TargetMachine(std::make_shared()), h(new jit_snippet()), isa(host_isa), compiled_kernel_cache(std::move(cache)) { - kernel_executor_table = std::make_shared(); // data movement jitters[op::v0::Parameter::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); jitters[op::v0::Result::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); @@ -241,7 +240,7 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho jitters[snippets::op::LoopEnd::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_loop_end_emitter); // Note: jit_brgemm_emitter supports runtime recompilation, so its constructor takes additional arguments jitters[intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter, - kernel_executor_table, + configurator->get_kernel_executor_table(), compiled_kernel_cache); jitters[intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_b_emitter); jitters[snippets::op::ReduceMax::get_type_info_static()] = CREATE_UNDEFINED_EMITTER({{ov::element::f32}}); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp index 0fec3269efae9b..759eadfe4d0747 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp @@ -43,7 +43,7 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t size_t leading_dimension = *(original_shape.rbegin()); if (!layout.empty()) { transposed_shape = snippets::utils::get_planar_vdims(original_shape, layout); - leading_dimension = jit_brgemm_emitter::get_in_leading_dim(original_shape, layout); + leading_dimension = ov::snippets::utils::get_in_leading_dim(original_shape, layout); } const auto& in_subtensor = in_desc->get_subtensor(); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp index 95fd0e522bfff4..a638da7c2e1a3f 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp @@ -7,6 +7,7 @@ #include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include #include +#include "snippets/utils.hpp" using namespace Xbyak; using namespace dnnl::impl; @@ -15,33 +16,6 @@ using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { -size_t jit_brgemm_emitter::get_in_leading_dim(const VectorDims& shape, const std::vector& layout) { - // Input shape is original, so we need to correctly read this data by order - // Example: - // Original shape (shape) = [1, 49, 2, 23] - // Layout (transpose order) = [2, 0, 1, 3] - // Transposed shape = [2, 1, 49, 23] - // The leading dimension is equal to stride of shape[layout[3]] = 2 x 23 - OV_CPU_JIT_EMITTER_ASSERT(layout.back() == layout.size() - 1 && layout.size() == shape.size(), - "detected invalid layout values: check that this shape + layout combination is schedulable"); - const auto idx = layout[layout.size() - 2]; // `1` in example - return std::accumulate(shape.cbegin() + idx + 1, shape.end(), 1, std::multiplies()); -} -size_t jit_brgemm_emitter::get_out_leading_dim(const VectorDims& shape, const std::vector& layout) { - // Output shape is already transposed, we need to correctly write the data with original shape by the order - // Example: - // Original transposed shape (shape) = [49, 2, 7, 39] - // Layout (transpose order) = [2, 0, 1, 3] - // Before leading dimension with index 3 there is dimension with index 2 in planar layout. - // Since we have non-planar layout, we have to find this before LD dim in transposed order. - // In layout 2nd idx is first element, it means, that the leading dimension is equal to stride of shape[0] - OV_CPU_JIT_EMITTER_ASSERT(layout.back() == layout.size() - 1 && layout.size() == shape.size(), - "detected invalid layout values: check that this shape + layout combination is schedulable"); - const auto idx = layout.size() - 2; // 2 in the example - const auto dim = std::distance(layout.cbegin(), std::find(layout.cbegin(), layout.cend(), idx)); // 0 in the example: shape[0] = 49 - return std::accumulate(shape.cbegin() + dim + 1, shape.cend(), 1, std::multiplies()); // shape[1] x shape[2] x shape[3] = 2 x 7 x 39 -} - jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr, const snippets::KernelExecutorTablePtr& kernel_table, @@ -49,67 +23,24 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, jit_emitter(h, isa) { in_out_type_ = emitter_in_out_map::gpr_to_gpr; const auto& brgemm_node = as_type_ptr(expr->get_node()); - OV_CPU_JIT_EMITTER_ASSERT(!brgemm_node->is_dynamic(), "Snippets don't support code generation for dynamic Brgemm"); - - std::vector leading_dimensions; - auto get_layout = [](const std::vector& layout, const snippets::VectorDims& io_shape) { - if (!layout.empty()) return layout; - std::vector default_layout(io_shape.size()); - std::iota(default_layout.begin(), default_layout.end(), 0); - return default_layout; - }; - auto init_in_scheduling_params = [&](const snippets::lowered::PortDescriptorPtr& input) { - const auto& layout = get_layout(input->get_layout(), input->get_shape()); - leading_dimensions.push_back(get_in_leading_dim(input->get_shape(), layout)); - }; - auto init_out_scheduling_params = [&](const snippets::lowered::PortDescriptorPtr& output) { - const auto& layout = get_layout(output->get_layout(), output->get_shape()); - leading_dimensions.push_back(get_out_leading_dim(output->get_shape(), layout)); - }; - - const auto& input_0_desc = expr->get_input_port_descriptor(0); - const auto& input_1_desc = expr->get_input_port_descriptor(1); - const auto& output_desc = expr->get_output_port_descriptor(0); - - init_in_scheduling_params(input_0_desc); - if (brgemm_node->is_with_data_repacking()) { - const auto repacking_buffer_shape = brgemm_node->get_brgemm_copy()->get_repacking_buffer_shape(); - OV_CPU_JIT_EMITTER_ASSERT(!repacking_buffer_shape.empty(), "Repacking buffer shape mustn't be empty"); - leading_dimensions.push_back(repacking_buffer_shape.back()); - } else { - init_in_scheduling_params(input_1_desc); - } - init_out_scheduling_params(output_desc); + m_with_scratch = brgemm_node->is_with_scratchpad(); const auto& brg0Prc = brgemm_node->get_input_element_type(0); const auto& brg1Prc = brgemm_node->get_input_element_type(1); - - m_with_scratch = brgemm_node->is_with_scratchpad(); - - const auto& output_subtensor = output_desc->get_subtensor(); - const auto& input_0_subtensor = input_0_desc->get_subtensor(); - const auto& input_1_subtensor = input_1_desc->get_subtensor(); - - OV_CPU_JIT_EMITTER_ASSERT(*(output_subtensor.rbegin() + 1) == *(input_0_subtensor.rbegin() + 1), - "Brgemm has different M dimension subtensors on input0 and output"); - OV_CPU_JIT_EMITTER_ASSERT(*output_subtensor.rbegin() == *input_1_subtensor.rbegin(), - "Brgemm has different N dimension subtensors on input1 and output"); - OV_CPU_JIT_EMITTER_ASSERT(*input_0_subtensor.rbegin() == *(input_1_subtensor.rbegin() + 1), - "Brgemm has different K dimension subtensors on input0 and input1"); - auto kernel_config = std::make_shared(brg0Prc, brg1Prc, brgemm_node->get_beta(), brgemm_node->is_amx(), brgemm_node->is_with_compensations()); - - m_kernel_executor = kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); - m_kernel_executor->update(*(output_subtensor.rbegin() + 1), - *output_subtensor.rbegin(), - *input_0_subtensor.rbegin(), - leading_dimensions[0], - leading_dimensions[1], - leading_dimensions[2]); + m_kernel_executor = kernel_table->register_kernel(expr, + compiled_kernel_cache, + kernel_config); + // Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update() + // are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually + // for both static and the 1st dynamic shapes. + OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()) && + !snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(1)->get_shape()), + "Jit emitter is called when the shapes are unknown"); m_load_offset_a = brgemm_node->get_offset_a(); m_load_offset_b = brgemm_node->get_offset_b(); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp index 1950f15f5b4c95..ecadb7105271b3 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp @@ -20,9 +20,6 @@ class jit_brgemm_emitter : public jit_emitter { size_t get_inputs_num() const override { return m_with_scratch ? 3 : 2; } static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); - static size_t get_in_leading_dim(const VectorDims& shape, const std::vector& layout); - static size_t get_out_leading_dim(const VectorDims& shape, const std::vector& layout); - private: void validate_arguments(const std::vector &in, const std::vector &out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp index 82770b7688f3f6..367b9f4506a980 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp @@ -7,6 +7,7 @@ #include #include "common/utils.hpp" #include "dnnl_extension_utils.h" +#include "transformations/snippets/x64/op/brgemm_cpu.hpp" #define DIM_CAST(X) static_cast(X) #define DTYPE_CAST(X) static_cast(DnnlExtensionUtils::ElementTypeToDataType(X)) @@ -20,84 +21,83 @@ namespace intel_cpu { BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, float beta, bool is_with_amx, bool is_with_comp, size_t M, size_t N, size_t K, size_t LDA, size_t LDB, size_t LDC) : - dt_in0(DTYPE_CAST(in0_dtype)), dt_in1(DTYPE_CAST(in1_dtype)), - is_with_amx(is_with_amx), is_with_comp(is_with_comp), beta(beta), - M(DIM_CAST(M)), N(DIM_CAST(N)), K(DIM_CAST(K)), - LDA(DIM_CAST(LDA)), LDB(DIM_CAST(LDB)), LDC(DIM_CAST(LDC)) { - bool is_int8 = utils::one_of(dt_in0, data_type::u8, data_type::s8) && - utils::one_of(dt_in1, data_type::u8, data_type::s8); - isa = is_with_amx ? + m_dt_in0(DTYPE_CAST(in0_dtype)), m_dt_in1(DTYPE_CAST(in1_dtype)), + m_is_with_amx(is_with_amx), m_is_with_comp(is_with_comp), m_beta(beta), + m_M(DIM_CAST(M)), m_N(DIM_CAST(N)), m_K(DIM_CAST(K)), + m_LDA(DIM_CAST(LDA)), m_LDB(DIM_CAST(LDB)), m_LDC(DIM_CAST(LDC)) { + bool is_int8 = utils::one_of(m_dt_in0, data_type::u8, data_type::s8) && + utils::one_of(m_dt_in1, data_type::u8, data_type::s8); + m_isa = is_with_amx ? cpu::x64::avx512_core_amx : - dt_in0 == dnnl_data_type_t::dnnl_bf16 ? + m_dt_in0 == dnnl_data_type_t::dnnl_bf16 ? cpu::x64::avx512_core_bf16 : is_int8 ? cpu::x64::avx512_core_vnni : cpu::x64::avx512_core; + m_hash = compute_hash(); } bool BrgemmKernelConfig::is_completed() const { - return !utils::one_of(0, M, N, K, LDA, LDB, LDC); + return !utils::one_of(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC); } -size_t BrgemmKernelConfig::hash() const { +void BrgemmKernelConfig::update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC) { + m_M = M; m_N = N; m_K = K; + m_LDA = LDA; m_LDB = LDB; m_LDC = LDC; + m_hash = compute_hash(); +} + +BrgemmKernelConfig::operator amx_tile_config_t() const { + amx_tile_config_t res; + res.M = m_M; res.N = m_N; res.K = m_K; + return res; +} + +size_t BrgemmKernelConfig::compute_hash() const { size_t seed = 0; #define HASH(X) seed = hash_combine(seed, X) - HASH(dt_in0); HASH(dt_in1); - HASH(is_with_amx); HASH(is_with_comp); - HASH(beta); HASH(isa); - HASH(M); HASH(N); HASH(K); - HASH(LDA); HASH(LDB); HASH(LDC); + HASH(m_dt_in0); HASH(m_dt_in1); + HASH(m_is_with_amx); HASH(m_is_with_comp); + HASH(m_beta); HASH(m_isa); + HASH(m_M); HASH(m_N); HASH(m_K); + HASH(m_LDA); HASH(m_LDB); HASH(m_LDC); #undef HASH return seed; } -bool BrgemmKernelConfig::operator==(const BrgemmKernelConfig& rhs) const { -#define EQUAL(X) X == rhs.X - return EQUAL(dt_in0) && EQUAL(dt_in1) && - EQUAL(is_with_amx) && EQUAL(is_with_comp) && - EQUAL(beta) && EQUAL(isa) && - EQUAL(M) && EQUAL(N) && EQUAL(K) && - EQUAL(LDA) && EQUAL(LDB) && EQUAL(LDC); -#undef EQUAL -} -bool BrgemmKernelConfig::operator!=(const BrgemmKernelConfig& rhs) const { - return !(*this == rhs); -} #ifdef SNIPPETS_DEBUG_CAPS std::string BrgemmKernelConfig::to_string() const { std::stringstream ss; #define PRINT(X) ss << #X << " = " << X << "\n" - PRINT(dt_in0); PRINT(dt_in1); - PRINT(is_with_amx); PRINT(is_with_comp); - PRINT(beta); PRINT(isa); - PRINT(M); PRINT(N); PRINT(K); - PRINT(LDA); PRINT(LDB); PRINT(LDC); + PRINT(m_dt_in0); PRINT(m_dt_in1); + PRINT(m_is_with_amx); PRINT(m_is_with_comp); + PRINT(m_beta); PRINT(m_isa); + PRINT(m_M); PRINT(m_N); PRINT(m_K); + PRINT(m_LDA); PRINT(m_LDB); PRINT(m_LDC); #undef PRINT return ss.str(); } #endif -BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, const std::shared_ptr& config) : - CPUKernelExecutor(std::move(kernel_cache), config) { - if (config->is_completed()) - update_kernel(); -} +BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, + const std::shared_ptr& config) : + CPUKernelExecutor(std::move(kernel_cache), config) { } -std::shared_ptr BrgemmKernelExecutor::compile_kernel(const std::shared_ptr& config) const { +std::shared_ptr BrgemmKernelExecutor::compile_kernel(const std::shared_ptr& config) const { OV_CPU_JIT_EMITTER_ASSERT(config, "Invalid config provided for BrgemmKernelDesc::compile_kernel"); cpu::x64::brgemm_t desc; - auto status = brgemm_desc_init(&desc, config->isa, cpu::x64::brgemm_strd, - config->dt_in0, config->dt_in1, + auto status = brgemm_desc_init(&desc, config->get_isa(), cpu::x64::brgemm_strd, + config->get_dt_in0(), config->get_dt_in1(), false, false, cpu::x64::brgemm_row_major, 1.f, - config->beta, - config->LDA, config->LDB, config->LDC, - config->M, config->N, config->K, nullptr); + config->get_beta(), + config->get_LDA(), config->get_LDB(), config->get_LDC(), + config->get_M(), config->get_N(), config->get_K(), nullptr); std::shared_ptr compiled_kernel = std::make_shared(); OV_CPU_JIT_EMITTER_ASSERT(status == dnnl_success, "Cannot initialize brgemm descriptor due to invalid params"); - if (config->is_with_amx) { + if (config->is_with_amx()) { status = brgemm_init_tiles(desc, compiled_kernel->palette); OV_CPU_JIT_EMITTER_ASSERT(status == dnnl_success, "Cannot initialize brgemm tiles due to invalid params"); } @@ -110,19 +110,55 @@ std::shared_ptr BrgemmKernelExecutor::compile_kernel(const return compiled_kernel; } -void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* desc, call_args* args) { - const auto& kernel = desc->m_kernel; - OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr compiler kernel"); - - const auto& config = desc->m_config; - if (config->is_with_amx) { - const auto& amx_tile_config = args->amx_tile_config; - if (config->M != amx_tile_config->M || config->K != amx_tile_config->K || config->N != amx_tile_config->N) { - amx_tile_config->M = config->M; - amx_tile_config->K = config->K; - amx_tile_config->N = config->N; - cpu::x64::amx_tile_configure(kernel->palette); +void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, std::shared_ptr& config) const { + auto get_projected_input_subtensor = [](const snippets::lowered::PortDescriptorPtr& desc) { + // Note: for output shape you will need get_preordered_vdims() + auto shape = snippets::utils::get_planar_vdims(desc->get_shape(), desc->get_layout()); + auto subtensor = desc->get_subtensor(); + OV_CPU_JIT_EMITTER_ASSERT(subtensor.size() <= shape.size() && subtensor.size() == 2, + "Invalid subtensor + shape combination"); + auto shape_it = shape.rbegin(); + for (auto sub_it = subtensor.rbegin(); sub_it != subtensor.rend(); sub_it++, shape_it++) { + *sub_it = std::min(*sub_it, *shape_it); } + return subtensor; + }; + const auto& input_pds = expr->get_input_port_descriptors(); + const auto& output_pds = expr->get_output_port_descriptors(); + OV_CPU_JIT_EMITTER_ASSERT((input_pds.size() == 2 || input_pds.size() == 3) && output_pds.size() == 1, + "Invalid number of in/out port descriptors"); + // Update runtime-defined config fields: + // Matrix A (first input) + const auto LDA = DIM_CAST(snippets::utils::get_in_leading_dim(input_pds[0])); + const auto& in0_subtensor = get_projected_input_subtensor(input_pds[0]); + const auto K = DIM_CAST(*in0_subtensor.rbegin()); + const auto M = DIM_CAST(*++in0_subtensor.rbegin()); + // Matrix B (second input) + // Non float input 1 => with data repacking + auto LDB = DIM_CAST(snippets::utils::get_in_leading_dim(input_pds[1])); + + const auto& brgemm_node = as_type_ptr(expr->get_node()); + OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config"); + if (brgemm_node->is_with_data_repacking()) { + const auto repacking_buffer_shape = brgemm_node->get_brgemm_copy()->get_repacking_buffer_shape(); + OV_CPU_JIT_EMITTER_ASSERT(!repacking_buffer_shape.empty(), "Repacking buffer shape mustn't be empty"); + LDB = DIM_CAST(repacking_buffer_shape.back()); + } + const auto N = DIM_CAST(*get_projected_input_subtensor(input_pds[1]).rbegin()); + // Matrix C (output) + const auto LDC = DIM_CAST(snippets::utils::get_out_leading_dim(output_pds[0])); + config->update(M, N, K, LDA, LDB, LDC); +} + +void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, call_args* args) { + const auto& kernel = executor->get_kernel(); + const auto& config = std::static_pointer_cast(executor->get_config()); + OV_CPU_JIT_EMITTER_ASSERT(kernel && config, "has nullptr compiler kernel or invalid config"); + + const auto tile_config = args->amx_tile_config; + if (config->is_with_amx() && tile_config && !config->compatible(tile_config)) { + *tile_config = static_cast(*config); + cpu::x64::amx_tile_configure(kernel->palette); } cpu::x64::brgemm_kernel_params_t brgemm_p; @@ -134,28 +170,13 @@ void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* desc, call_args* brgemm_p.ptr_D = args->C; brgemm_p.ptr_buf = args->scratch; brgemm_p.ptr_bias = nullptr; - brgemm_p.do_post_ops = static_cast(config->is_with_comp); - brgemm_p.do_apply_comp = static_cast(config->is_with_comp); + brgemm_p.do_post_ops = static_cast(config->is_with_comp()); + brgemm_p.do_apply_comp = static_cast(config->is_with_comp()); brgemm_p.skip_accm = 0; brgemm_p.BS = 1; // default value OV_CPU_JIT_EMITTER_ASSERT(kernel->compiled_kernel, "has nullptr kernel"); (*kernel->compiled_kernel)(&brgemm_p); } -void BrgemmKernelExecutor::update(size_t M, size_t N, size_t K, size_t LDA, size_t LDB, size_t LDC) { - OV_CPU_JIT_EMITTER_ASSERT(m_config, "update is called for empty kernel config"); -#define CAST(X) m_config->X = DIM_CAST(X) - CAST(M); CAST(N); CAST(K); - CAST(LDA); CAST(LDB); CAST(LDC); -#undef CAST - update_kernel(); -} - -#ifdef SNIPPETS_DEBUG_CAPS -std::string BrgemmKernelExecutor::config_to_string() const { - return m_config ? m_config->to_string() : ""; -} -#endif - } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp index 8c426e4ecf86b4..b0fbb468db12f2 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp @@ -11,11 +11,7 @@ namespace ov { namespace intel_cpu { -class BrgemmKernelExecutor; -#define GET_OFF_BRGEMM_ARGS(field) offsetof(BrgemmKernelExecutor::call_args, field) - struct BrgemmKernelConfig : public snippets::KernelExecutorBase::GenericConfig { - friend BrgemmKernelExecutor; public: BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, float beta, bool is_with_amx, bool is_with_comp, @@ -23,19 +19,46 @@ struct BrgemmKernelConfig : public snippets::KernelExecutorBase::GenericConfig { size_t LDA = 0, size_t LDB = 0, size_t LDC = 0); BrgemmKernelConfig() = default; bool is_completed() const override; - size_t hash() const; - bool operator==(const BrgemmKernelConfig& rhs) const; - bool operator!=(const BrgemmKernelConfig& rhs) const; + size_t hash() const override { return m_hash; } + std::shared_ptr clone() const override { + return std::make_shared(*this); + } + void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC); + + dnnl_data_type_t get_dt_in0() const { return m_dt_in0; } + dnnl_data_type_t get_dt_in1() const { return m_dt_in1; } + + dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { return m_isa; } + bool is_with_amx() const {return m_is_with_amx; } + bool is_with_comp() const { return m_is_with_comp; } + float get_beta() const { return m_beta; } + + dnnl_dim_t get_M() const { return m_M; } + dnnl_dim_t get_N() const { return m_N; } + dnnl_dim_t get_K() const { return m_K; } + + dnnl_dim_t get_LDA() const { return m_LDA; } + dnnl_dim_t get_LDB() const { return m_LDB; } + dnnl_dim_t get_LDC() const { return m_LDC; } + + explicit operator amx_tile_config_t() const; + inline bool compatible(amx_tile_config_t* rhs) const { + return rhs && rhs->M == m_M && rhs->N == m_N && rhs->K == m_K; + } + #ifdef SNIPPETS_DEBUG_CAPS - std::string to_string() const; + std::string to_string() const override; #endif + private: - dnnl_data_type_t dt_in0 {dnnl_f32}, dt_in1 {dnnl_f32}; - bool is_with_amx {false}; - bool is_with_comp {false}; - float beta {0}; - dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::isa_undef}; - dnnl_dim_t M {0}, N {0}, K {0}, LDA {0}, LDB {0}, LDC {0}; + size_t compute_hash() const; + dnnl_data_type_t m_dt_in0 {dnnl_f32}, m_dt_in1 {dnnl_f32}; + bool m_is_with_amx {false}; + bool m_is_with_comp {false}; + float m_beta {0}; + dnnl::impl::cpu::x64::cpu_isa_t m_isa {dnnl::impl::cpu::x64::isa_undef}; + dnnl_dim_t m_M {0}, m_N {0}, m_K {0}, m_LDA {0}, m_LDB {0}, m_LDC {0}; + size_t m_hash {SIZE_MAX}; }; struct BrgemmCompiledKernel { @@ -54,20 +77,16 @@ class BrgemmKernelExecutor : public CPUKernelExecutor& config); + BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, + const std::shared_ptr& config); /** Function that will be called in runtime to execute the kernel */ - static void execute(const BrgemmKernelExecutor* desc, call_args* args); - - /** Update kernel config using the arguments passed, and recompile the kernel */ - void update(size_t M, size_t N, size_t K, size_t LDA, size_t LDB, size_t LDC); + static void execute(const BrgemmKernelExecutor* executor, call_args* args); - /** print current kernel config for debug purposes */ -#ifdef SNIPPETS_DEBUG_CAPS - std::string config_to_string() const; -#endif protected: - std::shared_ptr compile_kernel(const std::shared_ptr& c) const override; + std::shared_ptr compile_kernel(const std::shared_ptr& c) const override; + void update_config(const ov::snippets::lowered::ExpressionPtr& expr, std::shared_ptr& config) const override; }; +#define GET_OFF_BRGEMM_ARGS(field) offsetof(BrgemmKernelExecutor::call_args, field) } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp index d9c87dbf8b3ae3..8f7e58048e92b9 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp @@ -87,7 +87,7 @@ static std::string init_info_jit_store_memory_emitter(const jit_store_memory_emi std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter *emitter) { std::stringstream ss; ss << "Emitter_type_name:jit_brgemm_emitter" - << emitter->m_kernel_executor->config_to_string() + << emitter->m_kernel_executor->to_string() << " m_load_offset_a:" << emitter->m_load_offset_a << " m_load_offset_b:" << emitter->m_load_offset_b << " m_load_offset_scratch:" << emitter->m_load_offset_scratch diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index f499a5a66993e3..eac50bf04dbd82 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -127,6 +127,10 @@ class SubgraphDynamicSpecializedExecutor : public Subgraph::SubgraphExecutor { OPENVINO_ASSERT(data_offsets.size() == inMemPtrs.size() + outMemPtrs.size(), "Incorrect data offset count!"); OPENVINO_ASSERT(data_offsets.front().size() == m_parallel_exec_domain.size(), "Data offsets with invalid ranks detected"); + // Note: we need to reset KernelExecutorTable to the state that was recorded in the SubgraphDynamicSpecializedExecutor + // constructor because the table might've been used for other shapes + reset_exec_table_state(); + std::vector src_ptrs; std::vector dst_ptrs; init_original_ptrs(inMemPtrs, outMemPtrs, src_ptrs, dst_ptrs); @@ -195,11 +199,13 @@ class SubgraphDynamicSpecializedExecutor : public Subgraph::SubgraphExecutor { buffer_offsets = snippet_config->buffer_cluster_offsets; data_offsets = snippet_config->io_data_offsets; loop_args = snippet_config->loop_args; - } + reset_exec_table_state = snippet_config->kernel_executor_table->get_state_reset(); + }; std::vector buffer_offsets = {}; std::vector> data_offsets = {}; std::vector loop_args = {}; + std::function reset_exec_table_state; }; struct SubgraphKey { diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp index d0265b8606d286..d701495ce7b372 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp @@ -26,7 +26,7 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Type ty set_input_port_descriptor({0, offset_a}, 0); set_input_port_descriptor({0, offset_b}, 1); set_output_port_descriptor({0, offset_c}, 0); - compute_block_size_values(blk_size_m, blk_size_k, blk_size_n); + set_block_size_values(blk_size_m, blk_size_k, blk_size_n); custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); set_beta(beta); } @@ -43,7 +43,7 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output< set_input_port_descriptor({0, offset_b}, 1); set_output_port_descriptor({0, offset_c}, 0); set_input_port_descriptor({0, offset_scratch}, 2); - compute_block_size_values(blk_size_m, blk_size_k, blk_size_n); + set_block_size_values(blk_size_m, blk_size_k, blk_size_n); custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); set_beta(beta); } @@ -57,7 +57,7 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Type ty set_output_size(1); m_input_ports = {{0, desc_a}, {1, desc_b}}; m_output_ports = {{0, desc_c}}; - compute_block_size_values(blk_size_m, blk_size_k, blk_size_n); + set_block_size_values(blk_size_m, blk_size_k, blk_size_n); custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); set_beta(beta); } @@ -71,7 +71,7 @@ BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output< set_output_size(1); m_input_ports = {{0, desc_a}, {1, desc_b}, {2, desc_scratch}}; m_output_ports = {{0, desc_c}}; - compute_block_size_values(blk_size_m, blk_size_k, blk_size_n); + set_block_size_values(blk_size_m, blk_size_k, blk_size_n); custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c)); set_beta(beta); } @@ -87,7 +87,7 @@ void BrgemmCPU::custom_constructor_validate_and_infer_types(std::vector std::vector{ snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_a), brgemm_copy ? snippets::utils::get_planar_pshape(brgemm_copy->input(0)) : snippets::utils::get_planar_pshape(get_input_partial_shape(1), layout_b) }; - auto output_shape = get_output_partial_shape(planar_input_shapes); + auto output_shape = infer_output_partial_shape(planar_input_shapes); set_output_type(0, get_output_type(), snippets::utils::get_planar_pshape(output_shape, layout_c)); // Additional check for 3rd input @@ -100,7 +100,7 @@ void BrgemmCPU::validate_and_infer_types() { const auto brgemm_copy = is_with_data_repacking() ? get_brgemm_copy() : nullptr; const auto planar_input_shapes = get_planar_input_shapes({input(0), brgemm_copy ? brgemm_copy->input(0) : input(1)}); - auto output_shape = get_output_partial_shape(planar_input_shapes); + auto output_shape = infer_output_partial_shape(planar_input_shapes); set_output_type(0, get_output_type(), get_planar_output_shape(output_shape)); // Additional check for 3rd input @@ -119,9 +119,6 @@ void BrgemmCPU::validate_with_scratchpad() const { } void BrgemmCPU::validate_inputs() const { - // If no leading dimensions are provided, assume dense row-major inputs-outputs - NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), - "BrgemmCPU currently supports only static shapes."); OPENVINO_ASSERT(implication(one_of(m_type, Type::Floating, Type::WithDataRepacking), get_input_size() == 2), "BrgemmCPU expects 2 inputs in cases, when input precisions are f32|f32, u8|i8 or bf16|bf16 (non-AMX system)"); OPENVINO_ASSERT(implication(one_of(m_type, Type::WithCompensations, Type::AMX), get_input_size() == 3), diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp index 80fde9c733ba18..81a9fa0a525013 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.cpp @@ -58,16 +58,12 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { if (!brgemm || brgemm_plugin) OPENVINO_THROW("BrgemmCPU cannot be in body before BrgemmToBrgemmCPU pass"); - if (brgemm->is_dynamic()) { - return false; - } - const auto& brgemm_in0_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->input(0)); const auto& brgemm_in1_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->input(1)); const auto& brgemm_out_desc = PortDescriptorUtils::get_port_descriptor_ptr(brgemm->output(0)); - const auto dimsMatMulIn0 = snippets::utils::get_planar_pshape(brgemm->input(0)).get_shape(); - const auto dimsMatMulIn1 = snippets::utils::get_planar_pshape(brgemm->input(1)).get_shape(); + const auto dimsMatMulIn0 = snippets::utils::get_planar_pshape(brgemm->input(0)); + const auto dimsMatMulIn1 = snippets::utils::get_planar_pshape(brgemm->input(1)); const auto K = *dimsMatMulIn0.rbegin(); const auto N = *dimsMatMulIn1.rbegin(); @@ -75,7 +71,9 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { const auto element_type_a = brgemm->get_input_element_type(0); const auto brgemmVNNIFactor = 4 / element_type_a.size(); const bool isAMXSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx); - const bool with_amx = isAMXSupported && element_type_a != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0); + const bool with_amx = isAMXSupported && element_type_a != ov::element::f32 && + K.is_static() && K.get_length() % brgemmVNNIFactor == 0 && + N.is_static() && N.get_length() % brgemmVNNIFactor == 0; const bool with_comp = element_type_a == ov::element::i8 && !with_amx; const auto offset_a = brgemm->get_offset_a(); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp index 63e33cd42635f1..c682c2eae85a14 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp @@ -10,12 +10,9 @@ #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" -#include "openvino/core/rt_info.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/matcher.hpp" -#include "cpu/x64/cpu_isa_traits.hpp" - #include "cpu_shape.h" #include "utils/general_utils.h" @@ -30,27 +27,32 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::SetBrgemmCPUBlockingParams") const auto node = m.get_match_root(); auto brgemm = ov::as_type_ptr(node); - if (brgemm->is_dynamic()) { - return false; - } const auto& input_1_precision = brgemm->get_input_element_type(1); // Ticket: 113745 // TODO: extend block size selection heuristics - auto get_block_size_m = [&](const size_t M) { + auto get_block_size_m = [&](const ov::Dimension& M_dim) -> size_t { return 32; }; - auto get_block_size_k = [&](const size_t K) { + auto get_block_size_k = [&](const ov::Dimension& K_dim) -> size_t { + // K blocking is disabled in dynamism by default + if (K_dim.is_dynamic()) + return snippets::utils::get_dynamic_value(); + + const auto K = K_dim.get_length(); if (input_1_precision != ov::element::f32) return K; return K > 1024 ? 1024 : K > 512 ? 512 : K; }; - auto get_block_size_n = [&](const size_t N) { - return input_1_precision != ov::element::f32 ? N : 64; + auto get_block_size_n = [&](const ov::Dimension& N_dim) -> size_t { + // N blocking is disabled in dynamism by default + if (N_dim.is_dynamic()) + return snippets::utils::get_dynamic_value(); + return input_1_precision == ov::element::f32 ? 64 : N_dim.get_length(); }; - const auto brgemm_in0_dims = snippets::utils::get_planar_pshape(brgemm->input(0)).get_shape(); - const auto brgemm_in1_dims = snippets::utils::get_planar_pshape(brgemm->input(1)).get_shape(); + const auto brgemm_in0_dims = snippets::utils::get_planar_pshape(brgemm->input(0)); + const auto brgemm_in1_dims = snippets::utils::get_planar_pshape(brgemm->input(1)); const auto& M = *++brgemm_in0_dims.rbegin(); const auto& K = *brgemm_in0_dims.rbegin(); const auto& N = *brgemm_in1_dims.rbegin(); @@ -61,7 +63,7 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() { if (brgemm->is_with_data_repacking()) { const auto brgemm_copy_b = brgemm->get_brgemm_copy(); const auto brgemmVNNIFactor = brgemm_copy_b->get_brgemm_vnni_factor(); - OPENVINO_ASSERT(k_blk == K || k_blk % brgemmVNNIFactor == 0, + OPENVINO_ASSERT(K.is_dynamic() || k_blk == static_cast(K.get_length()) || k_blk % brgemmVNNIFactor == 0, "K Block size (", k_blk, "), which is not divisible by brgemmVNNIFactor (", diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp index 55c6a2aa095cda..96733959205ca7 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -11,19 +11,21 @@ namespace ov { namespace test { namespace snippets { +#define STATIC_SHAPE(...) {{}, {{__VA_ARGS__}}} namespace { -std::vector> input_shapes{ - {{2, 1, 3, 5}, {1, 3, 5, 3}}, - {{3, 1, 32, 14}, {1, 2, 14, 32}}, - {{1, 2, 37, 23}, {2, 1, 23, 37}}, - {{1, 1, 37, 23}, {1, 2, 23, 33}}, - {{1, 1, 32, 23}, {1, 1, 23, 68}}, - {{1, 16, 384, 64}, {1, 16, 64, 384}}, - {{1, 1, 100, 700}, {1, 1, 700, 100}}, - {{1, 1, 100, 1024}, {1, 1, 1024, 100}}, - {{1, 1, 100, 2500}, {1, 1, 2500, 100}}, - {{1, 1, 100, 4500}, {1, 1, 4500, 100}}, +std::vector> input_shapes{ + {STATIC_SHAPE(2, 1, 3, 5), STATIC_SHAPE(1, 3, 5, 3)}, + {STATIC_SHAPE(3, 1, 32, 14), STATIC_SHAPE(1, 2, 14, 32)}, + {STATIC_SHAPE(1, 2, 37, 23), STATIC_SHAPE(2, 1, 23, 37)}, + {STATIC_SHAPE(1, 1, 37, 23), STATIC_SHAPE(1, 2, 23, 33)}, + + {STATIC_SHAPE(1, 1, 32, 23), STATIC_SHAPE(1, 1, 23, 68)}, + {STATIC_SHAPE(1, 16, 384, 64), STATIC_SHAPE(1, 16, 64, 384)}, + {STATIC_SHAPE(1, 1, 100, 700), STATIC_SHAPE(1, 1, 700, 100)}, + {STATIC_SHAPE(1, 1, 100, 1024), STATIC_SHAPE(1, 1, 1024, 100)}, + {STATIC_SHAPE(1, 1, 100, 2500), STATIC_SHAPE(1, 1, 2500, 100)}, + {STATIC_SHAPE(1, 1, 100, 4500), STATIC_SHAPE(1, 1, 4500, 100)}, }; static inline std::vector> quantized_precisions() { @@ -63,6 +65,25 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul, ::testing::Values(ov::test::utils::DEVICE_CPU)), MatMul::getTestCaseName); + +std::vector> input_shapes_dynamic{ + { + {PartialShape{-1, -1, -1, -1}, {{2, 1, 32, 64}, {2, 2, 10, 20}, {2, 2, 100, 80}, + {2, 2, 10, 20}, {2, 1, 32, 64}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 3, 64, 128}, {2, 2, 20, 30}, {2, 2, 80, 120}, + {2, 2, 20, 30}, {1, 3, 64, 128}}} + }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_DynMatMult, MatMul, + ::testing::Combine( + ::testing::ValuesIn(input_shapes_dynamic), + ::testing::ValuesIn(precisions(true)), + ::testing::Values(1), // MatMul + ::testing::Values(1), // Tokenized MatMul + ::testing::Values(ov::test::utils::DEVICE_CPU)), + MatMul::getTestCaseName); + INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulFQ, MatMulFQ, ::testing::Combine( ::testing::ValuesIn(input_shapes), @@ -74,8 +95,12 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulFQ, MatMulFQ, INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBias, MatMulBias, ::testing::Combine( - ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 1, 69, 49}}, - std::vector{{1, 2, 95, 1023}, {1, 2, 1023, 255}, {1, 2, 95, 255}}), + ::testing::Values(std::vector{STATIC_SHAPE(1, 2, 69, 43), + STATIC_SHAPE(2, 1, 43, 49), + STATIC_SHAPE(1, 1, 69, 49)}, + std::vector{STATIC_SHAPE(1, 2, 95, 1023), + STATIC_SHAPE(1, 2, 1023, 255), + STATIC_SHAPE(1, 2, 95, 255)}), ::testing::ValuesIn(precisions(false)), ::testing::Values(1), // Subgraph; ::testing::Values(1), // Tokenized MatMul+Bias @@ -84,9 +109,13 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBias, MatMulBias, INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBiasQuantized, MatMulBiasQuantized, ::testing::Combine( - ::testing::ValuesIn(std::vector>{ - std::vector{{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 2, 1, 1}}, - std::vector{{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 2, 69, 49}}}), + ::testing::ValuesIn(std::vector>{ + std::vector{STATIC_SHAPE(1, 2, 69, 43), + STATIC_SHAPE(2, 1, 43, 49), + STATIC_SHAPE(1, 2, 1, 1)}, + std::vector{STATIC_SHAPE(1, 2, 69, 43), + STATIC_SHAPE(2, 1, 43, 49), + STATIC_SHAPE(1, 2, 69, 49)}}), ::testing::ValuesIn(quantized_precisions()), ::testing::Values(1), // Subgraph ::testing::Values(1), // Tokenized MatMul+Bias @@ -95,7 +124,9 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBiasQuantized, MatMulBiasQuantized INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulQuantized, MatMulQuantized, ::testing::Combine( - ::testing::Values(std::vector{{1, 16, 128, 64}, {1, 16, 64, 128}, {128, 64}}), + ::testing::Values(std::vector{STATIC_SHAPE(1, 16, 128, 64), + STATIC_SHAPE(1, 16, 64, 128), + STATIC_SHAPE(128, 64)}), ::testing::ValuesIn(quantized_precisions()), ::testing::Values(3), // Subgraph + Reshape + Subgraph ::testing::Values(2), // Tokenized [MatMul+FQ+Matmul] and [FQ] @@ -104,7 +135,9 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulQuantized, MatMulQuantized, INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulQuantizedSoftmax, MatMulQuantizedSoftmax, ::testing::Combine( - ::testing::Values(std::vector{{1, 16, 128, 64}, {1, 16, 64, 128}, {128, 64}}), + ::testing::Values(std::vector{STATIC_SHAPE(1, 16, 128, 64), + STATIC_SHAPE(1, 16, 64, 128), + STATIC_SHAPE(128, 64)}), ::testing::ValuesIn(quantized_precisions()), ::testing::Values(3), // Subgraph + Reshape + Subgraph ::testing::Values(2), // Tokenized [MatMul+FQ+Matmul] and [FQ] diff --git a/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp b/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp index 436da71327f96f..7a276d2fe9ef95 100644 --- a/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp +++ b/src/plugins/intel_cpu/tests/unit/snippets_transformations/x64/lowered/brgemm_blocking.cpp @@ -131,18 +131,21 @@ TEST_F(BrgemmBlockingTest, BlockingIsNotNeeded) { const ov::PartialShape input_shape_a{1, 16, m, k}; const ov::PartialShape input_shape_b{1, 16, k, n}; const auto precision = ov::element::f32; + const std::vector layout{}; { auto data_a = linear_ir->push_node(precision, input_shape_a); auto data_b = linear_ir->push_node(precision, input_shape_b); - auto brgemm = linear_ir->push_node(data_a.second, data_b.second, BrgemmCPU::Type::Floating); + auto brgemm = linear_ir->push_node(data_a.second, data_b.second, BrgemmCPU::Type::Floating, + 0, 0, 0, layout, layout, layout, m, k, n); init_expr_descriptors(*brgemm.first); auto result = linear_ir->push_node(brgemm.second); } { auto data_a = linear_ir_ref->push_node(precision, input_shape_a); auto data_b = linear_ir_ref->push_node(precision, input_shape_b); - auto brgemm = linear_ir_ref->push_node(data_a.second, data_b.second, BrgemmCPU::Type::Floating); + auto brgemm = linear_ir_ref->push_node(data_a.second, data_b.second, BrgemmCPU::Type::Floating, + 0, 0, 0, layout, layout, layout, m, k, n); brgemm.second->set_beta(0.f); init_expr_descriptors(*brgemm.first, {{m, k}, {k, n}, {m, n}}); auto result = linear_ir_ref->push_node(brgemm.second); diff --git a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp index ad9ee8252e362d..300836117307cf 100644 --- a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp @@ -11,7 +11,7 @@ namespace test { namespace snippets { typedef std::tuple< - std::vector, // Input Shapes + std::vector, // Input Shapes std::vector,// Input Element types size_t, // Expected num nodes size_t, // Expected num subgraphs @@ -26,32 +26,32 @@ class MatMul : public testing::WithParamInterface& inputShapes, const std::vector& types); + virtual void init_subgraph(const std::vector& types); }; class MatMulFQ : public MatMul { protected: - void init_subgraph(const std::vector& inputShapes, const std::vector& types) override; + void init_subgraph(const std::vector& types) override; }; class MatMulBias : public MatMul { protected: - void init_subgraph(const std::vector& inputShapes, const std::vector& types) override; + void init_subgraph(const std::vector& types) override; }; class MatMulBiasQuantized : public MatMul { protected: - void init_subgraph(const std::vector& inputShapes, const std::vector& types) override; + void init_subgraph(const std::vector& types) override; }; class MatMulQuantized : public MatMul { protected: - void init_subgraph(const std::vector& inputShapes, const std::vector& types) override; + void init_subgraph(const std::vector& types) override; }; class MatMulQuantizedSoftmax : public MatMul { protected: - void init_subgraph(const std::vector& inputShapes, const std::vector& types) override; + void init_subgraph(const std::vector& types) override; }; } // namespace snippets diff --git a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp index e7bf30f51115c6..0da01d83e7d948 100644 --- a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp @@ -12,14 +12,15 @@ namespace test { namespace snippets { std::string MatMul::getTestCaseName(testing::TestParamInfo obj) { - std::vector input_shapes; + std::vector input_shapes; std::vector elem_types; std::string targetDevice; size_t num_nodes, num_subgraphs; std::tie(input_shapes, elem_types, num_nodes, num_subgraphs, targetDevice) = obj.param; std::ostringstream result; for (size_t i = 0; i < input_shapes.size(); i++) - result << "IS[" << i <<"]=" << ov::test::utils::partialShape2str({input_shapes[i]}) << "_"; + result << "IS[" << i << "]=" << input_shapes[i] << "_"; + for (size_t i = 0; i < elem_types.size(); i++) result << "T[" << i <<"]=" << elem_types[i] << "_"; result << "#N=" << num_nodes << "_"; @@ -29,44 +30,44 @@ std::string MatMul::getTestCaseName(testing::TestParamInfo input_shapes; + std::vector input_shapes; std::vector elem_types; std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + init_input_shapes(input_shapes); - init_subgraph(input_shapes, elem_types); + init_subgraph(elem_types); if (!configuration.count("SNIPPETS_MODE")) { configuration.insert({"SNIPPETS_MODE", "IGNORE_CALLBACK"}); } } -void MatMul::init_subgraph(const std::vector& inputShapes, const std::vector& types) { - auto f = ov::test::snippets::MatMulFunction(inputShapes, types); +void MatMul::init_subgraph(const std::vector& types) { + auto f = ov::test::snippets::MatMulFunction(inputDynamicShapes, types); function = f.getOriginal(); } -void MatMulFQ::init_subgraph(const std::vector& inputShapes, const std::vector& types) { - auto f = ov::test::snippets::FQMatMulFunction(inputShapes); +void MatMulFQ::init_subgraph(const std::vector& types) { + auto f = ov::test::snippets::FQMatMulFunction(inputDynamicShapes); function = f.getOriginal(); } -void MatMulBias::init_subgraph(const std::vector& inputShapes, const std::vector& types) { - auto f = ov::test::snippets::MatMulBiasFunction(inputShapes, types); +void MatMulBias::init_subgraph(const std::vector& types) { + auto f = ov::test::snippets::MatMulBiasFunction(inputDynamicShapes, types); function = f.getOriginal(); } -void MatMulBiasQuantized::init_subgraph(const std::vector& inputShapes, const std::vector& types) { - auto f = ov::test::snippets::MatMulBiasQuantizedFunction(inputShapes, types); +void MatMulBiasQuantized::init_subgraph(const std::vector& types) { + auto f = ov::test::snippets::MatMulBiasQuantizedFunction(inputDynamicShapes, types); function = f.getOriginal(); } -void MatMulQuantized::init_subgraph(const std::vector& inputShapes, const std::vector& types) { - auto f = ov::test::snippets::MatMulsQuantizedFunction(inputShapes, types); +void MatMulQuantized::init_subgraph(const std::vector& types) { + auto f = ov::test::snippets::MatMulsQuantizedFunction(inputDynamicShapes, types); function = f.getOriginal(); } -void MatMulQuantizedSoftmax::init_subgraph(const std::vector& inputShapes, const std::vector& types) { - auto f = ov::test::snippets::MatMulsQuantizedSoftmaxFunction(inputShapes, types); +void MatMulQuantizedSoftmax::init_subgraph(const std::vector& types) { + auto f = ov::test::snippets::MatMulsQuantizedSoftmaxFunction(inputDynamicShapes, types); function = f.getOriginal(); }