diff --git a/src/common/snippets/docs/mha_optimization_guide.md b/src/common/snippets/docs/mha_optimization_guide.md index 28245017833a4a..1ea3a4c24c3524 100644 --- a/src/common/snippets/docs/mha_optimization_guide.md +++ b/src/common/snippets/docs/mha_optimization_guide.md @@ -65,7 +65,7 @@ The supported by decomposition Transpose orders are defined by `TokenizeMHASnipp [SplitDimensionM](../src/pass/split_dimension_m.cpp) splits M dimension of MHA in 2 parts (`batch_m` and `new_m`) by inserting Reshape on A input of the first Matmul and output of the second Matmul (the rest Subgraph's inputs are reshaped by Unsqueeze-like reshapes in order not to break subgraph semantic). This optimization increases parallel work amount by `batch_m` times thus enabling a more efficient parallel execution in some cases. -The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::get_splited_dimensions` method. +The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::split` method. Let's consider an example of the transformation: diff --git a/src/common/snippets/include/snippets/pass/split_dimension_m.hpp b/src/common/snippets/include/snippets/pass/split_dimension_m.hpp index e9a9a46d3847ff..8107a7684f0abd 100644 --- a/src/common/snippets/include/snippets/pass/split_dimension_m.hpp +++ b/src/common/snippets/include/snippets/pass/split_dimension_m.hpp @@ -67,7 +67,8 @@ class SplitDimensionM: public CommonOptimizations::SubgraphPass { private: static std::shared_ptr get_matmul(const std::shared_ptr& subgraph); - static std::pair get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + static std::pair compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); + static std::pair compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount); void reshape_subgraph(const std::shared_ptr& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim); diff --git a/src/common/snippets/src/pass/split_dimension_m.cpp b/src/common/snippets/src/pass/split_dimension_m.cpp index ae95a371483163..5c3c0c88088377 100644 --- a/src/common/snippets/src/pass/split_dimension_m.cpp +++ b/src/common/snippets/src/pass/split_dimension_m.cpp @@ -4,8 +4,8 @@ #include "snippets/pass/split_dimension_m.hpp" -#include "snippets/utils/utils.hpp" #include "snippets/itt.hpp" +#include "snippets/utils/utils.hpp" namespace { size_t get_dim_M(const ov::Shape& shape) { @@ -31,7 +31,7 @@ bool SplitDimensionM::is_supported_matmul(const std::shared_ptr& return matmul && !matmul->get_transpose_a() && !matmul->is_dynamic(); } -std::pair SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { +std::pair SplitDimensionM::compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { std::pair splited = { 1, m_dim }; // Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel work amount @@ -70,6 +70,25 @@ std::pair SplitDimensionM::get_splited_dimensions(size_t batch_d return splited; } +std::pair SplitDimensionM::compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { + std::pair splited = { 1, m_dim }; + // If M dim is big enough and batch_dim is not ideally divisible by optimal_parallelism_work_amount, + // it is better to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop + // iteration. At the same time, kernel_m mustn't be less than min_kernel_m. + const bool big_m_dim = m_dim >= 4000; + if (big_m_dim && batch_dim % optimal_parallelism_work_amount != 0) { + const size_t min_kernel_m = 32; + for (size_t divisor = min_kernel_m; divisor < std::sqrt(m_dim); ++divisor) { + if (m_dim % divisor == 0) { + splited.first = m_dim / divisor; + splited.second = divisor; + break; + } + } + } + return splited; +} + bool SplitDimensionM::can_be_optimized(const std::shared_ptr& node, size_t concurrency) { if (!is_supported_matmul(node)) return false; @@ -128,19 +147,17 @@ bool SplitDimensionM::split(const ov::Shape& shape, size_t optimal_parallelism_w const auto batch_dim = std::accumulate(shape.rbegin() + 2, shape.rend(), size_t(1), std::multiplies()); // B (batch) const auto m_dim = get_dim_M(shape); // M + batch_m_dim = 1, new_m_dim = m_dim; if (is_prime_number(m_dim)) return false; - auto is_optimized = [&](size_t batch_dim) { - return batch_dim >= optimal_parallelism_work_amount; - }; - - // We skip optimization if the current batch is optimal for concurrency - if (is_optimized(batch_dim)) - return false; - - std::tie(batch_m_dim, new_m_dim) = get_splited_dimensions(batch_dim, m_dim, optimal_parallelism_work_amount); - return is_optimized(batch_dim * batch_m_dim); + if (batch_dim < optimal_parallelism_work_amount) { + std::tie(batch_m_dim, new_m_dim) = compute_conservative_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount); + } else { + std::tie(batch_m_dim, new_m_dim) = compute_aggressive_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount); + } + bool optimized = batch_m_dim != 1; + return optimized; } void SplitDimensionM::reshape_subgraph(const std::shared_ptr& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim) {