Skip to content

Commit

Permalink
[LPT][TESTS] added test-case with Split and unsupported Concat
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed May 31, 2021
1 parent 2c672b6 commit 9ef76e4
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SplitTransformationTestValues {
ngraph::pass::low_precision::LayerTransformation::Params params;
Actual actual;
Expected expected;
bool addUnsupportedConcat;
};

inline std::ostream& operator<<(std::ostream& os,
Expand Down Expand Up @@ -74,7 +75,8 @@ class SplitTransformation : public LayerTransformation, public testing::WithPara
testValues.actual.precisionBeforeDequantization,
testValues.actual.dequantization,
testValues.splitedAxis,
testValues.numSplits);
testValues.numSplits,
testValues.addUnsupportedConcat);

SimpleLowPrecisionTransformer transformer;
transformer.add<ngraph::pass::low_precision::SplitTransformation, ngraph::opset1::Split>(testValues.params.setSupportAsymmetricQuantization(true));
Expand All @@ -88,7 +90,8 @@ class SplitTransformation : public LayerTransformation, public testing::WithPara
testValues.expected.precisionAfterOperation,
testValues.expected.dequantizationAfter,
testValues.splitedAxis,
testValues.numSplits);
testValues.numSplits,
testValues.addUnsupportedConcat);
}

static std::string getTestCaseName(testing::TestParamInfo<SplitTransformationParams> obj) {
Expand Down Expand Up @@ -429,6 +432,22 @@ const std::vector<SplitTransformationTestValues> testValues = {
}
}
},
// issue #56781: unsupported Concat after Split
{
ngraph::Shape({ 1, 4, 3, 3 }), std::int64_t{2}, size_t{3},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {3.f}}
},
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {3.f}},
ngraph::element::f32,
{}
},
true
},
// no dequantization
{
ngraph::Shape({ 1, 3, 4, 4 }), std::int64_t{2}, size_t{2},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class SplitFunction {
const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const int64_t splitedAxis,
const size_t numSplits);
const size_t numSplits,
const bool addUnsupportedConcat = false);

static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type originalFunctionPrecision,
Expand All @@ -42,7 +43,8 @@ class SplitFunction {
const ngraph::element::Type precisionAfterOperation,
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
const int64_t splitedAxis,
const size_t numSplit);
const size_t numSplit,
const bool addUnsupportedConcat = false);
};
} // namespace subgraph
} // namespace builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ std::shared_ptr<ngraph::Function> SplitFunction::getOriginal(
const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const int64_t splitedAxis,
const size_t numSplits) {
const size_t numSplits,
const bool addUnsupportedConcat) {
const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
precisionBeforeDequantization,
ngraph::Shape(inputShape));
Expand All @@ -35,8 +36,14 @@ std::shared_ptr<ngraph::Function> SplitFunction::getOriginal(
const std::shared_ptr<Node> split = std::make_shared<ngraph::opset1::Split>(dequantizationOp, constant, numSplits);

ngraph::ResultVector results;
for (size_t i = 0; i < numSplits; ++i) {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));

if (addUnsupportedConcat) {
const auto concat = std::make_shared<opset1::Concat>(split->outputs(), 2ul);
results.push_back(std::make_shared<opset1::Result>(concat));
} else {
for (size_t i = 0; i < numSplits; ++i) {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
}
}
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "SplitFunction");
}
Expand Down Expand Up @@ -77,7 +84,8 @@ std::shared_ptr<ngraph::Function> SplitFunction::getReference(
const ngraph::element::Type precisionAfterOperation,
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
const int64_t splitedAxis,
const size_t numSplit) {
const size_t numSplit,
const bool addUnsupportedConcat) {
const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
inputPrecision,
ngraph::Shape(inputShape));
Expand All @@ -89,17 +97,23 @@ std::shared_ptr<ngraph::Function> SplitFunction::getReference(
split = std::make_shared<ngraph::opset1::Split>(deqBefore, constant, numSplit);

ngraph::ResultVector results;
for (size_t i = 0; i < numSplit; ++i) {
if (!dequantizationAfter.empty()) {
auto dequantizationStructure = dequantizationAfter[i];
if (!dequantizationStructure.multiply.empty()) {
dequantizationStructure.multiply.outPrecision = precision;
if (addUnsupportedConcat) {
const auto concat = std::make_shared<opset1::Concat>(split->outputs(), 2ul);
results.push_back(std::make_shared<opset1::Result>(concat));
} else {
for (size_t i = 0; i < numSplit; ++i) {
if (!dequantizationAfter.empty()) {
auto dequantizationStructure = dequantizationAfter[i];
if (!dequantizationStructure.multiply.empty()) {
dequantizationStructure.multiply.outPrecision = precision;
}
results.push_back(std::make_shared<ngraph::opset1::Result>(makeDequantization(split->output(i), dequantizationAfter[i])));
} else {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
}
results.push_back(std::make_shared<ngraph::opset1::Result>(makeDequantization(split->output(i), dequantizationAfter[i])));
} else {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
}
}

return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "SplitTransformation");
}

Expand Down

0 comments on commit 9ef76e4

Please sign in to comment.