diff --git a/inference-engine/src/low_precision_transformations/src/mvn.cpp b/inference-engine/src/low_precision_transformations/src/mvn.cpp index 383688c2f2ff06..781e26d7e64f39 100644 --- a/inference-engine/src/low_precision_transformations/src/mvn.cpp +++ b/inference-engine/src/low_precision_transformations/src/mvn.cpp @@ -79,9 +79,6 @@ bool MVNTransformation::canBeTransformed(const TransformationContext& context, s } } - const auto scalesConst = ov::as_type_ptr(NetworkHelper::getConstantInput(mvn->get_input_node_shared_ptr(0))); - bool isScalarScales = NetworkHelper::isScalarLike(scalesConst); - AxisSet reduction_axes; if (ov::is_type(mvn)) { reduction_axes = ov::as_type_ptr(mvn)->get_reduction_axes(); @@ -106,6 +103,7 @@ bool MVNTransformation::canBeTransformed(const TransformationContext& context, s } } + bool isScalarScales = NetworkHelper::isScalarLike(dequantization.multiplyConstant); return perTensor && isScalarScales; } @@ -128,13 +126,10 @@ bool MVNTransformation::transform(TransformationContext &context, ngraph::patter } FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(mvn); - auto scalesConst = ov::as_type_ptr(dequantization.multiply->get_input_node_shared_ptr(1)); - if (scalesConst == nullptr) { - scalesConst = ov::as_type_ptr(dequantization.multiply->get_input_node_shared_ptr(0)); - } + const auto scalesConst = dequantization.multiplyConstant; + const auto type = scalesConst->get_element_type(); auto newScalesConst = scalesConst; - const auto type = scalesConst->get_output_element_type(0); if (normalizeVariance) { switch (type) { case ngraph::element::Type_t::f16: { @@ -150,6 +145,7 @@ bool MVNTransformation::transform(TransformationContext &context, ngraph::patter } } } + std::shared_ptr newMVN; if (ov::is_type(mvn)) { newMVN = mvn->copy_with_new_inputs({dequantization.data});