forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LPT] MoveFakeQuantize (openvinotoolkit#6723)
* add move_fake_quantize_for_concat_transformation, mfk and mfk_function * fix relu_transformation.cpp * backup * add change * add cpu test * [LPT] MoveFakeQuantizeTransformation: fixes * get InferenceEngine::NotImplemented * fix ieFuncTests * try without new cpu_test * fix cpuFuncTests and ieFuncTests * fix tests * fix lin * add cpu test * fix link and matcher in move_fake_quantize.cpp * update matcher * add gpu test * naming fix * move_fake_quantize.cpp add set_fr_name for new_concat * naming new fq fix * fix NetworkHelper::copyInfo naming * concat.cpp naming fix * gpu tests fix * rm network_helper changes * rm extra output * resolve conversations * resolve other conversations * add multi inputs for concat * fix lin * fix move_fake_qunatize naming * rm maxpool from mfk_function * mkldnn update * fix style * rm extra change * fix concat matcher * rm mkldnn_plugin changes * fix conversations * fix interval * fix and add isQuantizedStatic, add attribute and negative tests * add negative plugin tests * fix style: Co-authored-by: Edward Shogulin <[email protected]>
- Loading branch information
Showing
11 changed files
with
957 additions
and
3 deletions.
There are no files selected for viewing
25 changes: 25 additions & 0 deletions
25
...nce-engine/src/low_precision_transformations/include/low_precision/move_fake_quantize.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <ngraph/ngraph.hpp> | ||
#include "low_precision/layer_transformation.hpp" | ||
|
||
namespace ngraph { | ||
namespace pass { | ||
namespace low_precision { | ||
|
||
class LP_TRANSFORMATIONS_API MoveFakeQuantize : public LayerTransformation { | ||
public: | ||
NGRAPH_RTTI_DECLARATION; | ||
MoveFakeQuantize(const Params& params = Params()); | ||
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override; | ||
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override; | ||
}; | ||
|
||
} // namespace low_precision | ||
} // namespace pass | ||
} // namespace ngraph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
107 changes: 107 additions & 0 deletions
107
inference-engine/src/low_precision_transformations/src/move_fake_quantize.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "low_precision/move_fake_quantize.hpp" | ||
|
||
#include <ngraph/pattern/op/wrap_type.hpp> | ||
#include <ngraph/opsets/opset1.hpp> | ||
|
||
#include <memory> | ||
#include <ngraph/ngraph.hpp> | ||
#include <ngraph/opsets/opset1.hpp> | ||
#include <ngraph/pattern/op/or.hpp> | ||
|
||
#include "low_precision/concat.hpp" | ||
#include "low_precision/network_helper.hpp" | ||
|
||
namespace ngraph { | ||
namespace pass { | ||
namespace low_precision { | ||
|
||
NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::MoveFakeQuantize, "MoveFakeQuantize", 0); | ||
|
||
MoveFakeQuantize::MoveFakeQuantize(const Params& params) : LayerTransformation(params) { | ||
const auto concat = ngraph::pattern::wrap_type<opset1::Concat>(pattern::consumers_count(1)); | ||
const auto operation = ngraph::pattern::wrap_type<opset1::Relu>({ concat }); | ||
const auto input_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>(); | ||
const auto input_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>(); | ||
const auto output_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>(); | ||
const auto output_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>(); | ||
const auto fq_with_operation = ngraph::pattern::wrap_type<opset1::FakeQuantize>({ operation, | ||
input_low, | ||
input_high, | ||
output_low, | ||
output_high}); | ||
const auto fq = ngraph::pattern::wrap_type<opset1::FakeQuantize>({ concat, | ||
input_low, | ||
input_high, | ||
output_low, | ||
output_high }); | ||
|
||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { | ||
auto op = m.get_match_root(); | ||
if (transformation_callback(op)) { | ||
return false; | ||
} | ||
|
||
return transform(*context, m); | ||
}; | ||
|
||
auto m = std::make_shared<ngraph::pattern::Matcher>( | ||
std::make_shared<pattern::op::Or>(OutputVector{fq, fq_with_operation}), | ||
"MoveFakeQuantize"); | ||
this->register_matcher(m, callback); | ||
} | ||
|
||
bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern::Matcher& m) { | ||
auto fq = m.get_match_root(); | ||
auto operation = fq->get_input_node_shared_ptr(0); | ||
std::shared_ptr<ngraph::Node> concat; | ||
bool only_concat = true; | ||
std::string fq_original_name = fq->get_friendly_name(), operation_original_name; | ||
if (is_type<opset1::Concat>(operation)) { | ||
concat = operation; | ||
} else { | ||
operation_original_name = operation->get_friendly_name(); | ||
concat = operation->get_input_node_shared_ptr(0); | ||
only_concat = false; | ||
} | ||
if (!ConcatTransformation::isQuantizedStatic(concat)) { | ||
return false; | ||
} | ||
std::vector<std::shared_ptr<ngraph::Node>> fqs; | ||
size_t input_size = concat->get_input_size(); | ||
for (size_t i{ 0 }; i < input_size; ++i) { | ||
std::shared_ptr<ngraph::Node> fq_input; | ||
if (only_concat) { | ||
fq_input = concat->get_input_node_shared_ptr(i); | ||
} else { | ||
auto input = concat->get_input_node_shared_ptr(i); | ||
fq_input = operation->clone_with_new_inputs({ input }); | ||
fq_input->set_friendly_name(operation_original_name + "_" + std::to_string(i + 1)); | ||
} | ||
auto newFq = fq->clone_with_new_inputs({ fq_input, | ||
fq->get_input_node_shared_ptr(1), | ||
fq->get_input_node_shared_ptr(2), | ||
fq->get_input_node_shared_ptr(3), | ||
fq->get_input_node_shared_ptr(4) }); | ||
newFq->set_friendly_name(fq_original_name + "_" + std::to_string(i + 1)); | ||
fqs.push_back(newFq); | ||
} | ||
ngraph::copy_runtime_info(fq, fqs); | ||
auto newConcat = concat->clone_with_new_inputs(ngraph::OutputVector(fqs.begin(), fqs.end())); | ||
newConcat->set_friendly_name(concat->get_friendly_name()); | ||
replace_node(fq, newConcat); | ||
NetworkHelper::copyInfo(concat, newConcat); | ||
updateOutput(context, newConcat, fq); | ||
return true; | ||
} | ||
|
||
bool MoveFakeQuantize::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept { | ||
return true; | ||
} | ||
|
||
} // namespace low_precision | ||
} // namespace pass | ||
} // namespace ngraph |
Oops, something went wrong.