Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Sep 15, 2023
1 parent 91887f3 commit 104cf66
Show file tree
Hide file tree
Showing 6 changed files with 483 additions and 34 deletions.
72 changes: 57 additions & 15 deletions src/common/low_precision_transformations/src/multiply.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down Expand Up @@ -44,7 +44,6 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass:
return false;
}

// TODO: normalizeDequantization + fold_fake_quantizes + foldDequantization <= ???
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0));
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1));

Expand All @@ -65,12 +64,8 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass:
fold_fake_quantizes(multiply, 1ul);

const auto dequantization1 = NetworkHelper::foldDequantization(multiply, 0, defaultPrecisions);
if (dequantization1.multiplyConstant == nullptr) {
return false;
}

const auto dequantization2 = NetworkHelper::foldDequantization(multiply, 1, defaultPrecisions);
if (dequantization2.multiplyConstant == nullptr) {
if ((dequantization1.multiplyConstant == nullptr) && (dequantization2.multiplyConstant == nullptr)) {
return false;
}

Expand All @@ -79,7 +74,50 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass:
// X1` = X1 - SH1
// X2` = X2 - SH2
// SC1' = SC1 * SC2

if ((dequantization1.empty() && (ov::is_type<ov::opset1::Constant>(dequantization1.data.get_node()))) ||
(dequantization2.empty() && (ov::is_type<ov::opset1::Constant>(dequantization2.data.get_node())))) {
// one input is constant
auto new_scales_values = fold<ov::opset1::Multiply>(
dequantization1.empty() ? dequantization1.data : dequantization1.multiplyConstant,
dequantization2.empty() ? dequantization2.data : dequantization2.multiplyConstant);

if (!ov::is_type<ov::opset1::Constant>(new_scales_values)) {
return false;
}

const Output<Node> in1 = dequantization1.empty() ?
new_scales_values :
dequantization1.subtract == nullptr ?
dequantization1.data :
NetworkHelper::optimizeSubtract(dequantization1.subtract);

const Output<Node> in2 = dequantization2.empty() ?
new_scales_values :
dequantization2.subtract == nullptr ?
dequantization2.data :
NetworkHelper::optimizeSubtract(dequantization2.subtract);

auto const new_multiply = (in1.get_element_type() == multiply->get_output_element_type(0)) &&
(in2.get_element_type() == multiply->get_output_element_type(0)) ?
std::make_shared<ov::opset1::Multiply>(in1, in2) :
std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ov::element::Type>{ deqPrecision, deqPrecision },
std::vector<ov::element::Type>{ multiply->get_output_element_type(0) },
ov::op::TemporaryReplaceOutputType(in1, deqPrecision).get(),
ov::op::TemporaryReplaceOutputType(in2, deqPrecision).get());

replace_node(multiply, new_multiply);
updateOutput(context, new_multiply, multiply);

return true;
}

auto new_scales_values = fold<ov::opset1::Multiply>(dequantization1.multiplyConstant, dequantization2.multiplyConstant);
if (!ov::is_type<ov::opset1::Constant>(new_scales_values)) {
return false;
}

const Output<Node> in1 = dequantization1.subtract == nullptr ?
dequantization1.data :
NetworkHelper::optimizeSubtract(dequantization1.subtract);
Expand All @@ -89,19 +127,23 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass:
NetworkHelper::optimizeSubtract(dequantization2.subtract);

// in1 & in2 can have different input types
auto const new_multiply = std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ov::element::Type>{ element::f32, element::f32 },
std::vector<ov::element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(in1, element::f32).get(),
ov::op::TemporaryReplaceOutputType(in2, element::f32).get());
const auto new_multiply = (in1.get_element_type() == deqPrecision) &&
(in2.get_element_type() == deqPrecision) ?
std::make_shared<ov::opset1::Multiply>(in1, in2) :
std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ov::element::Type>{ deqPrecision, deqPrecision },
std::vector<ov::element::Type>{ deqPrecision },
ov::op::TemporaryReplaceOutputType(in1, deqPrecision).get(),
ov::op::TemporaryReplaceOutputType(in2, deqPrecision).get());

NetworkHelper::copyInfo(multiply, newMultiply);

auto new_scales = new_multiply->get_output_element_type(0) != multiply->get_output_element_type(0) ?
auto new_scales = (new_multiply->get_output_element_type(0) == multiply->get_output_element_type(0)) &&
(new_scales_values->get_output_element_type(0) == multiply->get_output_element_type(0)) ?
std::make_shared<ov::opset1::Multiply>(new_multiply, new_scales_values) :
std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
ov::opset1::Multiply(new_multiply, new_scales_values),
multiply->get_output_element_type(0)) :
std::make_shared<ov::opset1::Multiply>(new_multiply, new_scales_values);
multiply->get_output_element_type(0));

replace_node(multiply, new_scales);
updateOutput(context, new_scales, multiply);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "low_precision/interpolate.hpp"
#include "low_precision/mat_mul.hpp"
#include "low_precision/max_pool.hpp"
#include "low_precision/multiply.hpp"
#include "low_precision/multiply_partial.hpp"
#include "low_precision/mvn.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/normalize_l2.hpp"
Expand Down Expand Up @@ -361,7 +361,7 @@ TEST(LPT, AvoidDequantizationToShapeOfPropagationMultiplyTransformation) {

auto f = std::make_shared<Model>(ResultVector{result1, result2}, ParameterVector{input1, input2});
pass::Manager m;
m.register_pass<ov::pass::low_precision::MultiplyTransformation>();
m.register_pass<ov::pass::low_precision::MultiplyPartialTransformation>();
m.run_passes(f);

auto dqBeforeShapeOf = ov::pass::low_precision::NetworkHelper::getDequantization(result2->get_input_node_shared_ptr(0));
Expand Down
Loading

0 comments on commit 104cf66

Please sign in to comment.