Skip to content

Commit

Permalink
ReverseSeq (#9)
Browse files Browse the repository at this point in the history
* ReverseSeq

* Select

* ExtractImagePatches, Seqence

* Fix Code Style

* remove extra

* Remove etra line@
  • Loading branch information
iefode authored Sep 22, 2020
1 parent f0a5399 commit a5c32c1
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,19 +388,23 @@ namespace ngraph
Shape arg1_padded_shape = arg1_shape;
Shape arg2_padded_shape = arg2_shape;

while (arg1_padded_shape.size() < arg2_padded_shape.size())
size_t max_shape_size = std::max({arg0_padded_shape.size(),
arg1_padded_shape.size(),
arg2_padded_shape.size()});

while (arg0_padded_shape.size() < max_shape_size)
{
arg1_padded_shape.insert(arg1_padded_shape.begin(), 1);
arg0_padded_shape.insert(arg0_padded_shape.begin(), 1);
}

while (arg2_padded_shape.size() < arg1_padded_shape.size())
while (arg1_padded_shape.size() < max_shape_size)
{
arg2_padded_shape.insert(arg2_padded_shape.begin(), 1);
arg1_padded_shape.insert(arg1_padded_shape.begin(), 1);
}

while (arg0_padded_shape.size() < arg1_padded_shape.size())
while (arg2_padded_shape.size() < max_shape_size)
{
arg0_padded_shape.insert(arg0_padded_shape.begin(), 1);
arg2_padded_shape.insert(arg2_padded_shape.begin(), 1);
}

Shape arg0_squeezed_shape;
Expand All @@ -411,7 +415,7 @@ namespace ngraph
AxisSet arg2_squeezed_axes;
Shape output_shape;

for (size_t i = 0; i < arg1_padded_shape.size(); i++)
for (size_t i = 0; i < max_shape_size; i++)
{
if (arg1_padded_shape[i] == 1)
{
Expand Down Expand Up @@ -440,9 +444,9 @@ namespace ngraph
arg0_squeezed_shape.push_back(arg0_padded_shape[i]);
}

output_shape.push_back(arg1_padded_shape[i] == 1
? arg2_padded_shape[i]
: arg1_padded_shape[i]);
output_shape.push_back(std::max({arg0_padded_shape[i],
arg2_padded_shape[i],
arg1_padded_shape[i]}));
}

CoordinateTransform arg0_transform(arg0_squeezed_shape);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#include <ngraph/ops.hpp>
#include "ngraph/shape_util.hpp"

namespace ngraph
Expand All @@ -10,12 +11,12 @@ namespace ngraph
{
namespace reference
{
template <typename T, typename U>
void extractImagePatches(const op::ExtractImagePatches* extImgPatches,
const T* input,
T* out,
const Shape& inShape,
const Shape& outShape)
template <typename T>
void extract_image_patches(const std::shared_ptr<op::ExtractImagePatches> extImgPatches,
const T* input,
T* out,
const Shape& inShape,
const Shape& outShape)
{
const size_t dimsSize = inShape.size();
const size_t BATCH = 0, CHANNEL = 1, HIGHT = 0, WIDTH = 1;
Expand Down
2 changes: 0 additions & 2 deletions ngraph/core/src/op/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/add.hpp"

NGRAPH_SUPPRESS_DEPRECATED_START

using namespace std;
using namespace ngraph;

Expand Down
104 changes: 103 additions & 1 deletion ngraph/test/runtime/interpreter/evaluates_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <ngraph/runtime/reference/convert.hpp>
#include <ngraph/runtime/reference/dequantize.hpp>
#include <ngraph/runtime/reference/dot.hpp>
#include <ngraph/runtime/reference/extract_image_patches.hpp>
#include <ngraph/runtime/reference/gather_nd.hpp>
#include <ngraph/runtime/reference/gru_cell.hpp>
#include <ngraph/runtime/reference/lstm_cell.hpp>
Expand All @@ -33,6 +34,7 @@
#include <ngraph/runtime/reference/reverse_sequence.hpp>
#include <ngraph/runtime/reference/rnn_cell.hpp>
#include <ngraph/runtime/reference/select.hpp>
#include <ngraph/runtime/reference/sequences.hpp>
#include "ngraph/ops.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
Expand Down Expand Up @@ -399,6 +401,7 @@ namespace
const HostTensorVector& input)
{
using T = typename element_type_traits<ET>::value_type;

runtime::reference::select<T>(input[0]->get_data_ptr<const char>(),
input[1]->get_data_ptr<const T>(),
input[2]->get_data_ptr<const T>(),
Expand Down Expand Up @@ -591,7 +594,7 @@ namespace
outputs[0]->get_data_ptr<T>(), \
input[0]->get_shape(), \
op->get_batch_axis(), \
op->get_origin_sequence_axis(), \
op->get_sequence_axis(), \
input[1]->get_data_ptr<U>()); \
break;

Expand All @@ -615,6 +618,20 @@ namespace
return true;
}

template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v3::ExtractImagePatches>& op,
const HostTensorVector& outputs,
const HostTensorVector& input)
{
using T = typename element_type_traits<ET>::value_type;
runtime::reference::extract_image_patches<T>(op,
input[0]->get_data_ptr<T>(),
outputs[0]->get_data_ptr<T>(),
input[0]->get_shape(),
outputs[0]->get_shape());
return true;
}

template <element::Type_t OUT_ET>
bool evaluate(const shared_ptr<op::v0::Convert>& op,
const HostTensorVector& outputs,
Expand Down Expand Up @@ -788,6 +805,91 @@ namespace
return true;
}

template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v5::RNNSequence>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs)
{
using T = typename element_type_traits<ET>::value_type;
runtime::reference::rnn_sequence<T>(inputs[0]->get_data_ptr<char>(),
inputs[0]->get_shape(),
inputs[1]->get_data_ptr<char>(),
inputs[1]->get_shape(),
inputs[2]->get_data_ptr<char>(),
inputs[2]->get_shape(),
inputs[3]->get_data_ptr<char>(),
inputs[3]->get_shape(),
inputs[4]->get_data_ptr<char>(),
inputs[4]->get_shape(),
inputs[5]->get_data_ptr<char>(),
inputs[5]->get_shape(),
outputs[0]->get_data_ptr<char>(),
outputs[1]->get_data_ptr<char>(),
op->get_activations()[0],
op->get_clip(),
op->get_direction());
return true;
}

template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v5::LSTMSequence>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs)
{
using T = typename element_type_traits<ET>::value_type;
runtime::reference::lstm_sequence<T>(inputs[0]->get_data_ptr<char>(),
inputs[0]->get_shape(),
inputs[1]->get_data_ptr<char>(),
inputs[1]->get_shape(),
inputs[2]->get_data_ptr<char>(),
inputs[2]->get_shape(),
inputs[3]->get_data_ptr<char>(),
inputs[3]->get_shape(),
inputs[4]->get_data_ptr<char>(),
inputs[4]->get_shape(),
inputs[5]->get_data_ptr<char>(),
inputs[5]->get_shape(),
inputs[6]->get_data_ptr<char>(),
inputs[6]->get_shape(),
outputs[0]->get_data_ptr<char>(),
outputs[1]->get_data_ptr<char>(),
outputs[2]->get_data_ptr<char>(),
op->get_activations()[0],
op->get_activations()[1],
op->get_activations()[2],
op->get_clip(),
op->get_direction());
return true;
}

template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v5::GRUSequence>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs)
{
using T = typename element_type_traits<ET>::value_type;
runtime::reference::gru_sequence<T>(inputs[0]->get_data_ptr<char>(),
inputs[0]->get_shape(),
inputs[1]->get_data_ptr<char>(),
inputs[1]->get_shape(),
inputs[2]->get_data_ptr<char>(),
inputs[2]->get_shape(),
inputs[3]->get_data_ptr<char>(),
inputs[3]->get_shape(),
inputs[4]->get_data_ptr<char>(),
inputs[4]->get_shape(),
inputs[5]->get_data_ptr<char>(),
inputs[5]->get_shape(),
outputs[0]->get_data_ptr<char>(),
outputs[1]->get_data_ptr<char>(),
op->get_activations()[0],
op->get_activations()[1],
op->get_clip(),
op->get_direction(),
op->get_linear_before_reset());
return true;
}

template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v1::Pad>& op,
const HostTensorVector& outputs,
Expand Down
4 changes: 4 additions & 0 deletions ngraph/test/runtime/interpreter/opset_int_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ NGRAPH_OP(ShapeOf, op::v3)

NGRAPH_OP(CTCLoss, op::v4)
NGRAPH_OP(LSTMCell, op::v4)

NGRAPH_OP(GRUSequence, op::v5)
NGRAPH_OP(LSTMSequence, op::v5)
NGRAPH_OP(RNNSequence, op::v5)

0 comments on commit a5c32c1

Please sign in to comment.