Skip to content

Commit

Permalink
[Transformations] Support precision conversion for PriorBox
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Dec 15, 2023
1 parent 137180b commit c10281c
Showing 1 changed file with 24 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,28 @@ bool fuse_type_to_reduce_logical(const std::shared_ptr<ov::Node>& node, const pr
return false;
}

template <class T>
bool fuse_type_to_prior_box(const std::shared_ptr<ov::Node>& node, const precisions_map& precisions) {
auto it = precisions.find(node->get_output_element_type(0));
if (it == precisions.end()) {
return false;
}
const auto& to = it->second;

if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
type_relaxed->set_overridden_output_type(to);
return true;
} else if (const auto casted = std::dynamic_pointer_cast<T>(node)) {
auto relaxed_op = std::make_shared<op::TypeRelaxed<T>>(
*casted,
ov::element::TypeVector{casted->get_input_element_type(0), casted->get_input_element_type(1)},
ov::element::TypeVector{to});
replace_node(node, relaxed_op);
return true;
}
return false;
}

namespace {

bool node_is_replaced(const std::shared_ptr<Node>& node) {
Expand Down Expand Up @@ -442,7 +464,8 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ov::Model>&
{opset10::Unique::get_type_info_static(), fuse_type_to_unique_v10},
{opset8::RandomUniform::get_type_info_static(), fuse_type_to_random_uniform_v8},
{opset13::Multinomial::get_type_info_static(), fuse_type_to_multinomial_v13},
};
{opset1::PriorBox::get_type_info_static(), fuse_type_to_prior_box<opset1::PriorBox>},
{opset8::PriorBox::get_type_info_static(), fuse_type_to_prior_box<opset8::PriorBox>}};

for (const auto& it : m_additional_type_to_fuse_map) {
type_to_fuse[it.first] = it.second;
Expand Down

0 comments on commit c10281c

Please sign in to comment.