Skip to content

Commit

Permalink
Tests improvement (openvinotoolkit#5704)
Browse files Browse the repository at this point in the history
* [LPT] Test: concat with convolution neighbor and convolution after

* [LPT] elementwise fuse to FakeQuantize

* [LPT] plugin test build fix

Co-authored-by: Edward Shogulin <[email protected]>
  • Loading branch information
2 people authored and yekruglov committed Jun 7, 2021
1 parent fb48112 commit 9f5267b
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <low_precision/transformer.hpp>
#include <low_precision/concat.hpp>
#include <low_precision/concat_multi_channels.hpp>
#include <low_precision/convolution.hpp>

#include "common_test_utils/ngraph_test_utils.hpp"
#include "lpt_ngraph_functions/concat_function.hpp"
Expand Down Expand Up @@ -65,6 +66,8 @@ class ConcatTransformationTestValues {
bool multiChannels;
ConcatTransformationActualValues actual;
ConcatTransformationResultValues result;
std::string neighborType;
std::string additionalLayer;
};

inline std::ostream& operator<<(std::ostream& out, const ConcatTransformationTestValues& values) {
Expand All @@ -89,15 +92,25 @@ class ConcatWithNeighborsTransformation : public LayerTransformation, public tes
shape,
testValues.actual.fakeQuantize1,
testValues.actual.fakeQuantize2,
testValues.actual.fakeQuantize3);
testValues.actual.fakeQuantize3,
testValues.neighborType,
testValues.additionalLayer);

SimpleLowPrecisionTransformer transform;
SimpleLowPrecisionTransformer transformBranchSpecific;
if (testValues.multiChannels) {
transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
transformBranchSpecific.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
} else {
transform.add<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params);
transformBranchSpecific.add<ngraph::pass::low_precision::ConcatTransformation, ngraph::opset1::Concat>(testValues.params);
}
if (testValues.additionalLayer == "convolution" || testValues.neighborType == "convolution") {
transformBranchSpecific.add<ngraph::pass::low_precision::ConvolutionTransformation, ngraph::opset1::Convolution>(testValues.params);
}
transformBranchSpecific.transform(actualFunction);
if (testValues.additionalLayer == "convolution" || testValues.neighborType == "convolution") {
SimpleLowPrecisionTransformer transformConvolution;
transformConvolution.add<ngraph::pass::low_precision::ConvolutionTransformation, ngraph::opset1::Convolution>(testValues.params);
transformConvolution.transform(actualFunction);
}
transform.transform(actualFunction);

referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithNeighbors(
precision,
Expand All @@ -109,7 +122,9 @@ class ConcatWithNeighborsTransformation : public LayerTransformation, public tes
testValues.result.dequantizationBefore,
testValues.result.precisionAfterOp,
testValues.result.dequantizationAfter1,
testValues.result.dequantizationAfter2);
testValues.result.dequantizationAfter2,
testValues.neighborType,
testValues.additionalLayer);
}

static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
Expand Down Expand Up @@ -157,7 +172,9 @@ const std::vector<ConcatTransformationTestValues> testValues = {
ngraph::element::u8,
{ ngraph::element::f32, {}, { 0.01f } },
{ ngraph::element::f32, {}, { 0.01f } }
}
},
"concat",
""
},
// U8: concat multi channels
{
Expand All @@ -177,7 +194,9 @@ const std::vector<ConcatTransformationTestValues> testValues = {
ngraph::element::u8,
{ ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
{ ngraph::element::f32, {}, {{ 0.005f, 0.005f, 0.005f, 0.00333f, 0.00333f, 0.00333f }} }
}
},
"concat",
""
},
// U8: concat multi channels with subtract
{
Expand All @@ -197,7 +216,9 @@ const std::vector<ConcatTransformationTestValues> testValues = {
ngraph::element::u8,
{ ngraph::element::f32, {{ 0.f, 0.f, 0.f, -255.f, -255.f, -255.f }}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
{ ngraph::element::f32, { -255.f }, { 0.005f } }
}
},
"concat",
""
},
// I8: concat
{
Expand All @@ -217,7 +238,9 @@ const std::vector<ConcatTransformationTestValues> testValues = {
ngraph::element::i8,
{ ngraph::element::f32, {}, { 0.01f } },
{ ngraph::element::f32, {}, { 0.01f } }
}
},
"concat",
""
},
// I8: concat multi channels
{
Expand All @@ -237,7 +260,9 @@ const std::vector<ConcatTransformationTestValues> testValues = {
ngraph::element::i8,
{ ngraph::element::f32, {}, {{ 0.01f, 0.01f, 0.01f, 0.005f, 0.005f, 0.005f }} },
{ ngraph::element::f32, {}, {{ 0.005f, 0.005f, 0.005f, 0.00333f, 0.00333f, 0.00333f }} }
}
},
"concat",
""
},
// mixed: U8 + I8: concat multi channels
{
Expand All @@ -257,7 +282,9 @@ const std::vector<ConcatTransformationTestValues> testValues = {
ngraph::element::u8,
{ ngraph::element::f32, {{ 0.f, 0.f, 0.f, 128.f, 128.f, 128.f }}, { 0.01f } },
{ ngraph::element::f32, { 128.f }, { 0.01f } }
}
},
"concat",
""
},
// not update precisions
{
Expand All @@ -277,7 +304,38 @@ const std::vector<ConcatTransformationTestValues> testValues = {
ngraph::element::f32,
{ {}, {{ 0.f, 0.f, 0.f, 128.f, 128.f, 128.f }}, { 0.01f } },
{ {}, { 128.f }, { 0.01f } }
}
},
"concat",
""
},
// convolution neighbor and additional layer
// different precisions on FQ, u8 have to be chosen
{
LayerTransformation::createParamsU8I8(),
true,
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {-12.8f}, {12.7f} },
{}
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {128.f}, {154.f} },
{ 256ul, ngraph::Shape({}), {-1.28f}, {1.27f}, {0.f}, {255.f} },
{},
ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{
{},
{{ 128.f, 128.f, 128.f, 128.f, 128.f, 128.f }, ngraph::element::f32, { 1, 6, 1, 1 }, false},
{{0.1f}, ngraph::element::f32, { 1, 1, 1, 1 } } },
{
{},
{{128.f, 128.f, 128.f}, ngraph::element::f32, { 1, 3, 1, 1 }, false},
{{0.1f}, ngraph::element::f32, { 1, 1, 1, 1 } } }
},
"convolution",
"convolution"
},
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,20 @@ class FuseMultiplyToFakeQuantizeTransformationTestValues {
Expected expected;
};

typedef std::tuple<size_t, FuseMultiplyToFakeQuantizeTransformationTestValues> FuseMultiplyToFakeQuantizeTransformationTestParams;

class FuseMultiplyToFakeQuantizeTransformation : public LayerTransformation,
public testing::WithParamInterface<FuseMultiplyToFakeQuantizeTransformationTestValues> {
public testing::WithParamInterface<FuseMultiplyToFakeQuantizeTransformationTestParams> {
public:
void SetUp() override {
const FuseMultiplyToFakeQuantizeTransformationTestValues testValues = GetParam();
const size_t quantizationLevel = std::get<0>(GetParam());
FuseMultiplyToFakeQuantizeTransformationTestValues testValues = std::get<1>(GetParam());
if (!testValues.actual.fakeQuantizeOnData.empty()) {
testValues.actual.fakeQuantizeOnData.quantizationLevel = quantizationLevel;
}
if (!testValues.expected.fakeQuantizeOnData.empty()) {
testValues.expected.fakeQuantizeOnData.quantizationLevel = quantizationLevel;
}

actualFunction = ngraph::builder::subgraph::FuseMultiplyToFakeQuantizeFunction::get(
testValues.inputShape,
Expand All @@ -65,8 +74,15 @@ class FuseMultiplyToFakeQuantizeTransformation : public LayerTransformation,
testValues.expected.dequantization);
}

static std::string getTestCaseName(testing::TestParamInfo<FuseMultiplyToFakeQuantizeTransformationTestValues> obj) {
const FuseMultiplyToFakeQuantizeTransformationTestValues testValues = obj.param;
static std::string getTestCaseName(testing::TestParamInfo<FuseMultiplyToFakeQuantizeTransformationTestParams> obj) {
const size_t quantizationLevel = std::get<0>(obj.param);
FuseMultiplyToFakeQuantizeTransformationTestValues testValues = std::get<1>(obj.param);
if (!testValues.actual.fakeQuantizeOnData.empty()) {
testValues.actual.fakeQuantizeOnData.quantizationLevel = quantizationLevel;
}
if (!testValues.expected.fakeQuantizeOnData.empty()) {
testValues.expected.fakeQuantizeOnData.quantizationLevel = quantizationLevel;
}

std::ostringstream result;
result << testValues.params.updatePrecisions << "_" <<
Expand All @@ -83,6 +99,8 @@ TEST_P(FuseMultiplyToFakeQuantizeTransformation, CompareFunctions) {
ASSERT_TRUE(res.first) << res.second;
}

std::vector<size_t> quantizationLevels = { 256ul, 128ul };

const std::vector<FuseMultiplyToFakeQuantizeTransformationTestValues> testValues = {
{
Shape{1, 3, 16, 16},
Expand Down Expand Up @@ -125,7 +143,9 @@ const std::vector<FuseMultiplyToFakeQuantizeTransformationTestValues> testValues
INSTANTIATE_TEST_CASE_P(
smoke_LPT,
FuseMultiplyToFakeQuantizeTransformation,
::testing::ValuesIn(testValues),
::testing::Combine(
::testing::ValuesIn(quantizationLevels),
::testing::ValuesIn(testValues)),
FuseMultiplyToFakeQuantizeTransformation::getTestCaseName);

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,26 @@ class FuseSubtractToFakeQuantizeTransformationTestValues {
Expected expected;
};

typedef std::tuple<size_t, FuseSubtractToFakeQuantizeTransformationTestValues> FuseSubtractToFakeQuantizeTransformationTestParams;

class FuseSubtractToFakeQuantizeTransformation : public LayerTransformation,
public testing::WithParamInterface<FuseSubtractToFakeQuantizeTransformationTestValues> {
public testing::WithParamInterface<FuseSubtractToFakeQuantizeTransformationTestParams> {
public:
void SetUp() override {
const FuseSubtractToFakeQuantizeTransformationTestValues testValues = GetParam();
const size_t quantizationLevel = get<0>(GetParam());
FuseSubtractToFakeQuantizeTransformationTestValues testValues = get<1>(GetParam());
if (!testValues.actual.fakeQuantizeOnData.empty()) {
testValues.actual.fakeQuantizeOnData.quantizationLevel = quantizationLevel;
}
if (!testValues.actual.fakeQuantizeOnData2.empty()) {
testValues.actual.fakeQuantizeOnData2.quantizationLevel = quantizationLevel;
}
if (!testValues.expected.fakeQuantizeOnData.empty()) {
testValues.expected.fakeQuantizeOnData.quantizationLevel = quantizationLevel;
}
if (!testValues.expected.fakeQuantizeOnData2.empty()) {
testValues.expected.fakeQuantizeOnData2.quantizationLevel = quantizationLevel;
}

actualFunction = testValues.actual.fakeQuantizeOnData2.empty() ?
ngraph::builder::subgraph::FuseSubtractToFakeQuantizeFunction::get(
Expand Down Expand Up @@ -84,8 +99,21 @@ class FuseSubtractToFakeQuantizeTransformation : public LayerTransformation,
testValues.expected.dequantization2);
}

static std::string getTestCaseName(testing::TestParamInfo<FuseSubtractToFakeQuantizeTransformationTestValues> obj) {
const FuseSubtractToFakeQuantizeTransformationTestValues testValues = obj.param;
static std::string getTestCaseName(testing::TestParamInfo<FuseSubtractToFakeQuantizeTransformationTestParams> obj) {
const size_t quantizationLevel = get<0>(obj.param);
FuseSubtractToFakeQuantizeTransformationTestValues testValues = get<1>(obj.param);
if (!testValues.actual.fakeQuantizeOnData.empty()) {
testValues.actual.fakeQuantizeOnData.quantizationLevel = quantizationLevel;
}
if (!testValues.actual.fakeQuantizeOnData2.empty()) {
testValues.actual.fakeQuantizeOnData2.quantizationLevel = quantizationLevel;
}
if (!testValues.expected.fakeQuantizeOnData.empty()) {
testValues.expected.fakeQuantizeOnData.quantizationLevel = quantizationLevel;
}
if (!testValues.expected.fakeQuantizeOnData2.empty()) {
testValues.expected.fakeQuantizeOnData2.quantizationLevel = quantizationLevel;
}

std::ostringstream result;
result << testValues.params.updatePrecisions << "_" <<
Expand All @@ -104,6 +132,8 @@ TEST_P(FuseSubtractToFakeQuantizeTransformation, CompareFunctions) {
ASSERT_TRUE(res.first) << res.second;
}

std::vector<size_t> quantizationLevels = { 256ul, 128ul };

const std::vector<FuseSubtractToFakeQuantizeTransformationTestValues> testValues = {
{
Shape{1, 3, 16, 16},
Expand Down Expand Up @@ -190,7 +220,9 @@ const std::vector<FuseSubtractToFakeQuantizeTransformationTestValues> testValues
INSTANTIATE_TEST_CASE_P(
smoke_LPT,
FuseSubtractToFakeQuantizeTransformation,
::testing::ValuesIn(testValues),
::testing::Combine(
::testing::ValuesIn(quantizationLevels),
::testing::ValuesIn(testValues)),
FuseSubtractToFakeQuantizeTransformation::getTestCaseName);

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ void ConcatWithNeighborsGraphTransformation::SetUp() {
inputShape,
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f / 2.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f / 3.f} });
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f / 3.f} },
"concat",
"");

validate();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ class DequantizationOperations {
bool operator==(const Subtract& value) const noexcept {
return equal(value);
}

void erase() {
isEmpty = true;
}
Subtract& setConstantPrecision(const ngraph::element::Type& precision);

std::vector<float> values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class ConcatFunction {
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2,
const FakeQuantizeOnData& fqOnData3);
const FakeQuantizeOnData& fqOnData3,
const std::string& neighborType,
const std::string& additionalLayer);

static std::shared_ptr<ngraph::Function> getOriginalWithIntermediate(
const ngraph::element::Type precision,
Expand Down Expand Up @@ -128,7 +130,9 @@ class ConcatFunction {
const DequantizationOperations& dequantizationBefore,
const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationOperations1,
const DequantizationOperations& dequantizationOperations2);
const DequantizationOperations& dequantizationOperations2,
const std::string& neighborType,
const std::string& additionalLayer);

static std::shared_ptr<ngraph::Function> getReferenceWithIntermediate(
const ngraph::element::Type precision,
Expand Down
Loading

0 comments on commit 9f5267b

Please sign in to comment.