diff --git a/ngraph/core/include/ngraph/op/lstm_cell.hpp b/ngraph/core/include/ngraph/op/lstm_cell.hpp index c830cae247fa7c..b05fe46c7feda1 100644 --- a/ngraph/core/include/ngraph/op/lstm_cell.hpp +++ b/ngraph/core/include/ngraph/op/lstm_cell.hpp @@ -401,7 +401,7 @@ namespace ngraph static constexpr std::size_t s_gates_count{4}; }; - } // v1 + } // v4 } // namespace op NGRAPH_API diff --git a/ngraph/test/backend/fused_op.in.cpp b/ngraph/test/backend/fused_op.in.cpp index 90ca33cf060523..f0ae7a9d4f8765 100644 --- a/ngraph/test/backend/fused_op.in.cpp +++ b/ngraph/test/backend/fused_op.in.cpp @@ -243,7 +243,7 @@ NGRAPH_TEST(${BACKEND_NAME}, depth_to_space_depth_first) 7.f, 23.f, 12.f, 28.f, 14.f, 30.f, 13.f, 29.f, 15.f, 31.f}); test_case.run(); } -// TODO: enable normilizeL2 tests after normilizeL2 reference implementation +// TODO: enable normalizeL2 tests after normalizeL2 reference implementation NGRAPH_TEST(${BACKEND_NAME}, DISABLED_normalize_across_chw_4d) { Shape data_shape{1, 2, 3, 4}; @@ -1163,7 +1163,7 @@ NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_split_channels) test_case.run(); } -// TODO: enable (RNN|LSTM|GRU)Cell tests after grn operation reference implementation +//TODO: Issue: 37514 NGRAPH_TEST(${BACKEND_NAME}, DISABLED_grn_4d) { const Shape data_shape{1, 2, 3, 4}; @@ -1334,7 +1334,7 @@ NGRAPH_TEST(${BACKEND_NAME}, squeeze_dynamic) EXPECT_THROW(make_shared(data_param, axes_param), CheckFailure); } -// TODO: enable squad diff tests after squared diff op reference implementation +// TODO: Issue: 37534 NGRAPH_TEST(${BACKEND_NAME}, DISABLED_squared_difference) { const auto x1 = make_shared(element::f32, Shape{2, 2}); @@ -1403,7 +1403,7 @@ NGRAPH_TEST(${BACKEND_NAME}, split_var_len_parts) test_case.run(); } -NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__zero_bias_peepholes) +NGRAPH_TEST(${BACKEND_NAME}, lstm_cell__zero_bias_peepholes) { const size_t batch_size = 2; const size_t input_size = 3; @@ -1478,7 +1478,8 @@ NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__zero_bias_peepholes) ct_test_case.run(); } -NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes) +// Peerholes unsupported in Ngraph +NGRAPH_TEST(${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes) { const size_t batch_size = 2; const size_t input_size = 3; @@ -1565,7 +1566,7 @@ NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes) ct_test_case.run(); } -NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes_clip_input_forget) +NGRAPH_TEST(${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes_clip_input_forget) { const size_t batch_size = 2; const size_t input_size = 3; @@ -1663,7 +1664,8 @@ NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes_clip_input_forg ct_test_case.run(); } -NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__activaction_functions) +// Hard Sigmoid is unsupprted +NGRAPH_TEST(${BACKEND_NAME}, DISABLED_lstm_cell__activaction_functions) { const size_t batch_size = 2; const size_t input_size = 3; @@ -1764,6 +1766,7 @@ NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__activaction_functions) ct_test_case.run(); } +// TODO: Issue: 37511 NGRAPH_TEST(${BACKEND_NAME}, DISABLED_fake_quantize) { const Shape data_shape{1, 2, 3, 4}; @@ -1939,7 +1942,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_fake_quantize_pdpd) test_case.run(); } -NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__no_bias) +NGRAPH_TEST(${BACKEND_NAME}, rnn_cell__no_bias) { const size_t batch_size = 2; const size_t input_size = 3; @@ -1988,7 +1991,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__no_bias) test_case.run(); } -NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__bias_clip) +NGRAPH_TEST(${BACKEND_NAME}, rnn_cell__bias_clip) { const size_t batch_size = 2; const size_t input_size = 3; @@ -2050,7 +2053,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__bias_clip) test_case.run(); } -NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__activation_function) +NGRAPH_TEST(${BACKEND_NAME}, rnn_cell__activation_function) { const size_t batch_size = 2; const size_t input_size = 3; @@ -2112,7 +2115,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__activation_function) test_case.run(); } -NGRAPH_TEST(${BACKEND_NAME}, DISABLED_gru_cell_bias_clip) +NGRAPH_TEST(${BACKEND_NAME}, gru_cell_bias_clip) { const size_t batch_size = 2; const size_t input_size = 3; @@ -2185,7 +2188,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_gru_cell_bias_clip) test_case.run(); } -NGRAPH_TEST(${BACKEND_NAME}, DISABLED_gru_cell_linear_before_reset) +NGRAPH_TEST(${BACKEND_NAME}, gru_cell_linear_before_reset) { const size_t batch_size = 2; const size_t input_size = 3; @@ -2257,7 +2260,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_gru_cell_linear_before_reset) test_case.run(); } -NGRAPH_TEST(${BACKEND_NAME}, DISABLED_gru_cell_activation_function) +NGRAPH_TEST(${BACKEND_NAME}, gru_cell_activation_function) { const size_t batch_size = 2; const size_t input_size = 3; diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index 783fe8c99877c8..4d7666b32976ee 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -1125,6 +1125,12 @@ IE_CPU.onnx_model_rnn_fwd_bias_initial_h IE_CPU.onnx_model_rnn_bidirectional IE_CPU.onnx_model_rnn_bidirectional_const +# RNN/LSTM Cells should be converted to IE representation +IE_CPU.lstm_cell__zero_bias_peepholes +IE_CPU.rnn_cell__no_bias +IE_CPU.rnn_cell__bias_clip +IE_CPU.rnn_cell__activation_function + #------------------------------------------------------------------------------- # # Inference Engine GPU plugin excludes diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index 2a94152c6e5520..a2c428bcadba97 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -37,6 +37,9 @@ #include #include #include +#include +#include +#include #include "ngraph/runtime/reference/detection_output.hpp" #include "ngraph/runtime/reference/scatter_nd_update.hpp" #include "reference/gelu.hpp" @@ -171,18 +174,13 @@ namespace { inputs[0]->get_shape(), \ op->is_exclusive(), \ op->is_reverse()); \ + break; switch (inputs[1]->get_element_type()) { case element::Type_t::i64: { - try { - REF_CALL(element::Type_t::i64); - } catch (...) { - REF_CALL(element::Type_t::i32); - }; - break; + REF_CALL(element::Type_t::i64); } default: -// std::cout << inputs[1]->get_element_type() << std::endl; REF_CALL(element::Type_t::i32); } #undef REF_CALL @@ -208,9 +206,9 @@ namespace { switch (inputs[1]->get_element_type()) { case element::Type_t::i32: - REF_CALL(element::Type_t::i32); + REF_CALL(element::Type_t::i32); case element::Type_t::i64: - REF_CALL(element::Type_t::i64); + REF_CALL(element::Type_t::i64); default: return false; } @@ -236,9 +234,9 @@ namespace { switch (inputs[1]->get_element_type()) { case element::Type_t::i32: - REF_CALL(element::Type_t::i32); + REF_CALL(element::Type_t::i32); case element::Type_t::i64: - REF_CALL(element::Type_t::i64); + REF_CALL(element::Type_t::i64); default: return false; } @@ -262,9 +260,9 @@ namespace { switch (inputs[1]->get_element_type()) { case element::Type_t::i32: - REF_CALL(element::Type_t::i32); + REF_CALL(element::Type_t::i32); case element::Type_t::i64: - REF_CALL(element::Type_t::i64); + REF_CALL(element::Type_t::i64); default: return false; } @@ -497,6 +495,7 @@ namespace { op->get_batch_axis(),\ op->get_origin_sequence_axis(),\ input[1]->get_data_ptr());\ + break; switch (input[1]->get_element_type()) { case element::Type_t::boolean: @@ -537,8 +536,8 @@ namespace { runtime::reference::convert::value_type>(\ input[0]->get_data_ptr(),\ outputs[0]->get_data_ptr(),\ - shape_size(input[0]->get_shape())); - + shape_size(input[0]->get_shape()));\ + break; switch (input[0]->get_element_type()) { case element::Type_t::boolean: @@ -587,11 +586,70 @@ namespace { template bool evaluate(const shared_ptr &op, const HostTensorVector &outputs, const HostTensorVector &inputs) { -// runtime::reference::rnn_cell(inputs[0]->get_data_ptr(), -// outputs[0]->get_data_ptr(), -// inputs[0]->get_shape(), -// outputs[0]->get_shape(), -// op->get_reduction_axes()); + + using T = typename element_type_traits::value_type; + runtime::reference::rnn_cell(inputs[0]->get_data_ptr(), + inputs[0]->get_shape(), + inputs[1]->get_data_ptr(), + inputs[1]->get_shape(), + inputs[2]->get_data_ptr(), + inputs[2]->get_shape(), + inputs[3]->get_data_ptr(), + inputs[3]->get_shape(), + inputs[4]->get_data_ptr(), + inputs[4]->get_shape(), + outputs[0]->get_data_ptr(), + op->get_activations().front(), + op->get_clip()); + return true; + } + + template + bool evaluate(const shared_ptr &op, const HostTensorVector &outputs, + const HostTensorVector &inputs) { + + using T = typename element_type_traits::value_type; + runtime::reference::lstm_cell(inputs[0]->get_data_ptr(), + inputs[0]->get_shape(), + inputs[1]->get_data_ptr(), + inputs[1]->get_shape(), + inputs[2]->get_data_ptr(), + inputs[2]->get_shape(), + inputs[3]->get_data_ptr(), + inputs[3]->get_shape(), + inputs[4]->get_data_ptr(), + inputs[4]->get_shape(), + inputs[5]->get_data_ptr(), + inputs[5]->get_shape(), + outputs[0]->get_data_ptr(), + outputs[1]->get_data_ptr(), + op->get_activations()[0], + op->get_activations()[1], + op->get_activations()[2], + op->get_clip()); + return true; + } + + template + bool evaluate(const shared_ptr &op, const HostTensorVector &outputs, + const HostTensorVector &inputs) { + + using T = typename element_type_traits::value_type; + runtime::reference::gru_cell(inputs[0]->get_data_ptr(), + inputs[0]->get_shape(), + inputs[1]->get_data_ptr(), + inputs[1]->get_shape(), + inputs[2]->get_data_ptr(), + inputs[2]->get_shape(), + inputs[3]->get_data_ptr(), + inputs[3]->get_shape(), + inputs[4]->get_data_ptr(), + inputs[4]->get_shape(), + outputs[0]->get_data_ptr(), + op->get_activations()[0], + op->get_activations()[1], + op->get_clip(), + op->get_linear_before_reset()); return true; } @@ -611,12 +669,15 @@ namespace { return true; } - - - template bool evaluate_node(std::shared_ptr node, const HostTensorVector &outputs, const HostTensorVector &inputs) { - switch (node->get_element_type()) { + auto element_type = node->get_output_element_type(0); + for (size_t i = 1; i < node->outputs().size(); i++) { + if (element_type != node->get_output_element_type(i)) { + throw std::logic_error("Output node element types is not equal"); + } + } + switch (element_type) { case element::Type_t::boolean: return evaluate(as_type_ptr(node), outputs, inputs);; // case element::Type_t::bf16: diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index 062d380ae3806b..41a0b4b68e0165 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -52,8 +52,10 @@ NGRAPH_OP(EmbeddingBagOffsetsSum, ngraph::op::v3) NGRAPH_OP(EmbeddingBagPackedSum, ngraph::op::v3) NGRAPH_OP(ExtractImagePatches, op::v3) NGRAPH_OP(EmbeddingSegmentsSum, ngraph::op::v3) +NGRAPH_OP(GRUCell, ngraph::op::v3) NGRAPH_OP(NonZero, op::v3) NGRAPH_OP(ScatterNDUpdate, op::v3) NGRAPH_OP(ShapeOf, op::v3) NGRAPH_OP(CTCLoss, op::v4) +NGRAPH_OP(LSTMCell, op::v4)