From 330e25afff2d608c28218303c897da63d63caca4 Mon Sep 17 00:00:00 2001 From: tomdol Date: Tue, 21 Mar 2023 17:43:11 +0100 Subject: [PATCH] Refactor #2 --- .../convert_interpolate11_downgrade.cpp | 1 + .../convert_interpolate11_downgrade_test.cpp | 127 +++++++----------- 2 files changed, 48 insertions(+), 80 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/convert_interpolate11_downgrade.cpp b/src/common/transformations/src/transformations/op_conversions/convert_interpolate11_downgrade.cpp index 44b189d72cdae6..98167ea4fdc0f2 100644 --- a/src/common/transformations/src/transformations/op_conversions/convert_interpolate11_downgrade.cpp +++ b/src/common/transformations/src/transformations/op_conversions/convert_interpolate11_downgrade.cpp @@ -4,6 +4,7 @@ #include "transformations/op_conversions/convert_interpolate11_downgrade.hpp" +#include #include #include #include diff --git a/src/common/transformations/tests/op_conversions/convert_interpolate11_downgrade_test.cpp b/src/common/transformations/tests/op_conversions/convert_interpolate11_downgrade_test.cpp index fde7f5579868c0..bf0458c56eaa06 100644 --- a/src/common/transformations/tests/op_conversions/convert_interpolate11_downgrade_test.cpp +++ b/src/common/transformations/tests/op_conversions/convert_interpolate11_downgrade_test.cpp @@ -51,103 +51,70 @@ std::shared_ptr create_v11_model(const bool with_axes, return std::make_shared(interpolate->outputs(), model_params); } -} // namespace - -TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales) { - { - function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES); - manager.register_pass(); - } - { - auto attributes = ov::opset4::Interpolate::InterpolateAttrs{}; - attributes.shape_calculation_mode = ov::opset4::Interpolate::ShapeCalcMode::SCALES; - attributes.pads_begin = {0, 0}; - attributes.pads_end = {0, 0}; +std::shared_ptr create_v4_model(const bool with_axes, + const ov::opset4::Interpolate::ShapeCalcMode shape_calc_mode) { + auto attributes = ov::opset4::Interpolate::InterpolateAttrs{}; + attributes.shape_calculation_mode = shape_calc_mode; + attributes.pads_begin = {0, 0}; + attributes.pads_end = {0, 0}; - const auto input = std::make_shared(ov::element::i32, ov::Shape{1, 2, 10, 10}); - const auto output_shape = ov::opset4::Constant::create(ov::element::i32, ov::Shape{}, {1}); - const auto scales = std::make_shared(ov::element::f32, ov::Shape{2}); - const auto axes = std::make_shared(ov::element::i32, ov::Shape{2}); + const auto input = std::make_shared(ov::element::i32, ov::Shape{1, 2, 10, 10}); + std::shared_ptr output_shape; + std::shared_ptr scales; + std::shared_ptr interpolate; - const auto interpolate = - std::make_shared(input, output_shape, scales, axes, attributes); - interpolate->set_friendly_name("interpolate11"); + ov::ParameterVector model_params; + model_params.push_back(input); - function_ref = std::make_shared(interpolate->outputs(), ov::ParameterVector{input, scales, axes}); - } -} + const size_t num_scales_or_sizes = with_axes ? 2 : 4; + if (shape_calc_mode == ov::opset4::Interpolate::ShapeCalcMode::SCALES) { + scales = std::make_shared(ov::element::f32, ov::Shape{num_scales_or_sizes}); + model_params.push_back(std::dynamic_pointer_cast(scales)); + output_shape = ov::opset4::Constant::create(ov::element::i32, ov::Shape{}, {1}); -TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes) { - { - function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES); - manager.register_pass(); + } else { + output_shape = std::make_shared(ov::element::i32, ov::Shape{num_scales_or_sizes}); + model_params.push_back(std::dynamic_pointer_cast(output_shape)); + scales = ov::opset4::Constant::create(ov::element::f32, ov::Shape{}, {1.0f}); } - { - auto attributes = ov::opset4::Interpolate::InterpolateAttrs{}; - attributes.shape_calculation_mode = ov::opset4::Interpolate::ShapeCalcMode::SIZES; - attributes.pads_begin = {0, 0}; - attributes.pads_end = {0, 0}; - - const auto input = std::make_shared(ov::element::i32, ov::Shape{1, 2, 10, 10}); - const auto output_shape = std::make_shared(ov::element::i32, ov::Shape{2}); - const auto scales = ov::opset4::Constant::create(ov::element::f32, ov::Shape{}, {1.0f}); + if (with_axes) { const auto axes = std::make_shared(ov::element::i32, ov::Shape{2}); - - const auto interpolate = - std::make_shared(input, output_shape, scales, axes, attributes); - interpolate->set_friendly_name("interpolate11"); - - function_ref = - std::make_shared(interpolate->outputs(), ov::ParameterVector{input, output_shape, axes}); + model_params.push_back(axes); + interpolate = std::make_shared(input, output_shape, scales, axes, attributes); + } else { + interpolate = std::make_shared(input, output_shape, scales, attributes); } -} + interpolate->set_friendly_name("interpolate11"); -TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_no_axes) { - { - function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES); - manager.register_pass(); - } + return std::make_shared(interpolate->outputs(), model_params); +} - { - auto attributes = ov::opset4::Interpolate::InterpolateAttrs{}; - attributes.shape_calculation_mode = ov::opset4::Interpolate::ShapeCalcMode::SCALES; - attributes.pads_begin = {0, 0}; - attributes.pads_end = {0, 0}; +} // namespace - const auto input = std::make_shared(ov::element::i32, ov::Shape{1, 2, 10, 10}); - const auto output_shape = ov::opset4::Constant::create(ov::element::i32, ov::Shape{}, {1}); - const auto scales = std::make_shared(ov::element::f32, ov::Shape{4}); +TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales) { + manager.register_pass(); + function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES); + function_ref = create_v4_model(WITH_AXES, ov::opset4::Interpolate::ShapeCalcMode::SCALES); +} - const auto interpolate = std::make_shared(input, output_shape, scales, attributes); - interpolate->set_friendly_name("interpolate11"); +TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes) { + manager.register_pass(); + function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES); + function_ref = create_v4_model(WITH_AXES, ov::opset4::Interpolate::ShapeCalcMode::SIZES); +} - function_ref = std::make_shared(interpolate->outputs(), ov::ParameterVector{input, scales}); - } +TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales_no_axes) { + manager.register_pass(); + function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES); + function_ref = create_v4_model(WITHOUT_AXES, ov::opset4::Interpolate::ShapeCalcMode::SCALES); } TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes_no_axes) { - { - function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES); - manager.register_pass(); - } - - { - auto attributes = ov::opset4::Interpolate::InterpolateAttrs{}; - attributes.shape_calculation_mode = ov::opset4::Interpolate::ShapeCalcMode::SIZES; - attributes.pads_begin = {0, 0}; - attributes.pads_end = {0, 0}; - - const auto input = std::make_shared(ov::element::i32, ov::Shape{1, 2, 10, 10}); - const auto output_shape = std::make_shared(ov::element::i32, ov::Shape{4}); - const auto scales = ov::opset4::Constant::create(ov::element::f32, ov::Shape{}, {1.0f}); - - const auto interpolate = std::make_shared(input, output_shape, scales, attributes); - interpolate->set_friendly_name("interpolate11"); - - function_ref = std::make_shared(interpolate->outputs(), ov::ParameterVector{input, output_shape}); - } + manager.register_pass(); + function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES); + function_ref = create_v4_model(WITHOUT_AXES, ov::opset4::Interpolate::ShapeCalcMode::SIZES); } namespace {