Skip to content

Commit

Permalink
Removed AddOldApiMapToParameters during FP16 weights compression (#22728
Browse files Browse the repository at this point in the history
)

### Details:
 - It was required for old API only
  • Loading branch information
ilya-lavrenov authored Feb 8, 2024
1 parent 778e16e commit 778e7e9
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ namespace ov {
namespace pass {

class TRANSFORMATIONS_API CompressFloatConstantsImpl;
class TRANSFORMATIONS_API AddOldApiMapToParameters;
class TRANSFORMATIONS_API CompressFloatConstants;

} // namespace pass
Expand All @@ -33,16 +32,6 @@ class ov::pass::CompressFloatConstantsImpl : public ov::pass::MatcherPass {
CompressFloatConstantsImpl(bool postponed = false);
};

/**
* @ingroup ie_transformation_common_api
* @brief AddOldApiMapToParameters transformation adds OldApiMap to each float input to the model.
*/
class ov::pass::AddOldApiMapToParameters : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("AddOldApiMapToParameters", "0");
AddOldApiMapToParameters();
};

/**
* @ingroup ie_transformation_common_api
* @brief CompressFloatConstants transformation replaces FP32/FP64 Constants with FP16 ones.
Expand All @@ -54,6 +43,5 @@ class ov::pass::CompressFloatConstants : public ov::pass::GraphRewrite {
/// @param postponed Postponed compression, see ov::pass::CompressFloatConstantsImpl for details.
CompressFloatConstants(bool postponed = false) {
add_matcher<ov::pass::CompressFloatConstantsImpl>(postponed);
add_matcher<ov::pass::AddOldApiMapToParameters>();
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -149,27 +149,3 @@ ov::pass::CompressFloatConstantsImpl::CompressFloatConstantsImpl(bool postponed)
auto m = std::make_shared<pattern::Matcher>(const_pattern, matcher_name);
this->register_matcher(m, callback);
}

ov::pass::AddOldApiMapToParameters::AddOldApiMapToParameters() {
MATCHER_SCOPE(AddOldApiMapToParameters);
auto param_pattern = pattern::wrap_type<ov::op::v0::Parameter>();

ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto node = pattern_map.at(param_pattern).get_node_shared_ptr();

auto param_node = std::dynamic_pointer_cast<ov::op::v0::Parameter>(node);
if (!param_node)
return false;
auto p_type = param_node->get_element_type();
if (p_type == ov::element::f32 || p_type == ov::element::f64) {
ov::set_old_api_map_element_type(node, ov::OldApiMapElementType(ov::element::Type_t::f16));
} else {
return false;
}
return true;
};

auto m = std::make_shared<pattern::Matcher>(param_pattern, matcher_name);
this->register_matcher(m, callback);
}

0 comments on commit 778e7e9

Please sign in to comment.