Skip to content

Commit

Permalink
FakeQuantize decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
mandrono committed Jan 15, 2021
1 parent a280a3a commit b914368
Show file tree
Hide file tree
Showing 11 changed files with 794 additions and 15 deletions.
13 changes: 10 additions & 3 deletions inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
#include <transformations/convert_precision.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
#include <transformations/op_conversions/fq_decomposition.hpp>
#include <transformations/utils/utils.hpp>

#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
Expand Down Expand Up @@ -226,13 +228,17 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork, const Config& conf)
transformer.transform(nGraphFunc);
}

bool keep_constant_inputs = ::ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(nGraphFunc);

ngraph::pass::Manager legacyManager;

legacyManager.register_pass<ngraph::pass::FakeQuantizeDecomposition>();
legacyManager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
legacyManager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
// not legacy actually, but it should be the last transformation in the transformation pipeline
legacyManager.register_pass<ngraph::pass::UnrollTensorIterator>();

auto legacyPassConfig = legacyManager.get_pass_config();

legacyPassConfig->set_callback<ngraph::pass::AddMultiplyFusion>([](const_node_ptr &node) -> bool {
if (auto mul_op = std::dynamic_pointer_cast<const ngraph::opset1::Multiply>(node)) {
auto add_op = std::dynamic_pointer_cast<const ngraph::opset1::Add>(mul_op->get_input_node_shared_ptr(0));
Expand All @@ -247,15 +253,16 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork, const Config& conf)
return false;
});

legacyManager.get_pass_config()->set_callback<ngraph::pass::UnrollTensorIterator>([](const_node_ptr &node) -> bool {
legacyPassConfig->set_callback<ngraph::pass::UnrollTensorIterator>([](const_node_ptr &node) -> bool {
// UnrollTI transformation is disabled by default, is turned on by LowLatency transformation
return node->get_rt_info().count("UNROLL_TI") == 0;
});

legacyManager.run_passes(nGraphFunc);

OV_ITT_TASK_CHAIN(taskChain, MKLDNNPlugin::itt::domains::MKLDNN_LT, "Transformation", "convertFunctionToICNNNetwork");

clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork, keep_constant_inputs);

OV_ITT_TASK_NEXT(taskChain, "ConvertIOPrecision");

Expand Down
32 changes: 32 additions & 0 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_quantize_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,38 @@ void MKLDNNQuantizeNode::initSupportedPrimitiveDescriptors() {
}
}

void MKLDNNQuantizeNode::filterSupportedPrimitiveDescriptors() {
MKLDNNNode::filterSupportedPrimitiveDescriptors();
filterSupportedDescriptors();
}

void MKLDNNQuantizeNode::filterSupportedDescriptors() {
if (!inputMemoryFormatsFilter.empty() || !outputMemoryFormatsFilter.empty()) {
if (inputMemoryFormatsFilter.size() > 1 || outputMemoryFormatsFilter.size() > 1) {
THROW_IE_EXCEPTION << "Incorrect number of input or output memory formats for Quantize node";
}
auto itd = descs.begin();
while (itd != descs.end()) {
bool isSuitableDesc = true;
if (!inputMemoryFormatsFilter.empty()) {
auto src_fmt = std::shared_ptr<mkldnn::quantization_forward::desc>(*itd)->data.src_desc.format;
if (src_fmt != inputMemoryFormatsFilter[0])
isSuitableDesc = false;
}
if (!outputMemoryFormatsFilter.empty()) {
auto dst_fmt = std::shared_ptr<mkldnn::quantization_forward::desc>(*itd)->data.dst_desc.format;
if (dst_fmt != outputMemoryFormatsFilter[0])
isSuitableDesc = false;
}
if (!isSuitableDesc) {
itd = descs.erase(itd);
} else {
itd++;
}
}
}
}

void MKLDNNQuantizeNode::createPrimitive() {
if (prim)
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class MKLDNNQuantizeNode : public MKLDNNNode {
bool created() const override;
void execute(mkldnn::stream strm) override;

void filterSupportedPrimitiveDescriptors() override;
void filterSupportedDescriptors();

size_t getAxis() const { return axis; }

bool isBinarization() const { return quantizeAlgorithm == mkldnn::algorithm::binarization_depthwise; }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API FakeQuantizeDecomposition;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief FakeQuantizeDecomposition transformation into sub-graph if:
* 1. 'data' node have rank > 5
* 2. 'ranges' node have rank > 5 or their type != Constant
* 3. 'ranges' node have more than one dimension != 1 or this dimension is not batch or chanel
*/
class ngraph::pass::FakeQuantizeDecomposition: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
FakeQuantizeDecomposition();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/fq_decomposition.hpp"

#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/builder/autobroadcast.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::FakeQuantizeDecomposition, "FakeQuantizeDecomposition", 0);

bool isNeedToDecomposFQ(std::shared_ptr<ngraph::opset1::FakeQuantize> fq) {
if (fq->get_input_node_shared_ptr(0)->get_shape().size() > 5)
return true;

for (size_t i = 1; i < fq->get_input_size(); i++) {
std::shared_ptr<ngraph::Node> node = fq->get_input_node_shared_ptr(i);
if (!std::dynamic_pointer_cast<const ngraph::opset1::Constant>(node) ||
node->get_shape().size() > 5)
return true;
}

for (size_t i = 1; i < fq->get_input_size(); i++) {
auto node = fq->get_input_node_shared_ptr(i);

size_t count_not_unit_axis = 0;
auto shape = node->get_shape();
if (ngraph::shape_size(shape) != 1) {
size_t not_unit_axis = 0;
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i] > 1) {
not_unit_axis = i;
count_not_unit_axis++;
}
}
if (count_not_unit_axis > 1 || not_unit_axis > 1)
return true;
}
}

return false;
}

/**
* Expression from specification:
* if x <= min(input_low, input_high):
* output = output_low
* elif x > max(input_low, input_high):
* output = output_high
* else:
* output = round((x - input_low) / (input_high - input_low) * (levels-1)) / (levels-1) * (output_high - output_low) + output_low
*
* expand brackets into round:
* round(x * (levels-1) / (input_high - input_low) - input_low * (levels-1) / (input_high - input_low))
* div on (levels-1) and mult on (output_high - output_low) => mult on (output_high - output_low) / (levels-1)
*
* =>
* round(x * (levels-1) / (input_high - input_low) - input_low * (levels-1) / (input_high - input_low)) * (output_high - output_low) / (levels-1) + output_low
*/
ngraph::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() {
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>();

ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto fake_quantize_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fake_quantize).get_node_shared_ptr());

if (fake_quantize_node == nullptr || !isNeedToDecomposFQ(fake_quantize_node)) {
return false;
}

Output<Node> data{fake_quantize_node->input_value(0)};
Output<Node> input_low{fake_quantize_node->input_value(1)};
Output<Node> input_high{fake_quantize_node->input_value(2)};
Output<Node> output_low{fake_quantize_node->input_value(3)};
Output<Node> output_high{fake_quantize_node->input_value(4)};
auto input_type = data.get_element_type();

ngraph::NodeVector decomp_ops;
if (input_type != input_low.get_element_type()) {
input_type = input_low.get_element_type();
data = std::make_shared<ngraph::opset1::Convert>(data, input_type);
decomp_ops.push_back(data.get_node_shared_ptr());
}

// if we set input_low or input_high in formula we got output = output_low and = output_high respectively
// so we just clamp x
auto max = std::make_shared<ngraph::opset1::Maximum>(data, input_low);
auto min = std::make_shared<ngraph::opset1::Minimum>(max, input_high);
decomp_ops.push_back(max);
decomp_ops.push_back(min);

// (levels-1)
auto levels_minus_one = std::make_shared<ngraph::opset1::Constant>(input_type, Shape{}, fake_quantize_node->get_levels() - 1);
decomp_ops.push_back(levels_minus_one);
// (input_high - input_low)
auto subInHighLow = std::make_shared<ngraph::opset1::Subtract>(input_high, input_low);
// (levels-1) / (input_high - input_low)
auto isc = std::make_shared<ngraph::opset1::Divide>(levels_minus_one, subInHighLow);
// input_low * (levels-1) / (input_high - input_low)
auto ish = std::make_shared<ngraph::opset1::Multiply>(input_low, isc);
decomp_ops.push_back(subInHighLow);
decomp_ops.push_back(isc);
decomp_ops.push_back(ish);

// x * (levels-1) / (input_high - input_low)
auto after_isc_apply = std::make_shared<ngraph::opset1::Multiply>(min, isc);
// x * (levels-1) / (input_high - input_low) - input_low * (levels-1) / (input_high - input_low)
auto after_ish_apply = std::make_shared<ngraph::opset1::Subtract>(after_isc_apply, ish);
decomp_ops.push_back(after_isc_apply);
decomp_ops.push_back(after_ish_apply);

// round(x * (levels-1) / (input_high - input_low) - input_low * (levels-1) / (input_high - input_low))
auto round = std::make_shared<ngraph::opset5::Round>(after_ish_apply, ngraph::opset5::Round::RoundMode::HALF_TO_EVEN);
decomp_ops.push_back(round);

// (output_high - output_low)
auto subOutHighLow = std::make_shared<ngraph::opset1::Subtract>(output_high, output_low);
// (output_high - output_low) / (levels-1)
auto osc = std::make_shared<ngraph::opset1::Divide>(subOutHighLow, levels_minus_one);
decomp_ops.push_back(subOutHighLow);
decomp_ops.push_back(osc);

// round(x * (levels-1) / (input_high - input_low) - input_low * (levels-1) / (input_high - input_low)) * (output_high - output_low) / (levels-1)
auto after_osc_apply = std::make_shared<ngraph::opset1::Multiply>(round, osc);
// round(x * (levels-1) / (input_high - input_low) - input_low * (levels-1) / (input_high - input_low)) * (output_high - output_low) / (levels-1) +
// output_low
std::shared_ptr<Node> result = std::make_shared<ngraph::opset1::Add>(after_osc_apply, output_low);
decomp_ops.push_back(after_osc_apply);
decomp_ops.push_back(result);

if (result->get_output_element_type(0) != fake_quantize_node->get_output_element_type(0)) {
result = std::make_shared<ngraph::opset1::Convert>(result, fake_quantize_node->get_output_element_type(0));
decomp_ops.push_back(result);
}

result->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(fake_quantize_node, decomp_ops);
ngraph::replace_node(m.get_match_root(), result);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(fake_quantize, "FakeQuantizeDecomposition");
register_matcher(m, callback);
}
Loading

0 comments on commit b914368

Please sign in to comment.