diff --git a/src/common/low_precision_transformations/include/low_precision/lstm_support.hpp b/src/common/low_precision_transformations/include/low_precision/lstm_support.hpp deleted file mode 100644 index e4df2053145574..00000000000000 --- a/src/common/low_precision_transformations/include/low_precision/lstm_support.hpp +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (C) 2018-2022 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include -#include -#include "low_precision/layer_transformation.hpp" - -namespace ngraph { -namespace pass { -namespace low_precision { - -class LP_TRANSFORMATIONS_API MoveFakeQuantize : public LayerTransformation { -public: - NGRAPH_RTTI_DECLARATION; - MoveFakeQuantize(const Params& params = Params()); - bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override; - bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; - bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; -}; - -} // namespace low_precision -} // namespace pass -} // namespace ngraph diff --git a/src/common/low_precision_transformations/src/low_precision.cpp b/src/common/low_precision_transformations/src/low_precision.cpp index e91373b1e0feea..4ee25afc53106b 100644 --- a/src/common/low_precision_transformations/src/low_precision.cpp +++ b/src/common/low_precision_transformations/src/low_precision.cpp @@ -69,6 +69,7 @@ #include "low_precision/unsqueeze.hpp" #include "low_precision/variadic_split.hpp" #include "low_precision/move_fake_quantize.hpp" +#include "low_precision/lstm.hpp" // cleanup transformations #include "low_precision/convert.hpp" @@ -202,6 +203,7 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p prerequisites->add_matcher(supportedTypes); prerequisites->add_matcher(); prerequisites->add_matcher(); + prerequisites->add_matcher(); manager.register_pass(); diff --git a/src/common/low_precision_transformations/src/lstm_support.cpp b/src/common/low_precision_transformations/src/lstm_support.cpp deleted file mode 100644 index b807559a36d9a7..00000000000000 --- a/src/common/low_precision_transformations/src/lstm_support.cpp +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright (C) 2018-2022 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "low_precision/move_fake_quantize.hpp" - -#include -#include - -#include -#include -#include -#include - -#include "low_precision/concat.hpp" -#include "low_precision/network_helper.hpp" -/* -namespace ngraph { -namespace pass { -namespace low_precision { - -NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::MoveFakeQuantize, "MoveFakeQuantize", 0); - -MoveFakeQuantize::MoveFakeQuantize(const Params& params) : LayerTransformation(params) { - const auto concat = ngraph::pattern::wrap_type(pattern::consumers_count(1)); - const auto operation = ngraph::pattern::wrap_type({ concat }); - const auto input_low = ngraph::pattern::wrap_type(); - const auto input_high = ngraph::pattern::wrap_type(); - const auto output_low = ngraph::pattern::wrap_type(); - const auto output_high = ngraph::pattern::wrap_type(); - const auto fq_with_operation = ngraph::pattern::wrap_type({ operation, - input_low, - input_high, - output_low, - output_high}); - const auto fq = ngraph::pattern::wrap_type({ concat, - input_low, - input_high, - output_low, - output_high }); - - ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { - auto op = m.get_match_root(); - if (transformation_callback(op)) { - return false; - } - - return transform(*context, m); - }; - - auto m = std::make_shared( - std::make_shared(OutputVector{fq, fq_with_operation}), - "MoveFakeQuantize"); - this->register_matcher(m, callback); -} - -bool MoveFakeQuantize::transform(TransformationContext& context, ngraph::pattern::Matcher& m) { - const auto fq = m.get_match_root(); - if (!canBeTransformed(context, fq)) { - return false; - } - - const auto operation = fq->get_input_node_shared_ptr(0); - std::shared_ptr concat; - bool without_operation = true; - const std::string fq_original_name = fq->get_friendly_name(); - std::string operation_original_name; - if (is_type(operation)) { - concat = operation; - } else { - operation_original_name = operation->get_friendly_name(); - concat = operation->get_input_node_shared_ptr(0); - without_operation = false; - } - - if (!ConcatTransformation::isQuantizedStatic(concat)) { - return false; - } - - std::vector> curr_constants(4); - bool multi_chanels = false; - const auto concat_node = as_type_ptr(concat); - if (concat_node == nullptr) { - return false; - } - const auto concat_axis = concat_node->get_concatenation_axis(); - for (size_t i = 0; i < 4; i++) { - curr_constants[i] = as_type_ptr(fq->get_input_node_shared_ptr(i + 1)); - if (!multi_chanels && curr_constants[i]->get_shape().size() > concat_axis && curr_constants[i]->get_shape()[concat_axis] != 1) { - multi_chanels = true; - } - } - - // it's impossible to split fq constants by channel if number of channels is dynamic - if (multi_chanels && fq->get_input_partial_shape(0)[concat_axis].is_dynamic()) { - return false; - } - - std::vector>> new_constants; - if (multi_chanels) { - new_constants = NetworkHelper::splitConstantsBeforeConcat(concat, curr_constants); - } - - const auto convert_q = fq->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); - if (convert_q == nullptr) { - return false; - } - - const bool q_dq = is_type(convert_q); - std::vector> newNodes; - for (size_t i = 0; i < concat->get_input_size(); ++i) { - ov::Output parent_output; - if (without_operation) { - parent_output = concat->get_input_source_output(i); - } else { - auto fq_input = operation->clone_with_new_inputs({concat->get_input_source_output(i)}); - fq_input->set_friendly_name(operation_original_name + "_" + std::to_string(i + 1)); - parent_output = fq_input->output(0); - } - - const std::shared_ptr new_fq = multi_chanels ? - fq->clone_with_new_inputs({parent_output, - new_constants[0][new_constants[0].size() == 1 ? 0 : i], - new_constants[1][new_constants[1].size() == 1 ? 0 : i], - new_constants[2][new_constants[2].size() == 1 ? 0 : i], - new_constants[3][new_constants[3].size() == 1 ? 0 : i] }) : - fq->clone_with_new_inputs({parent_output, - fq->get_input_node_ptr(1)->clone_with_new_inputs({}), - fq->get_input_node_ptr(2)->clone_with_new_inputs({}), - fq->get_input_node_ptr(3)->clone_with_new_inputs({}), - fq->get_input_node_ptr(4)->clone_with_new_inputs({}) }); - - ngraph::copy_runtime_info(fq, new_fq); - new_fq->set_friendly_name(fq_original_name + "_" + std::to_string(i + 1)); - if (q_dq) { - auto new_convert_q = convert_q->clone_with_new_inputs({new_fq}); - ngraph::copy_runtime_info(convert_q, new_convert_q); - new_convert_q->set_friendly_name(convert_q->get_friendly_name() + "_" + std::to_string(i + 1)); - newNodes.push_back(new_convert_q); - } else { - newNodes.push_back(new_fq); - } - } - - auto newConcat = concat->clone_with_new_inputs(ngraph::OutputVector(newNodes.begin(), newNodes.end())); - newConcat->set_friendly_name(concat->get_friendly_name()); - NetworkHelper::copyInfo(concat, newConcat); - if (q_dq) { - auto dq = NetworkHelper::getDequantizationBelow(convert_q); - moveDequantizationBefore(context, newConcat, dq, false); - return true; - } - replace_node(fq, newConcat); - updateOutput(context, newConcat, fq); - - return true; -} - -bool MoveFakeQuantize::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { - auto operation = layer->get_input_node_shared_ptr(0); - std::shared_ptr concat; - if (is_type(operation)) { - concat = operation; - } else { - concat = operation->get_input_node_shared_ptr(0); - } - if (!ConcatTransformation::isQuantizedStatic(concat)) { - return false; - } - const auto convert_q_target_inputs = layer->output(0).get_target_inputs(); - if (convert_q_target_inputs.empty()) { - return false; - } - const auto convert_q = convert_q_target_inputs.begin()->get_node()->shared_from_this(); - bool q_dq = is_type(convert_q); - if (q_dq && (convert_q->get_output_size() != 1 || layer->get_output_size() != 1)) { - return false; - } - return true; -} - -bool MoveFakeQuantize::isPrecisionPreserved(std::shared_ptr) const noexcept { - return true; -} - -} // namespace low_precision -} // namespace pass -} // namespace ngraph - -*/ diff --git a/src/tests/functional/inference_engine/lp_transformations/lstm_support_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/lstm_support_transformation.cpp deleted file mode 100644 index 0385eea0d02130..00000000000000 --- a/src/tests/functional/inference_engine/lp_transformations/lstm_support_transformation.cpp +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright (C) 2018-2022 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common_test_utils/ngraph_test_utils.hpp" -#include "layer_transformation.hpp" -#include "lpt_ngraph_functions/common/builders.hpp" -#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" -#include "lpt_ngraph_functions/lstm_function.hpp" -#include "simple_low_precision_transformer.hpp" - -using namespace testing; -using namespace ngraph; -using namespace ngraph::pass; -using namespace ngraph::builder::subgraph; - -namespace { - -class LSTMTransformationActualValues { -public: - std::vector fakeQuantizes; - std::vector converts; - std::vector dequantizations; -}; - -inline std::ostream& operator<<(std::ostream& out, const LSTMTransformationActualValues& values) { - return out << "_" << values.fakeQuantizes[0] << "_" << values.converts[0].outPrecision << "_" - << values.dequantizations[0]; -} - -class LSTMTransformationResultValues { -public: - std::vector fakeQuantizes; - std::vector converts; - std::vector dequantizations; - ngraph::element::Type precisionAfterOperation; - ngraph::builder::subgraph::DequantizationOperations dequantizationAfter; -}; - -inline std::ostream& operator<<(std::ostream& out, const LSTMTransformationResultValues& values) { - return out << "_" << values.fakeQuantizes[0] << "_" << values.converts[0].outPrecision << "_" - << values.dequantizations[0]; -} - -class LSTMTransformationTestValues { -public: - LSTMTransformationTestValues() = default; - LSTMTransformationTestValues(const TestTransformationParams& params, - const bool multiChannels, - const std::int64_t axis, - const LSTMTransformationActualValues& actual, - const LSTMTransformationResultValues& result, - const bool addNotPrecisionPreservedOperation = false, - const bool checkIntervalsAlignmentAttributes = true) - : params(params), - multiChannels(multiChannels), - axis(axis), - actual(actual), - result(result), - addNotPrecisionPreservedOperation(addNotPrecisionPreservedOperation), - checkIntervalsAlignmentAttributes(checkIntervalsAlignmentAttributes) {} - - TestTransformationParams params; - bool multiChannels; - std::int64_t axis; - LSTMTransformationActualValues actual; - LSTMTransformationResultValues result; - // add not precision preserved operation to set output precision for FakeQuantize - // don't set to 'true' by default to keep test cases with tested operation as output - bool addNotPrecisionPreservedOperation; - bool checkIntervalsAlignmentAttributes; -}; - -inline std::ostream& operator<<(std::ostream& out, const LSTMTransformationTestValues& values) { - return out << "_" << values.multiChannels << "_" << values.actual << "_" << values.result; -} - -typedef std::tuple, LSTMTransformationTestValues> - LSTMTransformationParams; - -class LSTMTransformation : public LayerTransformation, public testing::WithParamInterface { -public: - void SetUp() override { - const ngraph::element::Type precision = std::get<0>(GetParam()); - const std::vector shapes = std::get<1>(GetParam()); - LSTMTransformationTestValues testValues = std::get<2>(GetParam()); - - // dequantization output precision depends on input precision - // to avoid huge amount of tests cases let's define dequantization output precision as input precision - /*if (!testValues.actual.dequantization1.multiply.empty()) { - testValues.actual.dequantization1.multiply.outPrecision = precision; - } - if (!testValues.actual.dequantization2.multiply.empty()) { - testValues.actual.dequantization2.multiply.outPrecision = precision; - }*/ - - actualFunction = ngraph::builder::subgraph::LSTMFunction::get(precision, - shapes, - testValues.actual.fakeQuantizes, - testValues.actual.converts, - testValues.actual.dequantizations, - {}, - ngraph::element::undefined, - {}); - ngraph::pass::VisualizeTree("C:\\Users\\ndemasho\\rep\\Visual\\test.actual.dot") - .run_on_function(actualFunction); - auto supportedPrecisionsOnActivation = std::vector( - {ngraph::pass::low_precision::OperationPrecisionRestriction::create( - {{0, testValues.params.precisionsOnActivations}})}); - - auto quantizationRestrictions = - testValues.multiChannels - ? std::vector() - : std::vector( - {ngraph::pass::low_precision::OperationPerTensorQuantizationRestriction::create< - ngraph::opset1::AvgPool>()}); - - const auto params = TestTransformationParams::toParams(testValues.params); - SimpleLowPrecisionTransformer transformer(supportedPrecisionsOnActivation, quantizationRestrictions); - transformer.commonGraphRewrite - ->add_matcher(params); - transformer.commonGraphRewrite->add_matcher(params); - transformer.transform(actualFunction); - ngraph::pass::VisualizeTree("C:\\Users\\ndemasho\\rep\\Visual\\test.transform.dot") - .run_on_function(actualFunction); - { - ngraph::pass::Manager standaloneCleanupManager; - standaloneCleanupManager - .register_pass(); - standaloneCleanupManager.run_passes(actualFunction); - } - - { - ngraph::pass::Manager standaloneCleanupManager; - standaloneCleanupManager - .register_pass(); - standaloneCleanupManager.run_passes(actualFunction); - } - - // dequantization output precision depends on input precision - // to avoid huge amount of tests cases let's define dequantization output precision as input precision - if (!testValues.result.dequantizationAfter.multiply.empty()) { - testValues.result.dequantizationAfter.multiply.outPrecision = precision; - } - - if (!testValues.params.updatePrecisions && (precision == ngraph::element::f32) && - !testValues.result.dequantizationAfter.convert.empty()) { - testValues.result.dequantizationAfter.convert = {}; - } - - IntervalsAlignmentSharedValue::Interval interval{-1.28f, 2.55f}; - - referenceFunction = ngraph::builder::subgraph::LSTMFunction::get(precision, - shapes, - testValues.result.fakeQuantizes, - testValues.result.converts, - testValues.result.dequantizations, - {PrecisionPreservedAttribute(true), - IntervalsAlignmentAttribute(interval, 256), - QuantizationAlignmentAttribute(false)}, - testValues.result.precisionAfterOperation, - testValues.result.dequantizationAfter); - ngraph::pass::VisualizeTree("C:\\Users\\ndemasho\\rep\\Visual\\test.reference.dot") - .run_on_function(referenceFunction); - } - - static std::string getTestCaseName(testing::TestParamInfo obj) { - const ngraph::element::Type precision = std::get<0>(obj.param); - const std::vector shapes = std::get<1>(obj.param); - const LSTMTransformationTestValues testValues = std::get<2>(obj.param); - - std::ostringstream result; - result << LayerTransformation::getTestCaseNameByParams(precision, shapes[0], testValues.params) << "_" - << (testValues.multiChannels ? "multiChannels_" : "notMultiChannels_") << "axis_" << testValues.axis - << "_" << testValues.actual << "_" << testValues.result << "_"; - return result.str(); - } -}; - -TEST_P(LSTMTransformation, CompareFunctions) { - actualFunction->validate_nodes_and_infer_types(); - auto res = compare_functions(actualFunction, referenceFunction, true, true, false, true, false); - ASSERT_TRUE(res.first) << res.second; - - ASSERT_TRUE(LayerTransformation::allNamesAreUnique(actualFunction)) << "Not all names are unique"; - - LSTMTransformationTestValues testValues = std::get<2>(GetParam()); - const auto actualFakeQuantizes = LayerTransformation::get(actualFunction); - if (testValues.axis == 1) { - ASSERT_TRUE(checkIfOutputAttributesSharedValuesAreTheSame(actualFakeQuantizes)) - << "PrecisionsAttribute are not the same"; - - if (testValues.checkIntervalsAlignmentAttributes) { - auto operations = LayerTransformation::get(actualFunction); - operations.insert(operations.end(), actualFakeQuantizes.begin(), actualFakeQuantizes.end()); - ASSERT_TRUE(checkIfAttributesSharedValuesAreTheSame(operations)) - << "IntervalsAlignmentAttribute are not the same"; - } - } -} - -const std::vector precisions = { - ngraph::element::f32, - // ngraph::element::f16 -}; - -namespace testValues1 { -const std::vector> shapes = {{{1, 16}, {1, 128}, {1, 128}}}; - -const std::vector testValues = { - {LayerTransformation::createParamsU8I8(), - true, - 1, - {{{256ul, {}, {0.f}, {2550.f}, {0.f}, {2550.f}}}, {{}}, {{}}}, - {{{256ul, - {}, - {0.f}, - {2550.f}, - {0.f}, - {255.f}, - ngraph::element::u8, - {IntervalsAlignmentAttribute(IntervalsAlignmentSharedValue::Interval{0.f, 2.55f}, 256ul)}}}, - {{}}, - {{}}}, - true}, -}; -INSTANTIATE_TEST_SUITE_P( - smoke_LPT, - LSTMTransformation, - ::testing::Combine( - ::testing::ValuesIn(precisions), - ::testing::ValuesIn(shapes), - ::testing::ValuesIn(testValues)), - LSTMTransformation::getTestCaseName); -} // namespace testValues1 -} // namespace \ No newline at end of file diff --git a/src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp index 811501771c26a8..0cd6b96d0750bf 100644 --- a/src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/lstm_transformation.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include #include @@ -31,19 +31,7 @@ using namespace ngraph::builder::subgraph; namespace { -class LSTMTransformationActualValues { -public: - std::vector fakeQuantizes; - std::vector converts; - std::vector dequantizations; -}; - -inline std::ostream& operator<<(std::ostream& out, const LSTMTransformationActualValues& values) { - return out << "_" << values.fakeQuantizes[0] << "_" << values.converts[0].outPrecision << "_" - << values.dequantizations[0]; -} - -class LSTMTransformationResultValues { +class LSTMTransformationValues { public: std::vector fakeQuantizes; std::vector converts; @@ -52,7 +40,7 @@ class LSTMTransformationResultValues { ngraph::builder::subgraph::DequantizationOperations dequantizationAfter; }; -inline std::ostream& operator<<(std::ostream& out, const LSTMTransformationResultValues& values) { +inline std::ostream& operator<<(std::ostream& out, const LSTMTransformationValues& values) { return out << "_" << values.fakeQuantizes[0] << "_" << values.converts[0].outPrecision << "_" << values.dequantizations[0]; } @@ -64,8 +52,8 @@ class LSTMTransformationTestValues { bool multiChannels, const bool bias, const LSTMFunction::LSTMType type, - const LSTMTransformationActualValues& actual, - const LSTMTransformationResultValues& result, + const LSTMTransformationValues& actual, + const LSTMTransformationValues& result, const bool addNotPrecisionPreservedOperation = false, const bool checkIntervalsAlignmentAttributes = true) : params(params), @@ -81,8 +69,8 @@ class LSTMTransformationTestValues { bool multiChannels; bool bias; LSTMFunction::LSTMType type; - LSTMTransformationActualValues actual; - LSTMTransformationResultValues result; + LSTMTransformationValues actual; + LSTMTransformationValues result; // add not precision preserved operation to set output precision for FakeQuantize // don't set to 'true' by default to keep test cases with tested operation as output bool addNotPrecisionPreservedOperation; @@ -103,15 +91,6 @@ class LSTMTransformation : public LayerTransformation, public testing::WithParam const std::vector shapes = std::get<1>(GetParam()); LSTMTransformationTestValues testValues = std::get<2>(GetParam()); - // dequantization output precision depends on input precision - // to avoid huge amount of tests cases let's define dequantization output precision as input precision - /*if (!testValues.actual.dequantization1.multiply.empty()) { - testValues.actual.dequantization1.multiply.outPrecision = precision; - } - if (!testValues.actual.dequantization2.multiply.empty()) { - testValues.actual.dequantization2.multiply.outPrecision = precision; - }*/ - actualFunction = ngraph::builder::subgraph::LSTMFunction::get(precision, shapes, testValues.type, @@ -136,11 +115,14 @@ class LSTMTransformation : public LayerTransformation, public testing::WithParam ngraph::opset1::AvgPool>()}); const auto params = TestTransformationParams::toParams(testValues.params); - SimpleLowPrecisionTransformer transformer(supportedPrecisionsOnActivation, quantizationRestrictions); - transformer.commonGraphRewrite - ->add_matcher(params); - transformer.commonGraphRewrite->add_matcher(params); - transformer.transform(actualFunction); + /*SimpleLowPrecisionTransformer transform; + transform.add(params); + transform.transform(actualFunction);*/ + + ov::pass::Manager manager; + manager.register_pass(params); + manager.run_passes(actualFunction); + ngraph::pass::VisualizeTree("C:\\Users\\ndemasho\\rep\\Visual\\test.transform.dot") .run_on_function(actualFunction); { @@ -177,9 +159,7 @@ class LSTMTransformation : public LayerTransformation, public testing::WithParam testValues.result.fakeQuantizes, testValues.result.converts, testValues.result.dequantizations, - {PrecisionPreservedAttribute(true), - IntervalsAlignmentAttribute(interval, 256), - QuantizationAlignmentAttribute(false)}, + {}, testValues.result.precisionAfterOperation, testValues.result.dequantizationAfter); ngraph::pass::VisualizeTree("C:\\Users\\ndemasho\\rep\\Visual\\test.reference.dot") @@ -245,12 +225,12 @@ const std::vector testValues = { { { {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, - {255ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {}, {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}} }, { {ngraph::element::u8}, - {ngraph::element::u8}, + {}, {ngraph::element::u8} }, { @@ -292,12 +272,12 @@ const std::vector testValues = { { { {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, - {255ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {}, {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}} }, { {ngraph::element::u8}, - {ngraph::element::u8}, + {}, {ngraph::element::u8} }, { diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/lstm_function.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/lstm_function.hpp index d669b421dd02e3..cfe8a3505eed24 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/lstm_function.hpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/lstm_function.hpp @@ -43,7 +43,7 @@ std::shared_ptr makeQuantizationAndDequantization(const std::shared_ptr if (fqOnData.empty()) { parent = input; } else { - std::shared_ptr fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input, inputPrecision, fqOnData); + std::shared_ptr fakeQuantize1 = makeFakeQuantize(input, inputPrecision, fqOnData); fakeQuantize1->set_friendly_name("fakeQuantize_" + friendly_name); parent = fakeQuantize1; } diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp index 67bd9404a5202c..71a75146d8e394 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/lstm_function.cpp @@ -42,11 +42,11 @@ std::shared_ptr LSTMFunction::get( fqOnDatas[0], converts[0], dequantizations[0]); - std::shared_ptr squeeze_x; + std::shared_ptr squeeze_X; if (type == LSTMType::Cell) { auto squeeze_pattern = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}); - squeeze_x = std::make_shared(parent_X, squeeze_pattern); - squeeze_x->set_friendly_name("squeeze_X"); + squeeze_X = std::make_shared(parent_X, squeeze_pattern); + squeeze_X->set_friendly_name("squeeze_X"); } auto H = std::make_shared(inputPrecision, inputShapes[1]); H->set_friendly_name("H"); @@ -61,7 +61,7 @@ std::shared_ptr LSTMFunction::get( auto w_val = std::vector(512 * 16, 0); auto r_val = std::vector(512 * 128, 0); - auto W = ngraph::opset1::Constant::create(inputPrecision, + auto W = ngraph::opset1::Constant::create(fqOnDatas[1].empty() ? ngraph::element::i8 : inputPrecision, type == LSTMType::Cell ? ngraph::Shape{512, 16} : ngraph::Shape{1, 512, 16}, w_val); std::shared_ptr parent_W = makeQuantizationAndDequantization(W, @@ -70,7 +70,7 @@ std::shared_ptr LSTMFunction::get( fqOnDatas[1], converts[1], dequantizations[1]); - auto R = ngraph::opset1::Constant::create(inputPrecision, + auto R = ngraph::opset1::Constant::create(fqOnDatas[1].empty() ? ngraph::element::i8 : inputPrecision, type == LSTMType::Cell ? ngraph::Shape{512, 128} : ngraph::Shape{1, 512, 128}, r_val); std::shared_ptr parent_R = makeQuantizationAndDequantization(R, @@ -92,7 +92,7 @@ std::shared_ptr LSTMFunction::get( converts[2], dequantizations[2]); if (type == LSTMType::Cell) { - lstm = std::make_shared(squeeze_x, parent_H, C, parent_W, parent_R, parent_B, 128); + lstm = std::make_shared(squeeze_X, parent_H, C, parent_W, parent_R, parent_B, 128); lstm->set_friendly_name("lstm_cell"); } else { auto seq_lengths = ngraph::opset5::Constant::create(element::i32, Shape{1}, {3}); @@ -108,7 +108,7 @@ std::shared_ptr LSTMFunction::get( lstm->set_friendly_name("lstm_sequence"); } } else { - lstm = std::make_shared(squeeze_x, parent_H, C, parent_W, parent_R, 128); + lstm = std::make_shared(squeeze_X, parent_H, C, parent_W, parent_R, 128); lstm->set_friendly_name("lstm_cell"); }