From 3056b53056d6319666f3fc250bebefb0c4b1a91e Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Sat, 27 Jul 2024 01:05:27 +0100 Subject: [PATCH] [LPT] Quantized LSTMSequence & GRUSequence extended support (#25654) ### Details: - *Low Precision Transformations: Quantized LSTMSequence & GRUSequence extended support* ### Tickets: - Current implementation for: *CVS-146067* - Will be changed in feature request: *CVS-147588* --- .../include/low_precision/broadcast.hpp | 30 +++ .../include/low_precision/recurrent_cell.hpp | 5 +- .../src/broadcast.cpp | 77 +++++++ .../src/layer_transformation.cpp | 1 + .../src/low_precision.cpp | 2 + .../src/markup_precisions.cpp | 5 + .../src/recurrent_cell.cpp | 205 +++++++++++++----- .../tests/broadcast_transformation.cpp | 197 +++++++++++++++++ .../recurrent_cell_transformation.cpp | 4 +- .../recurrent_cell_transformation.cpp | 4 +- .../recurrent_cell_transformation.hpp | 1 + .../recurrent_cell_transformation.cpp | 14 +- .../include/ov_lpt_models/broadcast.hpp | 29 +++ .../include/ov_lpt_models/recurrent_cell.hpp | 8 +- .../ov_lpt_models/src/broadcast.cpp | 62 ++++++ .../ov_lpt_models/src/recurrent_cell.cpp | 40 +++- 16 files changed, 608 insertions(+), 76 deletions(-) create mode 100644 src/common/low_precision_transformations/include/low_precision/broadcast.hpp create mode 100644 src/common/low_precision_transformations/src/broadcast.cpp create mode 100644 src/common/low_precision_transformations/tests/broadcast_transformation.cpp create mode 100644 src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/broadcast.hpp create mode 100644 src/tests/ov_helpers/ov_lpt_models/src/broadcast.cpp diff --git a/src/common/low_precision_transformations/include/low_precision/broadcast.hpp b/src/common/low_precision_transformations/include/low_precision/broadcast.hpp new file mode 100644 index 00000000000000..39ba4052535c29 --- /dev/null +++ b/src/common/low_precision_transformations/include/low_precision/broadcast.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "transparent_base_transformation.hpp" + +namespace ov { +namespace pass { +namespace low_precision { + +/** + * @ingroup ov_transformation_common_api + * @brief BroadcastTransformation propagates dequantization operations through Broadcast operation. + * + * For more details about the transformation, refer to + * [BroadcastTransformation](@ref openvino_docs_OV_UG_lpt_BroadcastTransformation) page + * in the OpenVINO Developer Guide. + */ +class LP_TRANSFORMATIONS_API BroadcastTransformation : public TransparentBaseTransformation { +public: + OPENVINO_RTTI("BroadcastTransformation", "0"); + BroadcastTransformation(const Params& params = Params()); + bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; +}; + +} // namespace low_precision +} // namespace pass +} // namespace ov diff --git a/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp b/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp index 8a305db307c612..22aaf3281c2b94 100644 --- a/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp +++ b/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2022 Intel Corporation +// Copyright (C) 2022-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -23,6 +23,9 @@ class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransform static std::shared_ptr wrap_fake_quantize(const std::shared_ptr parameter); static std::shared_ptr wrap_quantization(const std::shared_ptr parameter); static std::shared_ptr wrap_dequantization(const std::shared_ptr parameter, const bool with_subtract); + +private: + void propagate(TransformationContext& context, const std::shared_ptr node); }; } // namespace low_precision diff --git a/src/common/low_precision_transformations/src/broadcast.cpp b/src/common/low_precision_transformations/src/broadcast.cpp new file mode 100644 index 00000000000000..5e78ca0ef50996 --- /dev/null +++ b/src/common/low_precision_transformations/src/broadcast.cpp @@ -0,0 +1,77 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "low_precision/broadcast.hpp" + +#include + +#include "openvino/opsets/opset1.hpp" +#include "openvino/opsets/opset3.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "low_precision/network_helper.hpp" + +#include "itt.hpp" + +using namespace ov::pass::low_precision; + +BroadcastTransformation::BroadcastTransformation(const Params& params) : TransparentBaseTransformation(params) { + MATCHER_SCOPE(BroadcastTransformation); + auto broadcast1 = pattern::wrap_type({ + pattern::wrap_type(), + ov::pass::pattern::any_input(), + ov::pass::pattern::any_input() }); + + auto broadcast3 = pattern::wrap_type({ + pattern::wrap_type(), + ov::pass::pattern::any_input(), + ov::pass::pattern::any_input() }); + + const auto matcher = std::make_shared(ov::OutputVector{ broadcast1, broadcast3 }); + + ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) { + auto op = m.get_match_root(); + if (transformation_callback(op)) { + return false; + } + return transform(*context, m); + }; + + auto m = std::make_shared(matcher, matcher_name); + this->register_matcher(m, callback); +} + +bool BroadcastTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { + if (!LayerTransformation::canBeTransformed(context, layer)) { + return false; + } + + const auto& dequantization = NetworkHelper::getDequantization(layer, defaultPrecisions); + if (dequantization.empty()) { + return false; + } + + if (dequantization.isPerTensor()) { + return true; + } + + const auto& inputShape = layer->get_input_partial_shape(0); + if (inputShape.rank().is_dynamic() || inputShape[dequantization.channelDimIndex].is_dynamic()) { + return false; + } + + const auto targetShapeConstant = ov::as_type_ptr(layer->get_input_node_shared_ptr(1)); + const auto& targetShape = targetShapeConstant->cast_vector(); + if (targetShape[dequantization.channelDimIndex] != inputShape[dequantization.channelDimIndex].get_length()) { + return false; + } + + const auto axesMappingConstant = ov::as_type_ptr(layer->get_input_node_shared_ptr(2)); + const auto& axesMapping = axesMappingConstant->cast_vector(); + if (static_cast(axesMapping[dequantization.channelDimIndex]) != dequantization.channelDimIndex) { + return false; + } + + return true; +} diff --git a/src/common/low_precision_transformations/src/layer_transformation.cpp b/src/common/low_precision_transformations/src/layer_transformation.cpp index a4c0133c5813c3..4ec573c0f2a6ea 100644 --- a/src/common/low_precision_transformations/src/layer_transformation.cpp +++ b/src/common/low_precision_transformations/src/layer_transformation.cpp @@ -401,6 +401,7 @@ std::shared_ptr LayerTransformation::moveDequantizationAfter( const FakeQuantizeDequantization& dequantization, const bool updateOutputPrecision, const bool moveSubtract) const { + OPENVINO_ASSERT(!dequantization.empty()); const auto result = ov::pass::low_precision::NetworkHelper::moveDequantizationAfter(operation, dequantization, updateOutputPrecision, diff --git a/src/common/low_precision_transformations/src/low_precision.cpp b/src/common/low_precision_transformations/src/low_precision.cpp index bba12f7e389be8..6435f47d12ffec 100644 --- a/src/common/low_precision_transformations/src/low_precision.cpp +++ b/src/common/low_precision_transformations/src/low_precision.cpp @@ -44,6 +44,7 @@ #include "low_precision/assign_and_read_value.hpp" #include "low_precision/avg_pool.hpp" #include "low_precision/batch_to_space.hpp" +#include "low_precision/broadcast.hpp" #include "low_precision/clamp.hpp" #include "low_precision/convolution.hpp" #include "low_precision/convolution_backprop_data.hpp" @@ -240,6 +241,7 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr() }, // TODO: there are conditions { name() }, + { name() }, + { name() }, { name() }, { name() }, { name() }, @@ -192,6 +195,8 @@ bool ov::pass::low_precision::MarkupPrecisions::isSupported(const std::shared_pt { name() }, { name() }, { name() }, + { name() }, + { name() }, { name() }, { name() }, // ? diff --git a/src/common/low_precision_transformations/src/recurrent_cell.cpp b/src/common/low_precision_transformations/src/recurrent_cell.cpp index 7fd40cf2071a0f..cec96044502596 100644 --- a/src/common/low_precision_transformations/src/recurrent_cell.cpp +++ b/src/common/low_precision_transformations/src/recurrent_cell.cpp @@ -1,17 +1,19 @@ -// Copyright (C) 2022 Intel Corporation +// Copyright (C) 2022-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include "low_precision/recurrent_cell.hpp" -#include "openvino/pass/pattern/op/wrap_type.hpp" -#include "openvino/opsets/opset1.hpp" - #include + #include "openvino/core/node.hpp" #include "openvino/opsets/opset1.hpp" +#include "openvino/opsets/opset2.hpp" +#include "openvino/opsets/opset3.hpp" #include "openvino/opsets/opset5.hpp" +#include "openvino/opsets/opset12.hpp" #include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" #include "low_precision/network_helper.hpp" #include "low_precision/rt_info/disable_cleanup_attribute.hpp" @@ -21,50 +23,14 @@ namespace pass { namespace low_precision { RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) : LayerTransformation(params) { - const auto X = ov::pass::pattern::any_input(); - const auto H = ov::pass::pattern::any_input(); const auto C = ov::pass::pattern::any_input(); const auto S = ov::pass::pattern::any_input(); - const auto W = ov::pass::pattern::wrap_type(); - const auto R = ov::pass::pattern::wrap_type(); const auto B = ov::pass::pattern::wrap_type(); - const auto H_as_const = ov::pass::pattern::wrap_type(); - - const auto fq_X = wrap_fake_quantize(X); - const auto fq_H = wrap_fake_quantize(H); - const auto fq_W = wrap_fake_quantize(W); - const auto fq_R = wrap_fake_quantize(R); - - const auto dequantization_X = wrap_dequantization(ov::pass::pattern::any_input(), true); - const auto dequantization_H = wrap_dequantization(ov::pass::pattern::any_input(), true); - const auto dequantization_W = wrap_dequantization(ov::pass::pattern::any_input(), true); - const auto dequantization_R = wrap_dequantization(ov::pass::pattern::any_input(), true); - - const auto dequantization_without_subtract_X = wrap_dequantization(ov::pass::pattern::any_input(), false); - const auto dequantization_without_subtract_H = wrap_dequantization(ov::pass::pattern::any_input(), false); - const auto dequantization_without_subtract_W = wrap_dequantization(ov::pass::pattern::any_input(), false); - const auto dequantization_without_subtract_R = wrap_dequantization(ov::pass::pattern::any_input(), false); - - auto X_in = std::make_shared( - OutputVector{ - fq_X, dequantization_X, dequantization_without_subtract_X - }); - - auto H_in = std::make_shared( - OutputVector{ - H_as_const, fq_H, dequantization_H, dequantization_without_subtract_H - }); - - auto W_in = std::make_shared( - OutputVector{ - fq_W, dequantization_W, dequantization_without_subtract_W - }); - - auto R_in = std::make_shared( - OutputVector{ - fq_R, dequantization_R, dequantization_without_subtract_R - }); + auto X_in = ov::pass::pattern::any_input(); + auto H_in = ov::pass::pattern::any_input(); + auto W_in = ov::pass::pattern::any_input(); + auto R_in = ov::pass::pattern::any_input(); const auto lstm_seq = ov::pass::pattern::wrap_type( {X_in, H_in, C, S, W_in, R_in, B}); @@ -91,8 +57,134 @@ RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) : this->register_matcher(m, callback); } +namespace { + +std::shared_ptr find_fake_quantize_upper(const std::shared_ptr& parent) { + if (auto fq = as_type_ptr(parent)) { + return fq; + } + + if (!NetworkHelper::isPrecisionPreserved(parent)) { + return nullptr; + } + + return find_fake_quantize_upper(parent->get_input_node_shared_ptr(0)); +} + +template +std::string name() { + return Operation::get_type_info_static().name; +} + +bool isSupportedForPerChannelQuantization(const std::shared_ptr& node) { + static const std::unordered_set supportedForPerChannelQuantization = { + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() }, + { name() } + }; + + return supportedForPerChannelQuantization.find(node->get_type_name()) != supportedForPerChannelQuantization.end(); +} + +std::vector> get_supported_precisions(std::shared_ptr lstm) { + // pair fields: + // 0 - input number, + // 1 - input type, `element::undefined` - any precision + if (is_type(lstm)) { + return std::vector>{ {0, element::u8}, { 1, element::u8 }, { 4, element::undefined }, { 5, element::undefined } }; + } else if (is_type(lstm)) { + return std::vector>{ {0, element::u8}, { 1, element::u8 }, { 3, element::undefined }, { 4, element::undefined } }; + } + + OPENVINO_THROW("unsupported operation type: ", lstm->get_type_name()); +} + +} // namespace + +void RecurrentCellTransformation::propagate(TransformationContext& context, const std::shared_ptr node) { + if (!isSupportedForPerChannelQuantization(node)) { + return; + } + + const auto& normalized_node = NetworkHelper::separateInStandaloneBranch(node, defaultPrecisions); + auto dequantization = NetworkHelper::getDequantization(node, defaultPrecisions); + if (dequantization.empty()) { + return; + } + const auto& new_node = moveDequantizationAfter(context, normalized_node, dequantization); + + const auto& new_dequantization = NetworkHelper::getDequantizationBelow(new_node); + if (new_dequantization.empty()) { + return; + } + + for (auto output : new_dequantization.multiply->outputs()) { + for (auto input : output.get_target_inputs()) { + auto child = input.get_node()->shared_from_this(); + propagate(context, child); + } + } +} + bool RecurrentCellTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher& m) { const auto lstm = m.get_match_root(); + const auto inputs = get_supported_precisions(lstm); + for (const auto& input : inputs) { + const auto& parent = lstm->get_input_node_shared_ptr(input.first); + if (!isSupportedForPerChannelQuantization(parent)) { + continue; + } + + const auto& fq = find_fake_quantize_upper(parent); + if (fq != nullptr) { + const auto& quantizationDetails = QuantizationDetails::getDetails(fq); + if ((quantizationDetails.inputLowValues.size() != 1) || (quantizationDetails.inputHighValues.size() != 1) || + (quantizationDetails.outputLowValues.size() != 1) || (quantizationDetails.outputHighValues.size() != 1)) { + continue; + } + + const auto& precisionsAttribute = getAttributeFromOutput(fq); + const auto& precisions = precisionsAttribute.empty() ? + defaultPrecisions : + precisionsAttribute.as().value(); + const auto& dataPrecision = getDataPrecision(fq, quantizationDetails, precisions); + if (dataPrecision.empty() || ((input.second != element::undefined) && (dataPrecision.precision != input.second))) { + return false; + } + + auto result = NetworkHelper::decomposeFakeQuantize( + fq, + dataPrecision.precision, + dataPrecision.min, + dataPrecision.max, + dataPrecision.hasZeroPoint, + updatePrecisions); + auto multiply = std::get<1>(result); + + for (const auto& output : multiply->outputs()) { + for (const auto& input : output.get_target_inputs()) { + const auto input_node = input.get_node(); + propagate(context, input_node->shared_from_this()); + } + } + } + } + if (!canBeTransformed(context, lstm)) { return false; } @@ -154,18 +246,21 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ov:: } bool RecurrentCellTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr lstm) const { - std::shared_ptr W, R; - - if (is_type(lstm)) { - W = lstm->get_input_node_shared_ptr(4); - R = lstm->get_input_node_shared_ptr(5); - } else if (is_type(lstm)) { - W = lstm->get_input_node_shared_ptr(3); - R = lstm->get_input_node_shared_ptr(4); - } else { - return false; - } + const auto inputs = get_supported_precisions(lstm); + for (const auto& index : inputs) { + const auto& input = lstm->get_input_node_ptr(index.first); + if (as_type(input) || as_type(input)) { + continue; + } + const auto dequantization = NetworkHelper::getDequantization(lstm, defaultPrecisions, index.first); + if (dequantization.empty()) { + continue; + } + if ((index.second != element::undefined) && (dequantization.data.get_element_type() != index.second)) { + return false; + } + } return true; } diff --git a/src/common/low_precision_transformations/tests/broadcast_transformation.cpp b/src/common/low_precision_transformations/tests/broadcast_transformation.cpp new file mode 100644 index 00000000000000..7745f38143d440 --- /dev/null +++ b/src/common/low_precision_transformations/tests/broadcast_transformation.cpp @@ -0,0 +1,197 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "layer_transformation.hpp" + +#include +#include + +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "low_precision/broadcast.hpp" +#include "ov_lpt_models/broadcast.hpp" +#include "simple_low_precision_transformer.hpp" + +namespace { +using namespace ov::pass; +using namespace ov::builder::subgraph; +using namespace ov::opset1; +using namespace ov; + +class BroadcastTransformationTestValues { +public: + class Pattern { + public: + ov::element::Type precisionBeforeDequantization; + ov::builder::subgraph::DequantizationOperations dequantizationBefore; + ov::builder::subgraph::DequantizationOperations dequantizationAfter; + }; + + TestTransformationParams params; + Shape tagetShape; + Shape axesMapping; + Pattern actual; + Pattern expected; +}; + +typedef std::tuple< + ov::PartialShape, + bool, + BroadcastTransformationTestValues> BroadcastTransformationParams; + +class BroadcastTransformation : public LayerTransformation, public testing::WithParamInterface { +public: + void SetUp() override { + const ov::PartialShape inputShape = std::get<0>(GetParam()); + const bool v1 = std::get<1>(GetParam()); + const BroadcastTransformationTestValues testValues = std::get<2>(GetParam()); + + // batch update support + auto tagetShape = testValues.tagetShape; + tagetShape[0] = inputShape[0].get_length(); + + actualFunction = BroadcastFunction::get( + v1, + inputShape, + testValues.actual.precisionBeforeDequantization, + testValues.actual.dequantizationBefore, + tagetShape, + testValues.axesMapping, + testValues.actual.dequantizationAfter); + + SimpleLowPrecisionTransformer transform; + transform.add(testValues.params); + transform.transform(actualFunction); + + referenceFunction = BroadcastFunction::get( + v1, + inputShape, + testValues.expected.precisionBeforeDequantization, + testValues.expected.dequantizationBefore, + tagetShape, + testValues.axesMapping, + testValues.expected.dequantizationAfter); + } + + static std::string getTestCaseName(testing::TestParamInfo obj) { + const ov::PartialShape inputShape = std::get<0>(obj.param); + const bool v1 = std::get<1>(obj.param); + const BroadcastTransformationTestValues testValues = std::get<2>(obj.param); + + std::ostringstream result; + result << + v1 << "_" << + inputShape << "_" << + testValues.tagetShape << "_" << + testValues.axesMapping << "_" << + testValues.actual.precisionBeforeDequantization << "_" << + testValues.actual.dequantizationBefore << "_" << + testValues.actual.dequantizationAfter << "_" << + testValues.expected.precisionBeforeDequantization << "_" << + testValues.expected.dequantizationBefore << "_" << + testValues.expected.dequantizationAfter; + return result.str(); + } +}; + +TEST_P(BroadcastTransformation, CompareFunctions) { + actualFunction->validate_nodes_and_infer_types(); + + auto res = compare_functions(actualFunction, referenceFunction, true); + ASSERT_TRUE(res.first) << res.second; + + ASSERT_TRUE(LayerTransformation::allNamesAreUnique(actualFunction)) << "Not all names are unique"; +} + +namespace hw_broadcast { +const std::vector inputShapes = { + { 1, 3, 1, 1 }, + { 4, 3, 1, 1 }, +}; + +const std::vector testValues = { + { + LayerTransformation::createParamsU8I8(), + { 1, 3, 9, 9}, + { 0, 1, 2, 3 }, + { + ov::element::u8, + {{ov::element::f32}, {0.1f}, {0.2f}}, + {{}, {}, {}}, + }, + { + ov::element::u8, + {{}, {}, {}}, + {{ov::element::f32}, {0.1f}, {0.2f}} + } + }, + { + LayerTransformation::createParamsU8I8(), + { 1, 3, 9, 9 }, + { 0, 1, 2, 3 }, + { + ov::element::u8, + { + {ov::element::f32}, + {{0.1f, 0.2f, 0.3f}}, + {{0.4f, 0.5f, 0.6f}} + } + }, + { + ov::element::u8, + { {}, {}, {}}, + { + {ov::element::f32}, + {{0.1f, 0.2f, 0.3f}}, + {{0.4f, 0.5f, 0.6f}} + } + } + } +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + BroadcastTransformation, + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::ValuesIn({ true, false }), + ::testing::ValuesIn(testValues)), + BroadcastTransformation::getTestCaseName); +} // hw_broadcast + +namespace chw_broadcast { +const std::vector inputShapes = { + { 1, 1, 1, 1 } +}; + +const std::vector testValues = { + { + LayerTransformation::createParamsU8I8(), + { 1, 9, 9, 9}, + { 0, 1, 2, 3 }, + { + ov::element::u8, + {{ov::element::f32}, {0.1f}, {0.2f}}, + {{}, {}, {}}, + }, + { + ov::element::u8, + {{}, {}, {}}, + {{ov::element::f32}, {0.1f}, {0.2f}} + } + } +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + BroadcastTransformation, + ::testing::Combine( + ::testing::ValuesIn(inputShapes), + ::testing::ValuesIn({ true, false }), + ::testing::ValuesIn(testValues)), + BroadcastTransformation::getTestCaseName); +} // chw_broadcast + +} // namespace diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp index ae5c19559e5a7b..066d81d1f37f36 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2022 Intel Corporation +// Copyright (C) 2022-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -92,6 +92,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation, ::testing::ValuesIn(weights_shapes), ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn({ true, false }), ::testing::ValuesIn(params)), RecurrentCellTransformation::getTestCaseName); } // namespace testValues1 @@ -171,6 +172,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation, ::testing::ValuesIn(weights_shapes), ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn({ true, false }), ::testing::ValuesIn(params)), RecurrentCellTransformation::getTestCaseName); } // namespace testValues2 diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp index afda5292e69c60..85f8d79e7ace31 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2022 Intel Corporation +// Copyright (C) 2022-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -94,6 +94,7 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, ::testing::ValuesIn(weights_shapes), ::testing::Values(ov::test::utils::DEVICE_GPU), ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn({ true, false }), ::testing::ValuesIn(params)), RecurrentCellTransformation::getTestCaseName); } // namespace testValues1 @@ -174,6 +175,7 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, ::testing::ValuesIn(weights_shapes), ::testing::Values(ov::test::utils::DEVICE_GPU), ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn({ true, false }), ::testing::ValuesIn(params)), RecurrentCellTransformation::getTestCaseName); } // namespace testValues2 diff --git a/src/tests/functional/plugin/shared/include/low_precision_transformations/recurrent_cell_transformation.hpp b/src/tests/functional/plugin/shared/include/low_precision_transformations/recurrent_cell_transformation.hpp index d0452c9da1b638..82a8795698bb36 100644 --- a/src/tests/functional/plugin/shared/include/low_precision_transformations/recurrent_cell_transformation.hpp +++ b/src/tests/functional/plugin/shared/include/low_precision_transformations/recurrent_cell_transformation.hpp @@ -42,6 +42,7 @@ typedef std::tuple< std::vector, std::string, ov::pass::low_precision::LayerTransformation::Params, + bool, // use precision transparent operations RecurrentCellTransformationParam >RecurrentCellTransformationParams; diff --git a/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp b/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp index e94663bf2b8596..692a00877c3368 100644 --- a/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp +++ b/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp @@ -21,14 +21,16 @@ std::string RecurrentCellTransformation::getTestCaseName(testing::TestParamInfo< std::string targetDevice; RecurrentCellTransformationParam param; ov::pass::low_precision::LayerTransformation::Params params; - std::tie(netPrecision, activationsShape, weightsShape, targetDevice, params, param) = obj.param; + bool addPrecisionTransparentOperations; + std::tie(netPrecision, activationsShape, weightsShape, targetDevice, params, addPrecisionTransparentOperations, param) = obj.param; std::ostringstream result; result << get_test_case_name_by_params(netPrecision, activationsShape[0], targetDevice, params) << "FQ_X_" << param.fakeQuantize_X << "_" << "DQ_X_" << param.dequantization_X << "_" << "FQ_W_" << param.fakeQuantize_W << "_" << - "DQ_W_" << param.dequantization_W; + "DQ_W_" << param.dequantization_W << "_" << + "PTO" << addPrecisionTransparentOperations; return result.str(); } @@ -37,9 +39,10 @@ void RecurrentCellTransformation::SetUp() { std::vector activations_shapes; std::vector weights_shapes; RecurrentCellTransformationParam param; + bool addPrecisionTransparentOperations; ov::pass::low_precision::LayerTransformation::Params params; - std::tie(precision, activations_shapes, weights_shapes, targetDevice, params, param) = this->GetParam(); + std::tie(precision, activations_shapes, weights_shapes, targetDevice, params, addPrecisionTransparentOperations, param) = this->GetParam(); init_input_shapes(activations_shapes); @@ -64,13 +67,14 @@ void RecurrentCellTransformation::SetUp() { param.dequantization_H, param.dequantization_W, param.dequantization_R - }); + }, + addPrecisionTransparentOperations); } void RecurrentCellTransformation::run() { LayerTransformation::run(); - const auto params = std::get<5>(GetParam()); + const auto params = std::get<6>(GetParam()); const auto actualPrecision = get_runtime_precision_by_type(params.layerName); auto expectedPrecision = params.expectedKernelType; if (expectedPrecision == "FP32" && std::get<0>(GetParam()) == ov::element::f16) { diff --git a/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/broadcast.hpp b/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/broadcast.hpp new file mode 100644 index 00000000000000..4384fecd089ea6 --- /dev/null +++ b/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/broadcast.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "low_precision/layer_transformation.hpp" +#include "ov_lpt_models/common/dequantization_operations.hpp" + +namespace ov { +namespace builder { +namespace subgraph { + +class BroadcastFunction { +public: + static std::shared_ptr get( + const bool v1, + const ov::PartialShape& inputShape, + const ov::element::Type precisionBeforeDequantization, + const ov::builder::subgraph::DequantizationOperations& dequantizationBefore, + const Shape& tagetShape, + const Shape& axesMapping, + const ov::builder::subgraph::DequantizationOperations& dequantizationAfter); +}; + +} // namespace subgraph +} // namespace builder +} // namespace ov diff --git a/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/recurrent_cell.hpp b/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/recurrent_cell.hpp index da98410c55d13c..57ffdedc4c0eb6 100644 --- a/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/recurrent_cell.hpp +++ b/src/tests/ov_helpers/ov_lpt_models/include/ov_lpt_models/recurrent_cell.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2022 Intel Corporation +// Copyright (C) 2022-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -25,7 +25,8 @@ class RecurrentCellFunction { const RNNType type, const std::vector& fqOnDatas, const std::vector& converts, - const std::vector& dequantizations); + const std::vector& dequantizations, + const bool addPrecisionTransparentOperations = false); }; std::shared_ptr makeQuantizationAndDequantization(const std::shared_ptr input, @@ -33,7 +34,8 @@ std::shared_ptr makeQuantizationAndDequantization(const std::shared_ptr +std::shared_ptr make_broadcast(const std::shared_ptr& parent, const Shape& tagetShape, const Shape& axesMapping) { + return std::make_shared( + parent, + std::make_shared(ov::element::i32, Shape{ tagetShape.size() }, tagetShape), + std::make_shared(ov::element::i32, Shape{ axesMapping.size() }, axesMapping)); +} +} // namespace + +std::shared_ptr BroadcastFunction::get( + const bool v1, + const ov::PartialShape& inputShape, + const ov::element::Type precisionBeforeDequantization, + const ov::builder::subgraph::DequantizationOperations& dequantizationBefore, + const Shape& tagetShape, + const Shape& axesMapping, + const ov::builder::subgraph::DequantizationOperations& dequantizationAfter) { + const auto input = std::make_shared(precisionBeforeDequantization, inputShape); + std::shared_ptr parent = input; + + if (!dequantizationBefore.empty()) { + parent = makeDequantization(parent, dequantizationBefore); + } + + parent = v1 ? + make_broadcast(parent, tagetShape, axesMapping) : + make_broadcast(parent, tagetShape, axesMapping); + parent->set_friendly_name("broadcast"); + + if (!dequantizationAfter.empty()) { + parent = makeDequantization(parent, dequantizationAfter); + } + + const std::shared_ptr result = std::make_shared(parent); + + const std::shared_ptr function = std::make_shared( + ov::ResultVector{ result }, + std::vector> { input }, + "BroadcastTransformation"); + return function; +} + +} // namespace subgraph +} // namespace builder +} // namespace ov diff --git a/src/tests/ov_helpers/ov_lpt_models/src/recurrent_cell.cpp b/src/tests/ov_helpers/ov_lpt_models/src/recurrent_cell.cpp index 7be3fca1217403..7a3537c91f3824 100644 --- a/src/tests/ov_helpers/ov_lpt_models/src/recurrent_cell.cpp +++ b/src/tests/ov_helpers/ov_lpt_models/src/recurrent_cell.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2022 Intel Corporation +// Copyright (C) 2022-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -30,7 +30,8 @@ std::shared_ptr RecurrentCellFunction::get( const RNNType type, const std::vector& fqOnDatas, const std::vector& converts, - const std::vector& dequantizations) { + const std::vector& dequantizations, + const bool addPrecisionTransparentOperations) { auto X = std::make_shared(inputPrecision, inputActivationsShapes[0]); X->set_friendly_name("X"); std::shared_ptr parent_X = makeQuantizationAndDequantization(X, @@ -46,7 +47,8 @@ std::shared_ptr RecurrentCellFunction::get( H->get_friendly_name(), fqOnDatas[1], converts[1], - dequantizations[1]); + dequantizations[1], + addPrecisionTransparentOperations); auto C = std::make_shared(inputPrecision, inputActivationsShapes[2]); C->set_friendly_name("C"); @@ -58,7 +60,8 @@ std::shared_ptr RecurrentCellFunction::get( W->get_friendly_name(), fqOnDatas[2], converts[2], - dequantizations[2]); + dequantizations[2], + addPrecisionTransparentOperations); auto R = ov::opset1::Constant::create(fqOnDatas[2].empty() ? ov::element::i8 : inputPrecision, inputWeightsShapes[1], {1}); @@ -127,12 +130,20 @@ std::shared_ptr makeQuantizationAndDequantization(const std::shared_ptr parent; - if (fqOnData.empty()) { - parent = input; - } else { - std::shared_ptr fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input, inputPrecision, fqOnData); + const DequantizationOperations& dequantization, + const bool addPrecisionTransparentOperations) { + std::shared_ptr parent = input; + if (addPrecisionTransparentOperations) { + auto shape = input->get_output_shape(0); + std::swap(shape[shape.size() - 2], shape[shape.size() - 1]); + parent = std::make_shared( + parent, + std::make_shared(element::u32, Shape({ shape.size() }), shape), + true); + } + + if (!fqOnData.empty()) { + std::shared_ptr fakeQuantize1 = makeFakeQuantizeTypeRelaxed(parent, inputPrecision, fqOnData); fakeQuantize1->set_friendly_name("fakeQuantize_" + friendly_name); parent = fakeQuantize1; } @@ -142,6 +153,15 @@ std::shared_ptr makeQuantizationAndDequantization(const std::shared_ptrget_output_shape(0); + parent = std::make_shared( + parent, + std::make_shared(element::u32, Shape({ shape.size() }), shape), + true); + } + return parent; }