Skip to content

Commit

Permalink
[Snippets][CPU] Added test
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Aug 21, 2024
1 parent 14a16ae commit fa36a30
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 49 deletions.
1 change: 0 additions & 1 deletion src/common/snippets/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,6 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken

// mark the Subgraph as Completed to not allow Snippets to include any nodes into the MHA Subgraph in common Tokenization
SetSnippetsSubgraphType(subgraph, SnippetsSubgraphType::Completed);
std::cout << "tokenized\n";

return true;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_4D,
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false, true}),
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(1),
::testing::Values(1),
Expand Down Expand Up @@ -82,7 +83,40 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_DynMHA_4D,
::testing::Combine(::testing::ValuesIn(inputShapes_4D_dynamic),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false}),
::testing::Values(false),
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(1),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

std::vector<std::vector<ov::test::InputShape>> inputShapes_4D_dynamic_with_mul{
{
{PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {1, 70, 3, 19}, {1, 128, 3, 64}, {1, 68, 6, 87}}},
{PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 49, 1, 19}, {1, 128, 1, 64}, {2, 13, 6, 87}}},
{PartialShape{1}, {{1}, {1}, {1}, {1} }},
{PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {1, 1, 70, 49}, {2, 1, 128, 128}, {1, 1, 68, 13}}},
{PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {1, 49, 3, 19}, {1, 128, 3, 64}, {2, 13, 6, 87}}},
},
{
{PartialShape{-1, -1, 12, 64}, {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}},
{PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}},
{PartialShape{1}, {{1}, {1}, {1}, {1}, {1}}},
{PartialShape{-1, 12, -1, -1}, {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}},
{PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}},
}
};


INSTANTIATE_TEST_SUITE_P(smoke_Snippets_DynMHA_4D_Wil_Dynamic_Mul,
MHA,
::testing::Combine(::testing::ValuesIn(inputShapes_4D_dynamic_with_mul),
::testing::ValuesIn(precision_f32(5)),
::testing::Values(ov::element::f32),
::testing::Values(true),
::testing::Values(false),
::testing::Values(MHA::default_thread_count),
::testing::Values(1),
::testing::Values(1),
Expand All @@ -97,6 +131,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_3D,
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false, true}),
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(5), // [122706]: Subgraph + 4 Transpose
::testing::Values(2), // decomposed Transpose + MHA
Expand All @@ -113,6 +148,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f32),
::testing::Values(true),
::testing::Values(true),
::testing::Values(4), // 4 Threads
::testing::Values(6), // Subgraph + 4 Reshapes on inputs and 1 Reshape on output
::testing::Values(1),
Expand All @@ -128,6 +164,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f32),
::testing::Values(true),
::testing::Values(true),
::testing::Values(4), // 4 Threads
::testing::Values(10), // Subgraph + 4 Reshapes on inputs and 1 Reshape on output + 4 Transposes
::testing::Values(1), // MHA
Expand Down Expand Up @@ -169,6 +206,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f32),
::testing::Values(false),
::testing::Values(true),
::testing::Values(4), // 4 Threads
::testing::Values(1),
::testing::Values(1),
Expand Down Expand Up @@ -198,6 +236,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f32),
::testing::Values(false),
::testing::Values(true),
::testing::Values(4), // 4 Threads
::testing::Values(5), // Subgraph + 4 Transpose
::testing::Values(2), // MHA + one of the transposes is executed via Subgraph (because callback is disabled)
Expand All @@ -211,6 +250,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
::testing::ValuesIn(precision_bf16(4)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false, true}),
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(7), // MHA + 5 Converts + 1 Transpose on output
::testing::Values(6), // MHA + 5 Converts on inputs and output
Expand All @@ -224,6 +264,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::bf16),
::testing::ValuesIn({false}),
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(7),
::testing::Values(6),
Expand All @@ -239,6 +280,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false}), // Need to support True for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(1),
::testing::Values(1),
Expand All @@ -262,6 +304,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(6)),
::testing::Values(ov::element::f32),
::testing::Values(false), // Need to support True for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(2), // Less + MHA
::testing::Values(2),
Expand All @@ -283,6 +326,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(std::vector<ov::element::Type>{}),
::testing::Values(ov::element::f32),
::testing::Values(true), // Need to support False for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(1),
::testing::Values(1),
Expand All @@ -297,6 +341,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(1),
::testing::Values(1),
Expand All @@ -311,6 +356,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(1),
::testing::Values(1),
Expand Down Expand Up @@ -340,6 +386,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(1),
::testing::Values(1),
Expand All @@ -354,6 +401,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_bf16(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
Expand All @@ -368,6 +416,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_bf16(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
Expand All @@ -382,6 +431,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::bf16),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
Expand All @@ -396,6 +446,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::bf16),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
Expand All @@ -411,6 +462,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::Values(false), // The graph doesn't contain Multiply
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(6), // FQx3 on inputs + MHA + Transpose on output + Deq Mul
::testing::Values(5), // FQx3 on inputs + MHA + Deq Mul
Expand All @@ -426,6 +478,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::Values(false), // The graph doesn't contain Multiply
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(9), // FQx2 on inputs + MHA + Transpose on output + 4 Reshapes + Deq Mul
::testing::Values(4), // FQx2 on inputs + MHA + Deq Mul
Expand All @@ -439,6 +492,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQAfterMatMul_4D,
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::Values(false), // The graph doesn't contain Multiply
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(3), // MHA + Transpose on output + Deq Mul
::testing::Values(2), // MHA + Deq Mul
Expand All @@ -456,6 +510,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::Values(false), // The graph doesn't contain Multiply
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(7), // Transposex2 + Subgraphsx5
::testing::Values(5), // MHA + Deq Mul on output + Deqs on inputs + 2 xFQ on inputs
Expand All @@ -472,6 +527,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(1),
::testing::Values(1),
Expand All @@ -498,6 +554,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(2),
::testing::Values(1),
Expand All @@ -520,6 +577,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // False is not supported for graph builder in tests
::testing::Values(true),
::testing::Values(MHA::default_thread_count),
::testing::Values(3), // Extracted Add + Extracted Reshape + MHA
::testing::Values(2), // Extracted Add + MHA
Expand Down
2 changes: 2 additions & 0 deletions src/tests/functional/plugin/shared/include/snippets/mha.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ typedef std::tuple<std::vector<InputShape>, // Input shapes
std::vector<ov::element::Type>, // Input Element types
ov::element::Type, // Inference precision
bool, // With Multiply
bool, // True if second input of Mul is Const
size_t, // Thread count
size_t, // Expected num nodes
size_t, // Expected num subgraphs
Expand All @@ -38,6 +39,7 @@ class MHA : public testing::WithParamInterface<ov::test::snippets::MHAParams>,
virtual std::shared_ptr<SnippetsFunctionBase> get_subgraph();

bool m_with_mul = false;
bool m_is_mul_const = true;
size_t m_thread_count;
std::vector<ov::element::Type> m_input_types;
};
Expand Down
Loading

0 comments on commit fa36a30

Please sign in to comment.