From 7634c44d8a69671f51b0db4cca35f93ba0c7a5b1 Mon Sep 17 00:00:00 2001 From: Nikita Demashov Date: Mon, 4 Apr 2022 13:27:28 +0300 Subject: [PATCH] refactoring --- .../include/low_precision/network_helper.hpp | 3 - .../include/low_precision/recurrent_cell.hpp | 3 + .../rt_info/skip_cleanup_attribute.hpp | 8 -- .../src/network_helper.cpp | 15 --- .../src/recurrent_cell.cpp | 121 +++++++++++++----- .../intel_cpu/src/nodes/tensoriterator.h | 7 - src/plugins/intel_cpu/src/plugin.cpp | 13 ++ .../recurrent_cell_transformation.cpp | 61 +++------ .../recurrent_cell_transformation.cpp | 109 +++------------- .../recurrent_cell_transformation.cpp | 109 +++------------- .../recurrent_cell_transformation.cpp | 5 +- .../recurrent_cell_function.hpp | 5 +- .../src/recurrent_cell_function.cpp | 19 ++- 13 files changed, 164 insertions(+), 314 deletions(-) diff --git a/src/common/low_precision_transformations/include/low_precision/network_helper.hpp b/src/common/low_precision_transformations/include/low_precision/network_helper.hpp index 63e0bea7d43bde..08cfc518d69d5c 100644 --- a/src/common/low_precision_transformations/include/low_precision/network_helper.hpp +++ b/src/common/low_precision_transformations/include/low_precision/network_helper.hpp @@ -253,9 +253,6 @@ class LP_TRANSFORMATIONS_API NetworkHelper { float& updatedOutputLowValue, float& updatedOutputHighValue); - static std::shared_ptr fakeQuantizeWraper - (const std::shared_ptr parameter); - private: static std::shared_ptr foldFakeQuantize( const std::shared_ptr& fq, diff --git a/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp b/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp index 521e5c3f19a59e..d04319a6067ad1 100644 --- a/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp +++ b/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp @@ -20,6 +20,9 @@ class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransform bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; void propagateSkipCleanupAttribute(std::shared_ptr dequantization_multiply); + static std::shared_ptr wrap_fake_quantize(const std::shared_ptr parameter); + static std::shared_ptr wrap_quantization(const std::shared_ptr parameter); + static std::shared_ptr wrap_dequantization(const std::shared_ptr parameter, const bool with_subtract); }; } // namespace low_precision diff --git a/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp b/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp index a41dd0a7e1d94c..7b8bb985b2033e 100644 --- a/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp +++ b/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp @@ -4,16 +4,8 @@ #pragma once -#include -#include -#include -#include - #include -#include -#include -#include "low_precision/lpt_visibility.hpp" #include "low_precision/rt_info/attribute_parameters.hpp" namespace ngraph { diff --git a/src/common/low_precision_transformations/src/network_helper.cpp b/src/common/low_precision_transformations/src/network_helper.cpp index be462b8ebbebac..2ede193b4f96fd 100644 --- a/src/common/low_precision_transformations/src/network_helper.cpp +++ b/src/common/low_precision_transformations/src/network_helper.cpp @@ -1982,21 +1982,6 @@ void NetworkHelper::insertDequantizationAfter( } } } - -std::shared_ptr NetworkHelper::fakeQuantizeWraper( - const std::shared_ptr parameter) { - const auto input_low = ngraph::pattern::wrap_type(); - const auto input_high = ngraph::pattern::wrap_type(); - const auto output_low = ngraph::pattern::wrap_type(); - const auto output_high = ngraph::pattern::wrap_type(); - return ngraph::pattern::wrap_type({ - parameter, - input_low, - input_high, - output_low, - output_high}); -} - } // namespace low_precision } // namespace pass } // namespace ngraph diff --git a/src/common/low_precision_transformations/src/recurrent_cell.cpp b/src/common/low_precision_transformations/src/recurrent_cell.cpp index 0a3657af986500..bc85af73e055d5 100644 --- a/src/common/low_precision_transformations/src/recurrent_cell.cpp +++ b/src/common/low_precision_transformations/src/recurrent_cell.cpp @@ -28,47 +28,38 @@ RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) : const auto R = ngraph::pattern::wrap_type(); const auto B = ngraph::pattern::wrap_type(); - const auto fq_X = NetworkHelper::fakeQuantizeWraper(X); - const auto fq_H = NetworkHelper::fakeQuantizeWraper(H); - const auto fq_W = NetworkHelper::fakeQuantizeWraper(W); - const auto fq_R = NetworkHelper::fakeQuantizeWraper(R); + const auto fq_X = wrap_fake_quantize(X); + const auto fq_H = wrap_fake_quantize(H); + const auto fq_W = wrap_fake_quantize(W); + const auto fq_R = wrap_fake_quantize(R); - const auto dequantization_convert_X = ngraph::pattern::wrap_type({ngraph::pattern::any_input()}); - const auto dequantization_convert_H = ngraph::pattern::wrap_type({ngraph::pattern::any_input()}); - const auto subtract_constant = ngraph::pattern::wrap_type(); - const auto dequantization_subtract_X = ngraph::pattern::wrap_type( - {dequantization_convert_X, subtract_constant}); - const auto dequantization_subtract_H = ngraph::pattern::wrap_type( - {dequantization_convert_H, subtract_constant}); - const auto multiply_constant = ngraph::pattern::wrap_type(); - const auto dequantization_multiply_X = ngraph::pattern::wrap_type( - {dequantization_subtract_X, multiply_constant}); + const auto quantization_X = wrap_quantization(X); + const auto quantization_H = wrap_quantization(H); + + const auto dequantization_X = wrap_dequantization(quantization_X, true); + const auto dequantization_H = wrap_dequantization(quantization_H, true); - const auto dequantization_multiply_without_subtract_X = ngraph::pattern::wrap_type( - {dequantization_convert_X, multiply_constant}); - const auto dequantization_multiply_H = ngraph::pattern::wrap_type( - {dequantization_subtract_H, multiply_constant}); - const auto dequantization_multiply_without_subtract_H = ngraph::pattern::wrap_type( - {dequantization_convert_H, multiply_constant}); + const auto dequantization_without_subtract_X = wrap_dequantization(quantization_X, false); + const auto dequantization_without_subtract_H = wrap_dequantization(quantization_H, false); const auto lstm_cell = ngraph::pattern::wrap_type( {fq_X, fq_H, C, fq_W, fq_R, B}); const auto lstm_cell_with_dequantizations = ngraph::pattern::wrap_type( - {dequantization_multiply_X, dequantization_multiply_H, C, fq_W, fq_R, B}); + {dequantization_X, dequantization_H, C, fq_W, fq_R, B}); const auto lstm_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type( - {dequantization_multiply_without_subtract_X, dequantization_multiply_without_subtract_H, C, fq_W, fq_R, B}); + {dequantization_without_subtract_X, dequantization_without_subtract_H, C, fq_W, fq_R, B}); const auto gru_cell = ngraph::pattern::wrap_type({fq_X, fq_H, fq_W, fq_R, B}); const auto gru_cell_with_dequantizations = ngraph::pattern::wrap_type( - {dequantization_multiply_X, dequantization_multiply_X, fq_W, fq_R, B}); + {dequantization_X, dequantization_H, fq_W, fq_R, B}); const auto gru_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type( - {dequantization_multiply_without_subtract_X, dequantization_multiply_without_subtract_H, fq_W, fq_R, B}); + {dequantization_without_subtract_X, dequantization_without_subtract_H, fq_W, fq_R, B}); const auto rnn_cell = ngraph::pattern::wrap_type({fq_X, fq_H, fq_W, fq_R, B}); const auto rnn_cell_with_dequantizations = ngraph::pattern::wrap_type( - {dequantization_multiply_X, dequantization_multiply_X, fq_W, fq_R, B}); + {dequantization_X, dequantization_H, fq_W, fq_R, B}); const auto rnn_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type( - {dequantization_multiply_without_subtract_X, dequantization_multiply_without_subtract_H, fq_W, fq_R, B}); + {dequantization_without_subtract_X, dequantization_without_subtract_H, fq_W, fq_R, B}); ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { auto op = m.get_match_root(); @@ -94,7 +85,7 @@ RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) : } bool RecurrentCellTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) { - const auto lstm = m.get_match_root(); + const auto lstm = m.get_match_root(); if (!canBeTransformed(context, lstm)) { return false; } @@ -118,13 +109,7 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ngra updatePrecisions); std::shared_ptr new_fq = std::get<0>(QDQ); std::shared_ptr deq_multiply = std::get<1>(QDQ); - if (deq_multiply == nullptr || new_fq == nullptr) { - return false; - } auto multiply_parent = deq_multiply->get_input_node_shared_ptr(0); - if (is_type(multiply_parent)) { - return false; - } ov::disable_constant_folding(multiply_parent); propagateSkipCleanupAttribute(deq_multiply); this->register_new_node(new_fq); @@ -142,7 +127,39 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ngra return true; } -bool RecurrentCellTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { +bool RecurrentCellTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr lstm) const { + std::shared_ptr W, R; + if (is_type(lstm)) { + W = lstm->get_input_node_shared_ptr(3); + R = lstm->get_input_node_shared_ptr(4); + } else { + W = lstm->get_input_node_shared_ptr(2); + R = lstm->get_input_node_shared_ptr(3); + } + for (auto fq_on_weight : {W, R}) { + auto fq_node = as_type_ptr(fq_on_weight); + const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fq_node); + const auto precisionsAttribute = getAttributeFromOutput(fq_on_weight); + const auto precisions = precisionsAttribute.empty() + ? defaultPrecisions + : precisionsAttribute.as().value(); + const DataPrecision dataPrecision = getDataPrecision(fq_on_weight, quantizationDetails, precisions); + auto QDQ = NetworkHelper::decomposeFakeQuantize(fq_node, + dataPrecision.precision, + dataPrecision.min, + dataPrecision.max, + dataPrecision.hasZeroPoint, + updatePrecisions); + std::shared_ptr new_fq = std::get<0>(QDQ); + std::shared_ptr deq_multiply = std::get<1>(QDQ); + if (deq_multiply == nullptr || new_fq == nullptr) { + return false; + } + auto multiply_parent = deq_multiply->get_input_node_shared_ptr(0); + if (is_type(multiply_parent)) { + return false; + } + } return true; } @@ -160,6 +177,42 @@ void RecurrentCellTransformation::propagateSkipCleanupAttribute(std::shared_ptr< } } +std::shared_ptr RecurrentCellTransformation::wrap_fake_quantize( + const std::shared_ptr parameter) { + const auto input_low = ngraph::pattern::wrap_type(); + const auto input_high = ngraph::pattern::wrap_type(); + const auto output_low = ngraph::pattern::wrap_type(); + const auto output_high = ngraph::pattern::wrap_type(); + return ngraph::pattern::wrap_type({ + parameter, + input_low, + input_high, + output_low, + output_high}); +} + +std::shared_ptr RecurrentCellTransformation::wrap_quantization( + const std::shared_ptr parameter) { + const auto quantization_fake_quantize = wrap_fake_quantize(parameter); + const auto quantization_convert = ngraph::pattern::wrap_type( + {quantization_fake_quantize}); + return quantization_convert; +} + +std::shared_ptr RecurrentCellTransformation::wrap_dequantization( + const std::shared_ptr parameter, + const bool with_subtract) { + const auto dequantization_convert = ngraph::pattern::wrap_type({parameter}); + const auto subtract_constant = ngraph::pattern::wrap_type(); + const auto dequantization_subtract = ngraph::pattern::wrap_type( + {dequantization_convert, subtract_constant}); + const auto multiply_constant = ngraph::pattern::wrap_type(); + const auto multiply_parent = with_subtract ? dequantization_subtract : dequantization_convert; + const auto dequantization_multiply = ngraph::pattern::wrap_type( + {multiply_parent, multiply_constant}); + return dequantization_multiply; +} + } // namespace low_precision } // namespace pass } // namespace ngraph diff --git a/src/plugins/intel_cpu/src/nodes/tensoriterator.h b/src/plugins/intel_cpu/src/nodes/tensoriterator.h index 0bd50f64f4c30a..64379c650df39a 100644 --- a/src/plugins/intel_cpu/src/nodes/tensoriterator.h +++ b/src/plugins/intel_cpu/src/nodes/tensoriterator.h @@ -108,13 +108,6 @@ class TensorIterator : public Node { void setExtManager(const ExtensionManager::Ptr& extMgr) { ext_mng = extMgr; } - Graph getSubGraph() const { - return sub_graph; - } - std::shared_ptr getOriginalOp() const { - return ngraphOp; - } - protected: // needShapeInfer() should return false // because we cannot resolve the output dimensions before the inference is completed diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index e18f5968585bad..6efb7e7effbbd6 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -88,6 +88,7 @@ #include #include #include +#include #include #include #include @@ -461,6 +462,18 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr {0, {ngraph::element::u8, ngraph::element::i8}}, {1, {ngraph::element::i8}} }), + PrecisionsRestriction::create({ + {0, {ngraph::element::u8}}, + {1, {ngraph::element::i8}}, + }), + PrecisionsRestriction::create({ + {0, {ngraph::element::u8}}, + {1, {ngraph::element::i8}}, + }), + PrecisionsRestriction::create({ + {0, {ngraph::element::u8}}, + {1, {ngraph::element::i8}}, + }), }); auto quantizationRestrictions = std::vector({ diff --git a/src/tests/functional/inference_engine/lp_transformations/recurrent_cell_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/recurrent_cell_transformation.cpp index 068081b49c3993..e7c1d0f62ac344 100644 --- a/src/tests/functional/inference_engine/lp_transformations/recurrent_cell_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/recurrent_cell_transformation.cpp @@ -4,8 +4,7 @@ #include -#include -#include +#include #include #include #include @@ -24,6 +23,7 @@ #include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" #include "lpt_ngraph_functions/recurrent_cell_function.hpp" #include "simple_low_precision_transformer.hpp" +#include using namespace testing; using namespace ngraph; @@ -69,18 +69,12 @@ class RecurrentCellTransformationTestValues { : params(params), type(type), actual(actual), - result(result), - addNotPrecisionPreservedOperation(addNotPrecisionPreservedOperation), - checkIntervalsAlignmentAttributes(checkIntervalsAlignmentAttributes) {} + result(result) {} TestTransformationParams params; RecurrentCellFunction::RNNType type; RecurrentCellTransformationValues actual; RecurrentCellTransformationValues result; - // add not precision preserved operation to set output precision for FakeQuantize - // don't set to 'true' by default to keep test cases with tested operation as output - bool addNotPrecisionPreservedOperation; - bool checkIntervalsAlignmentAttributes; }; inline std::ostream& operator<<(std::ostream& out, const RecurrentCellTransformationTestValues& values) { @@ -119,10 +113,7 @@ class RecurrentCellTransformation : public LayerTransformation, public testing:: testValues.actual.dequantization_H, testValues.actual.dequantization_W, testValues.actual.dequantization_R - }, - {}, - ngraph::element::undefined, - {}); + }); const auto params = TestTransformationParams::toParams(testValues.params); @@ -143,13 +134,6 @@ class RecurrentCellTransformation : public LayerTransformation, public testing:: testValues.result.dequantizationAfter.multiply.outPrecision = precision; } - if (!testValues.params.updatePrecisions && (precision == ngraph::element::f32) && - !testValues.result.dequantizationAfter.convert.empty()) { - testValues.result.dequantizationAfter.convert = {}; - } - - IntervalsAlignmentSharedValue::Interval interval{-1.28f, 2.55f}; - referenceFunction = ngraph::builder::subgraph::RecurrentCellFunction::get(precision, activations_shapes, @@ -172,10 +156,7 @@ class RecurrentCellTransformation : public LayerTransformation, public testing:: testValues.result.dequantization_H, testValues.result.dequantization_W, testValues.result.dequantization_R - }, - {}, - testValues.result.precisionAfterOperation, - testValues.result.dequantizationAfter); + }); } static std::string getTestCaseName(testing::TestParamInfo obj) { @@ -231,11 +212,11 @@ const std::vector testValues = { {0.01f}, }, // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, }, @@ -297,19 +278,19 @@ const std::vector testValues = { // W { 255ul, {{4, 1}, {4, 1}, {4, 1}, {4, 1}}, - {0.f, 0.f, 0.f, 0.f}, - {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f}, - {0.f, 0.f, 0.f, 0.f}, - {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f}}, + {-1.27f, -1.27f, -1.27f, -1.27f}, + {1.27f, 1.27f, 1.27f, 1.27f}, + {-1.27f, -1.27f, -1.27f, -1.27f}, + {1.27f, 1.27f, 1.27f, 1.27f}}, {}, {{}, {}, {}}, // R { 255ul, {{4, 1}, {4, 1}, {4, 1}, {4, 1}}, - {0.f, 0.f, 0.f, 0.f}, - {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f}, - {0.f, 0.f, 0.f, 0.f}, - {2.55f / 1.f, 2.55f / 2.f, 2.55f / 3.f, 2.55f / 4.f}}, + {-1.27f, -1.27f, -1.27f, -1.27f}, + {1.27f, 1.27f, 1.27f, 1.27f}, + {-1.27f, -1.27f, -1.27f, -1.27f}, + {1.27f, 1.27f, 1.27f, 1.27f}}, {}, {{}, {}, {}}, }, @@ -387,11 +368,11 @@ const std::vector testValues = { {0.01f}, }, // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, }, @@ -469,11 +450,11 @@ const std::vector testValues = { {0.01f}, }, // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, }, @@ -551,11 +532,11 @@ const std::vector testValues = { {0.01f}, }, // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, }, diff --git a/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp b/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp index ea58433aa67d1a..8373be4fda0224 100644 --- a/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp +++ b/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp @@ -40,11 +40,11 @@ const std::vector param {0.01f}, }, // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell, @@ -70,11 +70,11 @@ const std::vector param {0.01f}, }, // W - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell, @@ -99,85 +99,6 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, namespace testValues2 { -const std::vector params = { - // LSTMSequence - { - // X - {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, - {ngraph::element::u8}, - { - {ngraph::element::f32}, - {}, - {0.01f}, - }, - // H - {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, - {ngraph::element::u8}, - { - {ngraph::element::f32}, - {}, - {0.01f}, - }, - // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {{}, {}, {}}, - // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {{}, {}, {}}, - ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence, - "TensorIterator", - "U8" - }, - // asymmetrical FQ on weights - { - // X - {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, - {ngraph::element::u8}, - { - {ngraph::element::f32}, - {}, - {0.01f}, - }, - // H - {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, - {ngraph::element::u8}, - { - {ngraph::element::f32}, - {}, - {0.01f}, - }, - // W - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {{}, {}, {}}, - // R - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {{}, {}, {}}, - ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence, - "TensorIterator", - "FP32" - } -}; - -const std::vector> activations_shapes = {{{1, 1, 16}, {1, 1, 128}, {1, 1, 128}}}; -const std::vector> weights_shapes = {{{1, 512, 16}, {1, 512, 128}, {1, 512}}}; - -INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, - ::testing::Combine( - ::testing::ValuesIn(netPrecisions), - ::testing::ValuesIn(activations_shapes), - ::testing::ValuesIn(weights_shapes), - ::testing::Values(CommonTestUtils::DEVICE_CPU), - ::testing::ValuesIn(trasformationParamValues), - ::testing::ValuesIn(params)), - RecurrentCellTransformation::getTestCaseName); -} // namespace testValues2 - -namespace testValues3 { - const std::vector params = { // GRU { @@ -198,11 +119,11 @@ const std::vector param {0.01f}, }, // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU, @@ -228,11 +149,11 @@ const std::vector param {0.01f}, }, // W - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU, @@ -253,9 +174,9 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, ::testing::ValuesIn(trasformationParamValues), ::testing::ValuesIn(params)), RecurrentCellTransformation::getTestCaseName); -} // namespace testValues3 +} // namespace testValues2 -namespace testValues4 { +namespace testValues3 { const std::vector params = { // RNNCell @@ -277,11 +198,11 @@ const std::vector param {0.01f}, }, // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::RNNCell, @@ -307,11 +228,11 @@ const std::vector param {0.01f}, }, // W - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::RNNCell, @@ -332,4 +253,4 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, ::testing::ValuesIn(trasformationParamValues), ::testing::ValuesIn(params)), RecurrentCellTransformation::getTestCaseName); -} // namespace testValues4 +} // namespace testValues3 diff --git a/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp b/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp index 3c55febc67979c..2d47018f459ced 100644 --- a/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp +++ b/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp @@ -40,11 +40,11 @@ const std::vector param {0.01f}, }, // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell, @@ -70,11 +70,11 @@ const std::vector param {0.01f}, }, // W - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell, @@ -99,85 +99,6 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, namespace testValues2 { -const std::vector params = { - // LSTMSequence - { - // X - {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, - {ngraph::element::u8}, - { - {ngraph::element::f32}, - {}, - {0.01f}, - }, - // H - {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, - {ngraph::element::u8}, - { - {ngraph::element::f32}, - {}, - {0.01f}, - }, - // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {{}, {}, {}}, - // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {{}, {}, {}}, - ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence, - "TensorIterator", - "U8" - }, - // asymmetrical FQ on weights - { - // X - {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, - {ngraph::element::u8}, - { - {ngraph::element::f32}, - {}, - {0.01f}, - }, - // H - {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, - {ngraph::element::u8}, - { - {ngraph::element::f32}, - {}, - {0.01f}, - }, - // W - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {{}, {}, {}}, - // R - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, - {}, - {{}, {}, {}}, - ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence, - "TensorIterator", - "FP32" - } -}; - -const std::vector> activations_shapes = {{{1, 1, 16}, {1, 1, 128}, {1, 1, 128}}}; -const std::vector> weights_shapes = {{{1, 512, 16}, {1, 512, 128}, {1, 512}}}; - -INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, - ::testing::Combine( - ::testing::ValuesIn(netPrecisions), - ::testing::ValuesIn(activations_shapes), - ::testing::ValuesIn(weights_shapes), - ::testing::Values(CommonTestUtils::DEVICE_GPU), - ::testing::ValuesIn(trasformationParamValues), - ::testing::ValuesIn(params)), - RecurrentCellTransformation::getTestCaseName); -} // namespace testValues2 - -namespace testValues3 { - const std::vector params = { // GRU { @@ -198,11 +119,11 @@ const std::vector param {0.01f}, }, // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU, @@ -228,11 +149,11 @@ const std::vector param {0.01f}, }, // W - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU, @@ -253,9 +174,9 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, ::testing::ValuesIn(trasformationParamValues), ::testing::ValuesIn(params)), RecurrentCellTransformation::getTestCaseName); -} // namespace testValues3 +} // namespace testValues2 -namespace testValues4 { +namespace testValues3 { const std::vector params = { // RNNCell @@ -277,11 +198,11 @@ const std::vector param {0.01f}, }, // W - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {255ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::RNNCell, @@ -307,11 +228,11 @@ const std::vector param {0.01f}, }, // W - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, // R - {256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f}}, + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, {}, {{}, {}, {}}, ngraph::builder::subgraph::RecurrentCellFunction::RNNType::RNNCell, @@ -332,4 +253,4 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, ::testing::ValuesIn(trasformationParamValues), ::testing::ValuesIn(params)), RecurrentCellTransformation::getTestCaseName); -} // namespace testValues4 +} // namespace testValues3 diff --git a/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp b/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp index 0b98432726c713..b320bcff56fc64 100644 --- a/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp +++ b/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp @@ -66,10 +66,7 @@ void RecurrentCellTransformation::SetUp() { param.dequantization_H, param.dequantization_W, param.dequantization_R - }, - {}, - ngraph::element::undefined, - {}); + }); } void RecurrentCellTransformation::Run() { diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/recurrent_cell_function.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/recurrent_cell_function.hpp index fd114488a6dfec..e131fd83a705e8 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/recurrent_cell_function.hpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/recurrent_cell_function.hpp @@ -26,10 +26,7 @@ class RecurrentCellFunction { const RNNType type, const std::vector& fqOnDatas, const std::vector& converts, - const std::vector& dequantizations, - const std::vector& concatAttributes, - const ngraph::element::Type precisionAfterOperation, - const DequantizationOperations& dequantizationAfter); + const std::vector& dequantizations); }; std::shared_ptr makeQuantizationAndDequantization(const std::shared_ptr input, diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/recurrent_cell_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/recurrent_cell_function.cpp index 02a7cd5f562358..c2f1b8cb1dcd43 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/recurrent_cell_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/recurrent_cell_function.cpp @@ -30,11 +30,8 @@ std::shared_ptr RecurrentCellFunction::get( const RNNType type, const std::vector& fqOnDatas, const std::vector& converts, - const std::vector& dequantizations, - const std::vector& concatAttributes, - const ngraph::element::Type precisionAfterOperation, - const DequantizationOperations& dequantizationAfter) { - auto X = std::make_shared(inputPrecision, inputActivationsShapes[0]); + const std::vector& dequantizations) { + auto X = std::make_shared(inputPrecision, inputActivationsShapes[0]); X->set_friendly_name("X"); std::shared_ptr parent_X = makeQuantizationAndDequantization(X, inputPrecision, @@ -42,7 +39,7 @@ std::shared_ptr RecurrentCellFunction::get( fqOnDatas[0], converts[0], dequantizations[0]); - auto H = std::make_shared(inputPrecision, inputActivationsShapes[1]); + auto H = std::make_shared(inputPrecision, inputActivationsShapes[1]); H->set_friendly_name("H"); std::shared_ptr parent_H = makeQuantizationAndDequantization(H, inputPrecision, @@ -50,10 +47,10 @@ std::shared_ptr RecurrentCellFunction::get( fqOnDatas[1], converts[1], dequantizations[1]); - auto C = std::make_shared(inputPrecision, inputActivationsShapes[2]); + auto C = std::make_shared(inputPrecision, inputActivationsShapes[2]); C->set_friendly_name("C"); - auto W = ngraph::opset5::Constant::create(fqOnDatas[2].empty() ? ngraph::element::u8 : inputPrecision, + auto W = ngraph::opset1::Constant::create(fqOnDatas[2].empty() ? ngraph::element::i8 : inputPrecision, inputWeightsShapes[0], {1}); std::shared_ptr parent_W = makeQuantizationAndDequantization(W, @@ -62,7 +59,7 @@ std::shared_ptr RecurrentCellFunction::get( fqOnDatas[2], converts[2], dequantizations[2]); - auto R = ngraph::opset5::Constant::create(fqOnDatas[2].empty() ? ngraph::element::u8 : inputPrecision, + auto R = ngraph::opset1::Constant::create(fqOnDatas[2].empty() ? ngraph::element::i8 : inputPrecision, inputWeightsShapes[1], {1}); std::shared_ptr parent_R = makeQuantizationAndDequantization(R, @@ -71,8 +68,8 @@ std::shared_ptr RecurrentCellFunction::get( fqOnDatas[3], converts[3], dequantizations[3]); - auto B = ngraph::opset5::Constant::create(inputPrecision, inputWeightsShapes[2], {1}); - auto seq_lengths = ngraph::opset5::Constant::create(element::i32, Shape{1}, {3}); + auto B = ngraph::opset1::Constant::create(inputPrecision, inputWeightsShapes[2], {1}); + auto seq_lengths = ngraph::opset1::Constant::create(element::i32, Shape{1}, {3}); std::shared_ptr rnn_layer; switch (type) {