diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp index aa0c4bbe28e835..be387f03c5876c 100644 --- a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp @@ -25,6 +25,7 @@ typedef std::tuple< std::vector, // Input precisions std::vector, // MatMul input #0 precisions size_t, // pattern type # + std::string, // Expected node std::string // Device name > MHATuple; @@ -155,8 +156,9 @@ class MHATest : public testing::WithParamInterface, std::vector inputPrecisions; std::vector matMulIn0Precisions; size_t patternType; + std::string expectedNode; std::string targetName; - std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetName) = obj.param; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetName) = obj.param; std::ostringstream results; results << "IS=("; @@ -173,6 +175,7 @@ class MHATest : public testing::WithParamInterface, results << "InPRC" << std::to_string(i) << "=" << inputPrecisions[i] << "_"; } results << "patternType=" << patternType; + results << "expect=" << expectedNode; results << "targetDevice=" << targetName; return results.str(); @@ -195,7 +198,8 @@ class MHATest : public testing::WithParamInterface, std::vector inputPrecisions; std::vector matMulIn0Precisions; size_t patternType; - std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam(); + std::string expectedNode; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetDevice) = this->GetParam(); init_input_shapes(inputShapes); @@ -223,7 +227,8 @@ TEST_P(MHATest, CompareWithRefs) { std::vector inputPrecisions; std::vector matMulIn0Precisions; size_t patternType; - std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam(); + std::string expectedNode; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetDevice) = this->GetParam(); if (inputPrecisions[0] == ElementType::bf16 && !InferenceEngine::with_cpu_x86_bfloat16()) GTEST_SKIP(); @@ -232,7 +237,7 @@ TEST_P(MHATest, CompareWithRefs) { GTEST_SKIP(); run(); - CheckNumberOfNodesWithType(compiledModel, "MHA", 1); + CheckNumberOfNodesWithType(compiledModel, expectedNode, 1); } namespace { @@ -247,11 +252,6 @@ std::vector> inputShapes = { {{1, 204, 13, 212}, {1, 204, 13, 212}, {1, 1, 1, 204}, {1, 204, 13, 212}}, }; -std::vector> inputPrecisions = { - { ElementType::f32, ElementType::f32, ElementType::f32, ElementType::f32 }, - { ElementType::bf16, ElementType::bf16, ElementType::bf16, ElementType::bf16 }, -}; - std::vector> matMulIn0Precisions = { {}, }; @@ -260,15 +260,26 @@ std::vector patternTypes = { 0, 1 }; -INSTANTIATE_TEST_SUITE_P(smoke_MHA, MHATest, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHATest, ::testing::Combine( ::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)), - ::testing::ValuesIn(inputPrecisions), + ::testing::Values(std::vector{ ElementType::f32, ElementType::f32, ElementType::f32, ElementType::f32 }), ::testing::ValuesIn(matMulIn0Precisions), ::testing::ValuesIn(patternTypes), + ::testing::Values("Subgraph"), ::testing::Values(CommonTestUtils::DEVICE_CPU)), MHATest::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_MHA, MHATest, + ::testing::Combine( + ::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)), + ::testing::Values(std::vector{ ElementType::bf16, ElementType::bf16, ElementType::bf16, ElementType::bf16 }), + ::testing::ValuesIn(matMulIn0Precisions), + ::testing::ValuesIn(patternTypes), + ::testing::Values("MHA"), // Snippets don't support BF16 MHA pattern yet + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MHATest::getTestCaseName); + } // namespace static std::shared_ptr initMHAQuantSubgraph0(std::vector& inputDynamicShapes, std::vector& inputPrecisions, @@ -425,7 +436,8 @@ class MHAQuantTest : public testing::WithParamInterface, std::vector matMulIn0Precisions; size_t patternType; std::string targetName; - std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetName) = obj.param; + std::string expectedNode; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetName) = obj.param; std::ostringstream results; results << "IS=("; @@ -445,6 +457,7 @@ class MHAQuantTest : public testing::WithParamInterface, results << "MatMulIn0PRC" << std::to_string(i) << "=" << matMulIn0Precisions[i] << "_"; } results << "patternType=" << patternType; + results << "expect=" << expectedNode; results << "targetDevice=" << targetName; return results.str(); @@ -474,7 +487,8 @@ class MHAQuantTest : public testing::WithParamInterface, std::vector inputPrecisions; std::vector matMulIn0Precisions; size_t patternType; - std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam(); + std::string expectedNode; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetDevice) = this->GetParam(); init_input_shapes(inputShapes); @@ -493,7 +507,8 @@ TEST_P(MHAQuantTest, CompareWithRefs) { std::vector inputPrecisions; std::vector matMulIn0Precisions; size_t patternType; - std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam(); + std::string expectedNode; + std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetDevice) = this->GetParam(); if (inputPrecisions[0] == ElementType::bf16 && !InferenceEngine::with_cpu_x86_bfloat16()) GTEST_SKIP(); @@ -502,7 +517,7 @@ TEST_P(MHAQuantTest, CompareWithRefs) { GTEST_SKIP(); run(); - CheckNumberOfNodesWithType(compiledModel, "MHA", 1); + CheckNumberOfNodesWithType(compiledModel, expectedNode, 1); } namespace { @@ -538,6 +553,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest, ::testing::ValuesIn(inputPrecisionsQuant), ::testing::ValuesIn(matMulIn0PrecisionsQuant), ::testing::ValuesIn(patternTypesQuant), + ::testing::Values("MHA"), // Snippets don't support Quantized MHA pattern yet ::testing::Values(CommonTestUtils::DEVICE_CPU)), MHAQuantTest::getTestCaseName);