diff --git a/src/common/low_precision_transformations/include/low_precision/lstm.hpp b/src/common/low_precision_transformations/include/low_precision/lstm.hpp new file mode 100644 index 00000000000000..26ab88f959e630 --- /dev/null +++ b/src/common/low_precision_transformations/include/low_precision/lstm.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include "low_precision/layer_transformation.hpp" + +namespace ngraph { +namespace pass { +namespace low_precision { + +class LP_TRANSFORMATIONS_API LSTM : public LayerTransformation { +public: + NGRAPH_RTTI_DECLARATION; + LSTM(const Params& params = Params()); + bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override; + bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; + bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; +}; + +} // namespace low_precision +} // namespace pass +} // namespace ngraph diff --git a/src/common/low_precision_transformations/src/lstm.cpp b/src/common/low_precision_transformations/src/lstm.cpp new file mode 100644 index 00000000000000..fab16800d03c61 --- /dev/null +++ b/src/common/low_precision_transformations/src/lstm.cpp @@ -0,0 +1,113 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "low_precision/lstm.hpp" + +#include +#include + +#include +#include +#include +#include +#include + +#include "low_precision/concat.hpp" +#include "low_precision/network_helper.hpp" + +namespace ngraph { +namespace pass { +namespace low_precision { + +NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::LSTM, "LSTM", 0); + +LSTM::LSTM(const Params& params) : LayerTransformation(params) { + + const auto X = ngraph::pattern::wrap_type(); + const auto H = ngraph::pattern::wrap_type(); + const auto C = ngraph::pattern::wrap_type(); + const auto W = ngraph::pattern::wrap_type(); + const auto R = ngraph::pattern::wrap_type(); + const auto B = ngraph::pattern::wrap_type(); + const auto lstm_constant_1 = ngraph::pattern::wrap_type(); + const auto lstm_constant_2 = ngraph::pattern::wrap_type(); + const auto input_low = ngraph::pattern::wrap_type(); + const auto input_high = ngraph::pattern::wrap_type(); + const auto output_low = ngraph::pattern::wrap_type(); + const auto output_high = ngraph::pattern::wrap_type(); + const auto input_low2 = ngraph::pattern::wrap_type(); + const auto input_high2 = ngraph::pattern::wrap_type(); + const auto output_low2 = ngraph::pattern::wrap_type(); + const auto output_high2 = ngraph::pattern::wrap_type(); + const auto fq_X = ngraph::pattern::wrap_type({ X, + input_low, + input_high, + output_low, + output_high}); + const auto fq_H = ngraph::pattern::wrap_type({ H, + input_low, + input_high, + output_low, + output_high }); + const auto fq_W = ngraph::pattern::wrap_type({ W, + input_low2, + input_high2, + output_low2, + output_high2 }); + const auto fq_R = ngraph::pattern::wrap_type({ R, + input_low2, + input_high2, + output_low2, + output_high2 }); + const auto fq_B = ngraph::pattern::wrap_type({ B, + input_low, + input_high, + output_low, + output_high }); + const auto squeeze_pattern = ngraph::pattern::wrap_type(); + const auto fq_X_squeeze = ngraph::pattern::wrap_type({fq_X, squeeze_pattern}); + const auto lstm_cell = ngraph::pattern::wrap_type( + {C, fq_H, fq_W, fq_R, fq_X_squeeze}); + const auto lstm_cell_with_bias = ngraph::pattern::wrap_type( + {fq_X_squeeze, fq_H, C, fq_W, fq_R, fq_B}); + const auto lstm_sequence = ngraph::pattern::wrap_type( + {fq_X, fq_H, C, fq_W, fq_R, fq_B}); + + ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { + auto op = m.get_match_root(); + if (transformation_callback(op)) { + return false; + } + + return transform(*context, m); + }; + + auto m = std::make_shared( + std::make_shared(OutputVector{fq_X_squeeze, lstm_cell, lstm_cell_with_bias, lstm_sequence}), + "LSTM"); + this->register_matcher(m, callback); +} + +bool LSTM::transform(TransformationContext& context, ngraph::pattern::Matcher& m) { + const auto lstm = m.get_match_root(); + if (!canBeTransformed(context, lstm)) { + return false; + } + + return true; +} + +bool LSTM::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { + return true; +} + +bool LSTM::isPrecisionPreserved(std::shared_ptr) const noexcept { + return true; +} + +} // namespace low_precision +} // namespace pass +} // namespace ngraph + + diff --git a/src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp index 0cd6b96d0750bf..cf6dedbf655c96 100644 --- a/src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp @@ -210,8 +210,40 @@ const std::vector> shapes = {{{1, 1, 16}, {1, const std::vector testValues = { // LSTM Cell {LayerTransformation::createParamsU8I8(), - true, false, + false, + LSTMFunction::LSTMType::Cell, + { + { + {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}} + }, + {{}, {}, {}}, + {{}, {}, {}} + }, + { + { + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {}, + {} + }, + { + {ngraph::element::u8}, + {}, + {} + }, + { + {{element::f32}, {0.f}, {0.01f}}, + {{element::f32}, {0.f}, {0.01f}}, + {{element::f32}, {0.f}, {0.01f}} + } + }, + true}, + // LSTM Cell with bias + {LayerTransformation::createParamsU8I8(), + false, + true, LSTMFunction::LSTMType::Cell, { { @@ -226,12 +258,12 @@ const std::vector testValues = { { {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, {}, - {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}} + {} }, { {ngraph::element::u8}, {}, - {ngraph::element::u8} + {} }, { {{element::f32}, {0.f}, {0.01f}}, @@ -257,7 +289,7 @@ const std::vector> shapes = {{{1, 2, 16}, {1, const std::vector testValues = { // LSTM Sequence {LayerTransformation::createParamsU8I8(), - true, + false, false, LSTMFunction::LSTMType::Sequence, { @@ -273,12 +305,12 @@ const std::vector testValues = { { {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, {}, - {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}} + {} }, { {ngraph::element::u8}, {}, - {ngraph::element::u8} + {} }, { {{element::f32}, {0.f}, {0.01f}}, diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp index 71a75146d8e394..b3003b456c0f74 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp @@ -42,7 +42,7 @@ std::shared_ptr LSTMFunction::get( fqOnDatas[0], converts[0], dequantizations[0]); - std::shared_ptr squeeze_X; + std::shared_ptr squeeze_X; if (type == LSTMType::Cell) { auto squeeze_pattern = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}); squeeze_X = std::make_shared(parent_X, squeeze_pattern);