Skip to content

Commit

Permalink
FakeQuantize decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
mandrono committed Dec 30, 2020
1 parent 0850479 commit d275a15
Show file tree
Hide file tree
Showing 11 changed files with 717 additions and 15 deletions.
55 changes: 52 additions & 3 deletions inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
#include <transformations/convert_precision.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
#include <transformations/op_conversions/fq_decomposition.hpp>
#include <transformations/utils/utils.hpp>

#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
Expand Down Expand Up @@ -226,13 +228,17 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork, const Config& conf)
transformer.transform(nGraphFunc);
}

bool keep_constant_inputs = ::ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(nGraphFunc);

ngraph::pass::Manager legacyManager;

legacyManager.register_pass<ngraph::pass::FakeQuantizeDecomposition>();
legacyManager.register_pass<ngraph::pass::ConvertOpSet1ToLegacy>();
legacyManager.register_pass<ngraph::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
// not legacy actually, but it should be the last transformation in the transformation pipeline
legacyManager.register_pass<ngraph::pass::UnrollTensorIterator>();

auto legacyPassConfig = legacyManager.get_pass_config();

legacyPassConfig->set_callback<ngraph::pass::AddMultiplyFusion>([](const_node_ptr &node) -> bool {
if (auto mul_op = std::dynamic_pointer_cast<const ngraph::opset1::Multiply>(node)) {
auto add_op = std::dynamic_pointer_cast<const ngraph::opset1::Add>(mul_op->get_input_node_shared_ptr(0));
Expand All @@ -247,15 +253,58 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork, const Config& conf)
return false;
});

legacyManager.get_pass_config()->set_callback<ngraph::pass::UnrollTensorIterator>([](const_node_ptr &node) -> bool {
legacyPassConfig->set_callback<ngraph::pass::UnrollTensorIterator>([](const_node_ptr &node) -> bool {
// UnrollTI transformation is disabled by default, is turned on by LowLatency transformation
return node->get_rt_info().count("UNROLL_TI") == 0;
});

auto initAxisIdx = [](const std::shared_ptr<ngraph::Node> node) -> int {
int axisIdx = 0, numberOfNonUnit = 0;

for (size_t i = 0; i < node->get_shape().size(); i++) {
if (node->get_shape()[i] > 1) {
axisIdx = i;
numberOfNonUnit++;
}
}
return numberOfNonUnit > 1 ? -1 : axisIdx;
};
auto isSupportedFQ = [initAxisIdx](const_node_ptr &node) {
std::set<int> quantizationParamsAxisesIdxs;
std::set<size_t> quantizationParamsAxisesSizes;
for (size_t i = 1; i < node->get_input_size(); i++) {
auto inNode = node->get_input_node_shared_ptr(i);
auto axis = initAxisIdx(inNode);
if (axis == -1)
return false;
if (inNode->get_shape().size() != 0 && inNode->get_shape()[axis] != 1) {
quantizationParamsAxisesIdxs.insert(axis);
quantizationParamsAxisesSizes.insert(inNode->get_shape()[axis]);
}
}
return (quantizationParamsAxisesIdxs.size() <= 1 && quantizationParamsAxisesSizes.size() <= 1);
};

legacyPassConfig->set_callback<ngraph::pass::FakeQuantizeDecomposition>([isSupportedFQ](const_node_ptr &node) -> bool {
if (auto fq_op = std::dynamic_pointer_cast<const ngraph::opset1::FakeQuantize>(node)) {
if (node->get_input_node_shared_ptr(0)->get_shape().size() > 5)
return false;
for (size_t i = 1; i < fq_op->get_input_size(); i++) {
if (!std::dynamic_pointer_cast<const ngraph::opset1::Constant>(fq_op->get_input_node_shared_ptr(i)) ||
node->get_input_node_shared_ptr(i)->get_shape().size() > 5)
return false;
}
return isSupportedFQ(fq_op);
}

return true;
});

legacyManager.run_passes(nGraphFunc);

OV_ITT_TASK_CHAIN(taskChain, MKLDNNPlugin::itt::domains::MKLDNN_LT, "Transformation", "convertFunctionToICNNNetwork");

clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork);
clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork, keep_constant_inputs);

OV_ITT_TASK_NEXT(taskChain, "ConvertIOPrecision");

Expand Down
32 changes: 32 additions & 0 deletions inference-engine/src/mkldnn_plugin/nodes/mkldnn_quantize_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,38 @@ void MKLDNNQuantizeNode::initSupportedPrimitiveDescriptors() {
}
}

void MKLDNNQuantizeNode::filterSupportedPrimitiveDescriptors() {
MKLDNNNode::filterSupportedPrimitiveDescriptors();
filterSupportedDescriptors();
}

void MKLDNNQuantizeNode::filterSupportedDescriptors() {
if (!inputMemoryFormatsFilter.empty() || !outputMemoryFormatsFilter.empty()) {
if (inputMemoryFormatsFilter.size() > 1 || outputMemoryFormatsFilter.size() > 1) {
THROW_IE_EXCEPTION << "Incorrect number of input or output memory formats for Quantize node";
}
auto itd = descs.begin();
while (itd != descs.end()) {
bool isSuitableDesc = true;
if (!inputMemoryFormatsFilter.empty()) {
auto src_fmt = std::shared_ptr<mkldnn::quantization_forward::desc>(*itd)->data.src_desc.format;
if (src_fmt != inputMemoryFormatsFilter[0])
isSuitableDesc = false;
}
if (!outputMemoryFormatsFilter.empty()) {
auto dst_fmt = std::shared_ptr<mkldnn::quantization_forward::desc>(*itd)->data.dst_desc.format;
if (dst_fmt != outputMemoryFormatsFilter[0])
isSuitableDesc = false;
}
if (!isSuitableDesc) {
itd = descs.erase(itd);
} else {
itd++;
}
}
}
}

void MKLDNNQuantizeNode::createPrimitive() {
if (prim)
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class MKLDNNQuantizeNode : public MKLDNNNode {
bool created() const override;
void execute(mkldnn::stream strm) override;

void filterSupportedPrimitiveDescriptors() override;
void filterSupportedDescriptors();

size_t getAxis() const { return axis; }

bool isBinarization() const { return quantizeAlgorithm == mkldnn::algorithm::binarization_depthwise; }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>

namespace ngraph {
namespace pass {

class TRANSFORMATIONS_API FakeQuantizeDecomposition;

} // namespace pass
} // namespace ngraph

/**
* @ingroup ie_transformation_common_api
* @brief FakeQuantizeDecomposition transformation into sub-graph
*/
class ngraph::pass::FakeQuantizeDecomposition: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
FakeQuantizeDecomposition();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/fq_decomposition.hpp"

#include <ngraph/opsets/opset5.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/builder/autobroadcast.hpp>

NGRAPH_RTTI_DEFINITION(ngraph::pass::FakeQuantizeDecomposition, "FakeQuantizeDecomposition", 0);

ngraph::pass::FakeQuantizeDecomposition::FakeQuantizeDecomposition() {
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset5::FakeQuantize>();

ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto fake_quantize_node = std::dynamic_pointer_cast<ngraph::opset5::FakeQuantize>(pattern_to_output.at(fake_quantize).get_node_shared_ptr());

if (fake_quantize_node == nullptr || m_transformation_callback(fake_quantize_node)) {
return false;
}

Output<Node> data{fake_quantize_node->input_value(0)};
Output<Node> input_low{fake_quantize_node->input_value(1)};
Output<Node> input_high{fake_quantize_node->input_value(2)};
Output<Node> output_low{fake_quantize_node->input_value(3)};
Output<Node> output_high{fake_quantize_node->input_value(4)};
auto input_type = data.get_element_type();

ngraph::NodeVector decomp_ops;
if (input_type != input_low.get_element_type()) {
input_type = input_low.get_element_type();
data = std::make_shared<ngraph::opset5::Convert>(data, input_type);
decomp_ops.push_back(data.get_node_shared_ptr());
}

auto max = std::make_shared<ngraph::opset5::Maximum>(data, input_low);
auto min = std::make_shared<ngraph::opset5::Minimum>(max, input_high);
decomp_ops.push_back(max);
decomp_ops.push_back(min);

auto levels_minus_one = std::make_shared<ngraph::opset5::Constant>(input_type, Shape{}, fake_quantize_node->get_levels() - 1);
decomp_ops.push_back(levels_minus_one);
// input scale and shift
auto subInHighLow = std::make_shared<ngraph::opset5::Subtract>(input_high, input_low);
auto isc = std::make_shared<ngraph::opset5::Divide>(levels_minus_one, subInHighLow);
auto ish = std::make_shared<ngraph::opset5::Multiply>(input_low, isc);
decomp_ops.push_back(subInHighLow);
decomp_ops.push_back(isc);
decomp_ops.push_back(ish);

auto after_isc_apply = std::make_shared<ngraph::opset5::Multiply>(min, isc);
auto after_ish_apply = std::make_shared<ngraph::opset5::Subtract>(after_isc_apply, ish);
decomp_ops.push_back(after_isc_apply);
decomp_ops.push_back(after_ish_apply);

auto round = std::make_shared<ngraph::opset5::Round>(after_ish_apply, ngraph::opset5::Round::RoundMode::HALF_TO_EVEN);
decomp_ops.push_back(round);

// output scale and shift
auto subOutHighLow = std::make_shared<ngraph::opset5::Subtract>(output_high, output_low);
auto osc = std::make_shared<ngraph::opset5::Divide>(subOutHighLow, levels_minus_one);
decomp_ops.push_back(subOutHighLow);
decomp_ops.push_back(osc);

auto after_osc_apply = std::make_shared<ngraph::opset5::Multiply>(round, osc);
std::shared_ptr<Node> result = std::make_shared<ngraph::opset5::Add>(after_osc_apply, output_low);
decomp_ops.push_back(after_osc_apply);
decomp_ops.push_back(result);

if (result->get_output_element_type(0) != fake_quantize_node->get_output_element_type(0)) {
result = std::make_shared<ngraph::opset5::Convert>(result, fake_quantize_node->get_output_element_type(0));
decomp_ops.push_back(result);
}

result->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info(fake_quantize_node, decomp_ops);
ngraph::replace_node(m.get_match_root(), result);
return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(fake_quantize, "FakeQuantizeDecomposition");
register_matcher(m, callback);
}
Loading

0 comments on commit d275a15

Please sign in to comment.