Skip to content

Commit

Permalink
renamed transformation & refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ndemashov committed Mar 31, 2022
1 parent 6153d2b commit c8ef697
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 93 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -12,10 +12,10 @@ namespace ngraph {
namespace pass {
namespace low_precision {

class LP_TRANSFORMATIONS_API LSTMTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransformation {
public:
OPENVINO_RTTI("LSTMTransformation", "0");
LSTMTransformation(const Params& params = Params());
OPENVINO_RTTI("RecurrentCellTransformation", "0");
RecurrentCellTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@
#include "low_precision/rt_info/shared_value_attribute.hpp"

namespace ngraph {
/**
* @ingroup ie_transformation_common_api
* @brief PrecisionsAttribute defines precision which is required for input/output port or an operation.
*/
class LP_TRANSFORMATIONS_API SkipCleanupAttribute : public SharedAttribute<bool> {
class LP_TRANSFORMATIONS_API SkipCleanupAttribute : public ov::RuntimeAttribute {
bool skip;

public:
OPENVINO_RTTI("LowPrecision::SkipCleanup", "", ov::RuntimeAttribute, 0);
SkipCleanupAttribute(const bool skip);

static ov::Any create(const std::shared_ptr<ngraph::Node>& node, const bool skip);
// vizualize shared attributes details in VizualizeTree pass
std::string to_string() const override;
const bool value() const;
};
} // namespace ngraph
12 changes: 5 additions & 7 deletions src/common/low_precision_transformations/src/fuse_convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ngraph
}

bool FuseConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
auto skip = getAttribute<SkipCleanupAttribute>(op);
if (!skip.empty() && skip.as<SkipCleanupAttribute>().value()) {
return false;
}

const auto convert = ov::as_type_ptr<opset1::Convert>(op->get_input_node_shared_ptr(0));
// issue #40395
if (convert == nullptr) {
Expand All @@ -125,13 +130,6 @@ bool FuseConvertTransformation::canBeTransformed(const TransformationContext& co
return false;
}

auto skip = getAttribute<SkipCleanupAttribute>(op);
if (!skip.empty()) {
if (skip.as<SkipCleanupAttribute>().value()) {
return false;
}
}

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ bool FuseMultiplyToFakeQuantizeTransformation::canBeTransformed(const Transforma
return false;
}

auto skip = getAttribute<SkipCleanupAttribute>(operation);
if (!skip.empty() && skip.as<SkipCleanupAttribute>().value()) {
return false;
}

const auto parent = operation->get_input_node_shared_ptr(0);
auto fq = ov::as_type_ptr<opset1::FakeQuantize>(parent);
const auto convert = ov::as_type_ptr<opset1::Convert>(parent);
Expand All @@ -115,13 +120,6 @@ bool FuseMultiplyToFakeQuantizeTransformation::canBeTransformed(const Transforma
return false;
}

auto skip = getAttribute<SkipCleanupAttribute>(operation);
if (!skip.empty()) {
if (skip.as<SkipCleanupAttribute>().value()) {
return false;
}
}

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ bool FuseSubtractToFakeQuantizeTransformation::canBeTransformed(const Transforma
return false;
}

auto skip = getAttribute<SkipCleanupAttribute>(operation);
if (!skip.empty() && skip.as<SkipCleanupAttribute>().value()) {
return false;
}

const auto children = operation->get_output_target_inputs(0);

for (const auto& target : children) {
Expand All @@ -119,12 +124,6 @@ bool FuseSubtractToFakeQuantizeTransformation::canBeTransformed(const Transforma
if (fq->get_output_target_inputs(0).size() != 1) {
return false;
}
auto skip = getAttribute<SkipCleanupAttribute>(operation);
if (!skip.empty()) {
if (skip.as<SkipCleanupAttribute>().value()) {
return false;
}
}

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include "low_precision/normalize_l2.hpp"
#include "low_precision/pad.hpp"
#include "low_precision/prelu.hpp"
#include "low_precision/recurrent_cell.hpp"
#include "low_precision/reduce_max.hpp"
#include "low_precision/reduce_mean.hpp"
#include "low_precision/reduce_min.hpp"
Expand All @@ -69,7 +70,6 @@
#include "low_precision/unsqueeze.hpp"
#include "low_precision/variadic_split.hpp"
#include "low_precision/move_fake_quantize.hpp"
#include "low_precision/lstm.hpp"

// cleanup transformations
#include "low_precision/convert.hpp"
Expand Down Expand Up @@ -221,14 +221,14 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p
common->add_matcher<ngraph::pass::low_precision::FakeQuantizeTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::InterpolateTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::GroupConvolutionTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::LSTMTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::MatMulTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::MaxPoolTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::MultiplyTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::MVNTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::NormalizeL2Transformation>(params);
common->add_matcher<ngraph::pass::low_precision::PadTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::PReluTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::RecurrentCellTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::ReduceMaxTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::ReduceMeanTransformation>(params);
common->add_matcher<ngraph::pass::low_precision::ReduceMinTransformation>(params);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Copyright (C) 2018-2022 Intel Corporation
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "low_precision/lstm.hpp"
#include "low_precision/recurrent_cell.hpp"

#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset1.hpp>
Expand All @@ -13,15 +13,14 @@
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/op/or.hpp>

#include "low_precision/concat.hpp"
#include "low_precision/network_helper.hpp"
#include "../include/low_precision/rt_info/skip_cleanup_attribute.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

LSTMTransformation::LSTMTransformation(const Params& params) : LayerTransformation(params) {
RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) : LayerTransformation(params) {
const auto X = ngraph::pattern::wrap_type<ngraph::opset5::Parameter>();
const auto H = ngraph::pattern::wrap_type<ngraph::opset5::Parameter>();
const auto C = ngraph::pattern::wrap_type<ngraph::opset5::Parameter>();
Expand Down Expand Up @@ -125,7 +124,7 @@ LSTMTransformation::LSTMTransformation(const Params& params) : LayerTransformati
this->register_matcher(m, callback);
}

bool LSTMTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
bool RecurrentCellTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
const auto lstm = m.get_match_root();
if (!canBeTransformed(context, lstm)) {
return false;
Expand Down Expand Up @@ -178,15 +177,15 @@ bool LSTMTransformation::transform(TransformationContext& context, ngraph::patte
return true;
}

bool LSTMTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
bool RecurrentCellTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
return true;
}

bool LSTMTransformation::isPrecisionPreserved(std::shared_ptr<Node>) const noexcept {
bool RecurrentCellTransformation::isPrecisionPreserved(std::shared_ptr<Node>) const noexcept {
return true;
}

void LSTMTransformation::propagateSkipCleanupAttribute(std::shared_ptr<Node> multiply) {
void RecurrentCellTransformation::propagateSkipCleanupAttribute(std::shared_ptr<Node> multiply) {
SkipCleanupAttribute::create(multiply, true);
auto multiply_parent = multiply->get_input_node_shared_ptr(0);
SkipCleanupAttribute::create(multiply_parent, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
using namespace ngraph;
using namespace ov;

SkipCleanupAttribute::SkipCleanupAttribute(const bool skip)
:
SharedAttribute(skip) {
}
SkipCleanupAttribute::SkipCleanupAttribute(const bool skip) : skip(skip) {}

ov::Any SkipCleanupAttribute::create(
const std::shared_ptr<ngraph::Node>& node,
Expand All @@ -29,10 +26,14 @@ ov::Any SkipCleanupAttribute::create(
return (rt[SkipCleanupAttribute::get_type_info_static()] = SkipCleanupAttribute(skip));
}

const bool SkipCleanupAttribute::value() const {
return skip;
}

std::string SkipCleanupAttribute::to_string() const {
std::stringstream ss;
ss << "SkipCleanup: {";
attribute ? ss << "True" : ss << "False";
skip ? ss << "True" : ss << "False";
ss << "}";
return ss.str();
}
Loading

0 comments on commit c8ef697

Please sign in to comment.