Skip to content

Commit

Permalink
added transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
ndemashov committed Feb 28, 2022
1 parent c600c89 commit 954485e
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <ngraph/ngraph.hpp>
#include "low_precision/layer_transformation.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

class LP_TRANSFORMATIONS_API LSTM : public LayerTransformation {
public:
NGRAPH_RTTI_DECLARATION;
LSTM(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

} // namespace low_precision
} // namespace pass
} // namespace ngraph
113 changes: 113 additions & 0 deletions src/common/low_precision_transformations/src/lstm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "low_precision/lstm.hpp"

#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset1.hpp>

#include <memory>
#include <ngraph/node.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/op/or.hpp>

#include "low_precision/concat.hpp"
#include "low_precision/network_helper.hpp"

namespace ngraph {
namespace pass {
namespace low_precision {

NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::LSTM, "LSTM", 0);

LSTM::LSTM(const Params& params) : LayerTransformation(params) {

const auto X = ngraph::pattern::wrap_type<ngraph::opset1::Parameter>();
const auto H = ngraph::pattern::wrap_type<ngraph::opset1::Parameter>();
const auto C = ngraph::pattern::wrap_type<ngraph::opset1::Parameter>();
const auto W = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto R = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto B = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto lstm_constant_1 = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto lstm_constant_2 = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto input_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto input_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto input_low2 = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto input_high2 = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_low2 = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto output_high2 = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
const auto fq_X = ngraph::pattern::wrap_type<opset1::FakeQuantize>({ X,
input_low,
input_high,
output_low,
output_high});
const auto fq_H = ngraph::pattern::wrap_type<opset1::FakeQuantize>({ H,
input_low,
input_high,
output_low,
output_high });
const auto fq_W = ngraph::pattern::wrap_type<opset1::FakeQuantize>({ W,
input_low2,
input_high2,
output_low2,
output_high2 });
const auto fq_R = ngraph::pattern::wrap_type<opset1::FakeQuantize>({ R,
input_low2,
input_high2,
output_low2,
output_high2 });
const auto fq_B = ngraph::pattern::wrap_type<opset1::FakeQuantize>({ B,
input_low,
input_high,
output_low,
output_high });
const auto squeeze_pattern = ngraph::pattern::wrap_type<ngraph::opset5::Constant>();
const auto fq_X_squeeze = ngraph::pattern::wrap_type<ngraph::opset5::Squeeze>({fq_X, squeeze_pattern});
const auto lstm_cell = ngraph::pattern::wrap_type<opset5::LSTMCell>(
{C, fq_H, fq_W, fq_R, fq_X_squeeze});
const auto lstm_cell_with_bias = ngraph::pattern::wrap_type<opset5::LSTMCell>(
{fq_X_squeeze, fq_H, C, fq_W, fq_R, fq_B});
const auto lstm_sequence = ngraph::pattern::wrap_type<opset5::LSTMSequence>(
{fq_X, fq_H, C, fq_W, fq_R, fq_B});

ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}

return transform(*context, m);
};

auto m = std::make_shared<ngraph::pattern::Matcher>(
std::make_shared<pattern::op::Or>(OutputVector{fq_X_squeeze, lstm_cell, lstm_cell_with_bias, lstm_sequence}),
"LSTM");
this->register_matcher(m, callback);
}

bool LSTM::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
const auto lstm = m.get_match_root();
if (!canBeTransformed(context, lstm)) {
return false;
}

return true;
}

bool LSTM::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
return true;
}

bool LSTM::isPrecisionPreserved(std::shared_ptr<Node>) const noexcept {
return true;
}

} // namespace low_precision
} // namespace pass
} // namespace ngraph


Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,40 @@ const std::vector<std::vector<ngraph::PartialShape>> shapes = {{{1, 1, 16}, {1,
const std::vector<LSTMTransformationTestValues> testValues = {
// LSTM Cell
{LayerTransformation::createParamsU8I8(),
true,
false,
false,
LSTMFunction::LSTMType::Cell,
{
{
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
{255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}},
{256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}
},
{{}, {}, {}},
{{}, {}, {}}
},
{
{
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{},
{}
},
{
{ngraph::element::u8},
{},
{}
},
{
{{element::f32}, {0.f}, {0.01f}},
{{element::f32}, {0.f}, {0.01f}},
{{element::f32}, {0.f}, {0.01f}}
}
},
true},
// LSTM Cell with bias
{LayerTransformation::createParamsU8I8(),
false,
true,
LSTMFunction::LSTMType::Cell,
{
{
Expand All @@ -226,12 +258,12 @@ const std::vector<LSTMTransformationTestValues> testValues = {
{
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{},
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}
{}
},
{
{ngraph::element::u8},
{},
{ngraph::element::u8}
{}
},
{
{{element::f32}, {0.f}, {0.01f}},
Expand All @@ -257,7 +289,7 @@ const std::vector<std::vector<ngraph::PartialShape>> shapes = {{{1, 2, 16}, {1,
const std::vector<LSTMTransformationTestValues> testValues = {
// LSTM Sequence
{LayerTransformation::createParamsU8I8(),
true,
false,
false,
LSTMFunction::LSTMType::Sequence,
{
Expand All @@ -273,12 +305,12 @@ const std::vector<LSTMTransformationTestValues> testValues = {
{
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
{},
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}
{}
},
{
{ngraph::element::u8},
{},
{ngraph::element::u8}
{}
},
{
{{element::f32}, {0.f}, {0.01f}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ std::shared_ptr<ngraph::Function> LSTMFunction::get(
fqOnDatas[0],
converts[0],
dequantizations[0]);
std::shared_ptr<ov::op::v0::Squeeze> squeeze_X;
std::shared_ptr<opset5::Squeeze> squeeze_X;
if (type == LSTMType::Cell) {
auto squeeze_pattern = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
squeeze_X = std::make_shared<opset5::Squeeze>(parent_X, squeeze_pattern);
Expand Down

0 comments on commit 954485e

Please sign in to comment.