Skip to content

Commit

Permalink
SplitDimensionM: heuristic update
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Jan 2, 2025
1 parent c14134a commit 321e549
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/common/snippets/docs/mha_optimization_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The supported by decomposition Transpose orders are defined by `TokenizeMHASnipp

[SplitDimensionM](../src/pass/split_dimension_m.cpp) splits M dimension of MHA in 2 parts (`batch_m` and `new_m`) by inserting Reshape on A input of the first Matmul and output of the second Matmul (the rest Subgraph's inputs are reshaped by Unsqueeze-like reshapes in order not to break subgraph semantic).
This optimization increases parallel work amount by `batch_m` times thus enabling a more efficient parallel execution in some cases.
The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::get_splited_dimensions` method.
The splitting is performed based on heuristic algorithm which can be found in `SplitDimensionM::split` method.

Let's consider an example of the transformation:

Expand Down
15 changes: 14 additions & 1 deletion src/common/snippets/include/snippets/pass/split_dimension_m.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,24 @@ class SplitDimensionM: public CommonOptimizations::SubgraphPass {

private:
static std::shared_ptr<ov::op::v0::MatMul> get_matmul(const std::shared_ptr<op::Subgraph>& subgraph);
static std::pair<size_t, size_t> get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
/**
* @brief Contains splitM approaches allowing to get the batch ideally divisible by optimal_parallelism_work_amount
*/
static std::pair<size_t, size_t> split_ideally(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
/**
* @brief Splits m_dim to minimize kernel_m in order to reduce waiting time for idle threads at the last parallel loop iteration.
*/
static std::pair<size_t, size_t> split_minimize_kernel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);
/**
* @brief Splits m_dim to get the batch in (optimal_parallelism_work_amount, 2 * optimal_parallelism_work_amount) interval
*/
static std::pair<size_t, size_t> split_fallback_increase_parallel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount);

void reshape_subgraph(const std::shared_ptr<op::Subgraph>& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim);

size_t m_concurrency;

static const size_t min_kernel_m;
};
} // namespace pass
} // namespace snippets
Expand Down
88 changes: 58 additions & 30 deletions src/common/snippets/src/pass/split_dimension_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

#include "snippets/pass/split_dimension_m.hpp"

#include "snippets/utils/utils.hpp"
#include "snippets/itt.hpp"
#include "snippets/utils/utils.hpp"

namespace {
size_t get_dim_M(const ov::Shape& shape) {
Expand All @@ -26,50 +26,69 @@ bool is_prime_number(size_t value) {
namespace ov {
namespace snippets {
namespace pass {

const size_t SplitDimensionM::min_kernel_m = 32;

bool SplitDimensionM::is_supported_matmul(const std::shared_ptr<const ov::Node>& node) {
const auto matmul = ov::as_type_ptr<const ov::op::v0::MatMul>(node);
return matmul && !matmul->get_transpose_a() && !matmul->is_dynamic();
}

std::pair<size_t, size_t> SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = { 1, m_dim };

std::pair<size_t, size_t> SplitDimensionM::split_ideally(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
// Ideal case #1: M can be split on the parts one of which complements the batch dimension to the optimal parallel work amount
// In this case, each thread will execute the Snippets kernel once
const size_t lower_bound = optimal_parallelism_work_amount / batch_dim;
if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) {
splited.first = lower_bound;
splited.second = m_dim / lower_bound;
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}
if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0)
return std::make_pair(lower_bound, m_dim / lower_bound);

// Ideal case #2: M is divisible by optimal parallel work amount, and the new_m_dim is big enough
// In this case, each thread will execute the Snippets kernel 'batch_dim' times
if (m_dim % optimal_parallelism_work_amount == 0) {
const auto new_m_dim = m_dim / optimal_parallelism_work_amount;
const size_t min_kernel_m = 64;
if (new_m_dim >= min_kernel_m) {
splited.first = optimal_parallelism_work_amount;
splited.second = new_m_dim;
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}
if (new_m_dim >= min_kernel_m)
return std::make_pair(optimal_parallelism_work_amount, new_m_dim);
}

return std::make_pair(1, m_dim);
}

std::pair<size_t, size_t> SplitDimensionM::split_fallback_increase_parallel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = { 1, m_dim };
const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim);
for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) {
size_t divisor_1 = m_dim / divisor_0;
if (divisor_1 * divisor_0 == m_dim) {
splited.first = divisor_0;
splited.second = divisor_1;
break;
}
if (divisor_1 * divisor_0 == m_dim)
return divisor_0 * batch_dim >= optimal_parallelism_work_amount ? std::make_pair(divisor_0, divisor_1) : splited;
}
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}

std::pair<size_t, size_t> SplitDimensionM::split_minimize_kernel_wa(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
// This heuristic minimizes 'm_kernel' (=> maximizes 'm_batch') with a limitation that 'm_kernel >= min_kernel_m'.
// In other words, it tries to find 'm_kernel' bigger than 'min_kernel_m' and at the same time as close as possible to this value.
std::pair<size_t, size_t> best_result = {1, m_dim};
for (size_t divisor = 2; divisor < std::sqrt(m_dim); ++divisor) {
if (m_dim % divisor != 0)
continue;
// If divisor is more than 'min_kernel_m', divisor becomes 'm_kernel',
// guaranteeing the most optimal implementation from 'm_kernel' minimization perspective.
if (divisor >= min_kernel_m)
return std::make_pair(m_dim / divisor, divisor);

// If divisor is less than 'min_kernel_m', divisor becomes m_batch.
// However, it is not guaranteed that the current 'm_kernel = m_dim / divisor' is minimized, as one of the next divisors can be more optimal.
// So in this case the best result is remembered
const size_t m_kernel = m_dim / divisor;
if (m_kernel >= min_kernel_m) {
best_result.first = divisor;
best_result.second = m_kernel;
}
}
if (best_result.first * batch_dim >= optimal_parallelism_work_amount)
return best_result;
return std::make_pair(1, m_dim);
}

bool SplitDimensionM::can_be_optimized(const std::shared_ptr<const ov::Node>& node, size_t concurrency) {
if (!is_supported_matmul(node))
return false;
Expand Down Expand Up @@ -131,16 +150,25 @@ bool SplitDimensionM::split(const ov::Shape& shape, size_t optimal_parallelism_w
if (is_prime_number(m_dim))
return false;

auto is_optimized = [&](size_t batch_dim) {
return batch_dim >= optimal_parallelism_work_amount;
};

// We skip optimization if the current batch is optimal for concurrency
if (is_optimized(batch_dim))
if (batch_dim % optimal_parallelism_work_amount == 0)
return false;

std::tie(batch_m_dim, new_m_dim) = get_splited_dimensions(batch_dim, m_dim, optimal_parallelism_work_amount);
return is_optimized(batch_dim * batch_m_dim);
auto split_is_done = [&batch_m_dim]() {
return batch_m_dim != 1;
};

std::tie(batch_m_dim, new_m_dim) = split_ideally(batch_dim, m_dim, optimal_parallelism_work_amount);
if (split_is_done())
return true;

std::tie(batch_m_dim, new_m_dim) = split_minimize_kernel_wa(batch_dim, m_dim, optimal_parallelism_work_amount);
if (split_is_done())
return true;
// If all the previous heuristics failed, fallback heuristic is used, which reflects the old splitting behavior
if (batch_dim < optimal_parallelism_work_amount)
std::tie(batch_m_dim, new_m_dim) = split_fallback_increase_parallel_wa(batch_dim, m_dim, optimal_parallelism_work_amount);
return split_is_done();
}

void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim) {
Expand Down
2 changes: 2 additions & 0 deletions src/common/snippets/tests/src/utils/split_dim_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ const std::vector<SplitDimensionMParams> split_dimension_cases = {
{InputData{25, 50, 40}, ReferenceData{true, 2, 25}},
{InputData{5, 16384, 40}, ReferenceData{true, 8, 2048}},
{InputData{5, 16384, 32}, ReferenceData{true, 32, 512}},
{InputData{48, 4097, 32}, ReferenceData{true, 17, 241}},
{InputData{48, 6600, 32}, ReferenceData{true, 200, 33}},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SplitDimensionM,
Expand Down

0 comments on commit 321e549

Please sign in to comment.