Skip to content

Commit

Permalink
fix f16/f32 el type mismatch for shape subgraphs: Parameter type shou…
Browse files Browse the repository at this point in the history
…ld not be fused for precision sensitive nodes
  • Loading branch information
pavel-esir committed Sep 19, 2023
1 parent b7dcae3 commit 3f30904
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ bool convert_function_precision(const std::shared_ptr<Model>& f,
}

for (const auto& param : f->get_parameters()) {
if (skip_precision_sensitive && fp16_compression_is_disabled(param) && has_fp16_compression)
continue;
is_changed |= fuse_type_to_parameter(param, precisions, convert_input_output_precision);
}

Expand Down
91 changes: 91 additions & 0 deletions src/common/transformations/tests/utils/convert_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "openvino/core/model.hpp"
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/visualize_tree.hpp"
#include "openvino/opsets/opset3.hpp"
#include "openvino/opsets/opset4.hpp"
#include "openvino/opsets/opset5.hpp"
Expand Down Expand Up @@ -2136,3 +2137,93 @@ TEST(TransformationTests, ConvertPrecisionExplicitConvertsMultiSubgraphs) {
const auto& results = model->get_results();
ASSERT_EQ("if_result", results[0]->get_input_node_ptr(0)->get_friendly_name());
}

TEST(TransformationTests, align_mixed_fp16_fp32_with_parameter_for_shape_1) {
shared_ptr<Model> model, model_ref;
pass::Manager manager;
{
auto input_1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 3, 224, 224});
auto shape_input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2});

auto upscale_const = ov::op::v0::Constant::create(element::f32, Shape{1}, {2.0f});
auto mul_1 = make_shared<ov::op::v1::Multiply>(shape_input, upscale_const);
auto axis_const = ov::op::v0::Constant::create(element::i64, Shape{1}, {0});
auto final_float_shape = make_shared<ov::op::v1::ReduceProd>(mul_1, axis_const);
auto final_int_shape = make_shared<ov::op::v0::Convert>(final_float_shape, element::i64);
auto reshape_1 = make_shared<ov::op::v1::Reshape>(input_1, final_int_shape, false);

model = make_shared<Model>(NodeVector{reshape_1}, ParameterVector{input_1, shape_input});

type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = true;
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32);
manager.run_passes(model);
}

{
auto input_1 = make_shared<ov::op::v0::Parameter>(element::f16, Shape{1, 3, 224, 224});
auto shape_input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2});

// even for FP16 compressed model shape subgraph should be kept in fp32
auto upscale_const = ov::op::v0::Constant::create(element::f32, Shape{1}, {2.0f});
auto mul_1 = make_shared<ov::op::v1::Multiply>(shape_input, upscale_const);
auto axis_const = ov::op::v0::Constant::create(element::i64, Shape{1}, {0});
auto final_float_shape = make_shared<ov::op::v1::ReduceProd>(mul_1, axis_const);
auto final_int_shape = make_shared<ov::op::v0::Convert>(final_float_shape, element::i64);
auto reshape_1 = make_shared<ov::op::v1::Reshape>(input_1, final_int_shape, false);

model_ref = make_shared<Model>(NodeVector{reshape_1}, ParameterVector{input_1, shape_input});
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
FunctionsComparator::Result result = func_comparator(model_ref, model);
ASSERT_TRUE(result.valid) << result.message;
}

TEST(TransformationTests, align_mixed_fp16_fp32_with_parameter_for_shape_2) {
shared_ptr<Model> model, model_ref;
pass::Manager manager;
{
auto input_1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 3, 224, 224});
auto shape_input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2});

auto upscale_const = ov::op::v0::Constant::create(element::f32, Shape{1}, {2.0f});
auto mul_1 = make_shared<ov::op::v1::Multiply>(shape_input, upscale_const);
auto axis_const = ov::op::v0::Constant::create(element::i64, Shape{1}, {0});
auto final_float_shape = make_shared<ov::op::v1::ReduceProd>(mul_1, axis_const);
auto final_int_shape = make_shared<ov::op::v0::Convert>(final_float_shape, element::i64);
auto reshape_1 = make_shared<ov::op::v1::Reshape>(input_1, final_int_shape, false);

model = make_shared<Model>(NodeVector{reshape_1}, ParameterVector{input_1, shape_input});

type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = true;
const bool convert_input_output_precision = false;
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32,
convert_input_output_precision);
manager.run_passes(model);
}

{
auto input_1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 3, 224, 224});
auto convert_to_f16 = make_shared<ov::op::v0::Convert>(input_1, element::f16);
auto shape_input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2});

// even for FP16 compressed model shape subgraph should be kept in fp32
auto upscale_const = ov::op::v0::Constant::create(element::f32, Shape{1}, {2.0f});
auto mul_1 = make_shared<ov::op::v1::Multiply>(shape_input, upscale_const);
auto axis_const = ov::op::v0::Constant::create(element::i64, Shape{1}, {0});
auto final_float_shape = make_shared<ov::op::v1::ReduceProd>(mul_1, axis_const);
auto final_int_shape = make_shared<ov::op::v0::Convert>(final_float_shape, element::i64);
auto reshape_1 = make_shared<ov::op::v1::Reshape>(convert_to_f16, final_int_shape, false);
auto convert_to_f32 = make_shared<ov::op::v0::Convert>(reshape_1, element::f32);

model_ref = make_shared<Model>(NodeVector{convert_to_f32}, ParameterVector{input_1, shape_input});
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
FunctionsComparator::Result result = func_comparator(model_ref, model);
ASSERT_TRUE(result.valid) << result.message;
}

0 comments on commit 3f30904

Please sign in to comment.