Skip to content

Commit

Permalink
Correct MHA tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Nov 18, 2024
1 parent df6e734 commit e1c3ed7
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 13 deletions.
10 changes: 8 additions & 2 deletions src/common/snippets/include/snippets/pass/tokenization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ class SnippetsTokenization : public ov::pass::ModelPass {
*/
struct Config {
Config(size_t concurrency, size_t data_ptr_gpr_count, bool split_m_dimension, bool enable_transpose_on_output,
bool dyn_mha_token, std::set<size_t> mha_transpose_ranks)
bool dyn_mha_token, std::set<size_t> mha_transpose_ranks, ov::pass::param_callback mha_tokenize_mm_b_input_callback = nullptr)
: m_concurrency(concurrency), m_data_ptr_gpr_count(data_ptr_gpr_count), m_split_m_dimension(split_m_dimension),
m_mha_token_enable_transpose_on_output(enable_transpose_on_output), m_is_dynamic_mha_token_enabled(dyn_mha_token),
m_mha_supported_transpose_ranks(std::move(mha_transpose_ranks)) {
m_mha_supported_transpose_ranks(std::move(mha_transpose_ranks)), m_mha_tokenize_mm_b_input_callback(mha_tokenize_mm_b_input_callback) {
OPENVINO_ASSERT(concurrency > 0, "Concurrency should be greater than 0");
OPENVINO_ASSERT(data_ptr_gpr_count > 0, "data_ptr_gpr_count should be greater than 0");
}
Expand Down Expand Up @@ -102,6 +102,10 @@ class SnippetsTokenization : public ov::pass::ModelPass {
return m_mha_supported_transpose_ranks;
}

bool mha_tokenize_mm_b_input_callback(const std::shared_ptr<const ov::Node>& node) const {
return m_mha_tokenize_mm_b_input_callback ? m_mha_tokenize_mm_b_input_callback(node) : false;
}

private:
size_t m_concurrency = 0;
// The number of gpr that can be used as data pointers for data nodes (Parameter (and non-Scalar Constants),
Expand All @@ -121,6 +125,8 @@ class SnippetsTokenization : public ov::pass::ModelPass {
// Note that in general Snippets support Transpose of any ranks.
// But at the moment Transpose is used only in MHA pattern where 3D and 4D tensors are supported.
std::set<size_t> m_mha_supported_transpose_ranks = { 3, 4 };

ov::pass::param_callback m_mha_tokenize_mm_b_input_callback = nullptr;
};

OPENVINO_RTTI("SnippetsTokenization", "0");
Expand Down
16 changes: 11 additions & 5 deletions src/common/snippets/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
// We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching
const auto is_transposed_b_0 = matmul0->get_transpose_b();
bool has_matmul0_has_ops_on_input = false;
while (is_supported_intermediate_op(parent)) {

const bool support_mm0_b_input_tokenization = !config.mha_tokenize_mm_b_input_callback(matmul0);
while (support_mm0_b_input_tokenization && is_supported_intermediate_op(parent)) {
// All supported ops have only one output port
if (parent->get_output_target_inputs(0).size() != 1)
break;
Expand Down Expand Up @@ -404,12 +406,16 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
}
};

const auto transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(parent);
if (support_mm0_b_input_tokenization) {
const auto transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(parent);
tokenize_transpose(transpose1, is_transposed_b_0, get_decomposed_transpose_order(pattern_rank), ordered_ops.begin());
}
const auto transpose0 = ov::as_type_ptr<ov::opset1::Transpose>(matmul0->get_input_node_shared_ptr(0));
const auto transpose2 = ov::as_type_ptr<ov::opset1::Transpose>(matmul1->get_input_node_shared_ptr(1));
tokenize_transpose(transpose1, is_transposed_b_0, get_decomposed_transpose_order(pattern_rank), ordered_ops.begin());
tokenize_transpose(transpose0, matmul0->get_transpose_a(), get_fusion_transpose_order(pattern_rank), ordered_ops.begin());
tokenize_transpose(transpose2, matmul1->get_transpose_b(), get_fusion_transpose_order(pattern_rank), ordered_ops.end());
if (!config.mha_tokenize_mm_b_input_callback(matmul1)) {
const auto transpose2 = ov::as_type_ptr<ov::opset1::Transpose>(matmul1->get_input_node_shared_ptr(1));
tokenize_transpose(transpose2, matmul1->get_transpose_b(), get_fusion_transpose_order(pattern_rank), ordered_ops.end());
}
ordered_ops.push_back(matmul1);

bool are_ops_after_matmul1 = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -953,9 +953,21 @@ void Transformations::MainSnippets(void) {
bool split_m_dimension = !ignoreCallback;
// [122706] Some 3D MHA Patterns have perf regressions when Transpose op is tokenized
std::set<size_t> mha_supported_transpose_ranks = { 4 };

// Note: this is a temporary WA, avoiding matmul B input tokenization in the cases when CPU .
// It will be removed when plugin specific SubgraphPass will be implemented.
auto mha_tokenize_mm_b_input_callback = [this](const std::shared_ptr<const ov::Node>& node) {
const auto& input_type_0 = node->get_input_element_type(0);
const auto& input_type_1 = node->get_input_element_type(1);

const bool u8i8_repacking_wo_compensations = input_type_0 == ov::element::u8 && input_type_1 == ov::element::i8;
const bool bf16_repacking = input_type_0 == ov::element::f32 && input_type_1 == ov::element::f32 &&
config.inferencePrecision == ov::element::bf16;
return u8i8_repacking_wo_compensations || bf16_repacking;
};
snippets::pass::SnippetsTokenization::Config tokenization_config(concurrency, data_ptr_gpr_count, split_m_dimension,
mha_token_enable_transpose_on_output, is_dynamic_mha_token_enabled,
mha_supported_transpose_ranks);
mha_supported_transpose_ranks, mha_tokenize_mm_b_input_callback);

ov::pass::Manager snippetsManager("CPU:Snippets");
snippetsManager.set_per_pass_validation(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
//

#include "common_test_utils/common_utils.hpp"
#include "common_test_utils/ov_tensor_utils.hpp"
#include "common_test_utils/node_builders/constant.hpp"
#include "common_test_utils/node_builders/fake_quantize.hpp"
#include "common_test_utils/ov_tensor_utils.hpp"
#include "internal_properties.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "utils/cpu_test_utils.hpp"
Expand Down Expand Up @@ -666,15 +666,27 @@ std::vector<std::vector<ElementType>> matMulIn0PrecisionsQuant = {
{ElementType::i8, ElementType::u8},
};

INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant_Pattern0,
INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant_Pattern0_i8i8,
MHAQuantTest,
::testing::Combine(::testing::ValuesIn(static_shapes_to_test_representation(inputShapesQuant)),
::testing::ValuesIn(inputPrecisionsQuant),
::testing::ValuesIn(matMulIn0PrecisionsQuant),
::testing::Values(std::vector<ElementType>{ElementType::i8, ElementType::i8}),
::testing::Values(0),
::testing::Values(ExpectedNodes{
{"Subgraph", 5}, // FQs on inputs x 3 + MHA + Deq Mul
{"Transpose", 1}}), // Transpose between MHA and Deq Mul
{"Transpose", 1}}), // Transpose between MHA and Deq Mul + Extracted transpose on B input of 2nd MM
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHAQuantTest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant_Pattern0_i8u8,
MHAQuantTest,
::testing::Combine(::testing::ValuesIn(static_shapes_to_test_representation(inputShapesQuant)),
::testing::ValuesIn(inputPrecisionsQuant),
::testing::Values(std::vector<ElementType>{ElementType::i8, ElementType::u8}),
::testing::Values(0),
::testing::Values(ExpectedNodes{
{"Subgraph", 5}, // FQs on inputs x 3 + MHA + Deq Mul
{"Transpose", 2}}), // Transpose between MHA and Deq Mul + Extracted transpose on B input of 2nd MM
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHAQuantTest::getTestCaseName);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(ov::element::f32),
::testing::Values(false), // The graph doesn't contain Multiply
::testing::Values(MHA::default_thread_count),
::testing::Values(6), // FQx3 on inputs + MHA + Transpose on output + Deq Mul
::testing::Values(7), // FQx3 on inputs + MHA + Transpose on output + Transpose on Matmul's B input + Deq Mul
::testing::Values(5), // FQx3 on inputs + MHA + Deq Mul
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
Expand Down

0 comments on commit e1c3ed7

Please sign in to comment.