Skip to content

Commit

Permalink
Updated MHA Custom tests
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 9, 2023
1 parent cba38ae commit 2a4ab82
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions src/plugins/intel_cpu/tests/functional/subgraph_tests/src/mha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ typedef std::tuple<
std::vector<ElementType>, // Input precisions
std::vector<ElementType>, // MatMul input #0 precisions
size_t, // pattern type #
std::string, // Expected node
std::string // Device name
> MHATuple;

Expand Down Expand Up @@ -155,8 +156,9 @@ class MHATest : public testing::WithParamInterface<MHATuple>,
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::string expectedNode;
std::string targetName;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetName) = obj.param;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetName) = obj.param;
std::ostringstream results;

results << "IS=(";
Expand All @@ -173,6 +175,7 @@ class MHATest : public testing::WithParamInterface<MHATuple>,
results << "InPRC" << std::to_string(i) << "=" << inputPrecisions[i] << "_";
}
results << "patternType=" << patternType;
results << "expect=" << expectedNode;
results << "targetDevice=" << targetName;

return results.str();
Expand All @@ -195,7 +198,8 @@ class MHATest : public testing::WithParamInterface<MHATuple>,
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
std::string expectedNode;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetDevice) = this->GetParam();

init_input_shapes(inputShapes);

Expand Down Expand Up @@ -223,7 +227,8 @@ TEST_P(MHATest, CompareWithRefs) {
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
std::string expectedNode;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetDevice) = this->GetParam();

if (inputPrecisions[0] == ElementType::bf16 && !InferenceEngine::with_cpu_x86_bfloat16())
GTEST_SKIP();
Expand All @@ -232,7 +237,7 @@ TEST_P(MHATest, CompareWithRefs) {
GTEST_SKIP();

run();
CheckNumberOfNodesWithType(compiledModel, "MHA", 1);
CheckNumberOfNodesWithType(compiledModel, expectedNode, 1);
}

namespace {
Expand All @@ -247,11 +252,6 @@ std::vector<std::vector<ngraph::Shape>> inputShapes = {
{{1, 204, 13, 212}, {1, 204, 13, 212}, {1, 1, 1, 204}, {1, 204, 13, 212}},
};

std::vector<std::vector<ElementType>> inputPrecisions = {
{ ElementType::f32, ElementType::f32, ElementType::f32, ElementType::f32 },
{ ElementType::bf16, ElementType::bf16, ElementType::bf16, ElementType::bf16 },
};

std::vector<std::vector<ElementType>> matMulIn0Precisions = {
{},
};
Expand All @@ -260,15 +260,26 @@ std::vector<size_t> patternTypes = {
0, 1
};

INSTANTIATE_TEST_SUITE_P(smoke_MHA, MHATest,
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHATest,
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)),
::testing::ValuesIn(inputPrecisions),
::testing::Values(std::vector<ElementType>{ ElementType::f32, ElementType::f32, ElementType::f32, ElementType::f32 }),
::testing::ValuesIn(matMulIn0Precisions),
::testing::ValuesIn(patternTypes),
::testing::Values("Subgraph"),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
MHATest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_MHA, MHATest,
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)),
::testing::Values(std::vector<ElementType>{ ElementType::bf16, ElementType::bf16, ElementType::bf16, ElementType::bf16 }),
::testing::ValuesIn(matMulIn0Precisions),
::testing::ValuesIn(patternTypes),
::testing::Values("MHA"), // Snippets don't support BF16 MHA pattern yet
::testing::Values(CommonTestUtils::DEVICE_CPU)),
MHATest::getTestCaseName);

} // namespace

static std::shared_ptr<ov::Model> initMHAQuantSubgraph0(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions,
Expand Down Expand Up @@ -425,7 +436,8 @@ class MHAQuantTest : public testing::WithParamInterface<MHATuple>,
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::string targetName;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetName) = obj.param;
std::string expectedNode;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetName) = obj.param;
std::ostringstream results;

results << "IS=(";
Expand All @@ -445,6 +457,7 @@ class MHAQuantTest : public testing::WithParamInterface<MHATuple>,
results << "MatMulIn0PRC" << std::to_string(i) << "=" << matMulIn0Precisions[i] << "_";
}
results << "patternType=" << patternType;
results << "expect=" << expectedNode;
results << "targetDevice=" << targetName;

return results.str();
Expand Down Expand Up @@ -474,7 +487,8 @@ class MHAQuantTest : public testing::WithParamInterface<MHATuple>,
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
std::string expectedNode;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetDevice) = this->GetParam();

init_input_shapes(inputShapes);

Expand All @@ -493,7 +507,8 @@ TEST_P(MHAQuantTest, CompareWithRefs) {
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
std::string expectedNode;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, expectedNode, targetDevice) = this->GetParam();

if (inputPrecisions[0] == ElementType::bf16 && !InferenceEngine::with_cpu_x86_bfloat16())
GTEST_SKIP();
Expand All @@ -502,7 +517,7 @@ TEST_P(MHAQuantTest, CompareWithRefs) {
GTEST_SKIP();

run();
CheckNumberOfNodesWithType(compiledModel, "MHA", 1);
CheckNumberOfNodesWithType(compiledModel, expectedNode, 1);
}

namespace {
Expand Down Expand Up @@ -538,6 +553,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest,
::testing::ValuesIn(inputPrecisionsQuant),
::testing::ValuesIn(matMulIn0PrecisionsQuant),
::testing::ValuesIn(patternTypesQuant),
::testing::Values("MHA"), // Snippets don't support Quantized MHA pattern yet
::testing::Values(CommonTestUtils::DEVICE_CPU)),
MHAQuantTest::getTestCaseName);

Expand Down

0 comments on commit 2a4ab82

Please sign in to comment.