diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/autobroadcast_binop.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/autobroadcast_binop.hpp index 70410784226478..345555b6a8426b 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/autobroadcast_binop.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/autobroadcast_binop.hpp @@ -388,19 +388,23 @@ namespace ngraph Shape arg1_padded_shape = arg1_shape; Shape arg2_padded_shape = arg2_shape; - while (arg1_padded_shape.size() < arg2_padded_shape.size()) + size_t max_shape_size = std::max({arg0_padded_shape.size(), + arg1_padded_shape.size(), + arg2_padded_shape.size()}); + + while (arg0_padded_shape.size() < max_shape_size) { - arg1_padded_shape.insert(arg1_padded_shape.begin(), 1); + arg0_padded_shape.insert(arg0_padded_shape.begin(), 1); } - while (arg2_padded_shape.size() < arg1_padded_shape.size()) + while (arg1_padded_shape.size() < max_shape_size) { - arg2_padded_shape.insert(arg2_padded_shape.begin(), 1); + arg1_padded_shape.insert(arg1_padded_shape.begin(), 1); } - while (arg0_padded_shape.size() < arg1_padded_shape.size()) + while (arg2_padded_shape.size() < max_shape_size) { - arg0_padded_shape.insert(arg0_padded_shape.begin(), 1); + arg2_padded_shape.insert(arg2_padded_shape.begin(), 1); } Shape arg0_squeezed_shape; @@ -411,7 +415,7 @@ namespace ngraph AxisSet arg2_squeezed_axes; Shape output_shape; - for (size_t i = 0; i < arg1_padded_shape.size(); i++) + for (size_t i = 0; i < max_shape_size; i++) { if (arg1_padded_shape[i] == 1) { @@ -440,9 +444,9 @@ namespace ngraph arg0_squeezed_shape.push_back(arg0_padded_shape[i]); } - output_shape.push_back(arg1_padded_shape[i] == 1 - ? arg2_padded_shape[i] - : arg1_padded_shape[i]); + output_shape.push_back(std::max({arg0_padded_shape[i], + arg2_padded_shape[i], + arg1_padded_shape[i]})); } CoordinateTransform arg0_transform(arg0_squeezed_shape); diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/extract_image_patches.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/extract_image_patches.hpp index 4e16e1c0f75ebf..b78780a3a1b5f7 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/extract_image_patches.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/extract_image_patches.hpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // +#include #include "ngraph/shape_util.hpp" namespace ngraph @@ -10,12 +11,12 @@ namespace ngraph { namespace reference { - template - void extractImagePatches(const op::ExtractImagePatches* extImgPatches, - const T* input, - T* out, - const Shape& inShape, - const Shape& outShape) + template + void extract_image_patches(const std::shared_ptr extImgPatches, + const T* input, + T* out, + const Shape& inShape, + const Shape& outShape) { const size_t dimsSize = inShape.size(); const size_t BATCH = 0, CHANNEL = 1, HIGHT = 0, WIDTH = 1; diff --git a/ngraph/core/src/op/add.cpp b/ngraph/core/src/op/add.cpp index 3bdeea67b8137c..a41cafbb79d8cb 100644 --- a/ngraph/core/src/op/add.cpp +++ b/ngraph/core/src/op/add.cpp @@ -19,8 +19,6 @@ #include "ngraph/runtime/host_tensor.hpp" #include "ngraph/runtime/reference/add.hpp" -NGRAPH_SUPPRESS_DEPRECATED_START - using namespace std; using namespace ngraph; diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index 536b095c2dedf4..4b5b307ee87de6 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include #include #include +#include #include "ngraph/ops.hpp" #include "ngraph/runtime/reference/avg_pool.hpp" #include "ngraph/runtime/reference/batch_norm.hpp" @@ -399,6 +401,7 @@ namespace const HostTensorVector& input) { using T = typename element_type_traits::value_type; + runtime::reference::select(input[0]->get_data_ptr(), input[1]->get_data_ptr(), input[2]->get_data_ptr(), @@ -591,7 +594,7 @@ namespace outputs[0]->get_data_ptr(), \ input[0]->get_shape(), \ op->get_batch_axis(), \ - op->get_origin_sequence_axis(), \ + op->get_sequence_axis(), \ input[1]->get_data_ptr()); \ break; @@ -615,6 +618,20 @@ namespace return true; } + template + bool evaluate(const shared_ptr& op, + const HostTensorVector& outputs, + const HostTensorVector& input) + { + using T = typename element_type_traits::value_type; + runtime::reference::extract_image_patches(op, + input[0]->get_data_ptr(), + outputs[0]->get_data_ptr(), + input[0]->get_shape(), + outputs[0]->get_shape()); + return true; + } + template bool evaluate(const shared_ptr& op, const HostTensorVector& outputs, @@ -788,6 +805,91 @@ namespace return true; } + template + bool evaluate(const shared_ptr& op, + const HostTensorVector& outputs, + const HostTensorVector& inputs) + { + using T = typename element_type_traits::value_type; + runtime::reference::rnn_sequence(inputs[0]->get_data_ptr(), + inputs[0]->get_shape(), + inputs[1]->get_data_ptr(), + inputs[1]->get_shape(), + inputs[2]->get_data_ptr(), + inputs[2]->get_shape(), + inputs[3]->get_data_ptr(), + inputs[3]->get_shape(), + inputs[4]->get_data_ptr(), + inputs[4]->get_shape(), + inputs[5]->get_data_ptr(), + inputs[5]->get_shape(), + outputs[0]->get_data_ptr(), + outputs[1]->get_data_ptr(), + op->get_activations()[0], + op->get_clip(), + op->get_direction()); + return true; + } + + template + bool evaluate(const shared_ptr& op, + const HostTensorVector& outputs, + const HostTensorVector& inputs) + { + using T = typename element_type_traits::value_type; + runtime::reference::lstm_sequence(inputs[0]->get_data_ptr(), + inputs[0]->get_shape(), + inputs[1]->get_data_ptr(), + inputs[1]->get_shape(), + inputs[2]->get_data_ptr(), + inputs[2]->get_shape(), + inputs[3]->get_data_ptr(), + inputs[3]->get_shape(), + inputs[4]->get_data_ptr(), + inputs[4]->get_shape(), + inputs[5]->get_data_ptr(), + inputs[5]->get_shape(), + inputs[6]->get_data_ptr(), + inputs[6]->get_shape(), + outputs[0]->get_data_ptr(), + outputs[1]->get_data_ptr(), + outputs[2]->get_data_ptr(), + op->get_activations()[0], + op->get_activations()[1], + op->get_activations()[2], + op->get_clip(), + op->get_direction()); + return true; + } + + template + bool evaluate(const shared_ptr& op, + const HostTensorVector& outputs, + const HostTensorVector& inputs) + { + using T = typename element_type_traits::value_type; + runtime::reference::gru_sequence(inputs[0]->get_data_ptr(), + inputs[0]->get_shape(), + inputs[1]->get_data_ptr(), + inputs[1]->get_shape(), + inputs[2]->get_data_ptr(), + inputs[2]->get_shape(), + inputs[3]->get_data_ptr(), + inputs[3]->get_shape(), + inputs[4]->get_data_ptr(), + inputs[4]->get_shape(), + inputs[5]->get_data_ptr(), + inputs[5]->get_shape(), + outputs[0]->get_data_ptr(), + outputs[1]->get_data_ptr(), + op->get_activations()[0], + op->get_activations()[1], + op->get_clip(), + op->get_direction(), + op->get_linear_before_reset()); + return true; + } + template bool evaluate(const shared_ptr& op, const HostTensorVector& outputs, diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index 885ca53298bc61..8d4748caa4b3d1 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -62,3 +62,7 @@ NGRAPH_OP(ShapeOf, op::v3) NGRAPH_OP(CTCLoss, op::v4) NGRAPH_OP(LSTMCell, op::v4) + +NGRAPH_OP(GRUSequence, op::v5) +NGRAPH_OP(LSTMSequence, op::v5) +NGRAPH_OP(RNNSequence, op::v5)