From c8ef6977795ac3d8256406e93fe39069e3559e05 Mon Sep 17 00:00:00 2001 From: Nikita Demashov Date: Wed, 30 Mar 2022 21:29:16 +0300 Subject: [PATCH] renamed transformation & refactoring --- .../{lstm.hpp => recurrent_cell.hpp} | 8 +- .../rt_info/skip_cleanup_attribute.hpp | 9 +- .../src/fuse_convert.cpp | 12 +- .../src/fuse_multiply_to_fake_quantize.cpp | 12 +- .../src/fuse_subtract_to_fake_quantize.cpp | 11 +- .../src/low_precision.cpp | 4 +- .../src/{lstm.cpp => recurrent_cell.cpp} | 15 +- .../src/rt_info/skip_cleanup_attribute.cpp | 11 +- ....cpp => recurrent_cell_transformation.cpp} | 162 +++++++++++++----- ...nction.hpp => recurrent_cell_function.hpp} | 4 +- ...nction.cpp => recurrent_cell_function.cpp} | 7 +- 11 files changed, 162 insertions(+), 93 deletions(-) rename src/common/low_precision_transformations/include/low_precision/{lstm.hpp => recurrent_cell.hpp} (72%) rename src/common/low_precision_transformations/src/{lstm.cpp => recurrent_cell.cpp} (94%) rename src/tests/functional/inference_engine/lp_transformations/{lstm_transformation.cpp => recurrent_cell_transformation.cpp} (79%) rename src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/{lstm_function.hpp => recurrent_cell_function.hpp} (96%) rename src/tests/ngraph_helpers/lpt_ngraph_functions/src/{lstm_function.cpp => recurrent_cell_function.cpp} (97%) diff --git a/src/common/low_precision_transformations/include/low_precision/lstm.hpp b/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp similarity index 72% rename from src/common/low_precision_transformations/include/low_precision/lstm.hpp rename to src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp index da03976151a530..521e5c3f19a59e 100644 --- a/src/common/low_precision_transformations/include/low_precision/lstm.hpp +++ b/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2022 Intel Corporation +// Copyright (C) 2022 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -12,10 +12,10 @@ namespace ngraph { namespace pass { namespace low_precision { -class LP_TRANSFORMATIONS_API LSTMTransformation : public LayerTransformation { +class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransformation { public: - OPENVINO_RTTI("LSTMTransformation", "0"); - LSTMTransformation(const Params& params = Params()); + OPENVINO_RTTI("RecurrentCellTransformation", "0"); + RecurrentCellTransformation(const Params& params = Params()); bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override; bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; diff --git a/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp b/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp index 0e98ea3eeec8f9..a38d26ae1c5dda 100644 --- a/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp +++ b/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp @@ -18,11 +18,9 @@ #include "low_precision/rt_info/shared_value_attribute.hpp" namespace ngraph { -/** - * @ingroup ie_transformation_common_api - * @brief PrecisionsAttribute defines precision which is required for input/output port or an operation. - */ -class LP_TRANSFORMATIONS_API SkipCleanupAttribute : public SharedAttribute { +class LP_TRANSFORMATIONS_API SkipCleanupAttribute : public ov::RuntimeAttribute { + bool skip; + public: OPENVINO_RTTI("LowPrecision::SkipCleanup", "", ov::RuntimeAttribute, 0); SkipCleanupAttribute(const bool skip); @@ -30,5 +28,6 @@ class LP_TRANSFORMATIONS_API SkipCleanupAttribute : public SharedAttribute static ov::Any create(const std::shared_ptr& node, const bool skip); // vizualize shared attributes details in VizualizeTree pass std::string to_string() const override; + const bool value() const; }; } // namespace ngraph diff --git a/src/common/low_precision_transformations/src/fuse_convert.cpp b/src/common/low_precision_transformations/src/fuse_convert.cpp index fb06c84ed100f4..5ba25de28964f8 100644 --- a/src/common/low_precision_transformations/src/fuse_convert.cpp +++ b/src/common/low_precision_transformations/src/fuse_convert.cpp @@ -114,6 +114,11 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ngraph } bool FuseConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr op) const { + auto skip = getAttribute(op); + if (!skip.empty() && skip.as().value()) { + return false; + } + const auto convert = ov::as_type_ptr(op->get_input_node_shared_ptr(0)); // issue #40395 if (convert == nullptr) { @@ -125,13 +130,6 @@ bool FuseConvertTransformation::canBeTransformed(const TransformationContext& co return false; } - auto skip = getAttribute(op); - if (!skip.empty()) { - if (skip.as().value()) { - return false; - } - } - return true; } diff --git a/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp b/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp index 0717b817c9621e..81aadf6019770f 100644 --- a/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp +++ b/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp @@ -99,6 +99,11 @@ bool FuseMultiplyToFakeQuantizeTransformation::canBeTransformed(const Transforma return false; } + auto skip = getAttribute(operation); + if (!skip.empty() && skip.as().value()) { + return false; + } + const auto parent = operation->get_input_node_shared_ptr(0); auto fq = ov::as_type_ptr(parent); const auto convert = ov::as_type_ptr(parent); @@ -115,13 +120,6 @@ bool FuseMultiplyToFakeQuantizeTransformation::canBeTransformed(const Transforma return false; } - auto skip = getAttribute(operation); - if (!skip.empty()) { - if (skip.as().value()) { - return false; - } - } - return true; } diff --git a/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp b/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp index a9fd78d4d53da3..9331411e4acd5e 100644 --- a/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp +++ b/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp @@ -93,6 +93,11 @@ bool FuseSubtractToFakeQuantizeTransformation::canBeTransformed(const Transforma return false; } + auto skip = getAttribute(operation); + if (!skip.empty() && skip.as().value()) { + return false; + } + const auto children = operation->get_output_target_inputs(0); for (const auto& target : children) { @@ -119,12 +124,6 @@ bool FuseSubtractToFakeQuantizeTransformation::canBeTransformed(const Transforma if (fq->get_output_target_inputs(0).size() != 1) { return false; } - auto skip = getAttribute(operation); - if (!skip.empty()) { - if (skip.as().value()) { - return false; - } - } return true; } diff --git a/src/common/low_precision_transformations/src/low_precision.cpp b/src/common/low_precision_transformations/src/low_precision.cpp index fc915a9054416c..bda7659583daef 100644 --- a/src/common/low_precision_transformations/src/low_precision.cpp +++ b/src/common/low_precision_transformations/src/low_precision.cpp @@ -54,6 +54,7 @@ #include "low_precision/normalize_l2.hpp" #include "low_precision/pad.hpp" #include "low_precision/prelu.hpp" +#include "low_precision/recurrent_cell.hpp" #include "low_precision/reduce_max.hpp" #include "low_precision/reduce_mean.hpp" #include "low_precision/reduce_min.hpp" @@ -69,7 +70,6 @@ #include "low_precision/unsqueeze.hpp" #include "low_precision/variadic_split.hpp" #include "low_precision/move_fake_quantize.hpp" -#include "low_precision/lstm.hpp" // cleanup transformations #include "low_precision/convert.hpp" @@ -221,7 +221,6 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p common->add_matcher(params); common->add_matcher(params); common->add_matcher(params); - common->add_matcher(params); common->add_matcher(params); common->add_matcher(params); common->add_matcher(params); @@ -229,6 +228,7 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p common->add_matcher(params); common->add_matcher(params); common->add_matcher(params); + common->add_matcher(params); common->add_matcher(params); common->add_matcher(params); common->add_matcher(params); diff --git a/src/common/low_precision_transformations/src/lstm.cpp b/src/common/low_precision_transformations/src/recurrent_cell.cpp similarity index 94% rename from src/common/low_precision_transformations/src/lstm.cpp rename to src/common/low_precision_transformations/src/recurrent_cell.cpp index fb905ddfc288f7..7de00e56b34369 100644 --- a/src/common/low_precision_transformations/src/lstm.cpp +++ b/src/common/low_precision_transformations/src/recurrent_cell.cpp @@ -1,8 +1,8 @@ -// Copyright (C) 2018-2022 Intel Corporation +// Copyright (C) 2022 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // -#include "low_precision/lstm.hpp" +#include "low_precision/recurrent_cell.hpp" #include #include @@ -13,7 +13,6 @@ #include #include -#include "low_precision/concat.hpp" #include "low_precision/network_helper.hpp" #include "../include/low_precision/rt_info/skip_cleanup_attribute.hpp" @@ -21,7 +20,7 @@ namespace ngraph { namespace pass { namespace low_precision { -LSTMTransformation::LSTMTransformation(const Params& params) : LayerTransformation(params) { +RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) : LayerTransformation(params) { const auto X = ngraph::pattern::wrap_type(); const auto H = ngraph::pattern::wrap_type(); const auto C = ngraph::pattern::wrap_type(); @@ -125,7 +124,7 @@ LSTMTransformation::LSTMTransformation(const Params& params) : LayerTransformati this->register_matcher(m, callback); } -bool LSTMTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) { +bool RecurrentCellTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) { const auto lstm = m.get_match_root(); if (!canBeTransformed(context, lstm)) { return false; @@ -178,15 +177,15 @@ bool LSTMTransformation::transform(TransformationContext& context, ngraph::patte return true; } -bool LSTMTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { +bool RecurrentCellTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { return true; } -bool LSTMTransformation::isPrecisionPreserved(std::shared_ptr) const noexcept { +bool RecurrentCellTransformation::isPrecisionPreserved(std::shared_ptr) const noexcept { return true; } -void LSTMTransformation::propagateSkipCleanupAttribute(std::shared_ptr multiply) { +void RecurrentCellTransformation::propagateSkipCleanupAttribute(std::shared_ptr multiply) { SkipCleanupAttribute::create(multiply, true); auto multiply_parent = multiply->get_input_node_shared_ptr(0); SkipCleanupAttribute::create(multiply_parent, true); diff --git a/src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp b/src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp index bc05e369012eb7..41e43f95922c5f 100644 --- a/src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp +++ b/src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp @@ -17,10 +17,7 @@ using namespace ngraph; using namespace ov; -SkipCleanupAttribute::SkipCleanupAttribute(const bool skip) - : - SharedAttribute(skip) { -} +SkipCleanupAttribute::SkipCleanupAttribute(const bool skip) : skip(skip) {} ov::Any SkipCleanupAttribute::create( const std::shared_ptr& node, @@ -29,10 +26,14 @@ ov::Any SkipCleanupAttribute::create( return (rt[SkipCleanupAttribute::get_type_info_static()] = SkipCleanupAttribute(skip)); } +const bool SkipCleanupAttribute::value() const { + return skip; +} + std::string SkipCleanupAttribute::to_string() const { std::stringstream ss; ss << "SkipCleanup: {"; - attribute ? ss << "True" : ss << "False"; + skip ? ss << "True" : ss << "False"; ss << "}"; return ss.str(); } diff --git a/src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/recurrent_cell_transformation.cpp similarity index 79% rename from src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp rename to src/tests/functional/inference_engine/lp_transformations/recurrent_cell_transformation.cpp index b32b866cde7e24..c7afae4b86f51a 100644 --- a/src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/recurrent_cell_transformation.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2022 Intel Corporation +// Copyright (C) 2022 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -6,8 +6,7 @@ #include #include -#include -#include +#include #include #include #include @@ -23,7 +22,7 @@ #include "layer_transformation.hpp" #include "lpt_ngraph_functions/common/builders.hpp" #include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" -#include "lpt_ngraph_functions/lstm_function.hpp" +#include "lpt_ngraph_functions/recurrent_cell_function.hpp" #include "simple_low_precision_transformer.hpp" using namespace testing; @@ -33,7 +32,7 @@ using namespace ngraph::builder::subgraph; namespace { -class LSTMTransformationValues { +class RecurrentCellTransformationValues { public: ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_X; ngraph::builder::subgraph::DequantizationOperations::Convert convert_X; @@ -51,20 +50,20 @@ class LSTMTransformationValues { ngraph::builder::subgraph::DequantizationOperations dequantizationAfter; }; -inline std::ostream& operator<<(std::ostream& out, const LSTMTransformationValues& values) { +inline std::ostream& operator<<(std::ostream& out, const RecurrentCellTransformationValues& values) { return out << "_" << values.fakeQuantize_X << "_" << values.convert_X << "_" << values.dequantization_X << "_" << values.fakeQuantize_H << "_" << values.convert_H << "_" << values.dequantization_H << "_" << values.fakeQuantize_W << "_" << values.convert_W << "_" << values.dequantization_W << "_" << values.fakeQuantize_R << "_" << values.convert_R << "_" << values.dequantization_R; } -class LSTMTransformationTestValues { +class RecurrentCellTransformationTestValues { public: - LSTMTransformationTestValues() = default; - LSTMTransformationTestValues(const TestTransformationParams& params, - const LSTMFunction::RNNType type, - const LSTMTransformationValues& actual, - const LSTMTransformationValues& result, + RecurrentCellTransformationTestValues() = default; + RecurrentCellTransformationTestValues(const TestTransformationParams& params, + const RecurrentCellFunction::RNNType type, + const RecurrentCellTransformationValues& actual, + const RecurrentCellTransformationValues& result, const bool addNotPrecisionPreservedOperation = false, const bool checkIntervalsAlignmentAttributes = true) : params(params), @@ -75,31 +74,31 @@ class LSTMTransformationTestValues { checkIntervalsAlignmentAttributes(checkIntervalsAlignmentAttributes) {} TestTransformationParams params; - LSTMFunction::RNNType type; - LSTMTransformationValues actual; - LSTMTransformationValues result; + RecurrentCellFunction::RNNType type; + RecurrentCellTransformationValues actual; + RecurrentCellTransformationValues result; // add not precision preserved operation to set output precision for FakeQuantize // don't set to 'true' by default to keep test cases with tested operation as output bool addNotPrecisionPreservedOperation; bool checkIntervalsAlignmentAttributes; }; -inline std::ostream& operator<<(std::ostream& out, const LSTMTransformationTestValues& values) { +inline std::ostream& operator<<(std::ostream& out, const RecurrentCellTransformationTestValues& values) { return out << "_" << values.actual << "_" << values.result; } -typedef std::tuple, std::vector, LSTMTransformationTestValues> - LSTMTransformationParams; +typedef std::tuple, std::vector, RecurrentCellTransformationTestValues> + RecurrentCellTransformationParams; -class LSTMTransformation : public LayerTransformation, public testing::WithParamInterface { +class RecurrentCellTransformation : public LayerTransformation, public testing::WithParamInterface { public: void SetUp() override { const ngraph::element::Type precision = std::get<0>(GetParam()); const std::vector activations_shapes = std::get<1>(GetParam()); const std::vector weights_shapes = std::get<2>(GetParam()); - LSTMTransformationTestValues testValues = std::get<3>(GetParam()); + RecurrentCellTransformationTestValues testValues = std::get<3>(GetParam()); - actualFunction = ngraph::builder::subgraph::LSTMFunction::get(precision, + actualFunction = ngraph::builder::subgraph::RecurrentCellFunction::get(precision, activations_shapes, weights_shapes, testValues.type, @@ -128,7 +127,7 @@ class LSTMTransformation : public LayerTransformation, public testing::WithParam const auto params = TestTransformationParams::toParams(testValues.params); SimpleLowPrecisionTransformer transformer; - transformer.commonGraphRewrite->add_matcher(params); + transformer.commonGraphRewrite->add_matcher(params); transformer.transform(actualFunction); SimpleLowPrecisionTransformer clenup_transformer; @@ -151,7 +150,8 @@ class LSTMTransformation : public LayerTransformation, public testing::WithParam IntervalsAlignmentSharedValue::Interval interval{-1.28f, 2.55f}; - referenceFunction = ngraph::builder::subgraph::LSTMFunction::get(precision, + referenceFunction = + ngraph::builder::subgraph::RecurrentCellFunction::get(precision, activations_shapes, weights_shapes, testValues.type, @@ -178,11 +178,11 @@ class LSTMTransformation : public LayerTransformation, public testing::WithParam testValues.result.dequantizationAfter); } - static std::string getTestCaseName(testing::TestParamInfo obj) { + static std::string getTestCaseName(testing::TestParamInfo obj) { const ngraph::element::Type precision = std::get<0>(obj.param); const std::vector activations_shapes = std::get<1>(obj.param); const std::vector weights_shapes = std::get<2>(obj.param); - const LSTMTransformationTestValues testValues = std::get<3>(obj.param); + const RecurrentCellTransformationTestValues testValues = std::get<3>(obj.param); std::ostringstream result; result << LayerTransformation::getTestCaseNameByParams(precision, activations_shapes[0], testValues.params) @@ -191,7 +191,7 @@ class LSTMTransformation : public LayerTransformation, public testing::WithParam } }; -TEST_P(LSTMTransformation, CompareFunctions) { +TEST_P(RecurrentCellTransformation, CompareFunctions) { actualFunction->validate_nodes_and_infer_types(); auto res = compare_functions(actualFunction, referenceFunction); ASSERT_TRUE(res.first) << res.second; @@ -205,14 +205,14 @@ const std::vector precisions = { }; namespace testValues1 { -const std::vector> activations_shapes = {{{1, 16}, {1, 128}, {1, 128}}}; +const std::vector> activations_shapes = {{{1, 1}, {1, 1}, {1, 1}}}; -const std::vector> weights_shapes = {{{512, 16}, {512, 128}, {512}}}; +const std::vector> weights_shapes = {{{4, 1}, {4, 1}, {4}}}; -const std::vector testValues = { +const std::vector testValues = { // LSTM Cell {LayerTransformation::createParamsU8I8(), - LSTMFunction::RNNType::LSTMCell, + RecurrentCellFunction::RNNType::LSTMCell, { // X {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, @@ -274,16 +274,90 @@ const std::vector testValues = { }, } }, + // multi-channel fake quantizes on weights + {LayerTransformation::createParamsU8I8(), + RecurrentCellFunction::RNNType::LSTMCell, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + { 255ul, + {{4, 1}, {4, 1}, {4, 1}, {4, 1}}, + {0.f, 0.f, 0.f, 0.f}, + {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f}, + {0.f, 0.f, 0.f, 0.f}, + {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f}}, + {}, + {{}, {}, {}}, + // R + { 255ul, + {{4, 1}, {4, 1}, {4, 1}, {4, 1}}, + {0.f, 0.f, 0.f, 0.f}, + {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f}, + {0.f, 0.f, 0.f, 0.f}, + {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f}}, + {}, + {{}, {}, {}}, + }, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + {}, + {}, + { + {element::f32}, + {}, + {{0.01f / 1.f, 0.01f / 2.f, 0.01f / 3.f, 0.01f / 4.f}, ngraph::element::f32, {4, 1}} + }, + // R + {}, + {}, + { + {element::f32}, + {}, + {{0.01f / 1.f, 0.01f / 2.f, 0.01f / 3.f, 0.01f / 4.f}, ngraph::element::f32, {4, 1}} + }, + } + }, }; INSTANTIATE_TEST_SUITE_P( smoke_LPT, - LSTMTransformation, + RecurrentCellTransformation, ::testing::Combine( ::testing::ValuesIn(precisions), ::testing::ValuesIn(activations_shapes), ::testing::ValuesIn(weights_shapes), ::testing::ValuesIn(testValues)), - LSTMTransformation::getTestCaseName); + RecurrentCellTransformation::getTestCaseName); } // namespace testValues1 namespace testValues2 { @@ -291,10 +365,10 @@ const std::vector> activations_shapes = {{{1, const std::vector> weights_shapes = {{{1, 512, 16}, {1, 512, 128}, {1, 512}}}; -const std::vector testValues = { +const std::vector testValues = { // LSTM Sequence {LayerTransformation::createParamsU8I8(), - LSTMFunction::RNNType::LSTMSequence, + RecurrentCellFunction::RNNType::LSTMSequence, { // X {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, @@ -359,13 +433,13 @@ const std::vector testValues = { }; INSTANTIATE_TEST_SUITE_P( smoke_LPT, - LSTMTransformation, + RecurrentCellTransformation, ::testing::Combine( ::testing::ValuesIn(precisions), ::testing::ValuesIn(activations_shapes), ::testing::ValuesIn(weights_shapes), ::testing::ValuesIn(testValues)), - LSTMTransformation::getTestCaseName); + RecurrentCellTransformation::getTestCaseName); } // namespace testValues2 namespace testValues3 { @@ -373,10 +447,10 @@ const std::vector> activations_shapes = {{{2, const std::vector> weights_shapes = {{{9, 3}, {9, 3}, {9}}}; -const std::vector testValues = { +const std::vector testValues = { // GRU {LayerTransformation::createParamsU8I8(), - LSTMFunction::RNNType::GRU, + RecurrentCellFunction::RNNType::GRU, { // X {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, @@ -441,13 +515,13 @@ const std::vector testValues = { }; INSTANTIATE_TEST_SUITE_P( smoke_LPT, - LSTMTransformation, + RecurrentCellTransformation, ::testing::Combine( ::testing::ValuesIn(precisions), ::testing::ValuesIn(activations_shapes), ::testing::ValuesIn(weights_shapes), ::testing::ValuesIn(testValues)), - LSTMTransformation::getTestCaseName); + RecurrentCellTransformation::getTestCaseName); } // namespace testValues3 namespace testValues4 { @@ -455,10 +529,10 @@ const std::vector> activations_shapes = {{{2, const std::vector> weights_shapes = {{{3, 3}, {3, 3}, {9}}}; -const std::vector testValues = { +const std::vector testValues = { // RNNCell {LayerTransformation::createParamsU8I8(), - LSTMFunction::RNNType::RNNCell, + RecurrentCellFunction::RNNType::RNNCell, { // X {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, @@ -523,12 +597,12 @@ const std::vector testValues = { }; INSTANTIATE_TEST_SUITE_P( smoke_LPT, - LSTMTransformation, + RecurrentCellTransformation, ::testing::Combine( ::testing::ValuesIn(precisions), ::testing::ValuesIn(activations_shapes), ::testing::ValuesIn(weights_shapes), ::testing::ValuesIn(testValues)), - LSTMTransformation::getTestCaseName); + RecurrentCellTransformation::getTestCaseName); } // namespace testValues4 } // namespace diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/lstm_function.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/recurrent_cell_function.hpp similarity index 96% rename from src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/lstm_function.hpp rename to src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/recurrent_cell_function.hpp index 8119a0bfcf9102..fd114488a6dfec 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/lstm_function.hpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/recurrent_cell_function.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2022 Intel Corporation +// Copyright (C) 2022 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -15,7 +15,7 @@ namespace ngraph { namespace builder { namespace subgraph { -class LSTMFunction { +class RecurrentCellFunction { public: enum class RNNType { LSTMCell, LSTMSequence, GRU, RNNCell }; diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/recurrent_cell_function.cpp similarity index 97% rename from src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp rename to src/tests/ngraph_helpers/lpt_ngraph_functions/src/recurrent_cell_function.cpp index 9fd318fa105a26..02a7cd5f562358 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/recurrent_cell_function.cpp @@ -1,8 +1,8 @@ -// Copyright (C) 2018-2022 Intel Corporation +// Copyright (C) 2022 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // -#include "lpt_ngraph_functions/lstm_function.hpp" +#include "lpt_ngraph_functions/recurrent_cell_function.hpp" #include #include "ngraph_ops/type_relaxed.hpp" @@ -23,7 +23,7 @@ namespace subgraph { using namespace ngraph::pass; -std::shared_ptr LSTMFunction::get( +std::shared_ptr RecurrentCellFunction::get( const ngraph::element::Type inputPrecision, const std::vector& inputActivationsShapes, const std::vector& inputWeightsShapes, @@ -82,6 +82,7 @@ std::shared_ptr LSTMFunction::get( C, parent_W, parent_R, + B, 128); rnn_layer->set_friendly_name("lstm_cell"); break;