-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
mandrono
committed
Dec 30, 2020
1 parent
0850479
commit d275a15
Showing
11 changed files
with
717 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
...ce-engine/src/transformations/include/transformations/op_conversions/fq_decomposition.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (C) 2020 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <transformations_visibility.hpp> | ||
#include <ngraph/pass/graph_rewrite.hpp> | ||
|
||
namespace ngraph { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API FakeQuantizeDecomposition; | ||
|
||
} // namespace pass | ||
} // namespace ngraph | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief FakeQuantizeDecomposition transformation into sub-graph | ||
*/ | ||
class ngraph::pass::FakeQuantizeDecomposition: public ngraph::pass::MatcherPass { | ||
public: | ||
NGRAPH_RTTI_DECLARATION; | ||
FakeQuantizeDecomposition(); | ||
}; |
86 changes: 86 additions & 0 deletions
86
inference-engine/src/transformations/src/transformations/op_conversions/fq_decomposition.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
// Copyright (C) 2020 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/op_conversions/fq_decomposition.hpp" | ||
|
||
#include <ngraph/opsets/opset5.hpp> | ||
#include <ngraph/rt_info.hpp> | ||
#include <ngraph/pattern/op/wrap_type.hpp> | ||
#include <ngraph/builder/autobroadcast.hpp> | ||
|
||
NGRAPH_RTTI_DEFINITION(ngraph::pass::FakeQuantizeDecomposition, "FakeQuantizeDecomposition", 0); | ||
|
||
ngraph::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() { | ||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset5::FakeQuantize>(); | ||
|
||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) { | ||
auto &pattern_to_output = m.get_pattern_value_map(); | ||
auto fake_quantize_node = std::dynamic_pointer_cast<ngraph::opset5::FakeQuantize>(pattern_to_output.at(fake_quantize).get_node_shared_ptr()); | ||
|
||
if (fake_quantize_node == nullptr || m_transformation_callback(fake_quantize_node)) { | ||
return false; | ||
} | ||
|
||
Output<Node> data{fake_quantize_node->input_value(0)}; | ||
Output<Node> input_low{fake_quantize_node->input_value(1)}; | ||
Output<Node> input_high{fake_quantize_node->input_value(2)}; | ||
Output<Node> output_low{fake_quantize_node->input_value(3)}; | ||
Output<Node> output_high{fake_quantize_node->input_value(4)}; | ||
auto input_type = data.get_element_type(); | ||
|
||
ngraph::NodeVector decomp_ops; | ||
if (input_type != input_low.get_element_type()) { | ||
input_type = input_low.get_element_type(); | ||
data = std::make_shared<ngraph::opset5::Convert>(data, input_type); | ||
decomp_ops.push_back(data.get_node_shared_ptr()); | ||
} | ||
|
||
auto max = std::make_shared<ngraph::opset5::Maximum>(data, input_low); | ||
auto min = std::make_shared<ngraph::opset5::Minimum>(max, input_high); | ||
decomp_ops.push_back(max); | ||
decomp_ops.push_back(min); | ||
|
||
auto levels_minus_one = std::make_shared<ngraph::opset5::Constant>(input_type, Shape{}, fake_quantize_node->get_levels() - 1); | ||
decomp_ops.push_back(levels_minus_one); | ||
// input scale and shift | ||
auto subInHighLow = std::make_shared<ngraph::opset5::Subtract>(input_high, input_low); | ||
auto isc = std::make_shared<ngraph::opset5::Divide>(levels_minus_one, subInHighLow); | ||
auto ish = std::make_shared<ngraph::opset5::Multiply>(input_low, isc); | ||
decomp_ops.push_back(subInHighLow); | ||
decomp_ops.push_back(isc); | ||
decomp_ops.push_back(ish); | ||
|
||
auto after_isc_apply = std::make_shared<ngraph::opset5::Multiply>(min, isc); | ||
auto after_ish_apply = std::make_shared<ngraph::opset5::Subtract>(after_isc_apply, ish); | ||
decomp_ops.push_back(after_isc_apply); | ||
decomp_ops.push_back(after_ish_apply); | ||
|
||
auto round = std::make_shared<ngraph::opset5::Round>(after_ish_apply, ngraph::opset5::Round::RoundMode::HALF_TO_EVEN); | ||
decomp_ops.push_back(round); | ||
|
||
// output scale and shift | ||
auto subOutHighLow = std::make_shared<ngraph::opset5::Subtract>(output_high, output_low); | ||
auto osc = std::make_shared<ngraph::opset5::Divide>(subOutHighLow, levels_minus_one); | ||
decomp_ops.push_back(subOutHighLow); | ||
decomp_ops.push_back(osc); | ||
|
||
auto after_osc_apply = std::make_shared<ngraph::opset5::Multiply>(round, osc); | ||
std::shared_ptr<Node> result = std::make_shared<ngraph::opset5::Add>(after_osc_apply, output_low); | ||
decomp_ops.push_back(after_osc_apply); | ||
decomp_ops.push_back(result); | ||
|
||
if (result->get_output_element_type(0) != fake_quantize_node->get_output_element_type(0)) { | ||
result = std::make_shared<ngraph::opset5::Convert>(result, fake_quantize_node->get_output_element_type(0)); | ||
decomp_ops.push_back(result); | ||
} | ||
|
||
result->set_friendly_name(m.get_match_root()->get_friendly_name()); | ||
ngraph::copy_runtime_info(fake_quantize_node, decomp_ops); | ||
ngraph::replace_node(m.get_match_root(), result); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ngraph::pattern::Matcher>(fake_quantize, "FakeQuantizeDecomposition"); | ||
register_matcher(m, callback); | ||
} |
Oops, something went wrong.