Skip to content

Commit

Permalink
AddConverToReorder test was changed in accordance with review.
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Dec 28, 2020
1 parent 35600ee commit a23a2cc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,13 @@ class AddConvertToReorderTest : virtual public LayerTestsUtils::LayerTestsCommon
}
std::vector<std::vector<std::uint8_t>> CalculateRefs() override {
// Convert the second input constant precision to i64 to run the reference function
switch (secondConstantType) {
case ngraph::element::Type_t::i8:
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::i8, ngraph::element::Type_t::i64>().run_on_function(function);
break;
case ngraph::element::Type_t::bf16:
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::bf16, ngraph::element::Type_t::i64>().run_on_function(function);
break;
default:
// pass
break;
if (ngraph::element::Type_t::i8 == secondConstantType) {
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::i8, ngraph::element::Type_t::i64>().run_on_function(function);
} else if (ngraph::element::Type_t::bf16 == secondConstantType) {
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::bf16, ngraph::element::Type_t::i64>().run_on_function(function);
}

return LayerTestsUtils::LayerTestsCommon::CalculateRefs();
}
void CheckElementOfTypeCount(std::string typeName, size_t expectedCount) {
InferenceEngine::CNNNetwork execGraphInfo = executableNetwork.GetExecGraphInfo();
auto function = execGraphInfo.getFunction();
ASSERT_NE(nullptr, function);
size_t actualPermuteCount = 0;
for (const auto &node : function->get_ops()) {
const auto & rtInfo = node->get_rt_info();
auto getExecValue = [&rtInfo](const std::string & paramName) -> std::string {
auto it = rtInfo.find(paramName);
IE_ASSERT(rtInfo.end() != it);
auto value = std::dynamic_pointer_cast<ngraph::VariantImpl<std::string>>(it->second);
IE_ASSERT(nullptr != value);
return value->get();
};
if (getExecValue(ExecGraphInfoSerialization::LAYER_TYPE) == typeName) {
actualPermuteCount++;
}
}

ASSERT_EQ(expectedCount, actualPermuteCount) << "Unexpected count of the element type '" << typeName << "' ";
}

private:
ngraph::element::Type secondConstantType;
Expand All @@ -84,7 +56,7 @@ namespace {
Parameter[FP32] Constant[BF16]
\ /
\ /
\ Convert[I32] (Is inserted by the MKLDNNGraphOptimizer)
\ Convert[I32] (Is inserted by the MKLDNNGraph)
\ /
Gather[FP32]
|
Expand All @@ -95,16 +67,16 @@ namespace {
TEST_F(AddConvertToReorderTest, smoke_TestAddConvert_CPU) {
BuildGraph(ngraph::element::bf16);
Run();
CheckElementOfTypeCount("Convert", 1);
CheckElementOfTypeCount("Reorder", 0);
CheckNodeOfTypeCount(executableNetwork, "Convert", 1);
CheckNodeOfTypeCount(executableNetwork, "Reorder", 0);
}

/* Test insertion of the Reorder layer if there is one.
Parameter[FP32] Constant[I8]
\ /
\ /
\ Reorder[I32] (Is inserted by the MKLDNNGraphOptimizer)
\ Reorder[I32] (Is inserted by the MKLDNNGraph)
\ /
Gather[FP32]
|
Expand All @@ -114,8 +86,8 @@ TEST_F(AddConvertToReorderTest, smoke_TestAddConvert_CPU) {
TEST_F(AddConvertToReorderTest, smoke_TestAddReorder_CPU) {
BuildGraph(ngraph::element::i8);
Run();
CheckElementOfTypeCount("Convert", 0);
CheckElementOfTypeCount("Reorder", 1);
CheckNodeOfTypeCount(executableNetwork, "Convert", 0);
CheckNodeOfTypeCount(executableNetwork, "Reorder", 1);
}
} // namespace
} // namespace LayerTestsDefinitions
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,25 @@ std::vector<CPUSpecificParams> filterCPUSpecificParams(std::vector<CPUSpecificPa
return paramsVector;
}

void CheckNodeOfTypeCount(InferenceEngine::ExecutableNetwork &execNet, std::string nodeType, size_t expectedCount) {
InferenceEngine::CNNNetwork execGraphInfo = execNet.GetExecGraphInfo();
auto function = execGraphInfo.getFunction();
ASSERT_NE(nullptr, function);
size_t actualNodeCount = 0;
for (const auto &node : function->get_ops()) {
const auto & rtInfo = node->get_rt_info();
auto getExecValue = [&rtInfo](const std::string & paramName) -> std::string {
auto it = rtInfo.find(paramName);
IE_ASSERT(rtInfo.end() != it);
auto value = std::dynamic_pointer_cast<ngraph::VariantImpl<std::string>>(it->second);
IE_ASSERT(nullptr != value);
return value->get();
};
if (getExecValue(ExecGraphInfoSerialization::LAYER_TYPE) == nodeType) {
actualNodeCount++;
}
}

ASSERT_EQ(expectedCount, actualNodeCount) << "Unexpected count of the node type '" << nodeType << "' ";
}
} // namespace CPUTestUtils
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,6 @@ const auto conv_avx512_2D_1x1 = CPUSpecificParams{{nChw16c}, {nChw16c}, {"jit_av

// utility functions
std::vector<CPUSpecificParams> filterCPUSpecificParams(std::vector<CPUSpecificParams>& paramsVector);
void CheckNodeOfTypeCount(InferenceEngine::ExecutableNetwork &execNet, std::string nodeType, size_t expectedCount);

} // namespace CPUTestUtils

0 comments on commit a23a2cc

Please sign in to comment.