From 3a8c77e1884986a0817abcc9a312eac1dbd441cf Mon Sep 17 00:00:00 2001 From: "Efode, Irina" Date: Mon, 14 Sep 2020 13:35:25 +0300 Subject: [PATCH] PriorBox --- .../ngraph/runtime/reference/prior_box.hpp | 3 +- .../runtime/interpreter/evaluates_map.cpp | 36 +++++++++++++++++++ .../runtime/interpreter/opset_int_tbl.hpp | 1 + .../runtime/interpreter/reference/mod.hpp | 0 4 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 ngraph/test/runtime/interpreter/reference/mod.hpp diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/prior_box.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/prior_box.hpp index f9402f0ea5d5e5..6565c5dea98c37 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/prior_box.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/prior_box.hpp @@ -272,7 +272,8 @@ namespace ngraph } } - uint64_t channel_size = OH * OW; + + int64_t channel_size = OH * OW; if (variance.size() == 1) { for (uint64_t i = 0; i < channel_size; ++i) diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index de3ab3c8711d23..381d19cb3622ce 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include "ngraph/ops.hpp" #include "ngraph/runtime/reference/avg_pool.hpp" #include "ngraph/runtime/reference/batch_norm.hpp" @@ -455,6 +456,38 @@ namespace return true; } + template + bool evaluate(const shared_ptr& op, + const HostTensorVector& outputs, + const HostTensorVector& input) + { + using T = typename element_type_traits::value_type; + std::cout << "djdkldld" << std::endl; + std:: cout << input[0]->get_data_ptr()[0] << " " << input[0]->get_data_ptr()[1] << std::endl; + auto cons = dynamic_pointer_cast(op->input_value(0).get_node_shared_ptr()); + auto vec = cons->get_vector(); + runtime::reference::prior_box(input[0]->get_data_ptr(), + input[1]->get_data_ptr(), + outputs[0]->get_data_ptr(), + outputs[0]->get_shape(), + op->get_attrs()); + return true; + } + +// template +// bool evaluate(const shared_ptr& op, +// const HostTensorVector& outputs, +// const HostTensorVector& input) +// { +// using T = typename element_type_traits::value_type; +// runtime::reference::mod(input[0]->get_data_ptr(), +// input[1]->get_data_ptr(), +// outputs[0]->get_data_ptr(), +// outputs[0]->get_shape(), +// op->get_attrs()); +// return true; +// } + template bool evaluate(const shared_ptr& op, const HostTensorVector& outputs, @@ -779,6 +812,9 @@ namespace throw std::logic_error("Output node element types is not equal"); } } + if (is_type(node)) { + element_type = node->get_input_element_type(0); + } switch (element_type) { case element::Type_t::boolean: diff --git a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp index 9c2732d91e2390..e56d3b44c50b91 100644 --- a/ngraph/test/runtime/interpreter/opset_int_tbl.hpp +++ b/ngraph/test/runtime/interpreter/opset_int_tbl.hpp @@ -29,6 +29,7 @@ NGRAPH_OP(Gelu, op::v0) NGRAPH_OP(HardSigmoid, op::v0) NGRAPH_OP(LRN, ngraph::op::v0) NGRAPH_OP(MVN, ngraph::op::v0) +NGRAPH_OP(PriorBox, ngraph::op::v0) NGRAPH_OP(ReverseSequence, op::v0) NGRAPH_OP(RNNCell, op::v0) NGRAPH_OP(Selu, op::v0) diff --git a/ngraph/test/runtime/interpreter/reference/mod.hpp b/ngraph/test/runtime/interpreter/reference/mod.hpp new file mode 100644 index 00000000000000..e69de29bb2d1d6