Skip to content

Commit

Permalink
[LPT] MoveFakeQuantize (openvinotoolkit#6723)
Browse files Browse the repository at this point in the history
* add move_fake_quantize_for_concat_transformation, mfk and mfk_function

* fix relu_transformation.cpp

* backup

* add change

* add cpu test

* [LPT] MoveFakeQuantizeTransformation: fixes

* get InferenceEngine::NotImplemented

* fix ieFuncTests

* try without new cpu_test

* fix cpuFuncTests and ieFuncTests

* fix tests

* fix lin

* add cpu test

* fix link and matcher in move_fake_quantize.cpp

* update matcher

* add gpu test

* naming fix

* move_fake_quantize.cpp add set_fr_name for new_concat

* naming new fq fix

* fix NetworkHelper::copyInfo naming

* concat.cpp naming fix

* gpu tests fix

* rm network_helper changes

* rm extra output

* resolve conversations

* resolve other conversations

* add multi inputs for concat

* fix lin

* fix move_fake_qunatize naming

* rm maxpool from mfk_function

* mkldnn update

* fix style

* rm extra change

* fix concat matcher

* rm mkldnn_plugin changes

* fix conversations

* fix interval

* fix and add isQuantizedStatic, add attribute and negative tests

* add negative plugin tests

* fix style:

Co-authored-by: Edward Shogulin <[email protected]>
  • Loading branch information
2 people authored and akuporos committed Sep 29, 2021
1 parent 6a55fb8 commit 80b074b
Show file tree
Hide file tree
Showing 11 changed files with 957 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <ngraph/ngraph.hpp>
#include "low_precision/layer_transformation.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

class LP_TRANSFORMATIONS_API MoveFakeQuantize : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
MoveFakeQuantize(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

} // namespace low_precision
} // namespace pass
} // namespace ngraph
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
const auto convert = convertNodes[0]->clone_with_new_inputs({ newConcat });

NetworkHelper::copyInfo({ concat, convert }, convert);
convert->set_friendly_name(concat->get_friendly_name() + "/DequantizationConvert");
lastDequantization = convert;
}

Expand All @@ -150,6 +151,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
ngraph::pass::low_precision::fold<ngraph::opset1::Concat>(subtractNodes, 1)));

NetworkHelper::copyInfo({ concat, subtract }, subtract);
subtract->set_friendly_name(concat->get_friendly_name() + "/DequantizationSubtract");
lastDequantization = subtract;
}

Expand All @@ -163,6 +165,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
layerDequantizations[0].multiply->get_output_element_type(0));

NetworkHelper::copyInfo({ concat, multiply }, multiply);
multiply->set_friendly_name(concat->get_friendly_name() + "/DequantizationMultyply");
lastDequantization = multiply;
}

Expand Down Expand Up @@ -325,13 +328,12 @@ bool ConcatTransformation::isQuantizedStatic(const std::shared_ptr<const Node>&
return false;
}

const auto axis = concat->get_axis();
const auto outputRank = concat->get_output_partial_shape(0).rank();
if (axis < 0 && outputRank.is_dynamic()) {
if (outputRank.is_dynamic()) {
return false;
}

const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), axis, outputRank);
const size_t normalizedAxis = ngraph::normalize_axis(concat->get_friendly_name(), concat->get_axis(), outputRank);
return normalizedAxis == 1ul;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
#include "low_precision/transpose.hpp"
#include "low_precision/unsqueeze.hpp"
#include "low_precision/variadic_split.hpp"
#include "low_precision/move_fake_quantize.hpp"

// cleanup transformations
#include "low_precision/convert.hpp"
Expand Down Expand Up @@ -197,6 +198,7 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_function(std::shared_ptr<
prerequisites->add_matcher<PullReshapeThroughDequantization>(supportedTypes);
prerequisites->add_matcher<PullTransposeThroughDequantization>(supportedTypes);
prerequisites->add_matcher<ngraph::pass::LinOpSequenceFusion>();
prerequisites->add_matcher<ngraph::pass::low_precision::MoveFakeQuantize>();

manager.register_pass<TypeRelaxedReplacer>();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "low_precision/move_fake_quantize.hpp"

#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset1.hpp>

#include <memory>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/pattern/op/or.hpp>

#include "low_precision/concat.hpp"
#include "low_precision/network_helper.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::MoveFakeQuantize, "MoveFakeQuantize", 0);

MoveFakeQuantize::MoveFakeQuantize(const Params& params) : LayerTransformation(params) {
const auto concat = ngraph::pattern::wrap_type<opset1::Concat>(pattern::consumers_count(1));
const auto operation = ngraph::pattern::wrap_type<opset1::Relu>({ concat });
const auto input_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto input_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto fq_with_operation = ngraph::pattern::wrap_type<opset1::FakeQuantize>({ operation,
input_low,
input_high,
output_low,
output_high});
const auto fq = ngraph::pattern::wrap_type<opset1::FakeQuantize>({ concat,
input_low,
input_high,
output_low,
output_high });

ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}

return transform(*context, m);
};

auto m = std::make_shared<ngraph::pattern::Matcher>(
std::make_shared<pattern::op::Or>(OutputVector{fq, fq_with_operation}),
"MoveFakeQuantize");
this->register_matcher(m, callback);
}

bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
auto fq = m.get_match_root();
auto operation = fq->get_input_node_shared_ptr(0);
std::shared_ptr<ngraph::Node> concat;
bool only_concat = true;
std::string fq_original_name = fq->get_friendly_name(), operation_original_name;
if (is_type<opset1::Concat>(operation)) {
concat = operation;
} else {
operation_original_name = operation->get_friendly_name();
concat = operation->get_input_node_shared_ptr(0);
only_concat = false;
}
if (!ConcatTransformation::isQuantizedStatic(concat)) {
return false;
}
std::vector<std::shared_ptr<ngraph::Node>> fqs;
size_t input_size = concat->get_input_size();
for (size_t i{ 0 }; i < input_size; ++i) {
std::shared_ptr<ngraph::Node> fq_input;
if (only_concat) {
fq_input = concat->get_input_node_shared_ptr(i);
} else {
auto input = concat->get_input_node_shared_ptr(i);
fq_input = operation->clone_with_new_inputs({ input });
fq_input->set_friendly_name(operation_original_name + "_" + std::to_string(i + 1));
}
auto newFq = fq->clone_with_new_inputs({ fq_input,
fq->get_input_node_shared_ptr(1),
fq->get_input_node_shared_ptr(2),
fq->get_input_node_shared_ptr(3),
fq->get_input_node_shared_ptr(4) });
newFq->set_friendly_name(fq_original_name + "_" + std::to_string(i + 1));
fqs.push_back(newFq);
}
ngraph::copy_runtime_info(fq, fqs);
auto newConcat = concat->clone_with_new_inputs(ngraph::OutputVector(fqs.begin(), fqs.end()));
newConcat->set_friendly_name(concat->get_friendly_name());
replace_node(fq, newConcat);
NetworkHelper::copyInfo(concat, newConcat);
updateOutput(context, newConcat, fq);
return true;
}

bool MoveFakeQuantize::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return true;
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph
Loading

0 comments on commit 80b074b

Please sign in to comment.