Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement reference nGraph implementation for "Interpolate-4" with 5D tensor support in the "linear_onnx" mode #3948

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
066daac
Commit.
vgavrilo Jul 30, 2020
c212c50
Merge remote-tracking branch 'upstream/master'
vgavrilo Jul 30, 2020
228af66
Merge remote-tracking branch 'upstream/master'
vgavrilo Aug 15, 2020
66b99ac
Merge remote-tracking branch 'upstream/master'
vgavrilo Aug 20, 2020
62b2452
Merge remote-tracking branch 'upstream/master'
vgavrilo Aug 26, 2020
dd5a343
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 1, 2020
146cfcb
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 2, 2020
11cfd32
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 2, 2020
f135fdc
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 4, 2020
14b3b49
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 11, 2020
29d798a
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 16, 2020
5aa69a3
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 16, 2020
3754f40
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 17, 2020
a211ce8
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 18, 2020
e7ae609
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 21, 2020
2ed2d5c
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 21, 2020
bdbfb81
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 25, 2020
29cfcfc
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 28, 2020
e64b285
Merge remote-tracking branch 'upstream/master'
vgavrilo Sep 29, 2020
ebf97c4
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 6, 2020
42de39d
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 13, 2020
88daad2
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 15, 2020
7084963
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 24, 2020
7f7de48
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 26, 2020
e58295d
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 28, 2020
429cf23
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 29, 2020
b9da719
Merge branch 'master' of https://github.com/vgavrilo/openvino
vgavrilo Oct 29, 2020
b9f8f9f
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 30, 2020
f61aef8
Merge remote-tracking branch 'upstream/master'
vgavrilo Oct 30, 2020
1b8a2c5
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 2, 2020
6676840
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 3, 2020
a317f50
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 9, 2020
c43795b
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 9, 2020
f55363c
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 10, 2020
ec40ea5
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 11, 2020
131eb6d
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 12, 2020
65f385f
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 16, 2020
d96448d
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 19, 2020
21716b2
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 20, 2020
f74d471
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 24, 2020
1608f4b
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 24, 2020
301148d
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 24, 2020
bc71ef8
Merge remote-tracking branch 'upstream/master'
vgavrilo Nov 26, 2020
a3261cd
Merge remote-tracking branch 'upstream/master'
vgavrilo Dec 1, 2020
748cef0
Merge remote-tracking branch 'upstream/master'
vgavrilo Dec 14, 2020
eabfc05
Merge remote-tracking branch 'upstream/master'
vgavrilo Dec 18, 2020
44dd347
Merge remote-tracking branch 'upstream/master'
vgavrilo Dec 23, 2020
4f979e0
Merge remote-tracking branch 'upstream/master'
vgavrilo Dec 25, 2020
3f1a97b
Merge remote-tracking branch 'upstream/master'
vgavrilo Dec 30, 2020
555f63c
Merge remote-tracking branch 'upstream/master'
vgavrilo Dec 31, 2020
f3d8290
Merge remote-tracking branch 'upstream/master'
vgavrilo Jan 11, 2021
2621b8c
Merge remote-tracking branch 'upstream/master'
vgavrilo Jan 14, 2021
ed5fe6f
Merge remote-tracking branch 'upstream/master'
vgavrilo Jan 19, 2021
cab4024
Merge remote-tracking branch 'upstream/master'
vgavrilo Jan 20, 2021
7fa0602
Merge remote-tracking branch 'upstream/master'
vgavrilo Jan 20, 2021
7832311
Written the structure InfoForLinearONNXMode5D that contains info to p…
vgavrilo Jan 21, 2021
94f92da
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 21, 2021
ea275ae
Started to write the method get_info_for_linear_onnx_mode5D() that re…
vgavrilo Jan 21, 2021
ef47060
Written the method InterpolateEvalHelper::get_info_for_linear_onnx_mo…
vgavrilo Jan 21, 2021
4f16e4c
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 21, 2021
b2e975d
Code style fix.
vgavrilo Jan 21, 2021
a58598f
Started to write calculation of 5D case of 'linear_onnx' mode.
vgavrilo Jan 21, 2021
f1a9970
Written the method void InterpolateEval<T>::linear_onnx5D_func(const …
vgavrilo Jan 21, 2021
20dfc9f
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 21, 2021
36a97e2
Added dispatching of 4D/5D cases of the mode 'linear_onnx'.
vgavrilo Jan 21, 2021
0584300
Fixed code style.
vgavrilo Jan 21, 2021
5c78fe1
Some fixes.
vgavrilo Jan 21, 2021
dd01d00
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 21, 2021
8bbc412
Code style fixes.
vgavrilo Jan 21, 2021
a111956
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 21, 2021
75253de
Now linear_onnx_func throws an exception for incorrect input rank.
vgavrilo Jan 21, 2021
4d34d64
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 21, 2021
84b9633
Code style fix.
vgavrilo Jan 21, 2021
26889c7
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 21, 2021
60dc462
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 22, 2021
497d284
Started to write tests for evaluation of 'linear_onnx' mode in the 5D…
vgavrilo Jan 22, 2021
fd35513
Added first test for linear_onnx 5D.
vgavrilo Jan 22, 2021
161e30e
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 22, 2021
7cc1865
Small fixes.
vgavrilo Jan 22, 2021
1eaac70
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 25, 2021
4b265fa
Written tests for evaluation of Interpolate-4 in linear_onnx 5D case.
vgavrilo Jan 25, 2021
67c973d
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 25, 2021
cb5dfc3
Merge branch 'vgavrilo/5d-linear_onnx-ref-impl' of https://github.com…
vgavrilo Jan 25, 2021
3ac3478
Some code style fixes.
vgavrilo Jan 25, 2021
78be0a2
Small fix.
vgavrilo Jan 25, 2021
82ffdb4
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 25, 2021
922f3ef
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 25, 2021
3c2dd6d
Corrected documentation.
vgavrilo Jan 25, 2021
4f5c4d9
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Jan 28, 2021
dba6e45
Started to write generic implementation of 'linear_onnx' mode, for an…
vgavrilo Jan 28, 2021
0c5471e
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Feb 3, 2021
84bac71
Written the draft of a generic (for all ranks) implementation of 'lin…
vgavrilo Feb 3, 2021
1cd2448
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Feb 3, 2021
ee482a0
Small fixes.
vgavrilo Feb 3, 2021
76b36bb
Small fix.
vgavrilo Feb 3, 2021
22b561a
Small fix.
vgavrilo Feb 3, 2021
15b3ec2
Small fix.
vgavrilo Feb 3, 2021
ad5d0f5
Code style fix.
vgavrilo Feb 4, 2021
730c47c
Small fix.
vgavrilo Feb 4, 2021
7347868
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Feb 4, 2021
af07a9f
Code style fix.
vgavrilo Feb 4, 2021
217d215
Some fixes.
vgavrilo Feb 4, 2021
aef2868
Some fix.
vgavrilo Feb 4, 2021
6c753bc
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Feb 4, 2021
d9206dd
Small fix.
vgavrilo Feb 4, 2021
3cdc6eb
Small fix.
vgavrilo Feb 4, 2021
6475cbd
Code style fix.
vgavrilo Feb 4, 2021
f28a8f1
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Feb 5, 2021
6793b21
Added check for axes correctness into a generic implementation of the…
vgavrilo Feb 5, 2021
fd7facc
Now 5D case of the 'linear_onnx' mode is calculated using generic fun…
vgavrilo Feb 5, 2021
f7146fd
Code style fix.
vgavrilo Feb 5, 2021
08021ea
Deleted unused variable.
vgavrilo Feb 5, 2021
e602144
Resolved merge conflict.
vgavrilo Feb 7, 2021
fac1511
Added debug prints.
vgavrilo Feb 7, 2021
1a11d00
Small fix.
vgavrilo Feb 7, 2021
a4d41fb
Some fixes.
vgavrilo Feb 7, 2021
a988ad4
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Feb 7, 2021
b4387d5
Code style fix.
vgavrilo Feb 7, 2021
bc49903
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Feb 8, 2021
35bf6a8
Now all ranks are processed by a generic implementation in the 'linea…
vgavrilo Feb 8, 2021
e891ea5
Deleted name of missed test.
vgavrilo Feb 8, 2021
bd090ec
Deleted 4D case implementation of the 'linear_onnx' mode.
vgavrilo Feb 8, 2021
46c00d4
Reverted change in tests.
vgavrilo Feb 8, 2021
0f37e75
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Feb 8, 2021
69c658e
Added needed 'const' modifiers and added a comment about the variable…
vgavrilo Feb 8, 2021
5269f08
Merge remote-tracking branch 'upstream/master' into vgavrilo/5d-linea…
vgavrilo Feb 8, 2021
944f3b0
Small fixes.
vgavrilo Feb 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 123 additions & 2 deletions docs/ops/image/Interpolate_4.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* **Type**: string
* **Default value**: none
* **Required**: *yes*
* **Note**: Only 2D and 4D tensors with `axes = {0, 1}` and `axes = {2, 3}` respectively are supported for `"mode" == "linear_onnx"`.
**Note**: Only 2D, 3D, 4D, 5D tensors with `axes = {0, 1}`, `axes = {0, 1, 2}`, `axes = {2, 3}`, `axes = {2, 3, 4}` respectively are supported for `"mode" == "linear_onnx"`.
vgavrilo marked this conversation as resolved.
Show resolved Hide resolved

* *shape_calculation_mode*

Expand Down Expand Up @@ -366,7 +366,119 @@ class InterpolateCalculation:

return result

def onnx_linear_interpolation(self, input_data):
def onnx_linear_interpolation5D(self, input_data):
rank = len(self.input_shape)
assert rank in [3, 5], "mode 'linear_onnx' supports only 3D or 5D tensors"
assert set(self.axes) == {2, 3, 4} or set(self.axes) == {0, 1, 2}, \
"mode 'linear_onnx' supports only case when axes = {2, 3, 4} or axes = {0, 1, 2}"

result = np.zeros(self.output_shape)

if rank == 3:
reshaped_data = np.reshape(input_data, (1, 1, self.input_shape[0], self.input_shape[1], self.input_shape[2]))
result = np.reshape(result, (1, 1, self.output_shape[0], self.output_shape[1], self.output_shape[2]))
else:
reshaped_data = input_data

input_shape = np.array(reshaped_data.shape).astype(np.int64)
output_shape = np.array(result.shape).astype(np.int64)

batch_size = input_shape[0];
num_channels = input_shape[1];
input_depth = input_shape[2];
input_height = input_shape[3];
input_width = input_shape[4];
output_depth = output_shape[2];
output_height = output_shape[3];
output_width = output_shape[4];

depth_scale = self.scales[0];
height_scale = self.scales[1];
width_scale = self.scales[2];

z_original = np.zeros(output_depth).astype(np.float)
y_original = np.zeros(output_height).astype(np.float)
x_original = np.zeros(output_width).astype(np.float)

in_z1 = np.zeros(output_depth).astype(np.int64)
in_z2 = np.zeros(output_depth).astype(np.int64)
in_y1 = np.zeros(output_height).astype(np.int64)
in_y2 = np.zeros(output_height).astype(np.int64)
in_x1 = np.zeros(output_width).astype(np.int64)
in_x2 = np.zeros(output_width).astype(np.int64)

dz1 = np.zeros(output_depth).astype(np.float)
dz2 = np.zeros(output_depth).astype(np.float)

dy1 = np.zeros(output_height).astype(np.float)
dy2 = np.zeros(output_height).astype(np.float)

dx1 = np.zeros(output_width).astype(np.float)
dx2 = np.zeros(output_width).astype(np.float)

for z in range(0, output_depth):
in_z = self.get_original_coordinate(z, depth_scale, output_depth, input_depth)
z_original[z] = in_z
in_z = max(0, min(in_z, input_depth - 1))
in_z1[z] = max(0, min(int(in_z), input_depth - 1))
in_z2[z] = min(in_z1[z] + 1, input_depth - 1)
dz1[z] = abs(in_z - in_z1[z])
dz2[z] = abs(in_z - in_z2[z])

if in_z1[z] == in_z2[z]:
dz1[z] = 0.5
dz2[z] = 0.5

for y in range(0, output_height):
in_y = self.get_original_coordinate(y, height_scale, output_height, input_height)
y_original[y] = in_y
in_y = max(0, min(in_y, input_height - 1))
in_y1[y] = max(0, min(int(in_y), input_height - 1))
in_y2[y] = min(in_y1[y] + 1, input_height - 1)
dy1[y] = abs(in_y - in_y1[y])
dy2[y] = abs(in_y - in_y2[y])

if in_y1[y] == in_y2[y]:
dy1[y] = 0.5
dy2[y] = 0.5

for x in range(0, output_width):
in_x = self.get_original_coordinate(x, width_scale, output_width, input_width);
x_original[x] = in_x
in_x = max(0.0, min(in_x, input_width - 1));

in_x1[x] = min(in_x, input_width - 1);
in_x2[x] = min(in_x1[x] + 1, input_width - 1);

dx1[x] = abs(in_x - in_x1[x]);
dx2[x] = abs(in_x - in_x2[x]);
if in_x1[x] == in_x2[x]:
dx1[x] = 0.5
dx2[x] = 0.5
for n in range(0, batch_size):
for c in range(0, num_channels):
for z in range(0, output_depth):
for y in range(0, output_height):
for x in range(0, output_width):
x111 = reshaped_data[n, c, in_z1[z], in_y1[y], in_x1[x]]
x211 = reshaped_data[n, c, in_z1[z], in_y1[y], in_x2[x]]
x121 = reshaped_data[n, c, in_z1[z], in_y2[y], in_x1[x]]
x221 = reshaped_data[n, c, in_z1[z], in_y2[y], in_x2[x]]
x112 = reshaped_data[n, c, in_z2[z], in_y1[y], in_x1[x]]
x212 = reshaped_data[n, c, in_z2[z], in_y1[y], in_x2[x]]
x122 = reshaped_data[n, c, in_z2[z], in_y2[y], in_x1[x]]
x222 = reshaped_data[n, c, in_z2[z], in_y2[y], in_x2[x]]

temp = dx2[x] * dy2[y] * dz2[z] * x111 + dx1[x] * dy2[y] * dz2[z] * x211
temp += dx2[x] * dy1[y] * dz2[z] * x121 + dx1[x] * dy1[y] * dz2[z] * x221
temp += dx2[x] * dy2[y] * dz1[z] * x112 + dx1[x] * dy2[y] * dz1[z] * x212
temp += dx2[x] * dy1[y] * dz1[z] * x122 + dx1[x] * dy1[y] * dz1[z] * x222

result[n, c, z, y, x] = temp

return np.reshape(result, self.output_shape)

def onnx_linear_interpolation4D(self, input_data):
rank = len(self.input_shape)
assert rank in [2, 4], "mode 'linear_onnx' supports only 2D or 4D tensors"
assert set(self.axes) == {2, 3} or set(self.axes) == {0, 1}, \
Expand Down Expand Up @@ -446,6 +558,15 @@ class InterpolateCalculation:

return np.reshape(result, self.output_shape)

def onnx_linear_interpolation(self, input_data):
rank = len(self.input_shape)
assert rank in [2, 3, 4, 5], "mode 'linear_onnx' supports only 2D, 3D, 4D, or 5D tensors"

if rank in [2, 4]:
self.onnx_linear_interpolation4D(input_data)
else:
self.onnx_linear_interpolation5D(input_data)

def nearest_interpolation(self, input_data):
result = np.zeros(self.output_shape)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,30 +244,20 @@ namespace ngraph

Coordinate get_input_coords_for_nearest_mode(const Coordinate& output_coord);

struct InfoForLinearONNXMode
struct InfoForGenericLinearONNXMode
{
std::vector<float> y_original;
std::vector<float> x_original;

std::vector<int64_t> input_width_mul_y1;
std::vector<int64_t> input_width_mul_y2;
std::vector<int64_t> in_x1;
std::vector<int64_t> in_x2;

std::vector<float> dy1;
std::vector<float> dy2;
std::vector<float> dx1;
std::vector<float> dx2;

int64_t input_data_ptr_increment;
int64_t output_data_ptr_increment;
int64_t batch_size;
int64_t num_channels;
int64_t input_height;
int64_t input_width;
int64_t output_height;
int64_t output_width;
int64_t spatial_rank;
std::vector<int64_t> input_index_multipliers;
std::vector<int64_t> output_index_multipliers;
std::vector<int64_t> input_spatial_shape;
std::vector<int64_t> output_spatial_shape;
};

InfoForLinearONNXMode get_info_for_linear_onnx_mode();
InfoForGenericLinearONNXMode get_info_for_generic_linear_onnx();

struct InfoForLinearMode
{
Expand Down Expand Up @@ -392,7 +382,7 @@ namespace ngraph
/// \param out pointer to memory block for output data
void linear_func(const T* input_data, T* out);

/// \brief Calculates interpolation as in ONNX 'linear' mode
/// \brief Calculates interpolation as in ONNX 'linear' mode (generic case)
///
/// \param input_data pointer to input data
/// \param out pointer to memory block for output data
Expand Down Expand Up @@ -456,56 +446,152 @@ namespace ngraph
template <typename T>
void InterpolateEval<T>::linear_onnx_func(const T* input_data, T* out)
{
size_t input_rank = m_input_data_shape.size();
size_t num_of_axes = m_axes.size();
const size_t input_rank = m_input_data_shape.size();

assert(input_rank > 1);

assert((input_rank == 2) || (input_rank == 4));
assert((num_of_axes == 2) || (num_of_axes == input_rank));
const size_t num_of_axes = m_axes.size();

bool correct_axes = ((m_axes[0] == 0) && (m_axes[1] == 1)) ||
((m_axes[0] == 2) && (m_axes[1] == 3));
bool correct_axes = ((input_rank == 2) && (num_of_axes == 2) && (m_axes[0] == 0) &&
(m_axes[1] == 1)) ||
((input_rank == 3) && (num_of_axes == 3) && (m_axes[0] == 0) &&
(m_axes[1] == 1) && (m_axes[2] == 2));

if ((num_of_axes == 4) && (input_rank == 4))
if (input_rank >= 4)
{
correct_axes = (m_axes[0] == 0) && (m_axes[1] == 1) && (m_axes[2] == 2) &&
(m_axes[3] == 3);
std::vector<int64_t> all_axes;
std::vector<int64_t> axes_without_batch_and_channels;
all_axes.push_back(0);
all_axes.push_back(1);
for (int64_t i = 2; i < static_cast<int64_t>(input_rank); ++i)
{
all_axes.push_back(i);
axes_without_batch_and_channels.push_back(i);
}

correct_axes = ((num_of_axes == input_rank) && (m_axes == all_axes)) ||
((num_of_axes == input_rank - 2) &&
(m_axes == axes_without_batch_and_channels));
}

assert(correct_axes);

const auto info = helper.get_info_for_linear_onnx_mode();
const auto info = helper.get_info_for_generic_linear_onnx();

const int64_t batch_size = info.batch_size;
const int64_t num_channels = info.num_channels;

const auto& input_index_multipliers = info.input_index_multipliers;
const auto& output_index_multipliers = info.output_index_multipliers;

const int64_t input_data_ptr_increment = info.input_data_ptr_increment;
const int64_t output_data_ptr_increment = info.output_data_ptr_increment;

const auto& input_spatial_shape = info.input_spatial_shape;

int64_t batch_size = info.batch_size;
int64_t num_channels = info.num_channels;
int64_t output_height = info.output_height;
int64_t output_width = info.output_width;
int64_t input_height = info.input_height;
int64_t input_width = info.input_width;
// This mode supports only interpolation with respect to spatial dimensions,
// not with respect to batch or channels. That is, we can have only two cases:
// num_of_axes == input_rank
// or
// num_of_axes == input_rank - 2.
// Hence, if num_of_axes != input_rank, then interpolated axes indices are
// [0, 1, ..., num_of_axes - 1]
// Otherwise, if num_of_axes == input_rank, interpolated axes indices are
// [2, 3, ..., num_of_axes - 1]
const int64_t axis_idx_offset = (input_rank == num_of_axes) ? 2 : 0;

const int64_t spatial_rank = info.spatial_rank;
const int64_t points_in_neighbor = 1 << spatial_rank;

const T* xdata = input_data;
T* ydata = out;
for (int64_t n = 0; n < batch_size; ++n)
{
for (int64_t c = 0; c < num_channels; ++c)
{
for (int64_t y = 0; y < output_height; ++y)
for (int64_t idx = 0; idx < output_data_ptr_increment; ++idx)
{
for (int64_t x = 0; x < output_width; ++x)
// 1. Get the current spatial coords vector.
std::vector<int64_t> output_coords(spatial_rank);
int64_t curr = idx;
for (int64_t j = 0; j < spatial_rank - 1; ++j)
{
T x11 = xdata[info.input_width_mul_y1[y] + info.in_x1[x]];
T x21 = xdata[info.input_width_mul_y1[y] + info.in_x2[x]];
T x12 = xdata[info.input_width_mul_y2[y] + info.in_x1[x]];
T x22 = xdata[info.input_width_mul_y2[y] + info.in_x2[x]];

ydata[output_width * y + x] =
static_cast<T>(info.dx2[x] * info.dy2[y] * x11 +
info.dx1[x] * info.dy2[y] * x21 +
info.dx2[x] * info.dy1[y] * x12 +
info.dx1[x] * info.dy1[y] * x22);
output_coords[j] = curr / output_index_multipliers[j];
curr %= output_index_multipliers[j];
}
output_coords[spatial_rank - 1] = curr;

// 2. Some preliminaries.
std::vector<int64_t> in1(spatial_rank);
std::vector<int64_t> in2(spatial_rank);
std::vector<float> d1(spatial_rank);
std::vector<float> d2(spatial_rank);

for (int64_t i = 0; i < spatial_rank; ++i)
{
float out_coord = static_cast<float>(output_coords[i]);

float in_coord =
helper.get_in_coord(out_coord, i + axis_idx_offset);
in_coord = std::max(
0.0f,
std::min(in_coord,
static_cast<float>(input_spatial_shape[i] - 1)));

const int64_t in_coord1 = std::min(static_cast<int64_t>(in_coord),
input_spatial_shape[i] - 1);
const int64_t in_coord2 =
std::min(in_coord1 + 1, input_spatial_shape[i] - 1);

in1[i] = in_coord1;
in2[i] = in_coord2;
d1[i] = std::fabs(in_coord - in_coord1);
d2[i] = std::fabs(in_coord - in_coord2);

if (in_coord1 == in_coord2)
{
d1[i] = 0.5f;
d2[i] = 0.5f;
}
}

// 3. Get values in all points of a neighborhood.
std::vector<T> values_of_input_points(points_in_neighbor);
for (int64_t i = 0; i < points_in_neighbor; ++i)
{
int64_t offset = 0;
for (int64_t j = 0; j < spatial_rank; ++j)
{
if (i & (1 << (spatial_rank - 1 - j)))
{
offset += in1[j] * input_index_multipliers[j];
}
else
{
offset += in2[j] * input_index_multipliers[j];
}
}
values_of_input_points[i] = xdata[offset];
}

// 4. Interpolation.
float sum = 0.0f;
for (int64_t i = 0; i < points_in_neighbor; ++i)
{
float coeff = 1.0f;
for (int64_t j = 0; j < spatial_rank; ++j)
{
coeff *= (i & (1 << (spatial_rank - 1 - j))) ? d1[j] : d2[j];
}
sum += coeff * values_of_input_points[points_in_neighbor - 1 - i];
}

// 6. Store result.
ydata[idx] = static_cast<T>(sum);
}
xdata += input_height * input_width;
ydata += output_width * output_height;

xdata += input_data_ptr_increment;
ydata += output_data_ptr_increment;
}
}
}
Expand Down
Loading