Skip to content

Commit

Permalink
Enable cells refs in evaluate map
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Sep 9, 2020
1 parent f80c325 commit 53329df
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 39 deletions.
2 changes: 1 addition & 1 deletion ngraph/core/include/ngraph/op/lstm_cell.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ namespace ngraph

static constexpr std::size_t s_gates_count{4};
};
} // v1
} // v4
} // namespace op

NGRAPH_API
Expand Down
35 changes: 21 additions & 14 deletions ngraph/test/backend/fused_op.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ NGRAPH_TEST(${BACKEND_NAME}, depth_to_space_depth_first)
7.f, 23.f, 12.f, 28.f, 14.f, 30.f, 13.f, 29.f, 15.f, 31.f});
test_case.run();
}
// TODO: enable normilizeL2 tests after normilizeL2 reference implementation
// TODO: enable normalizeL2 tests after normalizeL2 reference implementation
NGRAPH_TEST(${BACKEND_NAME}, DISABLED_normalize_across_chw_4d)
{
Shape data_shape{1, 2, 3, 4};
Expand Down Expand Up @@ -1163,7 +1163,7 @@ NGRAPH_TEST(${BACKEND_NAME}, mvn_mean_variance_normalization_split_channels)
test_case.run();
}

// TODO: enable (RNN|LSTM|GRU)Cell tests after grn operation reference implementation
//TODO: Issue: 37514
NGRAPH_TEST(${BACKEND_NAME}, DISABLED_grn_4d)
{
const Shape data_shape{1, 2, 3, 4};
Expand Down Expand Up @@ -1334,7 +1334,7 @@ NGRAPH_TEST(${BACKEND_NAME}, squeeze_dynamic)
EXPECT_THROW(make_shared<op::Squeeze>(data_param, axes_param), CheckFailure);
}

// TODO: enable squad diff tests after squared diff op reference implementation
// TODO: Issue: 37534
NGRAPH_TEST(${BACKEND_NAME}, DISABLED_squared_difference)
{
const auto x1 = make_shared<op::Parameter>(element::f32, Shape{2, 2});
Expand Down Expand Up @@ -1403,7 +1403,7 @@ NGRAPH_TEST(${BACKEND_NAME}, split_var_len_parts)
test_case.run();
}

NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__zero_bias_peepholes)
NGRAPH_TEST(${BACKEND_NAME}, lstm_cell__zero_bias_peepholes)
{
const size_t batch_size = 2;
const size_t input_size = 3;
Expand Down Expand Up @@ -1478,7 +1478,8 @@ NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__zero_bias_peepholes)
ct_test_case.run();
}

NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes)
// Peerholes unsupported in Ngraph
NGRAPH_TEST(${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes)
{
const size_t batch_size = 2;
const size_t input_size = 3;
Expand Down Expand Up @@ -1565,7 +1566,7 @@ NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes)
ct_test_case.run();
}

NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes_clip_input_forget)
NGRAPH_TEST(${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes_clip_input_forget)
{
const size_t batch_size = 2;
const size_t input_size = 3;
Expand Down Expand Up @@ -1663,7 +1664,8 @@ NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__bias_peepholes_clip_input_forg
ct_test_case.run();
}

NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__activaction_functions)
// Hard Sigmoid is unsupprted
NGRAPH_TEST(${BACKEND_NAME}, DISABLED_lstm_cell__activaction_functions)
{
const size_t batch_size = 2;
const size_t input_size = 3;
Expand Down Expand Up @@ -1764,6 +1766,7 @@ NGRAPH_TEST($${BACKEND_NAME}, DISABLED_lstm_cell__activaction_functions)
ct_test_case.run();
}

// TODO: Issue: 37511
NGRAPH_TEST(${BACKEND_NAME}, DISABLED_fake_quantize)
{
const Shape data_shape{1, 2, 3, 4};
Expand Down Expand Up @@ -1890,7 +1893,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_fake_quantize_with_clip_across_channels)
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, DISABLED_fake_quantize_pdpd)
NGRAPH_TEST(${BACKEND_NAME}, fake_quantize_pdpd)
{
Shape data_shape{1, 2, 5, 5};
size_t levels = 5;
Expand Down Expand Up @@ -1939,7 +1942,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_fake_quantize_pdpd)
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__no_bias)
NGRAPH_TEST(${BACKEND_NAME}, rnn_cell__no_bias)
{
const size_t batch_size = 2;
const size_t input_size = 3;
Expand All @@ -1953,6 +1956,10 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__no_bias)
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
auto function = make_shared<Function>(rnn_cell, ParameterVector{X, H_t, W, R});

// ngraph::pass::Manager manager;
// manager.register_pass<ngraph::pass::ConvertRNNCellMatcher>();
// manager.run_passes(f);

auto test_case = test::TestCase<TestEngine>(function);
// X
test_case.add_input<float>(
Expand Down Expand Up @@ -1988,7 +1995,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__no_bias)
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__bias_clip)
NGRAPH_TEST(${BACKEND_NAME}, rnn_cell__bias_clip)
{
const size_t batch_size = 2;
const size_t input_size = 3;
Expand Down Expand Up @@ -2050,7 +2057,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__bias_clip)
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__activation_function)
NGRAPH_TEST(${BACKEND_NAME}, rnn_cell__activation_function)
{
const size_t batch_size = 2;
const size_t input_size = 3;
Expand Down Expand Up @@ -2112,7 +2119,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_rnn_cell__activation_function)
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, DISABLED_gru_cell_bias_clip)
NGRAPH_TEST(${BACKEND_NAME}, gru_cell_bias_clip)
{
const size_t batch_size = 2;
const size_t input_size = 3;
Expand Down Expand Up @@ -2185,7 +2192,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_gru_cell_bias_clip)
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, DISABLED_gru_cell_linear_before_reset)
NGRAPH_TEST(${BACKEND_NAME}, gru_cell_linear_before_reset)
{
const size_t batch_size = 2;
const size_t input_size = 3;
Expand Down Expand Up @@ -2257,7 +2264,7 @@ NGRAPH_TEST(${BACKEND_NAME}, DISABLED_gru_cell_linear_before_reset)
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, DISABLED_gru_cell_activation_function)
NGRAPH_TEST(${BACKEND_NAME}, gru_cell_activation_function)
{
const size_t batch_size = 2;
const size_t input_size = 3;
Expand Down
6 changes: 6 additions & 0 deletions ngraph/test/runtime/ie/unit_test.manifest
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,12 @@ IE_CPU.onnx_model_rnn_fwd_bias_initial_h
IE_CPU.onnx_model_rnn_bidirectional
IE_CPU.onnx_model_rnn_bidirectional_const

# RNN/LSTM Cells should be converted to IE representation
IE_CPU.lstm_cell__zero_bias_peepholes
IE_CPU.rnn_cell__no_bias
IE_CPU.rnn_cell__bias_clip
IE_CPU.rnn_cell__activation_function

#-------------------------------------------------------------------------------
#
# Inference Engine GPU plugin excludes
Expand Down
109 changes: 85 additions & 24 deletions ngraph/test/runtime/interpreter/evaluates_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
#include <ngraph/runtime/reference/dot.hpp>
#include <ngraph/runtime/reference/replace_slice.hpp>
#include <ngraph/runtime/reference/gather_nd.hpp>
#include <ngraph/runtime/reference/rnn_cell.hpp>
#include <ngraph/runtime/reference/lstm_cell.hpp>
#include <ngraph/runtime/reference/gru_cell.hpp>
#include "ngraph/runtime/reference/detection_output.hpp"
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
#include "reference/gelu.hpp"
Expand Down Expand Up @@ -171,18 +174,13 @@ namespace {
inputs[0]->get_shape(), \
op->is_exclusive(), \
op->is_reverse()); \
break;

switch (inputs[1]->get_element_type()) {
case element::Type_t::i64: {
try {
REF_CALL(element::Type_t::i64);
} catch (...) {
REF_CALL(element::Type_t::i32);
};
break;
REF_CALL(element::Type_t::i64);
}
default:
// std::cout << inputs[1]->get_element_type() << std::endl;
REF_CALL(element::Type_t::i32);
}
#undef REF_CALL
Expand All @@ -208,9 +206,9 @@ namespace {

switch (inputs[1]->get_element_type()) {
case element::Type_t::i32:
REF_CALL(element::Type_t::i32);
REF_CALL(element::Type_t::i32);
case element::Type_t::i64:
REF_CALL(element::Type_t::i64);
REF_CALL(element::Type_t::i64);
default:
return false;
}
Expand All @@ -236,9 +234,9 @@ namespace {

switch (inputs[1]->get_element_type()) {
case element::Type_t::i32:
REF_CALL(element::Type_t::i32);
REF_CALL(element::Type_t::i32);
case element::Type_t::i64:
REF_CALL(element::Type_t::i64);
REF_CALL(element::Type_t::i64);
default:
return false;
}
Expand All @@ -262,9 +260,9 @@ namespace {

switch (inputs[1]->get_element_type()) {
case element::Type_t::i32:
REF_CALL(element::Type_t::i32);
REF_CALL(element::Type_t::i32);
case element::Type_t::i64:
REF_CALL(element::Type_t::i64);
REF_CALL(element::Type_t::i64);
default:
return false;
}
Expand Down Expand Up @@ -497,6 +495,7 @@ namespace {
op->get_batch_axis(),\
op->get_origin_sequence_axis(),\
input[1]->get_data_ptr<U>());\
break;

switch (input[1]->get_element_type()) {
case element::Type_t::boolean:
Expand Down Expand Up @@ -537,8 +536,8 @@ namespace {
runtime::reference::convert<T, typename element_type_traits<U>::value_type>(\
input[0]->get_data_ptr<T>(),\
outputs[0]->get_data_ptr<U>(),\
shape_size(input[0]->get_shape()));

shape_size(input[0]->get_shape()));\
break;

switch (input[0]->get_element_type()) {
case element::Type_t::boolean:
Expand Down Expand Up @@ -587,11 +586,70 @@ namespace {
template<element::Type_t ET>
bool evaluate(const shared_ptr<op::v0::RNNCell> &op, const HostTensorVector &outputs,
const HostTensorVector &inputs) {
// runtime::reference::rnn_cell(inputs[0]->get_data_ptr<char>(),
// outputs[0]->get_data_ptr<char>(),
// inputs[0]->get_shape(),
// outputs[0]->get_shape(),
// op->get_reduction_axes());

using T = typename element_type_traits<ET>::value_type;
runtime::reference::rnn_cell<T>(inputs[0]->get_data_ptr<ET>(),
inputs[0]->get_shape(),
inputs[1]->get_data_ptr<ET>(),
inputs[1]->get_shape(),
inputs[2]->get_data_ptr<ET>(),
inputs[2]->get_shape(),
inputs[3]->get_data_ptr<ET>(),
inputs[3]->get_shape(),
inputs[4]->get_data_ptr<ET>(),
inputs[4]->get_shape(),
outputs[0]->get_data_ptr<ET>(),
op->get_activations().front(),
op->get_clip());
return true;
}

template<element::Type_t ET>
bool evaluate(const shared_ptr<op::v4::LSTMCell> &op, const HostTensorVector &outputs,
const HostTensorVector &inputs) {

using T = typename element_type_traits<ET>::value_type;
runtime::reference::lstm_cell<T>(inputs[0]->get_data_ptr<ET>(),
inputs[0]->get_shape(),
inputs[1]->get_data_ptr<ET>(),
inputs[1]->get_shape(),
inputs[2]->get_data_ptr<ET>(),
inputs[2]->get_shape(),
inputs[3]->get_data_ptr<ET>(),
inputs[3]->get_shape(),
inputs[4]->get_data_ptr<ET>(),
inputs[4]->get_shape(),
inputs[5]->get_data_ptr<ET>(),
inputs[5]->get_shape(),
outputs[0]->get_data_ptr<ET>(),
outputs[1]->get_data_ptr<ET>(),
op->get_activations()[0],
op->get_activations()[1],
op->get_activations()[2],
op->get_clip());
return true;
}

template<element::Type_t ET>
bool evaluate(const shared_ptr<op::v3::GRUCell> &op, const HostTensorVector &outputs,
const HostTensorVector &inputs) {

using T = typename element_type_traits<ET>::value_type;
runtime::reference::gru_cell<T>(inputs[0]->get_data_ptr<ET>(),
inputs[0]->get_shape(),
inputs[1]->get_data_ptr<ET>(),
inputs[1]->get_shape(),
inputs[2]->get_data_ptr<ET>(),
inputs[2]->get_shape(),
inputs[3]->get_data_ptr<ET>(),
inputs[3]->get_shape(),
inputs[4]->get_data_ptr<ET>(),
inputs[4]->get_shape(),
outputs[0]->get_data_ptr<ET>(),
op->get_activations()[0],
op->get_activations()[1],
op->get_clip(),
op->get_linear_before_reset());
return true;
}

Expand All @@ -611,12 +669,15 @@ namespace {
return true;
}




template<typename T>
bool evaluate_node(std::shared_ptr<Node> node, const HostTensorVector &outputs, const HostTensorVector &inputs) {
switch (node->get_element_type()) {
auto element_type = node->get_output_element_type(0);
for (size_t i = 1; i < node->outputs().size(); i++) {
if (element_type != node->get_output_element_type(i)) {
throw std::logic_error("Output node element types is not equal");
}
}
switch (element_type) {
case element::Type_t::boolean:
return evaluate<element::Type_t::boolean>(as_type_ptr<T>(node), outputs, inputs);;
// case element::Type_t::bf16:
Expand Down
2 changes: 2 additions & 0 deletions ngraph/test/runtime/interpreter/opset_int_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ NGRAPH_OP(EmbeddingBagOffsetsSum, ngraph::op::v3)
NGRAPH_OP(EmbeddingBagPackedSum, ngraph::op::v3)
NGRAPH_OP(ExtractImagePatches, op::v3)
NGRAPH_OP(EmbeddingSegmentsSum, ngraph::op::v3)
NGRAPH_OP(GRUCell, ngraph::op::v3)
NGRAPH_OP(NonZero, op::v3)
NGRAPH_OP(ScatterNDUpdate, op::v3)
NGRAPH_OP(ShapeOf, op::v3)

NGRAPH_OP(CTCLoss, op::v4)
NGRAPH_OP(LSTMCell, op::v4)

0 comments on commit 53329df

Please sign in to comment.