diff --git a/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp b/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp new file mode 100644 index 00000000000000..d04319a6067ad1 --- /dev/null +++ b/src/common/low_precision_transformations/include/low_precision/recurrent_cell.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include "low_precision/layer_transformation.hpp" + +namespace ngraph { +namespace pass { +namespace low_precision { + +class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransformation { +public: + OPENVINO_RTTI("RecurrentCellTransformation", "0"); + RecurrentCellTransformation(const Params& params = Params()); + bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override; + bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; + bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; + void propagateSkipCleanupAttribute(std::shared_ptr dequantization_multiply); + static std::shared_ptr wrap_fake_quantize(const std::shared_ptr parameter); + static std::shared_ptr wrap_quantization(const std::shared_ptr parameter); + static std::shared_ptr wrap_dequantization(const std::shared_ptr parameter, const bool with_subtract); +}; + +} // namespace low_precision +} // namespace pass +} // namespace ngraph diff --git a/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp b/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp new file mode 100644 index 00000000000000..7b8bb985b2033e --- /dev/null +++ b/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "low_precision/rt_info/attribute_parameters.hpp" + +namespace ngraph { +class LP_TRANSFORMATIONS_API SkipCleanupAttribute : public ov::RuntimeAttribute { +public: + OPENVINO_RTTI("LowPrecision::SkipCleanup", "", ov::RuntimeAttribute, 0); + static ov::Any create(const std::shared_ptr& node); +}; +} // namespace ngraph diff --git a/src/common/low_precision_transformations/src/fuse_convert.cpp b/src/common/low_precision_transformations/src/fuse_convert.cpp index 875c56efea1001..ddf72fbf084589 100644 --- a/src/common/low_precision_transformations/src/fuse_convert.cpp +++ b/src/common/low_precision_transformations/src/fuse_convert.cpp @@ -13,6 +13,7 @@ #include "low_precision/common/ie_lpt_exception.hpp" #include "low_precision/network_helper.hpp" #include "itt.hpp" +#include "low_precision/rt_info/skip_cleanup_attribute.hpp" namespace ngraph { namespace pass { @@ -113,6 +114,10 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ngraph } bool FuseConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr op) const { + if (!getAttribute(op).empty()) { + return false; + } + const auto convert = ov::as_type_ptr(op->get_input_node_shared_ptr(0)); // issue #40395 if (convert == nullptr) { diff --git a/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp b/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp index f2e17fb5bfd27b..faf55632fc989a 100644 --- a/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp +++ b/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp @@ -10,6 +10,7 @@ #include "low_precision/fake_quantize.hpp" #include "low_precision/network_helper.hpp" #include "itt.hpp" +#include "low_precision/rt_info/skip_cleanup_attribute.hpp" namespace ngraph { namespace pass { @@ -98,6 +99,10 @@ bool FuseMultiplyToFakeQuantizeTransformation::canBeTransformed(const Transforma return false; } + if (!getAttribute(operation).empty()) { + return false; + } + const auto parent = operation->get_input_node_shared_ptr(0); auto fq = ov::as_type_ptr(parent); const auto convert = ov::as_type_ptr(parent); diff --git a/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp b/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp index 1a727681052156..7eb21bad1a5355 100644 --- a/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp +++ b/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp @@ -9,6 +9,7 @@ #include "low_precision/fake_quantize.hpp" #include "low_precision/network_helper.hpp" #include "itt.hpp" +#include "low_precision/rt_info/skip_cleanup_attribute.hpp" namespace ngraph { namespace pass { @@ -92,6 +93,10 @@ bool FuseSubtractToFakeQuantizeTransformation::canBeTransformed(const Transforma return false; } + if (!getAttribute(operation).empty()) { + return false; + } + const auto children = operation->get_output_target_inputs(0); for (const auto& target : children) { diff --git a/src/common/low_precision_transformations/src/low_precision.cpp b/src/common/low_precision_transformations/src/low_precision.cpp index 507879aa90ffd4..f26084c390c107 100644 --- a/src/common/low_precision_transformations/src/low_precision.cpp +++ b/src/common/low_precision_transformations/src/low_precision.cpp @@ -54,6 +54,7 @@ #include "low_precision/normalize_l2.hpp" #include "low_precision/pad.hpp" #include "low_precision/prelu.hpp" +#include "low_precision/recurrent_cell.hpp" #include "low_precision/reduce_max.hpp" #include "low_precision/reduce_mean.hpp" #include "low_precision/reduce_min.hpp" @@ -227,6 +228,7 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p common->add_matcher(params); common->add_matcher(params); common->add_matcher(params); + common->add_matcher(params); common->add_matcher(params); common->add_matcher(params); common->add_matcher(params); diff --git a/src/common/low_precision_transformations/src/recurrent_cell.cpp b/src/common/low_precision_transformations/src/recurrent_cell.cpp new file mode 100644 index 00000000000000..5170de885579a3 --- /dev/null +++ b/src/common/low_precision_transformations/src/recurrent_cell.cpp @@ -0,0 +1,215 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "low_precision/recurrent_cell.hpp" + +#include +#include + +#include +#include +#include +#include +#include + +#include "low_precision/network_helper.hpp" +#include "../include/low_precision/rt_info/skip_cleanup_attribute.hpp" + +namespace ngraph { +namespace pass { +namespace low_precision { + +RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) : LayerTransformation(params) { + const auto X = ngraph::pattern::wrap_type(); + const auto H = ngraph::pattern::wrap_type(); + const auto C = ngraph::pattern::wrap_type(); + const auto W = ngraph::pattern::wrap_type(); + const auto R = ngraph::pattern::wrap_type(); + const auto B = ngraph::pattern::wrap_type(); + + const auto fq_X = wrap_fake_quantize(X); + const auto fq_H = wrap_fake_quantize(H); + const auto fq_W = wrap_fake_quantize(W); + const auto fq_R = wrap_fake_quantize(R); + + const auto dequantization_X = wrap_dequantization(ngraph::pattern::any_input(), true); + const auto dequantization_H = wrap_dequantization(ngraph::pattern::any_input(), true); + + const auto dequantization_without_subtract_X = wrap_dequantization(ngraph::pattern::any_input(), false); + const auto dequantization_without_subtract_H = wrap_dequantization(ngraph::pattern::any_input(), false); + + const auto lstm_cell = ngraph::pattern::wrap_type( + {fq_X, fq_H, C, fq_W, fq_R, B}); + const auto lstm_cell_with_dequantizations = ngraph::pattern::wrap_type( + {dequantization_X, dequantization_H, C, fq_W, fq_R, B}); + const auto lstm_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type( + {dequantization_without_subtract_X, dequantization_without_subtract_H, C, fq_W, fq_R, B}); + + const auto gru_cell = ngraph::pattern::wrap_type({fq_X, fq_H, fq_W, fq_R, B}); + const auto gru_cell_with_dequantizations = ngraph::pattern::wrap_type( + {dequantization_X, dequantization_H, fq_W, fq_R, B}); + const auto gru_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type( + {dequantization_without_subtract_X, dequantization_without_subtract_H, fq_W, fq_R, B}); + + const auto rnn_cell = ngraph::pattern::wrap_type({fq_X, fq_H, fq_W, fq_R, B}); + const auto rnn_cell_with_dequantizations = ngraph::pattern::wrap_type( + {dequantization_X, dequantization_H, fq_W, fq_R, B}); + const auto rnn_cell_with_dequantizations_without_subtract = ngraph::pattern::wrap_type( + {dequantization_without_subtract_X, dequantization_without_subtract_H, fq_W, fq_R, B}); + + ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { + auto op = m.get_match_root(); + if (transformation_callback(op)) { + return false; + } + + return transform(*context, m); + }; + + auto m = std::make_shared( + std::make_shared(OutputVector{lstm_cell, + lstm_cell_with_dequantizations, + lstm_cell_with_dequantizations_without_subtract, + gru_cell, + gru_cell_with_dequantizations, + gru_cell_with_dequantizations_without_subtract, + rnn_cell, + rnn_cell_with_dequantizations, + rnn_cell_with_dequantizations_without_subtract}), + "LSTM"); + this->register_matcher(m, callback); +} + +bool RecurrentCellTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) { + const auto lstm = m.get_match_root(); + if (!canBeTransformed(context, lstm)) { + return false; + } + for (size_t parentIndex = 0ul; parentIndex < lstm->get_input_size(); parentIndex++) { + auto lstm_parent = lstm->get_input_node_shared_ptr(parentIndex); + if (is_type(lstm_parent)) { + auto fq_parent = lstm_parent->get_input_node_shared_ptr(0); + if (is_type(fq_parent)) { + auto fq_node = as_type_ptr(lstm_parent); + const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fq_node); + const auto precisionsAttribute = getAttributeFromOutput(lstm_parent); + const auto precisions = precisionsAttribute.empty() + ? defaultPrecisions + : precisionsAttribute.as().value(); + const DataPrecision dataPrecision = getDataPrecision(lstm_parent, quantizationDetails, precisions); + auto QDQ = NetworkHelper::decomposeFakeQuantize(fq_node, + dataPrecision.precision, + dataPrecision.min, + dataPrecision.max, + dataPrecision.hasZeroPoint, + updatePrecisions); + std::shared_ptr new_fq = std::get<0>(QDQ); + std::shared_ptr deq_multiply = std::get<1>(QDQ); + auto multiply_parent = deq_multiply->get_input_node_shared_ptr(0); + ov::disable_constant_folding(multiply_parent); + propagateSkipCleanupAttribute(deq_multiply); + this->register_new_node(new_fq); + updateOutput(context, deq_multiply, new_fq); + } else { + continue; + } + } else { + if (is_type(lstm_parent)) { + propagateSkipCleanupAttribute(lstm_parent); + } + continue; + } + } + return true; +} + +bool RecurrentCellTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr lstm) const { + std::shared_ptr W, R; + if (is_type(lstm)) { + W = lstm->get_input_node_shared_ptr(3); + R = lstm->get_input_node_shared_ptr(4); + } else { + W = lstm->get_input_node_shared_ptr(2); + R = lstm->get_input_node_shared_ptr(3); + } + for (auto fq_on_weight : {W, R}) { + auto fq_node = as_type_ptr(fq_on_weight); + const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fq_node); + const auto precisionsAttribute = getAttributeFromOutput(fq_on_weight); + const auto precisions = precisionsAttribute.empty() + ? defaultPrecisions + : precisionsAttribute.as().value(); + const DataPrecision dataPrecision = getDataPrecision(fq_on_weight, quantizationDetails, precisions); + auto QDQ = NetworkHelper::decomposeFakeQuantize(fq_node, + dataPrecision.precision, + dataPrecision.min, + dataPrecision.max, + dataPrecision.hasZeroPoint, + updatePrecisions); + std::shared_ptr new_fq = std::get<0>(QDQ); + std::shared_ptr deq_multiply = std::get<1>(QDQ); + if (deq_multiply == nullptr || new_fq == nullptr) { + return false; + } + auto multiply_parent = deq_multiply->get_input_node_shared_ptr(0); + if (is_type(multiply_parent)) { + return false; + } + } + return true; +} + +bool RecurrentCellTransformation::isPrecisionPreserved(std::shared_ptr) const noexcept { + return true; +} + +void RecurrentCellTransformation::propagateSkipCleanupAttribute(std::shared_ptr multiply) { + SkipCleanupAttribute::create(multiply); + auto multiply_parent = multiply->get_input_node_shared_ptr(0); + SkipCleanupAttribute::create(multiply_parent); + if (is_type(multiply_parent)) { + auto subtract_parent = multiply_parent->get_input_node_shared_ptr(0); + SkipCleanupAttribute::create(subtract_parent); + } +} + +std::shared_ptr RecurrentCellTransformation::wrap_fake_quantize( + const std::shared_ptr parameter) { + const auto input_low = ngraph::pattern::wrap_type(); + const auto input_high = ngraph::pattern::wrap_type(); + const auto output_low = ngraph::pattern::wrap_type(); + const auto output_high = ngraph::pattern::wrap_type(); + return ngraph::pattern::wrap_type({ + parameter, + input_low, + input_high, + output_low, + output_high}); +} + +std::shared_ptr RecurrentCellTransformation::wrap_quantization( + const std::shared_ptr parameter) { + const auto quantization_fake_quantize = wrap_fake_quantize(parameter); + const auto quantization_convert = ngraph::pattern::wrap_type( + {quantization_fake_quantize}); + return quantization_convert; +} + +std::shared_ptr RecurrentCellTransformation::wrap_dequantization( + const std::shared_ptr parameter, + const bool with_subtract) { + const auto dequantization_convert = ngraph::pattern::wrap_type({parameter}); + const auto subtract_constant = ngraph::pattern::wrap_type(); + const auto dequantization_subtract = ngraph::pattern::wrap_type( + {dequantization_convert, subtract_constant}); + const auto multiply_constant = ngraph::pattern::wrap_type(); + const auto multiply_parent = with_subtract ? dequantization_subtract : dequantization_convert; + const auto dequantization_multiply = ngraph::pattern::wrap_type( + {multiply_parent, multiply_constant}); + return dequantization_multiply; +} + +} // namespace low_precision +} // namespace pass +} // namespace ngraph diff --git a/src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp b/src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp new file mode 100644 index 00000000000000..32cc0b118d598f --- /dev/null +++ b/src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp @@ -0,0 +1,20 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "low_precision/rt_info/skip_cleanup_attribute.hpp" + +#include +#include +#include +#include +#include + +using namespace ngraph; +using namespace ov; + +ov::Any SkipCleanupAttribute::create( + const std::shared_ptr& node) { + auto& rt = node->get_rt_info(); + return (rt[SkipCleanupAttribute::get_type_info_static()] = SkipCleanupAttribute()); +} diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index 8e68f75d8cb4af..43f43b937ef8bc 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -88,6 +88,7 @@ #include #include #include +#include #include #include #include @@ -461,6 +462,18 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr {0, {ngraph::element::u8, ngraph::element::i8}}, {1, {ngraph::element::i8}} }), + PrecisionsRestriction::create({ + {0, {ngraph::element::u8}}, + {1, {ngraph::element::i8}}, + }), + PrecisionsRestriction::create({ + {0, {ngraph::element::u8}}, + {1, {ngraph::element::i8}}, + }), + PrecisionsRestriction::create({ + {0, {ngraph::element::u8}}, + {1, {ngraph::element::i8}}, + }), }); auto quantizationRestrictions = std::vector({ diff --git a/src/tests/functional/inference_engine/lp_transformations/recurrent_cell_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/recurrent_cell_transformation.cpp new file mode 100644 index 00000000000000..e7c1d0f62ac344 --- /dev/null +++ b/src/tests/functional/inference_engine/lp_transformations/recurrent_cell_transformation.cpp @@ -0,0 +1,589 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "layer_transformation.hpp" +#include "lpt_ngraph_functions/common/builders.hpp" +#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" +#include "lpt_ngraph_functions/recurrent_cell_function.hpp" +#include "simple_low_precision_transformer.hpp" +#include + +using namespace testing; +using namespace ngraph; +using namespace ngraph::pass; +using namespace ngraph::builder::subgraph; + +namespace { + +class RecurrentCellTransformationValues { +public: + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_X; + ngraph::builder::subgraph::DequantizationOperations::Convert convert_X; + ngraph::builder::subgraph::DequantizationOperations dequantization_X; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_H; + ngraph::builder::subgraph::DequantizationOperations::Convert convert_H; + ngraph::builder::subgraph::DequantizationOperations dequantization_H; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_W; + ngraph::builder::subgraph::DequantizationOperations::Convert convert_W; + ngraph::builder::subgraph::DequantizationOperations dequantization_W; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_R; + ngraph::builder::subgraph::DequantizationOperations::Convert convert_R; + ngraph::builder::subgraph::DequantizationOperations dequantization_R; + ngraph::element::Type precisionAfterOperation; + ngraph::builder::subgraph::DequantizationOperations dequantizationAfter; +}; + +inline std::ostream& operator<<(std::ostream& out, const RecurrentCellTransformationValues& values) { + return out << "_" << values.fakeQuantize_X << "_" << values.convert_X << "_" << values.dequantization_X << + "_" << values.fakeQuantize_H << "_" << values.convert_H << "_" << values.dequantization_H << + "_" << values.fakeQuantize_W << "_" << values.convert_W << "_" << values.dequantization_W << + "_" << values.fakeQuantize_R << "_" << values.convert_R << "_" << values.dequantization_R; +} + +class RecurrentCellTransformationTestValues { +public: + RecurrentCellTransformationTestValues() = default; + RecurrentCellTransformationTestValues(const TestTransformationParams& params, + const RecurrentCellFunction::RNNType type, + const RecurrentCellTransformationValues& actual, + const RecurrentCellTransformationValues& result, + const bool addNotPrecisionPreservedOperation = false, + const bool checkIntervalsAlignmentAttributes = true) + : params(params), + type(type), + actual(actual), + result(result) {} + + TestTransformationParams params; + RecurrentCellFunction::RNNType type; + RecurrentCellTransformationValues actual; + RecurrentCellTransformationValues result; +}; + +inline std::ostream& operator<<(std::ostream& out, const RecurrentCellTransformationTestValues& values) { + return out << "_" << values.actual << "_" << values.result; +} + +typedef std::tuple, std::vector, RecurrentCellTransformationTestValues> + RecurrentCellTransformationParams; + +class RecurrentCellTransformation : public LayerTransformation, public testing::WithParamInterface { +public: + void SetUp() override { + const ngraph::element::Type precision = std::get<0>(GetParam()); + const std::vector activations_shapes = std::get<1>(GetParam()); + const std::vector weights_shapes = std::get<2>(GetParam()); + RecurrentCellTransformationTestValues testValues = std::get<3>(GetParam()); + + actualFunction = ngraph::builder::subgraph::RecurrentCellFunction::get(precision, + activations_shapes, + weights_shapes, + testValues.type, + { + testValues.actual.fakeQuantize_X, + testValues.actual.fakeQuantize_H, + testValues.actual.fakeQuantize_W, + testValues.actual.fakeQuantize_R + }, + { + testValues.actual.convert_X, + testValues.actual.convert_H, + testValues.actual.convert_W, + testValues.actual.convert_R + }, + { + testValues.actual.dequantization_X, + testValues.actual.dequantization_H, + testValues.actual.dequantization_W, + testValues.actual.dequantization_R + }); + + const auto params = TestTransformationParams::toParams(testValues.params); + + SimpleLowPrecisionTransformer transformer; + transformer.commonGraphRewrite->add_matcher(params); + transformer.transform(actualFunction); + + SimpleLowPrecisionTransformer clenup_transformer; + clenup_transformer.commonGraphRewrite->add_matcher(params); + clenup_transformer.commonGraphRewrite->add_matcher(params); + clenup_transformer.commonGraphRewrite->add_matcher(params); + clenup_transformer.commonGraphRewrite->add_matcher(params); + clenup_transformer.transform(actualFunction); + + // dequantization output precision depends on input precision + // to avoid huge amount of tests cases let's define dequantization output precision as input precision + if (!testValues.result.dequantizationAfter.multiply.empty()) { + testValues.result.dequantizationAfter.multiply.outPrecision = precision; + } + + referenceFunction = + ngraph::builder::subgraph::RecurrentCellFunction::get(precision, + activations_shapes, + weights_shapes, + testValues.type, + { + testValues.result.fakeQuantize_X, + testValues.result.fakeQuantize_H, + testValues.result.fakeQuantize_W, + testValues.result.fakeQuantize_R + }, + { + testValues.result.convert_X, + testValues.result.convert_H, + testValues.result.convert_W, + testValues.result.convert_R + }, + { + testValues.result.dequantization_X, + testValues.result.dequantization_H, + testValues.result.dequantization_W, + testValues.result.dequantization_R + }); + } + + static std::string getTestCaseName(testing::TestParamInfo obj) { + const ngraph::element::Type precision = std::get<0>(obj.param); + const std::vector activations_shapes = std::get<1>(obj.param); + const std::vector weights_shapes = std::get<2>(obj.param); + const RecurrentCellTransformationTestValues testValues = std::get<3>(obj.param); + + std::ostringstream result; + result << LayerTransformation::getTestCaseNameByParams(precision, activations_shapes[0], testValues.params) + << "_" << testValues.actual << "_" << testValues.result << "_"; + return result.str(); + } +}; + +TEST_P(RecurrentCellTransformation, CompareFunctions) { + actualFunction->validate_nodes_and_infer_types(); + auto res = compare_functions(actualFunction, referenceFunction); + ASSERT_TRUE(res.first) << res.second; + + ASSERT_TRUE(LayerTransformation::allNamesAreUnique(actualFunction)) << "Not all names are unique"; +} + +const std::vector precisions = { + ngraph::element::f32, + // ngraph::element::f16 +}; + +namespace testValues1 { +const std::vector> activations_shapes = {{{1, 1}, {1, 1}, {1, 1}}}; + +const std::vector> weights_shapes = {{{4, 1}, {4, 1}, {4}}}; + +const std::vector testValues = { + // LSTM Cell + {LayerTransformation::createParamsU8I8(), + RecurrentCellFunction::RNNType::LSTMCell, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + }, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + {}, + {}, + { + {element::f32}, + {}, + {0.01f} + }, + // R + {}, + {}, + { + {element::f32}, + {}, + {0.01f} + }, + } + }, + // multi-channel fake quantizes on weights + {LayerTransformation::createParamsU8I8(), + RecurrentCellFunction::RNNType::LSTMCell, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + { 255ul, + {{4, 1}, {4, 1}, {4, 1}, {4, 1}}, + {-1.27f, -1.27f, -1.27f, -1.27f}, + {1.27f, 1.27f, 1.27f, 1.27f}, + {-1.27f, -1.27f, -1.27f, -1.27f}, + {1.27f, 1.27f, 1.27f, 1.27f}}, + {}, + {{}, {}, {}}, + // R + { 255ul, + {{4, 1}, {4, 1}, {4, 1}, {4, 1}}, + {-1.27f, -1.27f, -1.27f, -1.27f}, + {1.27f, 1.27f, 1.27f, 1.27f}, + {-1.27f, -1.27f, -1.27f, -1.27f}, + {1.27f, 1.27f, 1.27f, 1.27f}}, + {}, + {{}, {}, {}}, + }, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + {}, + {}, + { + {element::f32}, + {}, + {{0.01f / 1.f, 0.01f / 2.f, 0.01f / 3.f, 0.01f / 4.f}, ngraph::element::f32, {4, 1}} + }, + // R + {}, + {}, + { + {element::f32}, + {}, + {{0.01f / 1.f, 0.01f / 2.f, 0.01f / 3.f, 0.01f / 4.f}, ngraph::element::f32, {4, 1}} + }, + } + }, +}; +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + RecurrentCellTransformation, + ::testing::Combine( + ::testing::ValuesIn(precisions), + ::testing::ValuesIn(activations_shapes), + ::testing::ValuesIn(weights_shapes), + ::testing::ValuesIn(testValues)), + RecurrentCellTransformation::getTestCaseName); +} // namespace testValues1 + +namespace testValues2 { +const std::vector> activations_shapes = {{{1, 1, 16}, {1, 1, 128}, {1, 1, 128}}}; + +const std::vector> weights_shapes = {{{1, 512, 16}, {1, 512, 128}, {1, 512}}}; + +const std::vector testValues = { + // LSTM Sequence + {LayerTransformation::createParamsU8I8(), + RecurrentCellFunction::RNNType::LSTMSequence, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + }, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + {}, + {}, + { + {element::f32}, + {}, + {0.01f} + }, + // R + {}, + {}, + { + {element::f32}, + {}, + {0.01f} + }, + } + }, +}; +INSTANTIATE_TEST_SUITE_P( + DISABLED_smoke_LPT, + RecurrentCellTransformation, + ::testing::Combine( + ::testing::ValuesIn(precisions), + ::testing::ValuesIn(activations_shapes), + ::testing::ValuesIn(weights_shapes), + ::testing::ValuesIn(testValues)), + RecurrentCellTransformation::getTestCaseName); +} // namespace testValues2 + +namespace testValues3 { +const std::vector> activations_shapes = {{{2, 3}, {2, 3}, {}}}; + +const std::vector> weights_shapes = {{{9, 3}, {9, 3}, {9}}}; + +const std::vector testValues = { + // GRU + {LayerTransformation::createParamsU8I8(), + RecurrentCellFunction::RNNType::GRU, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + }, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + {}, + {}, + { + {element::f32}, + {}, + {0.01f} + }, + // R + {}, + {}, + { + {element::f32}, + {}, + {0.01f} + }, + } + } +}; +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + RecurrentCellTransformation, + ::testing::Combine( + ::testing::ValuesIn(precisions), + ::testing::ValuesIn(activations_shapes), + ::testing::ValuesIn(weights_shapes), + ::testing::ValuesIn(testValues)), + RecurrentCellTransformation::getTestCaseName); +} // namespace testValues3 + +namespace testValues4 { +const std::vector> activations_shapes = {{{2, 3}, {2, 3}, {}}}; + +const std::vector> weights_shapes = {{{3, 3}, {3, 3}, {9}}}; + +const std::vector testValues = { + // RNNCell + {LayerTransformation::createParamsU8I8(), + RecurrentCellFunction::RNNType::RNNCell, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + }, + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {element::f32}, + {}, + {0.01f}, + }, + // W + {}, + {}, + { + {element::f32}, + {}, + {0.01f} + }, + // R + {}, + {}, + { + {element::f32}, + {}, + {0.01f} + }, + } + } +}; +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + RecurrentCellTransformation, + ::testing::Combine( + ::testing::ValuesIn(precisions), + ::testing::ValuesIn(activations_shapes), + ::testing::ValuesIn(weights_shapes), + ::testing::ValuesIn(testValues)), + RecurrentCellTransformation::getTestCaseName); +} // namespace testValues4 +} // namespace diff --git a/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp b/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp new file mode 100644 index 00000000000000..8373be4fda0224 --- /dev/null +++ b/src/tests/functional/plugin/cpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp @@ -0,0 +1,256 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "low_precision_transformations/recurrent_cell_transformation.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +const std::vector netPrecisions = { + ngraph::element::f32, + //ngraph::element::f16 +}; + +const std::vector trasformationParamValues = { + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams().setUpdatePrecisions(true) +}; + +namespace testValues1 { + +const std::vector params = { + // LSTMCell + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell, + "RNNCell", + "U8" + }, + // asymmetrical FQ on weights + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell, + "RNNCell", + "FP32" + } +}; + +const std::vector> activations_shapes = {{{1, 16}, {1, 128}, {1, 128}}}; +const std::vector> weights_shapes = {{{512, 16}, {512, 128}, {512}}}; + +INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, + ::testing::Combine( + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(activations_shapes), + ::testing::ValuesIn(weights_shapes), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn(params)), + RecurrentCellTransformation::getTestCaseName); +} // namespace testValues1 + +namespace testValues2 { + +const std::vector params = { + // GRU + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU, + "RNNCell", + "U8" + }, + // asymmetrical FQ on weights + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU, + "RNNCell", + "FP32" + } +}; + +const std::vector> activations_shapes = {{{2, 3}, {2, 3}, {}}}; +const std::vector> weights_shapes = {{{9, 3}, {9, 3}, {9}}}; + +INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, + ::testing::Combine( + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(activations_shapes), + ::testing::ValuesIn(weights_shapes), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn(params)), + RecurrentCellTransformation::getTestCaseName); +} // namespace testValues2 + +namespace testValues3 { + +const std::vector params = { + // RNNCell + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::RNNCell, + "RNNCell", + "U8" + }, + // asymmetrical FQ on weights + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::RNNCell, + "RNNCell", + "FP32" + } +}; + +const std::vector> activations_shapes = {{{2, 3}, {2, 3}, {}}}; +const std::vector> weights_shapes = {{{3, 3}, {3, 3}, {9}}}; + +INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, + ::testing::Combine( + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(activations_shapes), + ::testing::ValuesIn(weights_shapes), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn(params)), + RecurrentCellTransformation::getTestCaseName); +} // namespace testValues3 diff --git a/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp b/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp new file mode 100644 index 00000000000000..2d47018f459ced --- /dev/null +++ b/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/recurrent_cell_transformation.cpp @@ -0,0 +1,256 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "low_precision_transformations/recurrent_cell_transformation.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +const std::vector netPrecisions = { + ngraph::element::f32, + ngraph::element::f16 +}; + +const std::vector trasformationParamValues = { + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams().setUpdatePrecisions(true) +}; + +namespace testValues1 { + +const std::vector params = { + // LSTMCell + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell, + "RNNCell", + "U8" + }, + // asymmetrical FQ on weights + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell, + "RNNCell", + "FP32" + } +}; + +const std::vector> activations_shapes = {{{1, 16}, {1, 128}, {1, 128}}}; +const std::vector> weights_shapes = {{{512, 16}, {512, 128}, {512}}}; + +INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, + ::testing::Combine( + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(activations_shapes), + ::testing::ValuesIn(weights_shapes), + ::testing::Values(CommonTestUtils::DEVICE_GPU), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn(params)), + RecurrentCellTransformation::getTestCaseName); +} // namespace testValues1 + +namespace testValues2 { + +const std::vector params = { + // GRU + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU, + "RNNCell", + "U8" + }, + // asymmetrical FQ on weights + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU, + "RNNCell", + "FP32" + } +}; + +const std::vector> activations_shapes = {{{2, 3}, {2, 3}, {}}}; +const std::vector> weights_shapes = {{{9, 3}, {9, 3}, {9}}}; + +INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, + ::testing::Combine( + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(activations_shapes), + ::testing::ValuesIn(weights_shapes), + ::testing::Values(CommonTestUtils::DEVICE_GPU), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn(params)), + RecurrentCellTransformation::getTestCaseName); +} // namespace testValues2 + +namespace testValues3 { + +const std::vector params = { + // RNNCell + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::RNNCell, + "RNNCell", + "U8" + }, + // asymmetrical FQ on weights + { + // X + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // H + {256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}}, + {ngraph::element::u8}, + { + {ngraph::element::f32}, + {}, + {0.01f}, + }, + // W + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + // R + {256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}}, + {}, + {{}, {}, {}}, + ngraph::builder::subgraph::RecurrentCellFunction::RNNType::RNNCell, + "RNNCell", + "FP32" + } +}; + +const std::vector> activations_shapes = {{{2, 3}, {2, 3}, {}}}; +const std::vector> weights_shapes = {{{3, 3}, {3, 3}, {9}}}; + +INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation, + ::testing::Combine( + ::testing::ValuesIn(netPrecisions), + ::testing::ValuesIn(activations_shapes), + ::testing::ValuesIn(weights_shapes), + ::testing::Values(CommonTestUtils::DEVICE_GPU), + ::testing::ValuesIn(trasformationParamValues), + ::testing::ValuesIn(params)), + RecurrentCellTransformation::getTestCaseName); +} // namespace testValues3 diff --git a/src/tests/functional/plugin/shared/include/low_precision_transformations/recurrent_cell_transformation.hpp b/src/tests/functional/plugin/shared/include/low_precision_transformations/recurrent_cell_transformation.hpp new file mode 100644 index 00000000000000..61510e970387c5 --- /dev/null +++ b/src/tests/functional/plugin/shared/include/low_precision_transformations/recurrent_cell_transformation.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp" +#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" +#include "lpt_ngraph_functions/common/fake_quantize_on_weights.hpp" + +#include "low_precision/recurrent_cell.hpp" + +#include "lpt_ngraph_functions/recurrent_cell_function.hpp" + +namespace LayerTestsDefinitions { + +class RecurrentCellTransformationParam { +public: + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_X; + ngraph::builder::subgraph::DequantizationOperations::Convert convert_X; + ngraph::builder::subgraph::DequantizationOperations dequantization_X; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_H; + ngraph::builder::subgraph::DequantizationOperations::Convert convert_H; + ngraph::builder::subgraph::DequantizationOperations dequantization_H; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_W; + ngraph::builder::subgraph::DequantizationOperations::Convert convert_W; + ngraph::builder::subgraph::DequantizationOperations dequantization_W; + ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_R; + ngraph::builder::subgraph::DequantizationOperations::Convert convert_R; + ngraph::builder::subgraph::DequantizationOperations dequantization_R; + ngraph::builder::subgraph::RecurrentCellFunction::RNNType RNNType; + std::string layerName; + std::string expectedKernelType; +}; + +typedef std::tuple< + ngraph::element::Type, + std::vector, + std::vector, + std::string, + ngraph::pass::low_precision::LayerTransformation::Params, + RecurrentCellTransformationParam +>RecurrentCellTransformationParams; + +class RecurrentCellTransformation : + public testing::WithParamInterface, + public LayerTestsUtils::LayerTransformation { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + +protected: + void SetUp() override; + + void Run() override; +}; + +} // namespace LayerTestsDefinitions diff --git a/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp b/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp new file mode 100644 index 00000000000000..b320bcff56fc64 --- /dev/null +++ b/src/tests/functional/plugin/shared/src/low_precision_transformations/recurrent_cell_transformation.cpp @@ -0,0 +1,88 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "low_precision_transformations/recurrent_cell_transformation.hpp" + +#include +#include +#include +#include + +#include + +#include "common_test_utils/common_utils.hpp" +#include "shared_test_classes/base/layer_test_utils.hpp" +#include "functional_test_utils/blob_utils.hpp" +#include "lpt_ngraph_functions/recurrent_cell_function.hpp" + +namespace LayerTestsDefinitions { + +std::string RecurrentCellTransformation::getTestCaseName(testing::TestParamInfo obj) { + ngraph::element::Type netPrecision; + std::vector activationsShape; + std::vector weightsShape; + std::string targetDevice; + RecurrentCellTransformationParam param; + ngraph::pass::low_precision::LayerTransformation::Params params; + std::tie(netPrecision, activationsShape, weightsShape, targetDevice, params, param) = obj.param; + + std::ostringstream result; + result << getTestCaseNameByParams(netPrecision, activationsShape[0], targetDevice, params) << + "FQ_X:" << param.fakeQuantize_X << "_" << + "DQ_X:" << param.dequantization_X << "_" << + "FQ_W:" << param.fakeQuantize_W << "_" << + "DQ_W:" << param.dequantization_W; + return result.str(); +} + +void RecurrentCellTransformation::SetUp() { + ngraph::element::Type precision; + std::vector activations_shapes; + std::vector weights_shapes; + RecurrentCellTransformationParam param; + ngraph::pass::low_precision::LayerTransformation::Params params; + + std::tie(precision, activations_shapes, weights_shapes, targetDevice, params, param) = this->GetParam(); + + function = ngraph::builder::subgraph::RecurrentCellFunction::get(precision, + activations_shapes, + weights_shapes, + param.RNNType, + { + param.fakeQuantize_X, + param.fakeQuantize_H, + param.fakeQuantize_W, + param.fakeQuantize_R + }, + { + param.convert_X, + param.convert_H, + param.convert_W, + param.convert_R + }, + { + param.dequantization_X, + param.dequantization_H, + param.dequantization_W, + param.dequantization_R + }); +} + +void RecurrentCellTransformation::Run() { + LayerTestsCommon::Run(); + + const auto params = std::get<5>(GetParam()); + const auto actualPrecision = getRuntimePrecisionByType(params.layerName); + auto expectedPrecision = params.expectedKernelType; + if (expectedPrecision == "FP32" && std::get<0>(GetParam()) == ngraph::element::f16) { + expectedPrecision = "FP16"; + } + EXPECT_EQ(actualPrecision, expectedPrecision); +} + +TEST_P(RecurrentCellTransformation, CompareWithRefImpl) { + Run(); +}; + +} // namespace LayerTestsDefinitions diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/recurrent_cell_function.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/recurrent_cell_function.hpp new file mode 100644 index 00000000000000..e131fd83a705e8 --- /dev/null +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/recurrent_cell_function.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "low_precision/layer_transformation.hpp" +#include "common/fake_quantize_on_data.hpp" +#include "common/dequantization_operations.hpp" + +namespace ngraph { +namespace builder { +namespace subgraph { + +class RecurrentCellFunction { +public: + enum class RNNType { LSTMCell, LSTMSequence, GRU, RNNCell }; + + static std::shared_ptr get( + const ngraph::element::Type inputPrecision, + const std::vector& inputActivationsShapes, + const std::vector& inputWeightsShapes, + const RNNType type, + const std::vector& fqOnDatas, + const std::vector& converts, + const std::vector& dequantizations); +}; + +std::shared_ptr makeQuantizationAndDequantization(const std::shared_ptr input, + const ngraph::element::Type inputPrecision, + const std::string friendly_name, + const FakeQuantizeOnDataWithConstant& fqOnData, + const DequantizationOperations::Convert& convert, + const DequantizationOperations& dequantization); +} // namespace subgraph +} // namespace builder +} // namespace ngraph diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/recurrent_cell_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/recurrent_cell_function.cpp new file mode 100644 index 00000000000000..c2f1b8cb1dcd43 --- /dev/null +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/recurrent_cell_function.cpp @@ -0,0 +1,164 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "lpt_ngraph_functions/recurrent_cell_function.hpp" + +#include +#include "ngraph_ops/type_relaxed.hpp" +#include "low_precision/network_helper.hpp" +#include "low_precision/rt_info/precision_preserved_attribute.hpp" +#include "low_precision/rt_info/intervals_alignment_attribute.hpp" +#include "low_precision/rt_info/quantization_alignment_attribute.hpp" + +#include "ngraph_functions/builders.hpp" +#include "lpt_ngraph_functions/common/builders.hpp" +#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" +#include "lpt_ngraph_functions/common/dequantization_operations.hpp" +#include "lpt_ngraph_functions/common/builders.hpp" + +namespace ngraph { +namespace builder { +namespace subgraph { + +using namespace ngraph::pass; + +std::shared_ptr RecurrentCellFunction::get( + const ngraph::element::Type inputPrecision, + const std::vector& inputActivationsShapes, + const std::vector& inputWeightsShapes, + const RNNType type, + const std::vector& fqOnDatas, + const std::vector& converts, + const std::vector& dequantizations) { + auto X = std::make_shared(inputPrecision, inputActivationsShapes[0]); + X->set_friendly_name("X"); + std::shared_ptr parent_X = makeQuantizationAndDequantization(X, + inputPrecision, + X->get_friendly_name(), + fqOnDatas[0], + converts[0], + dequantizations[0]); + auto H = std::make_shared(inputPrecision, inputActivationsShapes[1]); + H->set_friendly_name("H"); + std::shared_ptr parent_H = makeQuantizationAndDequantization(H, + inputPrecision, + H->get_friendly_name(), + fqOnDatas[1], + converts[1], + dequantizations[1]); + auto C = std::make_shared(inputPrecision, inputActivationsShapes[2]); + C->set_friendly_name("C"); + + auto W = ngraph::opset1::Constant::create(fqOnDatas[2].empty() ? ngraph::element::i8 : inputPrecision, + inputWeightsShapes[0], + {1}); + std::shared_ptr parent_W = makeQuantizationAndDequantization(W, + inputPrecision, + W->get_friendly_name(), + fqOnDatas[2], + converts[2], + dequantizations[2]); + auto R = ngraph::opset1::Constant::create(fqOnDatas[2].empty() ? ngraph::element::i8 : inputPrecision, + inputWeightsShapes[1], + {1}); + std::shared_ptr parent_R = makeQuantizationAndDequantization(R, + inputPrecision, + R->get_friendly_name(), + fqOnDatas[3], + converts[3], + dequantizations[3]); + auto B = ngraph::opset1::Constant::create(inputPrecision, inputWeightsShapes[2], {1}); + auto seq_lengths = ngraph::opset1::Constant::create(element::i32, Shape{1}, {3}); + + std::shared_ptr rnn_layer; + switch (type) { + case RNNType::LSTMCell: + rnn_layer = std::make_shared(parent_X, + parent_H, + C, + parent_W, + parent_R, + B, + 128); + rnn_layer->set_friendly_name("lstm_cell"); + break; + case RNNType::LSTMSequence: + rnn_layer = std::make_shared(parent_X, + parent_H, + C, + seq_lengths, + parent_W, + parent_R, + B, + 128, + op::RecurrentSequenceDirection::FORWARD); + rnn_layer->set_friendly_name("lstm_sequense"); + break; + case RNNType::GRU: + rnn_layer = std::make_shared(parent_X, + parent_H, + parent_W, + parent_R, + 3); + rnn_layer->set_friendly_name("gru_cell"); + break; + case RNNType::RNNCell: + rnn_layer = std::make_shared(parent_X, + parent_H, + parent_W, + parent_R, + 3); + rnn_layer->set_friendly_name("rnn_layer"); + break; + default: + break; + } + + auto& rtInfo = rnn_layer->get_rt_info(); + bool is_lstm = type == RNNType::LSTMCell || type == RNNType::LSTMSequence; + rtInfo["Variant::std::string"] = "rnn_layer"; + + auto rnn_layer_res_1 = std::make_shared(rnn_layer->output(0)); + rnn_layer_res_1->set_friendly_name("output_1"); + std::shared_ptr rnn_layer_res_2 = {}; + if (is_lstm) { + rnn_layer_res_2 = std::make_shared(rnn_layer->output(1)); + rnn_layer_res_2->set_friendly_name("output_2"); + } + + ngraph::ResultVector results{rnn_layer_res_2 ? rnn_layer_res_1, rnn_layer_res_2 : rnn_layer_res_1}; + std::shared_ptr function = std::make_shared( + results, + is_lstm ? ngraph::ParameterVector{X, H, C} : ngraph::ParameterVector{X, H}, + "LSTMTransformation"); + + return function; +} + +std::shared_ptr makeQuantizationAndDequantization(const std::shared_ptr input, + const ngraph::element::Type inputPrecision, + const std::string friendly_name, + const FakeQuantizeOnDataWithConstant& fqOnData, + const DequantizationOperations::Convert& convert, + const DequantizationOperations& dequantization) { + std::shared_ptr parent; + if (fqOnData.empty()) { + parent = input; + } else { + std::shared_ptr fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input, inputPrecision, fqOnData); + fakeQuantize1->set_friendly_name("fakeQuantize_" + friendly_name); + parent = fakeQuantize1; + } + if (!convert.empty()) { + parent = std::make_shared(parent, convert.outPrecision); + } + if (!dequantization.empty()) { + parent = makeDequantization(parent, dequantization); + } + return parent; +} + +} // namespace subgraph +} // namespace builder +} // namespace ngraph