Skip to content

Commit

Permalink
added skip_cleanup_attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
ndemashov committed Mar 14, 2022
1 parent 7cddf80 commit bfa8601
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <set>
#include <unordered_set>
#include <vector>

#include <ngraph/node.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/variant.hpp>

#include "low_precision/lpt_visibility.hpp"
#include "low_precision/rt_info/attribute_parameters.hpp"
#include "low_precision/rt_info/shared_value_attribute.hpp"

namespace ngraph {
/**
* @ingroup ie_transformation_common_api
* @brief PrecisionsAttribute defines precision which is required for input/output port or an operation.
*/
class LP_TRANSFORMATIONS_API SkipCleanupAttribute : public SharedAttribute<bool> {
public:
OPENVINO_RTTI("LowPrecision::SkipCleanup", "", ov::RuntimeAttribute, 0);
SkipCleanupAttribute(const bool skip);

static ov::Any create(const std::shared_ptr<ngraph::Node>& node, const bool skip);
// vizualize shared attributes details in VizualizeTree pass
std::string to_string() const override;
};
} // namespace ngraph
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "low_precision/rt_info/intervals_alignment_attribute.hpp"
#include "low_precision/fake_quantize.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"


namespace ngraph {
namespace pass {
Expand Down Expand Up @@ -112,6 +114,13 @@ bool FuseMultiplyToFakeQuantizeTransformation::canBeTransformed(const Transforma
return false;
}

auto skip = getAttribute<SkipCleanupAttribute>(fq);
if (!skip.empty()) {
if (skip.as<SkipCleanupAttribute>().value()) {
return false;
}
}

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ngraph/pattern/op/wrap_type.hpp>
#include "low_precision/fake_quantize.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"

namespace ngraph {
namespace pass {
Expand Down Expand Up @@ -116,6 +117,12 @@ bool FuseSubtractToFakeQuantizeTransformation::canBeTransformed(const Transforma
if (fq->get_output_target_inputs(0).size() != 1) {
return false;
}
auto skip = getAttribute<SkipCleanupAttribute>(fq);
if (!skip.empty()) {
if (skip.as<SkipCleanupAttribute>().value()) {
return false;
}
}

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,9 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p

std::shared_ptr<ngraph::pass::GraphRewrite> cleanup = manager.register_pass<ngraph::pass::GraphRewrite>();
cleanup->add_matcher<ngraph::pass::low_precision::FoldConvertTransformation>(params);
/* cleanup->add_matcher<ngraph::pass::low_precision::FuseConvertTransformation>(params);
cleanup->add_matcher<ngraph::pass::low_precision::FuseConvertTransformation>(params);
cleanup->add_matcher<ngraph::pass::low_precision::FuseSubtractToFakeQuantizeTransformation>(params);
cleanup->add_matcher<ngraph::pass::low_precision::FuseMultiplyToFakeQuantizeTransformation>(params);*/
cleanup->add_matcher<ngraph::pass::low_precision::FuseMultiplyToFakeQuantizeTransformation>(params);
// WA: precision restrictions for groupConv must be propagated to MultiplyToGroupConvolution transformation
cleanup->add_matcher<ngraph::pass::low_precision::MultiplyToGroupConvolutionTransformation>(
params,
Expand Down
17 changes: 11 additions & 6 deletions src/common/low_precision_transformations/src/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "low_precision/concat.hpp"
#include "low_precision/network_helper.hpp"
#include "../include/low_precision/rt_info/skip_cleanup_attribute.hpp"

namespace ngraph {
namespace pass {
Expand Down Expand Up @@ -80,6 +81,8 @@ LSTM::LSTM(const Params& params) : LayerTransformation(params) {
const auto dequantization_without_subtract_squeeze = ngraph::pattern::wrap_type<ngraph::opset5::Reshape, ngraph::opset5::Squeeze>(
{dequantization_multiply_without_subtract_X, squeeze_constant});
const auto lstm_cell = ngraph::pattern::wrap_type<ngraph::opset5::LSTMCell>(
{fq_X, fq_H, C, fq_W, fq_R, B});
const auto lstm_cell_squeeze = ngraph::pattern::wrap_type<ngraph::opset5::LSTMCell>(
{squeeze, fq_H, C, fq_W, fq_R, B});
const auto lstm_cell_with_dequantizations = ngraph::pattern::wrap_type<ngraph::opset5::LSTMCell>(
{dequantization_squeeze, dequantization_multiply_H, C, fq_W, fq_R, B});
Expand All @@ -96,8 +99,9 @@ LSTM::LSTM(const Params& params) : LayerTransformation(params) {
};

auto m = std::make_shared<ngraph::pattern::Matcher>(
std::make_shared<pattern::op::Or>(
OutputVector{lstm_cell, lstm_cell_with_dequantizations, lstm_cell_with_dequantizations_without_subtract}),
std::make_shared<pattern::op::Or>(OutputVector{lstm_cell,
lstm_cell_squeeze, lstm_cell_with_dequantizations,
lstm_cell_with_dequantizations_without_subtract}),
"LSTM");
this->register_matcher(m, callback);
}
Expand All @@ -114,6 +118,7 @@ bool LSTM::transform(TransformationContext& context, ngraph::pattern::Matcher& m
for (size_t parentIndex = 0ul; parentIndex < lstm->get_input_size(); parentIndex++) {
auto fq = lstm->get_input_node_shared_ptr(parentIndex);
if (is_type<ngraph::opset1::FakeQuantize>(fq)) {
SkipCleanupAttribute::create(fq, true);
auto fq_parent = fq->get_input_node_shared_ptr(0);
if (is_type<ngraph::opset5::Constant>(fq_parent)) {
auto fq_node = as_type_ptr<ngraph::opset1::FakeQuantize>(fq);
Expand All @@ -130,12 +135,12 @@ bool LSTM::transform(TransformationContext& context, ngraph::pattern::Matcher& m
dataPrecision.hasZeroPoint,
updatePrecisions);
std::shared_ptr<ngraph::Node> new_fq = std::get<0>(QDQ);
std::shared_ptr<ngraph::Node> dequantize = std::get<1>(QDQ);
this->register_new_node(new_fq);
if (dequantize == nullptr || new_fq == nullptr) {
std::shared_ptr<ngraph::Node> deq_multiply = std::get<1>(QDQ);
if (deq_multiply == nullptr || new_fq == nullptr) {
return false;
}
updateOutput(context, dequantize, new_fq);
this->register_new_node(new_fq);
updateOutput(context, deq_multiply, new_fq);
} else {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "low_precision/rt_info/skip_cleanup_attribute.hpp"

#include <memory>
#include <string>
#include <unordered_map>
#include <iterator>
#include <vector>

#include <ngraph/opsets/opset1.hpp>
#include "low_precision/network_helper.hpp"
#include "low_precision/layer_transformation.hpp"

using namespace ngraph;
using namespace ov;

SkipCleanupAttribute::SkipCleanupAttribute(const bool skip)
:
SharedAttribute(skip) {
}

ov::Any SkipCleanupAttribute::create(
const std::shared_ptr<ngraph::Node>& node,
const bool skip) {
auto& rt = node->get_rt_info();
return (rt[SkipCleanupAttribute::get_type_info_static()] = SkipCleanupAttribute(skip));
}

std::string SkipCleanupAttribute::to_string() const {
std::stringstream ss;
ss << "SkipCleanup: {";
attribute ? ss << "True" : ss << "False";
ss << "}";
return ss.str();
}
Loading

0 comments on commit bfa8601

Please sign in to comment.