Skip to content

Commit

Permalink
SplitDimensionM: heuristic update
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Dec 23, 2024
1 parent 9ff5942 commit b358b87
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/common/snippets/docs/mha_optimization_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The supported by decomposition Transpose orders are defined by `TokenizeMHASnipp

[SplitDimensionM](../src/pass/split_dimension_m.cpp) splits M dimension of MHA in 2 parts (`batch_m` and `new_m`) by inserting Reshape on A input of the first Matmul and output of the second Matmul (the rest Subgraph's inputs are reshaped by Unsqueeze-like reshapes in order not to break subgraph semantic).
This optimization increases parallel work amount by `batch_m` times thus enabling a more efficient parallel execution in some cases.
The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::get_splited_dimensions` method.
The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::split` method.

Let's consider an example of the transformation:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class SplitDimensionM: public CommonOptimizations::SubgraphPass {

private:
static std::shared_ptr<ov::op::v0::MatMul> get_matmul(const std::shared_ptr<op::Subgraph>& subgraph);
static std::pair<size_t, size_t> get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
static std::pair<size_t, size_t> compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
static std::pair<size_t, size_t> compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);

void reshape_subgraph(const std::shared_ptr<op::Subgraph>& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim);

Expand Down
41 changes: 29 additions & 12 deletions src/common/snippets/src/pass/split_dimension_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

#include "snippets/pass/split_dimension_m.hpp"

#include "snippets/utils/utils.hpp"
#include "snippets/itt.hpp"
#include "snippets/utils/utils.hpp"

namespace {
size_t get_dim_M(const ov::Shape& shape) {
Expand All @@ -31,7 +31,7 @@ bool SplitDimensionM::is_supported_matmul(const std::shared_ptr<const ov::Node>&
return matmul && !matmul->get_transpose_a() && !matmul->is_dynamic();
}

std::pair<size_t, size_t> SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> SplitDimensionM::compute_conservative_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = { 1, m_dim };

// Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel work amount
Expand Down Expand Up @@ -70,6 +70,25 @@ std::pair<size_t, size_t> SplitDimensionM::get_splited_dimensions(size_t batch_d
return splited;
}

std::pair<size_t, size_t> SplitDimensionM::compute_aggressive_heuristic(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = { 1, m_dim };
// If M dim is big enough and batch_dim is not ideally divisible by optimal_parallelism_work_amount,
// it is better to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop
// iteration. At the same time, kernel_m mustn't be less than min_kernel_m.
const bool big_m_dim = m_dim >= 4000;
if (big_m_dim && batch_dim % optimal_parallelism_work_amount != 0) {
const size_t min_kernel_m = 32;
for (size_t divisor = min_kernel_m; divisor < std::sqrt(m_dim); ++divisor) {
if (m_dim % divisor == 0) {
splited.first = m_dim / divisor;
splited.second = divisor;
break;
}
}
}
return splited;
}

bool SplitDimensionM::can_be_optimized(const std::shared_ptr<const ov::Node>& node, size_t concurrency) {
if (!is_supported_matmul(node))
return false;
Expand Down Expand Up @@ -128,19 +147,17 @@ bool SplitDimensionM::split(const ov::Shape& shape, size_t optimal_parallelism_w
const auto batch_dim =
std::accumulate(shape.rbegin() + 2, shape.rend(), size_t(1), std::multiplies<size_t>()); // B (batch)
const auto m_dim = get_dim_M(shape); // M
batch_m_dim = 1, new_m_dim = m_dim;
if (is_prime_number(m_dim))
return false;

auto is_optimized = [&](size_t batch_dim) {
return batch_dim >= optimal_parallelism_work_amount;
};

// We skip optimization if the current batch is optimal for concurrency
if (is_optimized(batch_dim))
return false;

std::tie(batch_m_dim, new_m_dim) = get_splited_dimensions(batch_dim, m_dim, optimal_parallelism_work_amount);
return is_optimized(batch_dim * batch_m_dim);
if (batch_dim < optimal_parallelism_work_amount) {
std::tie(batch_m_dim, new_m_dim) = compute_conservative_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount);
} else {
std::tie(batch_m_dim, new_m_dim) = compute_aggressive_heuristic(batch_dim, m_dim, optimal_parallelism_work_amount);
}
bool optimized = batch_m_dim != 1;
return optimized;
}

void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim) {
Expand Down

0 comments on commit b358b87

Please sign in to comment.