From c10281c9d1b32d83d1892d05794a87490477e0a2 Mon Sep 17 00:00:00 2001 From: Lyamin-Roman Date: Sat, 16 Dec 2023 01:29:17 +0900 Subject: [PATCH] [Transformations] Support precision conversion for PriorBox --- .../src/transformations/convert_precision.cpp | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/common/transformations/src/transformations/convert_precision.cpp b/src/common/transformations/src/transformations/convert_precision.cpp index 3df6802ec360e8..53b83b7dba9f36 100644 --- a/src/common/transformations/src/transformations/convert_precision.cpp +++ b/src/common/transformations/src/transformations/convert_precision.cpp @@ -130,6 +130,28 @@ bool fuse_type_to_reduce_logical(const std::shared_ptr& node, const pr return false; } +template +bool fuse_type_to_prior_box(const std::shared_ptr& node, const precisions_map& precisions) { + auto it = precisions.find(node->get_output_element_type(0)); + if (it == precisions.end()) { + return false; + } + const auto& to = it->second; + + if (auto type_relaxed = std::dynamic_pointer_cast(node)) { + type_relaxed->set_overridden_output_type(to); + return true; + } else if (const auto casted = std::dynamic_pointer_cast(node)) { + auto relaxed_op = std::make_shared>( + *casted, + ov::element::TypeVector{casted->get_input_element_type(0), casted->get_input_element_type(1)}, + ov::element::TypeVector{to}); + replace_node(node, relaxed_op); + return true; + } + return false; +} + namespace { bool node_is_replaced(const std::shared_ptr& node) { @@ -442,7 +464,8 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr& {opset10::Unique::get_type_info_static(), fuse_type_to_unique_v10}, {opset8::RandomUniform::get_type_info_static(), fuse_type_to_random_uniform_v8}, {opset13::Multinomial::get_type_info_static(), fuse_type_to_multinomial_v13}, - }; + {opset1::PriorBox::get_type_info_static(), fuse_type_to_prior_box}, + {opset8::PriorBox::get_type_info_static(), fuse_type_to_prior_box}}; for (const auto& it : m_additional_type_to_fuse_map) { type_to_fuse[it.first] = it.second;