Skip to content

Commit

Permalink
PriorBox
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Sep 15, 2020
1 parent 0ea28db commit 3a8c77e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ namespace ngraph
}
}

uint64_t channel_size = OH * OW;

int64_t channel_size = OH * OW;
if (variance.size() == 1)
{
for (uint64_t i = 0; i < channel_size; ++i)
Expand Down
36 changes: 36 additions & 0 deletions ngraph/test/runtime/interpreter/evaluates_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <ngraph/runtime/reference/reverse_sequence.hpp>
#include <ngraph/runtime/reference/rnn_cell.hpp>
#include <ngraph/runtime/reference/select.hpp>
#include <ngraph/runtime/reference/prior_box.hpp>
#include "ngraph/ops.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
Expand Down Expand Up @@ -455,6 +456,38 @@ namespace
return true;
}

template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v0::PriorBox>& op,
const HostTensorVector& outputs,
const HostTensorVector& input)
{
using T = typename element_type_traits<ET>::value_type;
std::cout << "djdkldld" << std::endl;
std:: cout << input[0]->get_data_ptr<T>()[0] << " " << input[0]->get_data_ptr<T>()[1] << std::endl;
auto cons = dynamic_pointer_cast<op::v0::Constant>(op->input_value(0).get_node_shared_ptr());
auto vec = cons->get_vector<int64_t>();
runtime::reference::prior_box<T>(input[0]->get_data_ptr<T>(),
input[1]->get_data_ptr<T>(),
outputs[0]->get_data_ptr<float>(),
outputs[0]->get_shape(),
op->get_attrs());
return true;
}

// template <element::Type_t ET>
// bool evaluate(const shared_ptr<op::v1::Mod>& op,
// const HostTensorVector& outputs,
// const HostTensorVector& input)
// {
// using T = typename element_type_traits<ET>::value_type;
// runtime::reference::mod<T>(input[0]->get_data_ptr<T>(),
// input[1]->get_data_ptr<T>(),
// outputs[0]->get_data_ptr<float>(),
// outputs[0]->get_shape(),
// op->get_attrs());
// return true;
// }

template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v0::Selu>& op,
const HostTensorVector& outputs,
Expand Down Expand Up @@ -779,6 +812,9 @@ namespace
throw std::logic_error("Output node element types is not equal");
}
}
if (is_type<op::PriorBox>(node)) {
element_type = node->get_input_element_type(0);
}
switch (element_type)
{
case element::Type_t::boolean:
Expand Down
1 change: 1 addition & 0 deletions ngraph/test/runtime/interpreter/opset_int_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ NGRAPH_OP(Gelu, op::v0)
NGRAPH_OP(HardSigmoid, op::v0)
NGRAPH_OP(LRN, ngraph::op::v0)
NGRAPH_OP(MVN, ngraph::op::v0)
NGRAPH_OP(PriorBox, ngraph::op::v0)
NGRAPH_OP(ReverseSequence, op::v0)
NGRAPH_OP(RNNCell, op::v0)
NGRAPH_OP(Selu, op::v0)
Expand Down
Empty file.

0 comments on commit 3a8c77e

Please sign in to comment.