Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dequantize (Sub, Mul) to FakeQuantize #4189

Merged
merged 4 commits into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions inference-engine/src/cldnn_engine/cldnn_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <transformations/control_flow/unroll_tensor_iterator.hpp>

#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/op_conversions/convert_depth_to_space.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
Expand Down Expand Up @@ -279,6 +280,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
pass_config->disable<ngraph::pass::SoftPlusDecomposition>();
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
pass_config->disable<ngraph::pass::ConvertBroadcast3>();
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();

pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();

Expand Down
2 changes: 2 additions & 0 deletions inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>

#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
#include <transformations/op_conversions/convert_depth_to_space.hpp>
Expand Down Expand Up @@ -228,6 +229,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
pass_config->disable<ngraph::pass::ConvertMod>();
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
pass_config->disable<ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher>();
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();

pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <vector>
#include <memory>

#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API WeightsDequantizeToFakeQuantize;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief SoftPlusFusion transformation replaces group of
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
* operations: log(exp(x) + 1) to SoftPlus op.
*/
class ngraph::pass::WeightsDequantizeToFakeQuantize: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
WeightsDequantizeToFakeQuantize();
};
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations", 0);

Expand All @@ -69,6 +70,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::

// TODO: move to KMB
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::WeightsDequantizeToFakeQuantize>();

manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp"

#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/rt_info.hpp>
#include "itt.hpp"

NGRAPH_RTTI_DEFINITION(ngraph::pass::WeightsDequantizeToFakeQuantize, "WeightsDequantizeToFakeQuantize", 0);

ngraph::pass::WeightsDequantizeToFakeQuantize::WeightsDequantizeToFakeQuantize() {
MATCHER_SCOPE(WeightsDequantizeToFakeQuantize);

const auto weights = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
const auto convert = ngraph::pattern::wrap_type<ngraph::opset6::Convert>({weights});
const auto sub_c = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
const auto sub = ngraph::pattern::wrap_type<ngraph::opset6::Subtract>({convert, sub_c});

const auto sub_or_convert = std::make_shared<pattern::op::Or>(OutputVector{convert, sub});

const auto mul_c = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
const auto mul = ngraph::pattern::wrap_type<ngraph::opset6::Multiply>({sub_or_convert, mul_c});

ngraph::matcher_pass_callback callback;
callback = [=](ngraph::pattern::Matcher &m) {
auto pattern_map = m.get_pattern_map();
jane-intel marked this conversation as resolved.
Show resolved Hide resolved

const auto &weights_node = as_type_ptr<opset6::Constant>(pattern_map.at(weights));
const auto &convert_node = pattern_map.at(convert);
const auto &multiply_node = pattern_map.at(mul);
const auto &scale_node = pattern_map.at(mul_c);
if (!weights_node || !convert_node || !multiply_node || !scale_node) {
return false;
}

const auto *data = weights_node->get_data_ptr<int8_t>();
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
const int8_t weights_minimum = *std::min_element(data, data + shape_size(weights_node->get_shape()));
int64_t levels = (weights_minimum == static_cast<int8_t>(-128)) ? 256 : 255;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jane-intel, is it confirmed that we have to do this thing?

int64_t in_low = -(levels / 2), in_high = levels + in_low - 1;

const auto &input_low = opset6::Constant::create(convert_node->get_element_type(), {}, {in_low});
const auto &input_high = opset6::Constant::create(convert_node->get_element_type(), {}, {in_high});

auto &zero_point = pattern_map.at(sub_c);
if (!zero_point)
zero_point = opset6::Constant::create(convert_node->get_element_type(), {}, {0});

const auto &output_low = std::make_shared<opset6::Multiply>(
std::make_shared<opset6::Subtract>(input_low, zero_point), scale_node);
const auto &output_high = std::make_shared<opset6::Multiply>(
std::make_shared<opset6::Subtract>(input_high, zero_point), scale_node);

auto fq = std::make_shared<opset6::FakeQuantize>(
convert_node, input_low, input_high, output_low, output_high, levels);

NodeVector nodes_to_copy_RT_info_from{multiply_node, scale_node, zero_point};
if (pattern_map.at(sub))
nodes_to_copy_RT_info_from.push_back(sub);

ngraph::copy_runtime_info(fq, nodes_to_copy_RT_info_from);
multiply_node->output(0).replace(fq->output(0));

if (convert_node->get_rt_info().count("DISABLED_CONSTANT_FOLDING"))
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
convert_node->get_rt_info().erase("DISABLED_CONSTANT_FOLDING");
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "WeightsDequantizeToFakeQuantize");
jane-intel marked this conversation as resolved.
Show resolved Hide resolved
register_matcher(m, callback);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to add some tests.

}