Skip to content

Commit

Permalink
[LPT] ConcatTransformation fix
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Aug 20, 2021
1 parent 468964b commit 14e65a5
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions inference-engine/src/low_precision_transformations/src/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,25 +183,53 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
return false;
}

const auto axis = concat->get_axis();
const auto outPShape = concat->get_output_partial_shape(0);
const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, outPShape.rank());
const auto& axis = concat->get_axis();
const auto& outPShape = concat->get_output_partial_shape(0);
const auto& outRank = outPShape.rank();
if (outRank.is_dynamic()) {
return false;
}

const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, outRank);

if (normalizedAxis != 1ul) {
return false;
}

if (outPShape.rank().is_dynamic() || outPShape[normalizedAxis].is_dynamic()) {
if (outPShape[normalizedAxis].is_dynamic()) {
return false;
}

auto checkConstShape = [&normalizedAxis, &outRank](const std::shared_ptr<opset1::Constant>& constant) {
const size_t rankValue = outRank.get_length();
Shape constantShape = constant->get_shape();

while (constantShape.size() < rankValue) {
constantShape.insert(constantShape.begin(), 1ul);
}

for (size_t i = 0; i < constantShape.size(); ++i) {
bool dqNotByConcatAxis = (constantShape[i] != 1ul) && (i != normalizedAxis);
if (dqNotByConcatAxis) {
return false;
}
}

return true;
};

element::Type precision;
for (size_t i = 0ul; i < concat->get_input_size(); i++) {
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(concat, i);
if (dequantization.empty() || (updatePrecisions && !dequantization.isLowPrecision())) {
return false;
}

if ((dequantization.subtract != nullptr) && (!checkConstShape(dequantization.subtractConstant)) ||
(dequantization.multiply != nullptr) && (!checkConstShape(dequantization.multiplyConstant))) {
return false;
}

if (precision == element::undefined) {
precision = dequantization.data.get_element_type();
} else if (precision != dequantization.data.get_element_type()) {
Expand Down

0 comments on commit 14e65a5

Please sign in to comment.