diff --git a/src/core/dev_api/openvino/core/constant_fold_utils.hpp b/src/core/dev_api/openvino/core/constant_fold_utils.hpp index c62c14de6de35d..cd987a4ef5f3ea 100644 --- a/src/core/dev_api/openvino/core/constant_fold_utils.hpp +++ b/src/core/dev_api/openvino/core/constant_fold_utils.hpp @@ -14,7 +14,19 @@ OPENVINO_API const element::TypeVector& unsupported_types(); OPENVINO_API -bool is_type_unsupported(const ov::element::Type& type); +bool is_type_unsupported(const element::Type& type); + +OPENVINO_API +void save_original_input_precisions(const std::shared_ptr& node); + +OPENVINO_API +bool has_original_input_precision(const Input& input); + +OPENVINO_API +element::Type get_original_input_precision(const Input& input); + +OPENVINO_API +void remove_original_input_precision_attribute(Input& input); OPENVINO_API bool node_requires_precision_conversion(const Node* const node); @@ -25,9 +37,9 @@ OPENVINO_API bool node_requires_precision_conversion(const Node* const node); /// \param node /// /// \return New node with f32 inputs if the inputs require conversion or the input node otherwise -OPENVINO_API std::shared_ptr convert_to_supported_precision(const Node* const node); +OPENVINO_API std::shared_ptr convert_to_supported_precision(Node* const node); -OPENVINO_API std::shared_ptr convert_to_supported_precision(const Node* const node, const OutputVector& inputs); +OPENVINO_API std::shared_ptr convert_to_supported_precision(Node* const node, const OutputVector& inputs); OPENVINO_API bool evaluate_node_with_unsupported_precision(const Node* node, TensorVector& outputs, diff --git a/src/core/src/constant_fold_utils.cpp b/src/core/src/constant_fold_utils.cpp index 6b50406e92393f..1aef7b7cbec761 100644 --- a/src/core/src/constant_fold_utils.cpp +++ b/src/core/src/constant_fold_utils.cpp @@ -25,6 +25,29 @@ bool ov::util::is_type_unsupported(const ov::element::Type& type) { return std::find(unsupported_types.begin(), unsupported_types.end(), type) != unsupported_types.end(); } +void ov::util::save_original_input_precisions(const std::shared_ptr& node) { + for (size_t i = 0; i < node->get_input_size(); i++) { + auto input = node->input(i); + input.get_rt_info()["original_precision"] = input.get_element_type(); + } +} + +bool ov::util::has_original_input_precision(const ov::Input& input) { + return input.get_rt_info().count("original_precision") > 0; +} + +ov::element::Type ov::util::get_original_input_precision(const ov::Input& input) { + return input.get_rt_info().at("original_precision").as(); +} + +void ov::util::remove_original_input_precision_attribute(ov::Input& input) { + auto& rt_info = input.get_rt_info(); + auto it = rt_info.find("original_precision"); + if (it != rt_info.end()) { + rt_info.erase(it); + } +} + namespace { template @@ -105,11 +128,11 @@ static const std::unordered_map ov::util::convert_to_supported_precision(const Node* const node) { +std::shared_ptr ov::util::convert_to_supported_precision(Node* const node) { return ov::util::convert_to_supported_precision(node, node->input_values()); } -std::shared_ptr ov::util::convert_to_supported_precision(const Node* const node, const OutputVector& inputs) { +std::shared_ptr ov::util::convert_to_supported_precision(Node* const node, const OutputVector& inputs) { size_t num_inputs = node->get_input_size(); OutputVector converted_inputs; converted_inputs.reserve(num_inputs); @@ -128,23 +151,49 @@ std::shared_ptr ov::util::convert_to_supported_precision(const Node* c } } - // Create a new node with new (converted) inputs. - auto cloned_node = node->clone_with_new_inputs(converted_inputs); + std::shared_ptr cloned_node; + + auto type_relaxed = dynamic_cast(node); + if (type_relaxed != nullptr) { + // Save TypeRelaxed's origin input types + // If origin input type is undefined let's temporarily override it with original input precision attribute + // value. During ConstantFolding, some nodes can have temporarily mismatched input types (e.g. Add(f16, f32)). + // If the node is TypeRelaxed - we're unable to clone it since TypeRelaxed::clone_with_new_inputs creates a + // clone with 'fake' inputs based on current inputs and that can trigger an exception for certain nodes if the + // inputs have mismatched types. + element::TypeVector origin_input_types; + origin_input_types.reserve(num_inputs); + for (size_t i = 0; i < num_inputs; i++) { + const auto& origin_type = type_relaxed->get_origin_input_type(i); + origin_input_types.push_back(origin_type); + if (origin_type == element::undefined && has_original_input_precision(node->input(i))) { + type_relaxed->set_origin_input_type(get_original_input_precision(node->input(i)), i); + } + } + + cloned_node = node->clone_with_new_inputs(converted_inputs); + + // Restore TypeRelaxed's origin input types + for (size_t i = 0; i < num_inputs; i++) { + type_relaxed->set_origin_input_type(origin_input_types[i], i); + } - // Override TypeRelaxed types - auto type_relaxed = std::dynamic_pointer_cast(cloned_node); - if (type_relaxed) { + auto cloned_type_relaxed = std::dynamic_pointer_cast(cloned_node); + // Override TypeRelaxed types for (size_t i = 0; i < num_inputs; i++) { - if (ov::util::is_type_unsupported(type_relaxed->get_origin_input_type(i))) { - type_relaxed->set_origin_input_type(cloned_node->get_input_element_type(i), i); + if (ov::util::is_type_unsupported(cloned_type_relaxed->get_origin_input_type(i))) { + cloned_type_relaxed->set_origin_input_type(cloned_node->get_input_element_type(i), i); } } for (size_t i = 0; i < cloned_node->get_output_size(); i++) { if (ov::util::is_type_unsupported(cloned_node->get_output_element_type(i))) { - type_relaxed->set_overridden_output_type(element::f32, i); + cloned_type_relaxed->set_overridden_output_type(element::f32, i); } } cloned_node->validate_and_infer_types(); + } else { + // Create a new node with new (converted) inputs. + cloned_node = node->clone_with_new_inputs(converted_inputs); } // Handle nodes which outputs precisions don't depend on input precisions @@ -221,9 +270,19 @@ bool ov::util::evaluate_node_with_unsupported_precision(const ov::Node* node, } } - // evaluate converted node - if (!node->evaluate(converted_output_tensors, converted_input_tensors)) { - return false; + auto type_relaxed = dynamic_cast(node); + if (type_relaxed == nullptr) { + // evaluate node with converted tensors + if (!node->evaluate(converted_output_tensors, converted_input_tensors)) { + return false; + } + } else { + // node is const so let's clone it + auto cloned = node->clone_with_new_inputs(node->input_values()); + cloned = convert_to_supported_precision(cloned.get()); + if (!cloned->evaluate(converted_output_tensors, converted_input_tensors)) { + return false; + } } // convert outputs tensors from f32 to original type if necessary diff --git a/src/core/src/pass/constant_folding.cpp b/src/core/src/pass/constant_folding.cpp index 4ac985bb277ee8..3e93d0da979258 100644 --- a/src/core/src/pass/constant_folding.cpp +++ b/src/core/src/pass/constant_folding.cpp @@ -49,42 +49,19 @@ const auto friendly_name_from = [](const ov::Node& node, const size_t output_cou } }; -static void save_original_input_precisions(const std::shared_ptr& node) { - for (size_t i = 0; i < node->get_input_size(); i++) { - auto input = node->input(i); - input.get_rt_info()["original_precision"] = input.get_element_type(); - } -} - -static bool has_original_input_precision(const ov::Input& input) { - return input.get_rt_info().count("original_precision") > 0; -} - -static ov::element::Type get_original_input_precision(const ov::Input& input) { - return input.get_rt_info().at("original_precision").as(); -} - -static void remove_original_input_precision_attribute(ov::Input& input) { - auto& rt_info = input.get_rt_info(); - auto it = rt_info.find("original_precision"); - if (it != rt_info.end()) { - rt_info.erase(it); - } -} - static bool restore_original_input_precision(const std::shared_ptr& node) { bool restored = false; if (ov::is_type(node)) { auto input = node->input(0); - remove_original_input_precision_attribute(input); + ov::util::remove_original_input_precision_attribute(input); return restored; } for (size_t i = 0; i < node->get_input_size(); i++) { auto input = node->input(i); - if (!has_original_input_precision(input)) + if (!ov::util::has_original_input_precision(input)) continue; - const auto original_type = get_original_input_precision(input); - remove_original_input_precision_attribute(input); + const auto original_type = ov::util::get_original_input_precision(input); + ov::util::remove_original_input_precision_attribute(input); if (original_type != node->get_input_element_type(i)) { auto convert = std::make_shared(node->input_value(i), original_type); ov::OutputVector replacements(1); @@ -206,7 +183,7 @@ bool ov::pass::ConstantFolding::pre_calculated_values_folding(const std::shared_ // we need to convert constants with those types to f32. And at some point - this f32 constant may // become an input to a node that's not constfoldable. Then we need to convert that constant back to // that input's original precision. - save_original_input_precisions(node); + util::save_original_input_precisions(node); if (!node_has_disabled_constant_folding && util::node_requires_precision_conversion(node.get())) { mark_node_requires_precision_conversion(node); } diff --git a/src/core/tests/pass/constant_folding.cpp b/src/core/tests/pass/constant_folding.cpp index f3e1e640881b87..6332e0ce38d5a6 100644 --- a/src/core/tests/pass/constant_folding.cpp +++ b/src/core/tests/pass/constant_folding.cpp @@ -16,6 +16,7 @@ #include "openvino/op/convert_like.hpp" #include "openvino/op/loop.hpp" #include "openvino/op/multiply.hpp" +#include "ov_ops/type_relaxed.hpp" #include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp" #include "transformations/utils/utils.hpp" @@ -4000,6 +4001,34 @@ TEST_P(UnsupportedTypesTest, convert_like) { ASSERT_EQ(m->get_results().size(), 1); } +TEST_P(UnsupportedTypesTest, type_relaxed) { + Shape shape_in{2, 4, 1}; + + const auto& type = GetParam(); + auto cond = op::v0::Constant::create(element::boolean, shape_in, {1}); + auto param = std::make_shared(type, shape_in); + auto constant1 = op::v0::Constant::create(type, shape_in, {2}); + auto then_value = std::make_shared(OutputVector{param, constant1}, 2); + auto constant2 = op::v0::Constant::create(type, shape_in, {3}); + auto else_value = std::make_shared( + constant2, + op::v0::Constant::create(element::u64, Shape{shape_in.size()}, Shape{shape_in[0], shape_in[1], 2})); + auto select = make_shared(cond, then_value, else_value); + auto type_relaxed = make_shared>(*select, + element::TypeVector{element::boolean}, + element::TypeVector{}); + auto m = make_shared(type_relaxed, ParameterVector{param}); + + run_constant_folding(m); + + EXPECT_EQ(m->get_ops().size(), 7); + EXPECT_EQ(count_ops_of_type(m), 1); + EXPECT_EQ(count_ops_of_type(m), 3); + EXPECT_EQ(count_ops_of_type(m), 0); + EXPECT_EQ(count_ops_of_type(m), 1); + ASSERT_EQ(m->get_results().size(), 1); +} + static std::string unsupported_types_test_case_name(const testing::TestParamInfo& info) { return info.param.get_type_name(); }