Skip to content

Commit

Permalink
Dequantize (Sub, Mul) to FakeQuantize (#4189)
Browse files Browse the repository at this point in the history
* Dequantize (Sub, Mul) to FakeQuantize

* disable for CPU/GPU
  • Loading branch information
Evgenya Stepyreva authored Feb 10, 2021
1 parent a327b72 commit a313c0c
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 0 deletions.
2 changes: 2 additions & 0 deletions inference-engine/src/cldnn_engine/cldnn_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <transformations/control_flow/unroll_tensor_iterator.hpp>

#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/op_conversions/convert_depth_to_space.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
Expand Down Expand Up @@ -279,6 +280,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
pass_config->disable<ngraph::pass::SoftPlusDecomposition>();
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
pass_config->disable<ngraph::pass::ConvertBroadcast3>();
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();

pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();

Expand Down
2 changes: 2 additions & 0 deletions inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>

#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
#include <transformations/op_conversions/convert_depth_to_space.hpp>
Expand Down Expand Up @@ -228,6 +229,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
pass_config->disable<ngraph::pass::ConvertMod>();
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
pass_config->disable<ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher>();
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();

pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>
#include <memory>

#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API WeightsDequantizeToFakeQuantize;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief SoftPlusFusion transformation replaces group of
* operations: log(exp(x) + 1) to SoftPlus op.
*/
class ngraph::pass::WeightsDequantizeToFakeQuantize: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
WeightsDequantizeToFakeQuantize();
};
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations", 0);

Expand All @@ -69,6 +70,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::

// TODO: move to KMB
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::WeightsDequantizeToFakeQuantize>();

manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp"

#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/rt_info.hpp>
#include "itt.hpp"

NGRAPH_RTTI_DEFINITION(ngraph::pass::WeightsDequantizeToFakeQuantize, "WeightsDequantizeToFakeQuantize", 0);

ngraph::pass::WeightsDequantizeToFakeQuantize::WeightsDequantizeToFakeQuantize() {
MATCHER_SCOPE(WeightsDequantizeToFakeQuantize);

const auto weights = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
const auto convert = ngraph::pattern::wrap_type<ngraph::opset6::Convert>({weights});
const auto sub_c = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
const auto sub = ngraph::pattern::wrap_type<ngraph::opset6::Subtract>({convert, sub_c});

const auto sub_or_convert = std::make_shared<pattern::op::Or>(OutputVector{convert, sub});

const auto mul_c = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
const auto mul = ngraph::pattern::wrap_type<ngraph::opset6::Multiply>({sub_or_convert, mul_c});

ngraph::matcher_pass_callback callback;
callback = [=](ngraph::pattern::Matcher &m) {
auto pattern_map = m.get_pattern_map();

const auto &weights_node = as_type_ptr<opset6::Constant>(pattern_map.at(weights));
const auto &convert_node = pattern_map.at(convert);
const auto &multiply_node = pattern_map.at(mul);
const auto &scale_node = pattern_map.at(mul_c);
if (!weights_node || !convert_node || !multiply_node || !scale_node) {
return false;
}

const auto *data = weights_node->get_data_ptr<int8_t>();
const int8_t weights_minimum = *std::min_element(data, data + shape_size(weights_node->get_shape()));
int64_t levels = (weights_minimum == static_cast<int8_t>(-128)) ? 256 : 255;
int64_t in_low = -(levels / 2), in_high = levels + in_low - 1;

const auto &input_low = opset6::Constant::create(convert_node->get_element_type(), {}, {in_low});
const auto &input_high = opset6::Constant::create(convert_node->get_element_type(), {}, {in_high});

auto &zero_point = pattern_map.at(sub_c);
if (!zero_point)
zero_point = opset6::Constant::create(convert_node->get_element_type(), {}, {0});

const auto &output_low = std::make_shared<opset6::Multiply>(
std::make_shared<opset6::Subtract>(input_low, zero_point), scale_node);
const auto &output_high = std::make_shared<opset6::Multiply>(
std::make_shared<opset6::Subtract>(input_high, zero_point), scale_node);

auto fq = std::make_shared<opset6::FakeQuantize>(
convert_node, input_low, input_high, output_low, output_high, levels);

NodeVector nodes_to_copy_RT_info_from{multiply_node, scale_node, zero_point};
if (pattern_map.at(sub))
nodes_to_copy_RT_info_from.push_back(sub);

ngraph::copy_runtime_info(fq, nodes_to_copy_RT_info_from);
multiply_node->output(0).replace(fq->output(0));

if (convert_node->get_rt_info().count("DISABLED_CONSTANT_FOLDING"))
convert_node->get_rt_info().erase("DISABLED_CONSTANT_FOLDING");
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "WeightsDequantizeToFakeQuantize");
register_matcher(m, callback);
}

0 comments on commit a313c0c

Please sign in to comment.