diff --git a/src/core/reference/include/openvino/reference/interpolate.hpp b/src/core/reference/include/openvino/reference/interpolate.hpp index 64e21f27f0c87a..bec559691591a2 100644 --- a/src/core/reference/include/openvino/reference/interpolate.hpp +++ b/src/core/reference/include/openvino/reference/interpolate.hpp @@ -204,7 +204,7 @@ class InterpolateEvalHelper final { std::vector output_spatial_shape; }; - InfoForGenericLinearONNXMode get_info_for_generic_linear_onnx(); + InfoForGenericLinearONNXMode get_info_for_generic_linear_onnx(bool channel_last); struct InfoForLinearMode { bool antialias; @@ -408,24 +408,34 @@ void InterpolateEval::linear_onnx_func(const T* input_data, T* out) { ((input_rank == 2) && (num_of_axes == 2) && (m_axes[0] == 0) && (m_axes[1] == 1)) || ((input_rank == 3) && (num_of_axes == 3) && (m_axes[0] == 0) && (m_axes[1] == 1) && (m_axes[2] == 2)); + bool channel_last = false; + if (input_rank >= 4) { std::vector all_axes; std::vector axes_without_batch_and_channels; + std::vector axes_without_batch_and_channels_last; all_axes.push_back(0); all_axes.push_back(1); + axes_without_batch_and_channels_last.push_back(1); for (int64_t i = 2; i < static_cast(input_rank); ++i) { all_axes.push_back(i); axes_without_batch_and_channels.push_back(i); + if (i != static_cast(input_rank) - 1) + axes_without_batch_and_channels_last.push_back(i); } correct_axes = ((num_of_axes == input_rank) && (m_axes == all_axes)) || - ((num_of_axes == input_rank - 2) && (m_axes == axes_without_batch_and_channels)); + ((num_of_axes == input_rank - 2) && + (m_axes == axes_without_batch_and_channels || m_axes == axes_without_batch_and_channels_last)); + + if ((num_of_axes == input_rank - 2) && (m_axes == axes_without_batch_and_channels_last)) + channel_last = true; } if (!correct_axes) OPENVINO_THROW("Axes are not correct!"); - const auto info = helper.get_info_for_generic_linear_onnx(); + const auto info = helper.get_info_for_generic_linear_onnx(channel_last); const int64_t batch_size = info.batch_size; const int64_t num_channels = info.num_channels; @@ -444,9 +454,9 @@ void InterpolateEval::linear_onnx_func(const T* input_data, T* out) { // or // num_of_axes == input_rank - 2. // Hence, if num_of_axes != input_rank, then interpolated axes indices are - // [0, 1, ..., num_of_axes - 1] - // Otherwise, if num_of_axes == input_rank, interpolated axes indices are // [2, 3, ..., num_of_axes - 1] + // Otherwise, if num_of_axes == input_rank, interpolated axes indices are + // [0, 1, ..., num_of_axes - 1] const int64_t axis_idx_offset = (input_rank == num_of_axes) ? 2 : 0; const int64_t spatial_rank = info.spatial_rank; @@ -454,8 +464,12 @@ void InterpolateEval::linear_onnx_func(const T* input_data, T* out) { const T* xdata = input_data; T* ydata = out; + + const auto loop_channels = channel_last ? 1 : num_channels; + const auto last_spatial_divider = channel_last ? num_channels : 1; + for (int64_t n = 0; n < batch_size; ++n) { - for (int64_t c = 0; c < num_channels; ++c) { + for (int64_t c = 0; c < loop_channels; ++c) { for (int64_t idx = 0; idx < output_data_ptr_increment; ++idx) { // 1. Get the current spatial coords vector. std::vector output_coords(spatial_rank); @@ -464,7 +478,7 @@ void InterpolateEval::linear_onnx_func(const T* input_data, T* out) { output_coords[j] = curr / output_index_multipliers[j]; curr %= output_index_multipliers[j]; } - output_coords[spatial_rank - 1] = curr; + output_coords[spatial_rank - 1] = curr / last_spatial_divider; // 2. Some preliminaries. std::vector in1(spatial_rank); @@ -519,7 +533,6 @@ void InterpolateEval::linear_onnx_func(const T* input_data, T* out) { // 6. Store result. ydata[idx] = static_cast(sum); } - xdata += input_data_ptr_increment; ydata += output_data_ptr_increment; } diff --git a/src/core/reference/src/op/interpolate.cpp b/src/core/reference/src/op/interpolate.cpp index ff9bf20eb1a293..023d56a3a7a28c 100644 --- a/src/core/reference/src/op/interpolate.cpp +++ b/src/core/reference/src/op/interpolate.cpp @@ -40,7 +40,8 @@ Coordinate InterpolateEvalHelper::get_input_coords_for_nearest_mode(const Coordi return input_coord; } -InterpolateEvalHelper::InfoForGenericLinearONNXMode InterpolateEvalHelper::get_info_for_generic_linear_onnx() { +InterpolateEvalHelper::InfoForGenericLinearONNXMode InterpolateEvalHelper::get_info_for_generic_linear_onnx( + bool channel_last = false) { InfoForGenericLinearONNXMode result; std::size_t input_rank = m_input_data_shape.size(); @@ -62,30 +63,45 @@ InterpolateEvalHelper::InfoForGenericLinearONNXMode InterpolateEvalHelper::get_i output_shape = m_out_shape; } - int64_t batch_size = input_shape[0]; - int64_t num_channels = input_shape[1]; - std::size_t spatial_rank = input_shape.size() - 2; - std::vector input_index_multipliers(spatial_rank); std::vector output_index_multipliers(spatial_rank); - input_index_multipliers[spatial_rank - 1] = 1; - output_index_multipliers[spatial_rank - 1] = 1; + + int64_t input_data_ptr_increment; + int64_t output_data_ptr_increment; + + const auto mutipliers_offset = channel_last ? 2 : 3; + const auto spatial_offset = channel_last ? 1 : 2; + const auto ptr_offset = channel_last ? 1 : 2; + + int64_t batch_size = input_shape[0]; + int64_t num_channels; + if (channel_last) { + num_channels = input_shape[input_shape.size() - 1]; + input_index_multipliers[spatial_rank - 1] = num_channels; + output_index_multipliers[spatial_rank - 1] = num_channels; + } else { + num_channels = input_shape[1]; + input_index_multipliers[spatial_rank - 1] = 1; + output_index_multipliers[spatial_rank - 1] = 1; + } for (int64_t i = static_cast(spatial_rank) - 2; i >= 0; --i) { - input_index_multipliers[i] = input_index_multipliers[i + 1] * static_cast(input_shape[i + 3]); - output_index_multipliers[i] = output_index_multipliers[i + 1] * static_cast(output_shape[i + 3]); + input_index_multipliers[i] = + input_index_multipliers[i + 1] * static_cast(input_shape[i + mutipliers_offset]); + output_index_multipliers[i] = + output_index_multipliers[i + 1] * static_cast(output_shape[i + mutipliers_offset]); } - int64_t input_data_ptr_increment = input_index_multipliers[0] * static_cast(input_shape[2]); - int64_t output_data_ptr_increment = output_index_multipliers[0] * static_cast(output_shape[2]); + input_data_ptr_increment = input_index_multipliers[0] * static_cast(input_shape[ptr_offset]); + output_data_ptr_increment = output_index_multipliers[0] * static_cast(output_shape[ptr_offset]); std::vector input_spatial_shape(spatial_rank); std::vector output_spatial_shape(spatial_rank); for (size_t i = 0; i < spatial_rank; ++i) { - input_spatial_shape[i] = static_cast(input_shape[i + 2]); - output_spatial_shape[i] = static_cast(output_shape[i + 2]); + input_spatial_shape[i] = static_cast(input_shape[i + spatial_offset]); + output_spatial_shape[i] = static_cast(output_shape[i + spatial_offset]); } result.input_data_ptr_increment = input_data_ptr_increment; diff --git a/src/plugins/template/tests/functional/op_reference/interpolate.cpp b/src/plugins/template/tests/functional/op_reference/interpolate.cpp index 1363ae788a52d2..df0c2806f4594e 100644 --- a/src/plugins/template/tests/functional/op_reference/interpolate.cpp +++ b/src/plugins/template/tests/functional/op_reference/interpolate.cpp @@ -640,6 +640,21 @@ std::vector generateParamsForInterpolate_v4_linear_onnx NearestMode::ROUND_PREFER_FLOOR}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, {1.0f, 4.0f}, + }, + { "linear_onnx.resize_downsample_scales_linear_align_corners_channel_last", + Shape{1, 2, 4, 1}, + {1, 2}, + Shape{1, 1, 2, 1}, + {0.6f, 0.6f}, + {1, 2}, + { InterpolateMode::LINEAR_ONNX, + ShapeCalcMode::SCALES, + zero_pads, + zero_pads, + CoordinateTransformMode::ALIGN_CORNERS, + NearestMode::ROUND_PREFER_FLOOR}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, + {1.0f, 4.0f}, } }; // clang-format on