-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Dequantize (Sub, Mul) to FakeQuantize (#4189)
* Dequantize (Sub, Mul) to FakeQuantize * disable for CPU/GPU
- Loading branch information
Evgenya Stepyreva
authored
Feb 10, 2021
1 parent
a327b72
commit a313c0c
Showing
5 changed files
with
110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
...ions/include/transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <vector> | ||
#include <memory> | ||
|
||
#include <transformations_visibility.hpp> | ||
#include <ngraph/pass/graph_rewrite.hpp> | ||
|
||
namespace ngraph { | ||
namespace pass { | ||
|
||
class TRANSFORMATIONS_API WeightsDequantizeToFakeQuantize; | ||
|
||
} // namespace pass | ||
} // namespace ngraph | ||
|
||
/** | ||
* @ingroup ie_transformation_common_api | ||
* @brief SoftPlusFusion transformation replaces group of | ||
* operations: log(exp(x) + 1) to SoftPlus op. | ||
*/ | ||
class ngraph::pass::WeightsDequantizeToFakeQuantize: public ngraph::pass::MatcherPass { | ||
public: | ||
NGRAPH_RTTI_DECLARATION; | ||
WeightsDequantizeToFakeQuantize(); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
...rmations/src/transformations/common_optimizations/weights_dequantize_to_fake_quantize.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// Copyright (C) 2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp" | ||
|
||
#include <ngraph/opsets/opset6.hpp> | ||
#include <ngraph/pattern/op/wrap_type.hpp> | ||
#include <ngraph/pattern/op/or.hpp> | ||
#include <ngraph/rt_info.hpp> | ||
#include "itt.hpp" | ||
|
||
NGRAPH_RTTI_DEFINITION(ngraph::pass::WeightsDequantizeToFakeQuantize, "WeightsDequantizeToFakeQuantize", 0); | ||
|
||
ngraph::pass::WeightsDequantizeToFakeQuantize::WeightsDequantizeToFakeQuantize() { | ||
MATCHER_SCOPE(WeightsDequantizeToFakeQuantize); | ||
|
||
const auto weights = ngraph::pattern::wrap_type<ngraph::opset6::Constant>(); | ||
const auto convert = ngraph::pattern::wrap_type<ngraph::opset6::Convert>({weights}); | ||
const auto sub_c = ngraph::pattern::wrap_type<ngraph::opset6::Constant>(); | ||
const auto sub = ngraph::pattern::wrap_type<ngraph::opset6::Subtract>({convert, sub_c}); | ||
|
||
const auto sub_or_convert = std::make_shared<pattern::op::Or>(OutputVector{convert, sub}); | ||
|
||
const auto mul_c = ngraph::pattern::wrap_type<ngraph::opset6::Constant>(); | ||
const auto mul = ngraph::pattern::wrap_type<ngraph::opset6::Multiply>({sub_or_convert, mul_c}); | ||
|
||
ngraph::matcher_pass_callback callback; | ||
callback = [=](ngraph::pattern::Matcher &m) { | ||
auto pattern_map = m.get_pattern_map(); | ||
|
||
const auto &weights_node = as_type_ptr<opset6::Constant>(pattern_map.at(weights)); | ||
const auto &convert_node = pattern_map.at(convert); | ||
const auto &multiply_node = pattern_map.at(mul); | ||
const auto &scale_node = pattern_map.at(mul_c); | ||
if (!weights_node || !convert_node || !multiply_node || !scale_node) { | ||
return false; | ||
} | ||
|
||
const auto *data = weights_node->get_data_ptr<int8_t>(); | ||
const int8_t weights_minimum = *std::min_element(data, data + shape_size(weights_node->get_shape())); | ||
int64_t levels = (weights_minimum == static_cast<int8_t>(-128)) ? 256 : 255; | ||
int64_t in_low = -(levels / 2), in_high = levels + in_low - 1; | ||
|
||
const auto &input_low = opset6::Constant::create(convert_node->get_element_type(), {}, {in_low}); | ||
const auto &input_high = opset6::Constant::create(convert_node->get_element_type(), {}, {in_high}); | ||
|
||
auto &zero_point = pattern_map.at(sub_c); | ||
if (!zero_point) | ||
zero_point = opset6::Constant::create(convert_node->get_element_type(), {}, {0}); | ||
|
||
const auto &output_low = std::make_shared<opset6::Multiply>( | ||
std::make_shared<opset6::Subtract>(input_low, zero_point), scale_node); | ||
const auto &output_high = std::make_shared<opset6::Multiply>( | ||
std::make_shared<opset6::Subtract>(input_high, zero_point), scale_node); | ||
|
||
auto fq = std::make_shared<opset6::FakeQuantize>( | ||
convert_node, input_low, input_high, output_low, output_high, levels); | ||
|
||
NodeVector nodes_to_copy_RT_info_from{multiply_node, scale_node, zero_point}; | ||
if (pattern_map.at(sub)) | ||
nodes_to_copy_RT_info_from.push_back(sub); | ||
|
||
ngraph::copy_runtime_info(fq, nodes_to_copy_RT_info_from); | ||
multiply_node->output(0).replace(fq->output(0)); | ||
|
||
if (convert_node->get_rt_info().count("DISABLED_CONSTANT_FOLDING")) | ||
convert_node->get_rt_info().erase("DISABLED_CONSTANT_FOLDING"); | ||
return true; | ||
}; | ||
|
||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "WeightsDequantizeToFakeQuantize"); | ||
register_matcher(m, callback); | ||
} |