Skip to content

Commit

Permalink
[LPT] MoveFakeQuantizeTransformation: fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Aug 4, 2021
1 parent db573b4 commit 7c66559
Show file tree
Hide file tree
Showing 9 changed files with 525 additions and 546 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace low_precision {

class LP_TRANSFORMATIONS_API MoveFakeQuantize : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
MoveFakeQuantize(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

<<<<<<< HEAD
#include "low_precision/move_fake_quantize.hpp"
=======
#include "low_precision/move_fake_quatize.hpp"
>>>>>>> d3f2779243103cbd159d46e486d9f7ea0ffaec27

#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset1.hpp>
Expand All @@ -18,72 +14,81 @@
#include "low_precision/network_helper.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

MoveFakeQuantize::MoveFakeQuantize(const Params& params) : LayerTransformation(params) {
auto matcher = ngraph::pattern::wrap_type<opset1::FakeQuantize>();

ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}

return transform(*context, m);
};

auto m = std::make_shared<ngraph::pattern::Matcher>(matcher, "MoveFakeQuantize");
this->register_matcher(m, callback);
}

bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
auto fq = m.get_match_root();
auto result = *fq->output(0).get_target_inputs().begin();
auto operation = fq->get_input_node_shared_ptr(0);
auto type = operation->get_type_name();
std::shared_ptr<ngraph::Node> concat, fq1input, fq2input;
if (strcmp(type, "Concat") == 0) {
concat = operation;
fq1input = operation->get_input_node_shared_ptr(0);
fq2input = operation->get_input_node_shared_ptr(1);
}
else {
concat = operation->get_input_node_shared_ptr(0);
auto input1 = concat->get_input_node_shared_ptr(0);
auto input2 = concat->get_input_node_shared_ptr(1);
if (strcmp(type, "Relu") == 0) {
fq1input = std::make_shared<ngraph::opset1::Relu>(input1->output(0));
fq2input = std::make_shared<ngraph::opset1::Relu>(input2->output(0));
}
}
auto fq1 = std::make_shared<opset1::FakeQuantize>(fq1input,
fq->get_input_node_shared_ptr(1),
fq->get_input_node_shared_ptr(2),
fq->get_input_node_shared_ptr(3),
fq->get_input_node_shared_ptr(4),
as_type_ptr<opset1::FakeQuantize>(fq)->get_levels());
auto fq2 = std::make_shared<opset1::FakeQuantize>(fq2input,
fq->get_input_node_shared_ptr(1),
fq->get_input_node_shared_ptr(2),
fq->get_input_node_shared_ptr(3),
fq->get_input_node_shared_ptr(4),
as_type_ptr<opset1::FakeQuantize>(fq)->get_levels());

auto new_concat = concat->clone_with_new_inputs({ fq1->output(0), fq2->output(0) });
auto& rtInfo = new_concat->get_rt_info();
new_concat->set_friendly_name("output");
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");

replace_node(concat, new_concat);
replace_node(fq, new_concat);
return true;
}

bool MoveFakeQuantize::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return true;
}

} // namespace low_precision
} // namespace pass
namespace pass {
namespace low_precision {

NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::MoveFakeQuantize, "MoveFakeQuantize", 0);

MoveFakeQuantize::MoveFakeQuantize(const Params& params) : LayerTransformation(params) {
auto matcher = ngraph::pattern::wrap_type<opset1::FakeQuantize>();

ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}

return transform(*context, m);
};

auto m = std::make_shared<ngraph::pattern::Matcher>(matcher, "MoveFakeQuantize");
this->register_matcher(m, callback);
}

bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
auto fq = m.get_match_root();
auto operation = fq->get_input_node_shared_ptr(0);

// TODO: temporary to enable other transformations <= update matcher instead this validation
if (!is_type<opset1::Relu>(operation) && !is_type<opset1::Concat>(operation)) {
return false;
}

auto type = operation->get_type_name();
std::shared_ptr<ngraph::Node> concat, fq1input, fq2input;
if (strcmp(type, "Concat") == 0) {
concat = operation;
fq1input = operation->get_input_node_shared_ptr(0);
fq2input = operation->get_input_node_shared_ptr(1);
} else {
concat = operation->get_input_node_shared_ptr(0);
auto input1 = concat->get_input_node_shared_ptr(0);
auto input2 = concat->get_input_node_shared_ptr(1);
if (strcmp(type, "Relu") == 0) {
fq1input = std::make_shared<ngraph::opset1::Relu>(input1->output(0));
fq2input = std::make_shared<ngraph::opset1::Relu>(input2->output(0));
}
}

auto fq1 = std::make_shared<opset1::FakeQuantize>(fq1input,
fq->get_input_node_shared_ptr(1),
fq->get_input_node_shared_ptr(2),
fq->get_input_node_shared_ptr(3),
fq->get_input_node_shared_ptr(4),
as_type_ptr<opset1::FakeQuantize>(fq)->get_levels());
auto fq2 = std::make_shared<opset1::FakeQuantize>(fq2input,
fq->get_input_node_shared_ptr(1),
fq->get_input_node_shared_ptr(2),
fq->get_input_node_shared_ptr(3),
fq->get_input_node_shared_ptr(4),
as_type_ptr<opset1::FakeQuantize>(fq)->get_levels());

auto new_concat = concat->clone_with_new_inputs({ fq1->output(0), fq2->output(0) });
//auto& rtInfo = new_concat->get_rt_info();
//new_concat->set_friendly_name("output");
//rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");

replace_node(concat, new_concat);
replace_node(fq, new_concat);

updateOutput(context, new_concat, fq);
return true;
}

bool MoveFakeQuantize::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return true;
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph
4 changes: 0 additions & 4 deletions inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,6 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {

manager.run_passes(nGraphFunc);

//ngraph::pass:
ngraph::pass::VisualizeTree("C:/Users/ndemasho/rep/Visual/cpu.common.svg").run_on_function(nGraphFunc);

using namespace ngraph::pass::low_precision;
if (useLpt) {
OV_ITT_SCOPE(FIRST_INFERENCE, MKLDNNPlugin::itt::domains::MKLDNN_LT, "LowPrecisionTransformations");
Expand Down Expand Up @@ -370,7 +367,6 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
return MultiplyToGroupConvolutionTransformation::isDynamicOrScalar(node);
});
lptManager.run_passes(nGraphFunc);
ngraph::pass::VisualizeTree("C:/Users/ndemasho/rep/Visual/cpu.commonAfter.svg").run_on_function(nGraphFunc);
}

ngraph::pass::Manager postLPTPassManager;
Expand Down
Loading

0 comments on commit 7c66559

Please sign in to comment.