Skip to content

Commit

Permalink
ConvertInterpolate1ToInterpolate4 fixes (#6019)
Browse files Browse the repository at this point in the history
* half_pixel -> asymmetric and round_prefer_floor -> simple in ConvertInterpolate1ToInterpolate4

* test fix
  • Loading branch information
Maxim Andronov authored Jun 8, 2021
1 parent 4409a74 commit 98f45ff
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,20 @@ ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolat
return false;
}
attrsV4.shape_calculation_mode = ngraph::opset4::Interpolate::ShapeCalcMode::sizes;
attrsV4.nearest_mode = ngraph::opset4::Interpolate::NearestMode::round_prefer_floor;
attrsV4.nearest_mode = ngraph::opset4::Interpolate::NearestMode::simple;
attrsV4.pads_begin = attrsV0.pads_begin;
attrsV4.pads_end = attrsV0.pads_end;
attrsV4.antialias = attrsV0.antialias;
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::half_pixel;
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::asymmetric;
attrsV4.cube_coeff = -0.75f;
if (attrsV0.align_corners) {
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::align_corners;
} else if ((attrsV4.mode == ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx ||
attrsV4.mode == ngraph::op::v4::Interpolate::InterpolateMode::linear) &&
std::all_of(attrsV4.pads_begin.begin(), attrsV4.pads_begin.end(), [](size_t i){return i == 0;}) &&
std::all_of(attrsV4.pads_end.begin(), attrsV4.pads_end.end(), [](size_t i){return i == 0;}) &&
!(input_shape_rank - 2 == 2 && attrsV0.axes == AxisSet{2, 3})) {
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::half_pixel;
}

auto interpolateV4 = std::make_shared<ngraph::opset4::Interpolate>(interpolationV0->input_value(0), interpolationV0->input_value(1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) {

auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::nearest,
opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::floor,
opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::simple,
false, -0.75);

auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr);

f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
}

auto res = compare_functions(f, f_ref);
auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second;
}

Expand Down Expand Up @@ -97,16 +97,16 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4_1) {
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {4.0f / 3.0f, 4.0f / 3.0f});
auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{2}, {2, 3});

auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear,
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear_onnx,
opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
opset4::Interpolate::CoordinateTransformMode::align_corners, opset4::Interpolate::NearestMode::floor,
false, -0.75);
opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::simple,
true, -0.75);

auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr);

f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
}

auto res = compare_functions(f, f_ref);
auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second;
}

0 comments on commit 98f45ff

Please sign in to comment.