Skip to content

Commit

Permalink
Refactor #2
Browse files Browse the repository at this point in the history
  • Loading branch information
tomdol committed Mar 21, 2023
1 parent 987f4a4 commit 330e25a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "transformations/op_conversions/convert_interpolate11_downgrade.hpp"

#include <array>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <openvino/opsets/opset11.hpp>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,103 +51,70 @@ std::shared_ptr<ov::Model> create_v11_model(const bool with_axes,

return std::make_shared<ov::Model>(interpolate->outputs(), model_params);
}
} // namespace

TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales) {
{
function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES);
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
}

{
auto attributes = ov::opset4::Interpolate::InterpolateAttrs{};
attributes.shape_calculation_mode = ov::opset4::Interpolate::ShapeCalcMode::SCALES;
attributes.pads_begin = {0, 0};
attributes.pads_end = {0, 0};
std::shared_ptr<ov::Model> create_v4_model(const bool with_axes,
const ov::opset4::Interpolate::ShapeCalcMode shape_calc_mode) {
auto attributes = ov::opset4::Interpolate::InterpolateAttrs{};
attributes.shape_calculation_mode = shape_calc_mode;
attributes.pads_begin = {0, 0};
attributes.pads_end = {0, 0};

const auto input = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
const auto output_shape = ov::opset4::Constant::create(ov::element::i32, ov::Shape{}, {1});
const auto scales = std::make_shared<ov::opset4::Parameter>(ov::element::f32, ov::Shape{2});
const auto axes = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{2});
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
std::shared_ptr<ov::Node> output_shape;
std::shared_ptr<ov::Node> scales;
std::shared_ptr<ov::opset4::Interpolate> interpolate;

const auto interpolate =
std::make_shared<ov::opset4::Interpolate>(input, output_shape, scales, axes, attributes);
interpolate->set_friendly_name("interpolate11");
ov::ParameterVector model_params;
model_params.push_back(input);

function_ref = std::make_shared<ov::Model>(interpolate->outputs(), ov::ParameterVector{input, scales, axes});
}
}
const size_t num_scales_or_sizes = with_axes ? 2 : 4;
if (shape_calc_mode == ov::opset4::Interpolate::ShapeCalcMode::SCALES) {
scales = std::make_shared<ov::opset4::Parameter>(ov::element::f32, ov::Shape{num_scales_or_sizes});
model_params.push_back(std::dynamic_pointer_cast<ov::opset4::Parameter>(scales));
output_shape = ov::opset4::Constant::create(ov::element::i32, ov::Shape{}, {1});

TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes) {
{
function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES);
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
} else {
output_shape = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{num_scales_or_sizes});
model_params.push_back(std::dynamic_pointer_cast<ov::opset4::Parameter>(output_shape));
scales = ov::opset4::Constant::create(ov::element::f32, ov::Shape{}, {1.0f});
}

{
auto attributes = ov::opset4::Interpolate::InterpolateAttrs{};
attributes.shape_calculation_mode = ov::opset4::Interpolate::ShapeCalcMode::SIZES;
attributes.pads_begin = {0, 0};
attributes.pads_end = {0, 0};

const auto input = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
const auto output_shape = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{2});
const auto scales = ov::opset4::Constant::create(ov::element::f32, ov::Shape{}, {1.0f});
if (with_axes) {
const auto axes = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{2});

const auto interpolate =
std::make_shared<ov::opset4::Interpolate>(input, output_shape, scales, axes, attributes);
interpolate->set_friendly_name("interpolate11");

function_ref =
std::make_shared<ov::Model>(interpolate->outputs(), ov::ParameterVector{input, output_shape, axes});
model_params.push_back(axes);
interpolate = std::make_shared<ov::opset4::Interpolate>(input, output_shape, scales, axes, attributes);
} else {
interpolate = std::make_shared<ov::opset4::Interpolate>(input, output_shape, scales, attributes);
}
}
interpolate->set_friendly_name("interpolate11");

TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_no_axes) {
{
function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES);
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
}
return std::make_shared<ov::Model>(interpolate->outputs(), model_params);
}

{
auto attributes = ov::opset4::Interpolate::InterpolateAttrs{};
attributes.shape_calculation_mode = ov::opset4::Interpolate::ShapeCalcMode::SCALES;
attributes.pads_begin = {0, 0};
attributes.pads_end = {0, 0};
} // namespace

const auto input = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
const auto output_shape = ov::opset4::Constant::create(ov::element::i32, ov::Shape{}, {1});
const auto scales = std::make_shared<ov::opset4::Parameter>(ov::element::f32, ov::Shape{4});
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales) {
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES);
function_ref = create_v4_model(WITH_AXES, ov::opset4::Interpolate::ShapeCalcMode::SCALES);
}

const auto interpolate = std::make_shared<ov::opset4::Interpolate>(input, output_shape, scales, attributes);
interpolate->set_friendly_name("interpolate11");
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes) {
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
function = create_v11_model(WITH_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES);
function_ref = create_v4_model(WITH_AXES, ov::opset4::Interpolate::ShapeCalcMode::SIZES);
}

function_ref = std::make_shared<ov::Model>(interpolate->outputs(), ov::ParameterVector{input, scales});
}
TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_scales_no_axes) {
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SCALES);
function_ref = create_v4_model(WITHOUT_AXES, ov::opset4::Interpolate::ShapeCalcMode::SCALES);
}

TEST_F(TransformationTestsF, ConvertInterpolate11ToInterpolate4_sizes_no_axes) {
{
function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES);
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
}

{
auto attributes = ov::opset4::Interpolate::InterpolateAttrs{};
attributes.shape_calculation_mode = ov::opset4::Interpolate::ShapeCalcMode::SIZES;
attributes.pads_begin = {0, 0};
attributes.pads_end = {0, 0};

const auto input = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{1, 2, 10, 10});
const auto output_shape = std::make_shared<ov::opset4::Parameter>(ov::element::i32, ov::Shape{4});
const auto scales = ov::opset4::Constant::create(ov::element::f32, ov::Shape{}, {1.0f});

const auto interpolate = std::make_shared<ov::opset4::Interpolate>(input, output_shape, scales, attributes);
interpolate->set_friendly_name("interpolate11");

function_ref = std::make_shared<ov::Model>(interpolate->outputs(), ov::ParameterVector{input, output_shape});
}
manager.register_pass<ov::pass::ConvertInterpolate11ToInterpolate4>();
function = create_v11_model(WITHOUT_AXES, ov::opset11::Interpolate::ShapeCalcMode::SIZES);
function_ref = create_v4_model(WITHOUT_AXES, ov::opset4::Interpolate::ShapeCalcMode::SIZES);
}

namespace {
Expand Down

0 comments on commit 330e25a

Please sign in to comment.