Skip to content

Commit

Permalink
[LPT] MVNTransformation quick refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Aug 20, 2021
1 parent 14e65a5 commit 5d189ad
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions inference-engine/src/low_precision_transformations/src/mvn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ bool MVNTransformation::canBeTransformed(const TransformationContext& context, s
}
}

const auto scalesConst = ov::as_type_ptr<opset1::Constant>(NetworkHelper::getConstantInput(mvn->get_input_node_shared_ptr(0)));
bool isScalarScales = NetworkHelper::isScalarLike(scalesConst);

AxisSet reduction_axes;
if (ov::is_type<op::MVN>(mvn)) {
reduction_axes = ov::as_type_ptr<op::MVN>(mvn)->get_reduction_axes();
Expand All @@ -106,6 +103,7 @@ bool MVNTransformation::canBeTransformed(const TransformationContext& context, s
}
}

bool isScalarScales = NetworkHelper::isScalarLike(dequantization.multiplyConstant);
return perTensor && isScalarScales;
}

Expand All @@ -128,13 +126,10 @@ bool MVNTransformation::transform(TransformationContext &context, ngraph::patter
}

FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(mvn);
auto scalesConst = ov::as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(1));
if (scalesConst == nullptr) {
scalesConst = ov::as_type_ptr<opset1::Constant>(dequantization.multiply->get_input_node_shared_ptr(0));
}
const auto scalesConst = dequantization.multiplyConstant;
const auto type = scalesConst->get_element_type();

auto newScalesConst = scalesConst;
const auto type = scalesConst->get_output_element_type(0);
if (normalizeVariance) {
switch (type) {
case ngraph::element::Type_t::f16: {
Expand All @@ -150,6 +145,7 @@ bool MVNTransformation::transform(TransformationContext &context, ngraph::patter
}
}
}

std::shared_ptr<Node> newMVN;
if (ov::is_type<op::MVN>(mvn)) {
newMVN = mvn->copy_with_new_inputs({dequantization.data});
Expand Down

0 comments on commit 5d189ad

Please sign in to comment.