Skip to content

Commit

Permalink
Fix convert_to_supported_precision for TypeRelaxed types (#23143)
Browse files Browse the repository at this point in the history
If TypeRelaxed's origin input type is undefined let's temporarily
override it with original input precision attribute value. During
ConstantFolding, some nodes can have temporarily mismatched input types
(e.g. Add(f16, f32)). If the node is TypeRelaxed - we're unable to clone
it since TypeRelaxed::clone_with_new_inputs creates a clone with 'fake'
inputs based on current inputs and that can trigger an exception for
certain nodes if the inputs have mismatched types.

Ticket: CVS-134604
  • Loading branch information
mateusztabaka authored Mar 11, 2024
1 parent a3c7e15 commit f8d0710
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 44 deletions.
18 changes: 15 additions & 3 deletions src/core/dev_api/openvino/core/constant_fold_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,19 @@ OPENVINO_API
const element::TypeVector& unsupported_types();

OPENVINO_API
bool is_type_unsupported(const ov::element::Type& type);
bool is_type_unsupported(const element::Type& type);

OPENVINO_API
void save_original_input_precisions(const std::shared_ptr<Node>& node);

OPENVINO_API
bool has_original_input_precision(const Input<Node>& input);

OPENVINO_API
element::Type get_original_input_precision(const Input<Node>& input);

OPENVINO_API
void remove_original_input_precision_attribute(Input<Node>& input);

OPENVINO_API bool node_requires_precision_conversion(const Node* const node);

Expand All @@ -25,9 +37,9 @@ OPENVINO_API bool node_requires_precision_conversion(const Node* const node);
/// \param node
///
/// \return New node with f32 inputs if the inputs require conversion or the input node otherwise
OPENVINO_API std::shared_ptr<Node> convert_to_supported_precision(const Node* const node);
OPENVINO_API std::shared_ptr<Node> convert_to_supported_precision(Node* const node);

OPENVINO_API std::shared_ptr<Node> convert_to_supported_precision(const Node* const node, const OutputVector& inputs);
OPENVINO_API std::shared_ptr<Node> convert_to_supported_precision(Node* const node, const OutputVector& inputs);

OPENVINO_API bool evaluate_node_with_unsupported_precision(const Node* node,
TensorVector& outputs,
Expand Down
85 changes: 72 additions & 13 deletions src/core/src/constant_fold_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,29 @@ bool ov::util::is_type_unsupported(const ov::element::Type& type) {
return std::find(unsupported_types.begin(), unsupported_types.end(), type) != unsupported_types.end();
}

void ov::util::save_original_input_precisions(const std::shared_ptr<ov::Node>& node) {
for (size_t i = 0; i < node->get_input_size(); i++) {
auto input = node->input(i);
input.get_rt_info()["original_precision"] = input.get_element_type();
}
}

bool ov::util::has_original_input_precision(const ov::Input<ov::Node>& input) {
return input.get_rt_info().count("original_precision") > 0;
}

ov::element::Type ov::util::get_original_input_precision(const ov::Input<ov::Node>& input) {
return input.get_rt_info().at("original_precision").as<ov::element::Type>();
}

void ov::util::remove_original_input_precision_attribute(ov::Input<ov::Node>& input) {
auto& rt_info = input.get_rt_info();
auto it = rt_info.find("original_precision");
if (it != rt_info.end()) {
rt_info.erase(it);
}
}

namespace {

template <typename... Args>
Expand Down Expand Up @@ -105,11 +128,11 @@ static const std::unordered_map<ov::NodeTypeInfo, std::function<bool(const std::
{ov::op::v4::Range::get_type_info_static(), convert_range_precision},
};

std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(const Node* const node) {
std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(Node* const node) {
return ov::util::convert_to_supported_precision(node, node->input_values());
}

std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(const Node* const node, const OutputVector& inputs) {
std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(Node* const node, const OutputVector& inputs) {
size_t num_inputs = node->get_input_size();
OutputVector converted_inputs;
converted_inputs.reserve(num_inputs);
Expand All @@ -128,23 +151,49 @@ std::shared_ptr<ov::Node> ov::util::convert_to_supported_precision(const Node* c
}
}

// Create a new node with new (converted) inputs.
auto cloned_node = node->clone_with_new_inputs(converted_inputs);
std::shared_ptr<Node> cloned_node;

auto type_relaxed = dynamic_cast<op::TypeRelaxedBase*>(node);
if (type_relaxed != nullptr) {
// Save TypeRelaxed's origin input types
// If origin input type is undefined let's temporarily override it with original input precision attribute
// value. During ConstantFolding, some nodes can have temporarily mismatched input types (e.g. Add(f16, f32)).
// If the node is TypeRelaxed - we're unable to clone it since TypeRelaxed::clone_with_new_inputs creates a
// clone with 'fake' inputs based on current inputs and that can trigger an exception for certain nodes if the
// inputs have mismatched types.
element::TypeVector origin_input_types;
origin_input_types.reserve(num_inputs);
for (size_t i = 0; i < num_inputs; i++) {
const auto& origin_type = type_relaxed->get_origin_input_type(i);
origin_input_types.push_back(origin_type);
if (origin_type == element::undefined && has_original_input_precision(node->input(i))) {
type_relaxed->set_origin_input_type(get_original_input_precision(node->input(i)), i);
}
}

cloned_node = node->clone_with_new_inputs(converted_inputs);

// Restore TypeRelaxed's origin input types
for (size_t i = 0; i < num_inputs; i++) {
type_relaxed->set_origin_input_type(origin_input_types[i], i);
}

// Override TypeRelaxed types
auto type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(cloned_node);
if (type_relaxed) {
auto cloned_type_relaxed = std::dynamic_pointer_cast<op::TypeRelaxedBase>(cloned_node);
// Override TypeRelaxed types
for (size_t i = 0; i < num_inputs; i++) {
if (ov::util::is_type_unsupported(type_relaxed->get_origin_input_type(i))) {
type_relaxed->set_origin_input_type(cloned_node->get_input_element_type(i), i);
if (ov::util::is_type_unsupported(cloned_type_relaxed->get_origin_input_type(i))) {
cloned_type_relaxed->set_origin_input_type(cloned_node->get_input_element_type(i), i);
}
}
for (size_t i = 0; i < cloned_node->get_output_size(); i++) {
if (ov::util::is_type_unsupported(cloned_node->get_output_element_type(i))) {
type_relaxed->set_overridden_output_type(element::f32, i);
cloned_type_relaxed->set_overridden_output_type(element::f32, i);
}
}
cloned_node->validate_and_infer_types();
} else {
// Create a new node with new (converted) inputs.
cloned_node = node->clone_with_new_inputs(converted_inputs);
}

// Handle nodes which outputs precisions don't depend on input precisions
Expand Down Expand Up @@ -221,9 +270,19 @@ bool ov::util::evaluate_node_with_unsupported_precision(const ov::Node* node,
}
}

// evaluate converted node
if (!node->evaluate(converted_output_tensors, converted_input_tensors)) {
return false;
auto type_relaxed = dynamic_cast<const op::TypeRelaxedBase*>(node);
if (type_relaxed == nullptr) {
// evaluate node with converted tensors
if (!node->evaluate(converted_output_tensors, converted_input_tensors)) {
return false;
}
} else {
// node is const so let's clone it
auto cloned = node->clone_with_new_inputs(node->input_values());
cloned = convert_to_supported_precision(cloned.get());
if (!cloned->evaluate(converted_output_tensors, converted_input_tensors)) {
return false;
}
}

// convert outputs tensors from f32 to original type if necessary
Expand Down
33 changes: 5 additions & 28 deletions src/core/src/pass/constant_folding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,42 +49,19 @@ const auto friendly_name_from = [](const ov::Node& node, const size_t output_cou
}
};

static void save_original_input_precisions(const std::shared_ptr<ov::Node>& node) {
for (size_t i = 0; i < node->get_input_size(); i++) {
auto input = node->input(i);
input.get_rt_info()["original_precision"] = input.get_element_type();
}
}

static bool has_original_input_precision(const ov::Input<ov::Node>& input) {
return input.get_rt_info().count("original_precision") > 0;
}

static ov::element::Type get_original_input_precision(const ov::Input<ov::Node>& input) {
return input.get_rt_info().at("original_precision").as<ov::element::Type>();
}

static void remove_original_input_precision_attribute(ov::Input<ov::Node>& input) {
auto& rt_info = input.get_rt_info();
auto it = rt_info.find("original_precision");
if (it != rt_info.end()) {
rt_info.erase(it);
}
}

static bool restore_original_input_precision(const std::shared_ptr<ov::Node>& node) {
bool restored = false;
if (ov::is_type<ov::op::v0::Convert>(node)) {
auto input = node->input(0);
remove_original_input_precision_attribute(input);
ov::util::remove_original_input_precision_attribute(input);
return restored;
}
for (size_t i = 0; i < node->get_input_size(); i++) {
auto input = node->input(i);
if (!has_original_input_precision(input))
if (!ov::util::has_original_input_precision(input))
continue;
const auto original_type = get_original_input_precision(input);
remove_original_input_precision_attribute(input);
const auto original_type = ov::util::get_original_input_precision(input);
ov::util::remove_original_input_precision_attribute(input);
if (original_type != node->get_input_element_type(i)) {
auto convert = std::make_shared<ov::op::v0::Convert>(node->input_value(i), original_type);
ov::OutputVector replacements(1);
Expand Down Expand Up @@ -206,7 +183,7 @@ bool ov::pass::ConstantFolding::pre_calculated_values_folding(const std::shared_
// we need to convert constants with those types to f32. And at some point - this f32 constant may
// become an input to a node that's not constfoldable. Then we need to convert that constant back to
// that input's original precision.
save_original_input_precisions(node);
util::save_original_input_precisions(node);
if (!node_has_disabled_constant_folding && util::node_requires_precision_conversion(node.get())) {
mark_node_requires_precision_conversion(node);
}
Expand Down
29 changes: 29 additions & 0 deletions src/core/tests/pass/constant_folding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "openvino/op/convert_like.hpp"
#include "openvino/op/loop.hpp"
#include "openvino/op/multiply.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "transformations/common_optimizations/disable_shapeof_constant_folding.hpp"
#include "transformations/utils/utils.hpp"

Expand Down Expand Up @@ -4000,6 +4001,34 @@ TEST_P(UnsupportedTypesTest, convert_like) {
ASSERT_EQ(m->get_results().size(), 1);
}

TEST_P(UnsupportedTypesTest, type_relaxed) {
Shape shape_in{2, 4, 1};

const auto& type = GetParam();
auto cond = op::v0::Constant::create(element::boolean, shape_in, {1});
auto param = std::make_shared<op::v0::Parameter>(type, shape_in);
auto constant1 = op::v0::Constant::create(type, shape_in, {2});
auto then_value = std::make_shared<op::v0::Concat>(OutputVector{param, constant1}, 2);
auto constant2 = op::v0::Constant::create(type, shape_in, {3});
auto else_value = std::make_shared<op::v3::Broadcast>(
constant2,
op::v0::Constant::create(element::u64, Shape{shape_in.size()}, Shape{shape_in[0], shape_in[1], 2}));
auto select = make_shared<op::v1::Select>(cond, then_value, else_value);
auto type_relaxed = make_shared<op::TypeRelaxed<op::v1::Select>>(*select,
element::TypeVector{element::boolean},
element::TypeVector{});
auto m = make_shared<Model>(type_relaxed, ParameterVector{param});

run_constant_folding(m);

EXPECT_EQ(m->get_ops().size(), 7);
EXPECT_EQ(count_ops_of_type<op::v1::Select>(m), 1);
EXPECT_EQ(count_ops_of_type<op::v0::Constant>(m), 3);
EXPECT_EQ(count_ops_of_type<op::v3::Broadcast>(m), 0);
EXPECT_EQ(count_ops_of_type<op::v0::Concat>(m), 1);
ASSERT_EQ(m->get_results().size(), 1);
}

static std::string unsupported_types_test_case_name(const testing::TestParamInfo<element::Type>& info) {
return info.param.get_type_name();
}
Expand Down

0 comments on commit f8d0710

Please sign in to comment.