From f8d071002b6fe0bca2617558221850e01fe202ad Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Mon, 11 Mar 2024 18:48:04 +0100 Subject: [PATCH] Fix convert_to_supported_precision for TypeRelaxed types (#23143) If TypeRelaxed's origin input type is undefined let's temporarily override it with original input precision attribute value. During ConstantFolding, some nodes can have temporarily mismatched input types (e.g. Add(f16, f32)). If the node is TypeRelaxed - we're unable to clone it since TypeRelaxed::clone_with_new_inputs creates a clone with 'fake' inputs based on current inputs and that can trigger an exception for certain nodes if the inputs have mismatched types. Ticket: CVS-134604 --- .../openvino/core/constant_fold_utils.hpp | 18 +++- src/core/src/constant_fold_utils.cpp | 85 ++++++++++++++++--- src/core/src/pass/constant_folding.cpp | 33 ++----- src/core/tests/pass/constant_folding.cpp | 29 +++++++ 4 files changed, 121 insertions(+), 44 deletions(-) diff --git a/src/core/dev_api/openvino/core/constant_fold_utils.hpp b/src/core/dev_api/openvino/core/constant_fold_utils.hpp index c62c14de6de35d..cd987a4ef5f3ea 100644 --- a/src/core/dev_api/openvino/core/constant_fold_utils.hpp +++ b/src/core/dev_api/openvino/core/constant_fold_utils.hpp @@ -14,7 +14,19 @@ OPENVINO_API const element::TypeVector& unsupported_types(); OPENVINO_API -bool is_type_unsupported(const ov::element::Type& type); +bool is_type_unsupported(const element::Type& type); + +OPENVINO_API +void save_original_input_precisions(const std::shared_ptr& node); + +OPENVINO_API +bool has_original_input_precision(const Input& input); + +OPENVINO_API +element::Type get_original_input_precision(const Input& input); + +OPENVINO_API +void remove_original_input_precision_attribute(Input& input); OPENVINO_API bool node_requires_precision_conversion(const Node* const node); @@ -25,9 +37,9 @@ OPENVINO_API bool node_requires_precision_conversion(const Node* const node); /// \param node /// /// \return New node with f32 inputs if the inputs require conversion or the input node otherwise -OPENVINO_API std::shared_ptr convert_to_supported_precision(const Node* const node); +OPENVINO_API std::shared_ptr convert_to_supported_precision(Node* const node); -OPENVINO_API std::shared_ptr convert_to_supported_precision(const Node* const node, const OutputVector& inputs); +OPENVINO_API std::shared_ptr convert_to_supported_precision(Node* const node, const OutputVector& inputs); OPENVINO_API bool evaluate_node_with_unsupported_precision(const Node* node, TensorVector& outputs, diff --git a/src/core/src/constant_fold_utils.cpp b/src/core/src/constant_fold_utils.cpp index 6b50406e92393f..1aef7b7cbec761 100644 --- a/src/core/src/constant_fold_utils.cpp +++ b/src/core/src/constant_fold_utils.cpp @@ -25,6 +25,29 @@ bool ov::util::is_type_unsupported(const ov::element::Type& type) { return std::find(unsupported_types.begin(), unsupported_types.end(), type) != unsupported_types.end(); } +void ov::util::save_original_input_precisions(const std::shared_ptr& node) { + for (size_t i = 0; i < node->get_input_size(); i++) { + auto input = node->input(i); + input.get_rt_info()["original_precision"] = input.get_element_type(); + } +} + +bool ov::util::has_original_input_precision(const ov::Input& input) { + return input.get_rt_info().count("original_precision") > 0; +} + +ov::element::Type ov::util::get_original_input_precision(const ov::Input& input) { + return input.get_rt_info().at("original_precision").as(); +} + +void ov::util::remove_original_input_precision_attribute(ov::Input& input) { + auto& rt_info = input.get_rt_info(); + auto it = rt_info.find("original_precision"); + if (it != rt_info.end()) { + rt_info.erase(it); + } +} + namespace { template @@ -105,11 +128,11 @@ static const std::unordered_map ov::util::convert_to_supported_precision(const Node* const node) { +std::shared_ptr ov::util::convert_to_supported_precision(Node* const node) { return ov::util::convert_to_supported_precision(node, node->input_values()); } -std::shared_ptr ov::util::convert_to_supported_precision(const Node* const node, const OutputVector& inputs) { +std::shared_ptr ov::util::convert_to_supported_precision(Node* const node, const OutputVector& inputs) { size_t num_inputs = node->get_input_size(); OutputVector converted_inputs; converted_inputs.reserve(num_inputs); @@ -128,23 +151,49 @@ std::shared_ptr ov::util::convert_to_supported_precision(const Node* c } } - // Create a new node with new (converted) inputs. - auto cloned_node = node->clone_with_new_inputs(converted_inputs); + std::shared_ptr cloned_node; + + auto type_relaxed = dynamic_cast(node); + if (type_relaxed != nullptr) { + // Save TypeRelaxed's origin input types + // If origin input type is undefined let's temporarily override it with original input precision attribute + // value. During ConstantFolding, some nodes can have temporarily mismatched input types (e.g. Add(f16, f32)). + // If the node is TypeRelaxed - we're unable to clone it since TypeRelaxed::clone_with_new_inputs creates a + // clone with 'fake' inputs based on current inputs and that can trigger an exception for certain nodes if the + // inputs have mismatched types. + element::TypeVector origin_input_types; + origin_input_types.reserve(num_inputs); + for (size_t i = 0; i < num_inputs; i++) { + const auto& origin_type = type_relaxed->get_origin_input_type(i); + origin_input_types.push_back(origin_type); + if (origin_type == element::undefined && has_original_input_precision(node->input(i))) { + type_relaxed->set_origin_input_type(get_original_input_precision(node->input(i)), i); + } + } + + cloned_node = node->clone_with_new_inputs(converted_inputs); + + // Restore TypeRelaxed's origin input types + for (size_t i = 0; i < num_inputs; i++) { + type_relaxed->set_origin_input_type(origin_input_types[i], i); + } - // Override TypeRelaxed types - auto type_relaxed = std::dynamic_pointer_cast(cloned_node); - if (type_relaxed) { + auto cloned_type_relaxed = std::dynamic_pointer_cast(cloned_node); + // Override TypeRelaxed types for (size_t i = 0; i < num_inputs; i++) { - if (ov::util::is_type_unsupported(type_relaxed->get_origin_input_type(i))) { - type_relaxed->set_origin_input_type(cloned_node->get_input_element_type(i), i); + if (ov::util::is_type_unsupported(cloned_type_relaxed->get_origin_input_type(i))) { + cloned_type_relaxed->set_origin_input_type(cloned_node->get_input_element_type(i), i); } } for (size_t i = 0; i < cloned_node->get_output_size(); i++) { if (ov::util::is_type_unsupported(cloned_node->get_output_element_type(i))) { - type_relaxed->set_overridden_output_type(element::f32, i); + cloned_type_relaxed->set_overridden_output_type(element::f32, i); } } cloned_node->validate_and_infer_types(); + } else { + // Create a new node with new (converted) inputs. + cloned_node = node->clone_with_new_inputs(converted_inputs); } // Handle nodes which outputs precisions don't depend on input precisions @@ -221,9 +270,19 @@ bool ov::util::evaluate_node_with_unsupported_precision(const ov::Node* node, } } - // evaluate converted node - if (!node->evaluate(converted_output_tensors, converted_input_tensors)) { - return false; + auto type_relaxed = dynamic_cast(node); + if (type_relaxed == nullptr) { + // evaluate node with converted tensors + if (!node->evaluate(converted_output_tensors, converted_input_tensors)) { + return false; + } + } else { + // node is const so let's clone it + auto cloned = node->clone_with_new_inputs(node->input_values()); + cloned = convert_to_supported_precision(cloned.get()); + if (!cloned->evaluate(converted_output_tensors, converted_input_tensors)) { + return false; + } } // convert outputs tensors from f32 to original type if necessary diff --git a/src/core/src/pass/constant_folding.cpp b/src/core/src/pass/constant_folding.cpp index 4ac985bb277ee8..3e93d0da979258 100644 --- a/src/core/src/pass/constant_folding.cpp +++ b/src/core/src/pass/constant_folding.cpp @@ -49,42 +49,19 @@ const auto friendly_name_from = [](const ov::Node& node, const size_t output_cou } }; -static void save_original_input_precisions(const std::shared_ptr& node) { - for (size_t i = 0; i < node->get_input_size(); i++) { - auto input = node->input(i); - input.get_rt_info()["original_precision"] = input.get_element_type(); - } -} - -static bool has_original_input_precision(const ov::Input& input) { - return input.get_rt_info().count("original_precision") > 0; -} - -static ov::element::Type get_original_input_precision(const ov::Input& input) { - return input.get_rt_info().at("original_precision").as(); -} - -static void remove_original_input_precision_attribute(ov::Input& input) { - auto& rt_info = input.get_rt_info(); - auto it = rt_info.find("original_precision"); - if (it != rt_info.end()) { - rt_info.erase(it); - } -} - static bool restore_original_input_precision(const std::shared_ptr& node) { bool restored = false; if (ov::is_type(node)) { auto input = node->input(0); - remove_original_input_precision_attribute(input); + ov::util::remove_original_input_precision_attribute(input); return restored; } for (size_t i = 0; i < node->get_input_size(); i++) { auto input = node->input(i); - if (!has_original_input_precision(input)) + if (!ov::util::has_original_input_precision(input)) continue; - const auto original_type = get_original_input_precision(input); - remove_original_input_precision_attribute(input); + const auto original_type = ov::util::get_original_input_precision(input); + ov::util::remove_original_input_precision_attribute(input); if (original_type != node->get_input_element_type(i)) { auto convert = std::make_shared(node->input_value(i), original_type); ov::OutputVector replacements(1); @@ -206,7 +183,7 @@ bool ov::pass::ConstantFolding::pre_calculated_values_folding(const std::shared_ // we need to convert constants with those types to f32. And at some point - this f32 constant may // become an input to a node that's not constfoldable. Then we need to convert that constant back to // that input's original precision. - save_original_input_precisions(node); + util::save_original_input_precisions(node); if (!node_has_disabled_constant_folding && util::node_requires_precision_conversion(node.get())) { mark_node_requires_precision_conversion(node); } diff --git a/src/core/tests/pass/constant_folding.cpp b/src/core/tests/pass/constant_folding.cpp index f3e1e640881b87..6332e0ce38d5a6 100644 --- a/src/core/tests/pass/constant_folding.cpp +++ b/src/core/tests/pass/constant_folding.cpp @@ -16,6 +16,7 @@ #include "openvino/op/convert_like.hpp" #include "openvino/op/loop.hpp" #include "openvino/op/multiply.hpp" +#include "ov_ops/type_relaxed.hpp" #include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp" #include "transformations/utils/utils.hpp" @@ -4000,6 +4001,34 @@ TEST_P(UnsupportedTypesTest, convert_like) { ASSERT_EQ(m->get_results().size(), 1); } +TEST_P(UnsupportedTypesTest, type_relaxed) { + Shape shape_in{2, 4, 1}; + + const auto& type = GetParam(); + auto cond = op::v0::Constant::create(element::boolean, shape_in, {1}); + auto param = std::make_shared(type, shape_in); + auto constant1 = op::v0::Constant::create(type, shape_in, {2}); + auto then_value = std::make_shared(OutputVector{param, constant1}, 2); + auto constant2 = op::v0::Constant::create(type, shape_in, {3}); + auto else_value = std::make_shared( + constant2, + op::v0::Constant::create(element::u64, Shape{shape_in.size()}, Shape{shape_in[0], shape_in[1], 2})); + auto select = make_shared(cond, then_value, else_value); + auto type_relaxed = make_shared>(*select, + element::TypeVector{element::boolean}, + element::TypeVector{}); + auto m = make_shared(type_relaxed, ParameterVector{param}); + + run_constant_folding(m); + + EXPECT_EQ(m->get_ops().size(), 7); + EXPECT_EQ(count_ops_of_type(m), 1); + EXPECT_EQ(count_ops_of_type(m), 3); + EXPECT_EQ(count_ops_of_type(m), 0); + EXPECT_EQ(count_ops_of_type(m), 1); + ASSERT_EQ(m->get_results().size(), 1); +} + static std::string unsupported_types_test_case_name(const testing::TestParamInfo& info) { return info.param.get_type_name(); }