Skip to content

Commit

Permalink
[LPT] SplitTransformation WA for unsupported concat (openvinotoolkit#…
Browse files Browse the repository at this point in the history
…5833)

* [LPT] SplitTransformation WA for unsupported concat

* [LPT][TESTS] added test-case with Split and unsupported Concat
  • Loading branch information
v-Golubev authored Jun 1, 2021
1 parent b2abf25 commit c980a0b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 17 deletions.
15 changes: 14 additions & 1 deletion inference-engine/src/low_precision_transformations/src/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,20 @@ bool SplitTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) cons
}

bool SplitTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
return (!NetworkHelper::getDequantization(layer).empty()) && LayerTransformation::canBeTransformed(context, layer);
if (!LayerTransformation::canBeTransformed(context, layer) || NetworkHelper::getDequantization(layer).empty()) {
return false;
}

const auto consumers = NetworkHelper::consumers(layer);
const auto concat = as_type_ptr<opset1::Concat>(consumers[0]);

// WA to avoid propagation of dequantization if after Split all consumers are the same unsupported Concat
if (concat && concat->get_axis() != 1ul) {
const size_t id = consumers[0]->get_instance_id();
return std::any_of(consumers.begin(), consumers.end(), [&](const std::shared_ptr<Node>& node) { return node->get_instance_id() != id; });
}

return true;
}

} // namespace low_precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SplitTransformationTestValues {
ngraph::pass::low_precision::LayerTransformation::Params params;
Actual actual;
Expected expected;
bool addUnsupportedConcat;
};

inline std::ostream& operator<<(std::ostream& os,
Expand Down Expand Up @@ -74,7 +75,8 @@ class SplitTransformation : public LayerTransformation, public testing::WithPara
testValues.actual.precisionBeforeDequantization,
testValues.actual.dequantization,
testValues.splitedAxis,
testValues.numSplits);
testValues.numSplits,
testValues.addUnsupportedConcat);

SimpleLowPrecisionTransformer transformer;
transformer.add<ngraph::pass::low_precision::SplitTransformation, ngraph::opset1::Split>(testValues.params.setSupportAsymmetricQuantization(true));
Expand All @@ -88,7 +90,8 @@ class SplitTransformation : public LayerTransformation, public testing::WithPara
testValues.expected.precisionAfterOperation,
testValues.expected.dequantizationAfter,
testValues.splitedAxis,
testValues.numSplits);
testValues.numSplits,
testValues.addUnsupportedConcat);
}

static std::string getTestCaseName(testing::TestParamInfo<SplitTransformationParams> obj) {
Expand Down Expand Up @@ -429,6 +432,22 @@ const std::vector<SplitTransformationTestValues> testValues = {
}
}
},
// issue #56781: unsupported Concat after Split
{
ngraph::Shape({ 1, 4, 3, 3 }), std::int64_t{2}, size_t{3},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {3.f}}
},
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {3.f}},
ngraph::element::f32,
{}
},
true
},
// no dequantization
{
ngraph::Shape({ 1, 3, 4, 4 }), std::int64_t{2}, size_t{2},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class SplitFunction {
const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const int64_t splitedAxis,
const size_t numSplits);
const size_t numSplits,
const bool addUnsupportedConcat = false);

static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type originalFunctionPrecision,
Expand All @@ -42,7 +43,8 @@ class SplitFunction {
const ngraph::element::Type precisionAfterOperation,
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
const int64_t splitedAxis,
const size_t numSplit);
const size_t numSplit,
const bool addUnsupportedConcat = false);
};
} // namespace subgraph
} // namespace builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ std::shared_ptr<ngraph::Function> SplitFunction::getOriginal(
const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const int64_t splitedAxis,
const size_t numSplits) {
const size_t numSplits,
const bool addUnsupportedConcat) {
const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
precisionBeforeDequantization,
ngraph::Shape(inputShape));
Expand All @@ -35,8 +36,14 @@ std::shared_ptr<ngraph::Function> SplitFunction::getOriginal(
const std::shared_ptr<Node> split = std::make_shared<ngraph::opset1::Split>(dequantizationOp, constant, numSplits);

ngraph::ResultVector results;
for (size_t i = 0; i < numSplits; ++i) {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));

if (addUnsupportedConcat) {
const auto concat = std::make_shared<opset1::Concat>(split->outputs(), 2ul);
results.push_back(std::make_shared<opset1::Result>(concat));
} else {
for (size_t i = 0; i < numSplits; ++i) {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
}
}
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "SplitFunction");
}
Expand Down Expand Up @@ -77,7 +84,8 @@ std::shared_ptr<ngraph::Function> SplitFunction::getReference(
const ngraph::element::Type precisionAfterOperation,
const std::vector<ngraph::builder::subgraph::DequantizationOperations>& dequantizationAfter,
const int64_t splitedAxis,
const size_t numSplit) {
const size_t numSplit,
const bool addUnsupportedConcat) {
const std::shared_ptr<op::v0::Parameter> input = std::make_shared<ngraph::opset1::Parameter>(
inputPrecision,
ngraph::Shape(inputShape));
Expand All @@ -89,17 +97,23 @@ std::shared_ptr<ngraph::Function> SplitFunction::getReference(
split = std::make_shared<ngraph::opset1::Split>(deqBefore, constant, numSplit);

ngraph::ResultVector results;
for (size_t i = 0; i < numSplit; ++i) {
if (!dequantizationAfter.empty()) {
auto dequantizationStructure = dequantizationAfter[i];
if (!dequantizationStructure.multiply.empty()) {
dequantizationStructure.multiply.outPrecision = precision;
if (addUnsupportedConcat) {
const auto concat = std::make_shared<opset1::Concat>(split->outputs(), 2ul);
results.push_back(std::make_shared<opset1::Result>(concat));
} else {
for (size_t i = 0; i < numSplit; ++i) {
if (!dequantizationAfter.empty()) {
auto dequantizationStructure = dequantizationAfter[i];
if (!dequantizationStructure.multiply.empty()) {
dequantizationStructure.multiply.outPrecision = precision;
}
results.push_back(std::make_shared<ngraph::opset1::Result>(makeDequantization(split->output(i), dequantizationAfter[i])));
} else {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
}
results.push_back(std::make_shared<ngraph::opset1::Result>(makeDequantization(split->output(i), dequantizationAfter[i])));
} else {
results.push_back(std::make_shared<ngraph::opset1::Result>(split->output(i)));
}
}

return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "SplitTransformation");
}

Expand Down

0 comments on commit c980a0b

Please sign in to comment.