diff --git a/inference-engine/src/low_precision_transformations/src/mat_mul.cpp b/inference-engine/src/low_precision_transformations/src/mat_mul.cpp index 212a8e8d11ab14..916d026c832ca8 100644 --- a/inference-engine/src/low_precision_transformations/src/mat_mul.cpp +++ b/inference-engine/src/low_precision_transformations/src/mat_mul.cpp @@ -17,14 +17,16 @@ using namespace ngraph::pass; using namespace ngraph::pass::low_precision; bool MatMulTransformation::transform(TransformationContext &context, ngraph::pattern::Matcher &m) const { - std::shared_ptr matMul = as_type_ptr(m.get_match_root()); + std::shared_ptr matMul = as_type_ptr(m.get_match_root()); if ((matMul == nullptr) || !canBeTransformed(context, matMul)) { return false; } - matMul = as_type_ptr(NetworkHelper::separateInStandaloneBranch(matMul)); + matMul = as_type_ptr(NetworkHelper::separateInStandaloneBranch(matMul)); + + const auto dequantization1 = NetworkHelper::getDequantization(matMul, 0); + auto dequantization2 = NetworkHelper::getDequantization(matMul, 1); - FakeQuantizeDequantization dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1); if (dequantization2.empty()) { const std::shared_ptr fakeQuantize = as_type_ptr(dequantization2.data.get_node_shared_ptr()); @@ -40,21 +42,19 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat dataPrecision.hasZeroPoint, updatePrecisions); - dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1); + dequantization2 = NetworkHelper::getDequantization(matMul, 1); } } - const FakeQuantizeDequantization dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 0); - if (dequantization2.subtract != nullptr) { NetworkHelper::optimizeSubtract(dequantization2.subtract); - dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(matMul, 1); + dequantization2 = NetworkHelper::getDequantization(matMul, 1); } - const std::shared_ptr newMatMul = std::make_shared>( + const std::shared_ptr newMatMul = std::make_shared>( std::vector({ element::f32, element::f32 }), std::vector({}), - ngraph::op::TemporaryReplaceOutputType(dequantization1.data, element::f32).get(), - ngraph::op::TemporaryReplaceOutputType(dequantization2.data, element::f32).get(), + op::TemporaryReplaceOutputType(dequantization1.data, element::f32).get(), + op::TemporaryReplaceOutputType(dequantization2.data, element::f32).get(), matMul->get_transpose_a(), matMul->get_transpose_b()); NetworkHelper::setOutDataPrecisionForTypeRelaxed(newMatMul, matMul->get_output_element_type(0)); @@ -64,15 +64,15 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat // dequantization with subtract on activations & constant weights if (dequantization1.subtract) { - auto broadcastShape = NetworkHelper::isScalarLike(as_type_ptr(dequantization1.subtract->get_input_node_shared_ptr(1))) ? - ngraph::Shape(dequantization1.subtract->get_shape().size(), 1) : - dequantization1.subtract->get_input_node_shared_ptr(1)->get_shape(); + auto broadcastShape = NetworkHelper::isScalarLike(as_type_ptr(dequantization1.subtractConstant)) ? + Shape(dequantization1.subtract->get_shape().size(), 1) : + dequantization1.subtractConstant->get_shape(); const size_t lastIdx = matMul->get_transpose_a() ? broadcastShape.size() - 2 : broadcastShape.size() - 1; broadcastShape[lastIdx] = dequantization1.subtract->get_shape()[lastIdx]; // broadcasted sub const to form [1, ..., 1, Y] const auto broadcastedConst = fold( - dequantization1.subtract->get_input_node_shared_ptr(1), + dequantization1.subtractConstant, opset1::Constant::create(ngraph::element::i32, { broadcastShape.size() }, broadcastShape)); // multiply by weights: [1, ..., 1, Y] x [Y, Z] => [1, ..., 1, Z] @@ -84,7 +84,7 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat const auto newSubtract = std::make_shared(newMatMul, newSubConst); newSubtract->set_friendly_name(newMatMul->get_friendly_name() + "/DequantizationSubtract"); - ngraph::copy_runtime_info({ newSubtract, matMul }, newSubtract); + copy_runtime_info({ newSubtract, matMul }, newSubtract); parent = newSubtract; } @@ -100,17 +100,12 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat std::swap(*(transposeConstant.end() - 1), *(transposeConstant.end() - 2)); auto order = opset1::Constant::create(element::u32, Shape{ transposeConstant.size() }, transposeConstant); - std::shared_ptr transposedConstant = fold(node, order); + std::shared_ptr transposedConstant = fold(node, order); return transposedConstant; }; - const auto mulConst1 = matMul->get_transpose_a() ? - transpose(dequantization1.multiply->get_input_node_shared_ptr(1)) : - dequantization1.multiply->get_input_node_shared_ptr(1); - - auto mulConst2 = matMul->get_transpose_b() ? - transpose(dequantization2.multiply->get_input_node_shared_ptr(1)) : - dequantization2.multiply->get_input_node_shared_ptr(1); + const auto mulConst1 = matMul->get_transpose_a() ? transpose(dequantization1.multiplyConstant) : dequantization1.multiplyConstant; + auto mulConst2 = matMul->get_transpose_b() ? transpose(dequantization2.multiplyConstant) : dequantization2.multiplyConstant; if (NetworkHelper::isScalarLike(as_type_ptr(mulConst2))) { mulConst2 = NetworkHelper::toScalar(as_type_ptr(mulConst2)); @@ -125,16 +120,16 @@ bool MatMulTransformation::transform(TransformationContext &context, ngraph::pat mulConst2 = fold( mulConst2, - op::Constant::create(ngraph::element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape)); + op::Constant::create(element::i32, Shape{ unsqueezeConstantShape.size() }, unsqueezeConstantShape)); } } - const auto newMulConst = NetworkHelper::toScalarIfPossible(fold(mulConst1, mulConst2)); + const auto newMulConst = NetworkHelper::toScalarIfPossible(fold(mulConst1, mulConst2)); const std::shared_ptr newMultiply = std::make_shared(parent, newMulConst); newMultiply->set_friendly_name(newMatMul->get_friendly_name() + "/DequantizationMultiply"); replace_node(matMul, newMultiply); - ngraph::copy_runtime_info({ newMultiply, matMul }, newMultiply); + copy_runtime_info({ newMultiply, matMul }, newMultiply); updateOutput(context, newMultiply, matMul); @@ -145,12 +140,12 @@ void MatMulTransformation::registerMatcherIn(GraphRewrite& pass, TransformationC addPattern( pass, context, - make_op_pattern({ make_op_label(), make_op_label() })); + make_op_pattern({ make_op_label(), make_op_label() })); addPattern( pass, context, - make_op_pattern({ make_op_label(), make_op_label() })); + make_op_pattern({ make_op_label(), make_op_label() })); } bool MatMulTransformation::isPrecisionPreserved(std::shared_ptr layer) const noexcept { @@ -167,15 +162,14 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context return false; } - const auto dequantization1 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer); + const auto dequantization1 = NetworkHelper::getDequantization(layer, 0); if (!dequantization1.empty()) { if (updatePrecisions && !dequantization1.isLowPrecision()) { return false; } - const auto mulConst = as_type_ptr(dequantization1.multiply->get_input_node_shared_ptr(1)); - if (!NetworkHelper::isScalarLike(mulConst)) { - const auto constantShape = mulConst->get_shape(); + if (!NetworkHelper::isScalarLike(dequantization1.multiplyConstant)) { + const auto constantShape = dequantization1.multiplyConstant->get_shape(); const auto mulShape = dequantization1.multiply->get_shape(); const size_t columnsIdx = matMul->get_transpose_a() ? mulShape.size() - 2ul : mulShape.size() - 1ul; @@ -186,15 +180,22 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context } } - const auto dequantization2 = ngraph::pass::low_precision::NetworkHelper::getDequantization(layer, 1); + const auto dequantization2 = NetworkHelper::getDequantization(layer, 1); if (!dequantization2.empty()) { - if ((updatePrecisions && !dequantization2.isLowPrecision()) || (dequantization2.subtract)) { + if ((updatePrecisions && !dequantization2.isLowPrecision())) { return false; } - const auto mulConst = as_type_ptr(dequantization2.multiply->get_input_node_shared_ptr(1)); - if (!NetworkHelper::isScalarLike(mulConst)) { - const auto constantShape = mulConst->get_shape(); + if (dequantization2.subtract) { + std::shared_ptr roundedConst = NetworkHelper::round(dequantization2.subtractConstant, dequantization2.data.get_element_type()); + roundedConst = NetworkHelper::toScalarIfPossible(roundedConst); + if (NetworkHelper::isZeroConst(roundedConst)) { + return false; + } + } + + if (!NetworkHelper::isScalarLike(dequantization2.multiplyConstant)) { + const auto constantShape = dequantization2.multiplyConstant->get_shape(); const auto mulShape = dequantization2.multiply->get_shape(); const size_t rowsIdx = matMul->get_transpose_b() ? mulShape.size() - 1ul : mulShape.size() - 2ul; @@ -229,7 +230,7 @@ bool MatMulTransformation::canBeTransformed(const TransformationContext& context } } - if (fakeQuantize == nullptr && dequantization1.subtract) { + if ((!NetworkHelper::isConstantPath(layer->get_input_node_shared_ptr(1))) && (dequantization1.subtract)) { return false; }