Skip to content

Commit

Permalink
[LPT] DisableCleanup usage
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Sep 28, 2023
1 parent 006e037 commit 79cba9a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,18 @@ class ov::pass::low_precision::LowPrecision : public ov::pass::ModelPass {
static bool isFunctionQuantized(const std::shared_ptr<const ov::Model>& model);
static bool isFQLevelsPresent(const std::shared_ptr<const ov::Model>& model, const std::set<size_t>& levels);

template <typename T, class... Args>
std::shared_ptr<T> add_main(Args&&... args) {
const auto tr = std::make_shared<T>(std::forward<Args>(args)...);
added_main.push_back(tr);
return tr;
}

protected:
std::vector<PrecisionsRestriction> precisionRestrictions;
std::vector<QuantizationGranularityRestriction> quantizationRestrictions;
// remove
LayerTransformation::Params params;

std::vector<std::shared_ptr<MatcherPass>> added_main;
};
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr<o
ADD_MATCHER(common, UnsqueezeTransformation, params)
ADD_MATCHER(common, VariadicSplitTransformation, params)

for (const auto& tr : added_main) {
common->add_matcher(tr);
}

std::shared_ptr<ov::pass::GraphRewrite> cleanup = manager.register_pass<ov::pass::GraphRewrite>();
ADD_MATCHER(cleanup, EliminateFakeQuantizeTransformation, params)
ADD_MATCHER(cleanup, FoldConvertTransformation, params)
Expand Down
62 changes: 30 additions & 32 deletions src/common/low_precision_transformations/src/multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"

#include "low_precision/common/ie_lpt_exception.hpp"
#include "low_precision/rt_info/disable_cleanup_attribute.hpp"
#include "low_precision/network_helper.hpp"
#include "itt.hpp"

Expand Down Expand Up @@ -61,31 +62,36 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass:
// before: y = (deq_scales1 * (x1 - zero_point1)) * (deq_scales2 * (x2 - zero_point2))
// after : y = deq_scales1 * deq_scales2 * (x1 - zero_point1) * (x2 - zero_point2)

if ((dequantization1.empty() && (ov::is_type<ov::opset1::Constant>(dequantization1.data.get_node()))) ||
(dequantization2.empty() && (ov::is_type<ov::opset1::Constant>(dequantization2.data.get_node())))) {
// one input is constant
auto new_scales_values = fold<ov::opset1::Multiply>(
dequantization1.empty() ? dequantization1.data : dequantization1.multiplyConstant,
dequantization2.empty() ? dequantization2.data : dequantization2.multiplyConstant);
auto new_scales_values = fold<ov::opset1::Multiply>(
dequantization1.empty() ? dequantization1.data : dequantization1.multiplyConstant,
dequantization2.empty() ? dequantization2.data : dequantization2.multiplyConstant);

if (!ov::is_type<ov::opset1::Constant>(new_scales_values)) {
return false;
if (!ov::is_type<ov::opset1::Constant>(new_scales_values)) {
return false;
}

const auto init_input = [&new_scales_values](const FakeQuantizeDequantization& dequantization) -> Output<Node> {
if (dequantization.empty()) {
return new_scales_values;
}

const auto create_input = [&new_scales_values](const FakeQuantizeDequantization& dequantization) -> Output<Node> {
if (dequantization.empty()) {
return new_scales_values;
}
if (dequantization.subtract == nullptr) {
return dequantization.data;
}

if (dequantization.subtract == nullptr) {
return dequantization.data;
}
const auto subtract = NetworkHelper::optimizeSubtract(dequantization.subtract);
if (subtract != nullptr) {
DisableCleanupAttribute::create(subtract);
}

const auto subtract = NetworkHelper::optimizeSubtract(dequantization.subtract);
return subtract == nullptr ? dequantization.data : subtract;
};
const Output<Node> in1 = create_input(dequantization1);
const Output<Node> in2 = create_input(dequantization2);
return subtract == nullptr ? dequantization.data : subtract;
};

if ((dequantization1.empty() && (ov::is_type<ov::opset1::Constant>(dequantization1.data.get_node()))) ||
(dequantization2.empty() && (ov::is_type<ov::opset1::Constant>(dequantization2.data.get_node())))) {
// one input is constant
const Output<Node> in1 = init_input(dequantization1);
const Output<Node> in2 = init_input(dequantization2);

const auto new_multiply = (in1.get_element_type() == multiply->get_output_element_type(0)) &&
(in2.get_element_type() == multiply->get_output_element_type(0)) ?
Expand All @@ -102,18 +108,8 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass:
return true;
}

auto new_scales_values = fold<ov::opset1::Multiply>(dequantization1.multiplyConstant, dequantization2.multiplyConstant);
if (!ov::is_type<ov::opset1::Constant>(new_scales_values)) {
return false;
}

const Output<Node> in1 = dequantization1.subtract == nullptr ?
dequantization1.data :
NetworkHelper::optimizeSubtract(dequantization1.subtract);

const Output<Node> in2 = dequantization2.subtract == nullptr ?
dequantization2.data :
NetworkHelper::optimizeSubtract(dequantization2.subtract);
Output<Node> in1 = init_input(dequantization1);
Output<Node> in2 = init_input(dequantization2);

// in1 & in2 can have different input types
const auto new_multiply = (in1.get_element_type() == deqPrecision) &&
Expand All @@ -125,6 +121,8 @@ bool MultiplyTransformation::transform(TransformationContext& context, ov::pass:
ov::op::TemporaryReplaceOutputType(in1, deqPrecision).get(),
ov::op::TemporaryReplaceOutputType(in2, deqPrecision).get());

DisableCleanupAttribute::create(new_multiply);

auto new_scales = (new_multiply->get_output_element_type(0) == multiply->get_output_element_type(0)) &&
(new_scales_values->get_output_element_type(0) == multiply->get_output_element_type(0)) ?
std::make_shared<ov::opset1::Multiply>(new_multiply, new_scales_values) :
Expand Down

0 comments on commit 79cba9a

Please sign in to comment.