Skip to content

Commit

Permalink
[TF FE] Simplify ResizeBilinear and ResizeNearestNeighbor translators (
Browse files Browse the repository at this point in the history
…#19099)

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Aug 9, 2023
1 parent 37eef6e commit 4ee47fc
Showing 1 changed file with 25 additions and 29 deletions.
54 changes: 25 additions & 29 deletions src/frontends/tensorflow_common/src/op/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/interpolate.hpp"

#include "common_op_table.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"

using namespace std;
using namespace ov;
using namespace ov::opset8;
using namespace ov::op;

namespace ov {
namespace frontend {
Expand All @@ -30,58 +36,48 @@ OutputVector translate_interpolate_op(const NodeContext& node) {
" is True, the attribute align_corners must be False.");

// prepare attributes for OpenVINO Interpolate operation
Interpolate::InterpolateAttrs interpolate_attrs;
interpolate_attrs.shape_calculation_mode = Interpolate::ShapeCalcMode::SIZES;
v11::Interpolate::InterpolateAttrs interpolate_attrs;
interpolate_attrs.shape_calculation_mode = v11::Interpolate::ShapeCalcMode::SIZES;
if (op_type == "ResizeNearestNeighbor") {
interpolate_attrs.mode = Interpolate::InterpolateMode::NEAREST;
interpolate_attrs.nearest_mode = Interpolate::NearestMode::FLOOR;
interpolate_attrs.mode = v11::Interpolate::InterpolateMode::NEAREST;
interpolate_attrs.nearest_mode = v11::Interpolate::NearestMode::FLOOR;
} else if (op_type == "ResizeBilinear") {
auto input_rank = images.get_partial_shape().rank();
if (input_rank.is_static() && input_rank.get_length() == 4) {
interpolate_attrs.mode = Interpolate::InterpolateMode::LINEAR_ONNX;
interpolate_attrs.mode = v11::Interpolate::InterpolateMode::LINEAR_ONNX;
} else {
interpolate_attrs.mode = Interpolate::InterpolateMode::LINEAR;
interpolate_attrs.mode = v11::Interpolate::InterpolateMode::LINEAR;
}
interpolate_attrs.nearest_mode = Interpolate::NearestMode::ROUND_PREFER_FLOOR;
interpolate_attrs.nearest_mode = v11::Interpolate::NearestMode::ROUND_PREFER_FLOOR;
}

if (tf_align_corners) {
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
if (interpolate_attrs.mode == Interpolate::InterpolateMode::NEAREST) {
interpolate_attrs.nearest_mode = Interpolate::NearestMode::ROUND_PREFER_CEIL;
interpolate_attrs.coordinate_transformation_mode = v11::Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
if (interpolate_attrs.mode == v11::Interpolate::InterpolateMode::NEAREST) {
interpolate_attrs.nearest_mode = v11::Interpolate::NearestMode::ROUND_PREFER_CEIL;
}
} else if (tf_half_pixel_centers) {
if (interpolate_attrs.mode == Interpolate::InterpolateMode::NEAREST) {
if (interpolate_attrs.mode == v11::Interpolate::InterpolateMode::NEAREST) {
interpolate_attrs.coordinate_transformation_mode =
Interpolate::CoordinateTransformMode::TF_HALF_PIXEL_FOR_NN;
v11::Interpolate::CoordinateTransformMode::TF_HALF_PIXEL_FOR_NN;
} else {
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::HALF_PIXEL;
interpolate_attrs.coordinate_transformation_mode = v11::Interpolate::CoordinateTransformMode::HALF_PIXEL;
}
} else {
interpolate_attrs.coordinate_transformation_mode = Interpolate::CoordinateTransformMode::ASYMMETRIC;
interpolate_attrs.coordinate_transformation_mode = v11::Interpolate::CoordinateTransformMode::ASYMMETRIC;
}

// prepare scales input
auto images_shape = make_shared<ShapeOf>(images, element::i32);
auto spatial_shape = make_shared<Slice>(images_shape,
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1}),
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{3}),
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{1}),
make_shared<Constant>(element::i64, Shape{1}, std::vector<int64_t>{0}));
auto scales = make_shared<Divide>(make_shared<Convert>(size, element::f32),
make_shared<Convert>(spatial_shape, element::f32));

// since Interpolate is layout agnostic
// we can avoid Transpose operation by specifying axes = {1, 2} for original NHWC layout
auto axes = make_shared<Constant>(element::i32, Shape{2}, std::vector<int>({1, 2}));
auto axes = make_shared<v0::Constant>(element::i32, Shape{2}, std::vector<int>({1, 2}));

// according to the specification of ResizeBilinear,
// it always returns FP32 output type so we immediately align input type for it
if (op_type == "ResizeBilinear") {
images = make_shared<Convert>(images, element::f32);
images = make_shared<v0::Convert>(images, element::f32);
}

auto interpolate = make_shared<Interpolate>(images, size, scales, axes, interpolate_attrs);
auto interpolate = make_shared<v11::Interpolate>(images, size, axes, interpolate_attrs);
set_node_name(node.get_name(), interpolate);
return {interpolate};
}
Expand Down

0 comments on commit 4ee47fc

Please sign in to comment.