Skip to content

Commit

Permalink
Fix some failed layer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikhail Treskin committed Sep 10, 2020
1 parent f80c325 commit 1f42b8e
Show file tree
Hide file tree
Showing 16 changed files with 118 additions and 164 deletions.

This file was deleted.

1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ namespace ngraph

} // namespace v1
using v1::Add;
NGRAPH_SUPPRESS_DEPRECATED_END
} // namespace op

NGRAPH_DEPRECATED("This operator was deprecated and will be removed with v0 operation.")
Expand Down
2 changes: 0 additions & 2 deletions ngraph/core/include/ngraph/op/divide.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,8 @@ namespace ngraph
} // namespace v1

using v1::Divide;
NGRAPH_SUPPRESS_DEPRECATED_END
} // namespace op

NGRAPH_DEPRECATED("This operator was deprecated and will be removed with v0 operation.")
NGRAPH_API
std::shared_ptr<Node> operator/(const Output<Node>& arg0, const Output<Node>& arg1);
} // namespace ngraph
1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,5 @@ namespace ngraph
} // namespace v1

using v1::Equal;
NGRAPH_SUPPRESS_DEPRECATED_END
}
}
1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/greater.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,5 @@ namespace ngraph
} // namespace v1

using v1::Greater;
NGRAPH_SUPPRESS_DEPRECATED_END
}
}
1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/greater_eq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,5 @@ namespace ngraph
} // namespace v1

using v1::GreaterEqual;
NGRAPH_SUPPRESS_DEPRECATED_END
}
}
1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/less.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,5 @@ namespace ngraph
} // namespace v1

using v1::Less;
NGRAPH_SUPPRESS_DEPRECATED_END
}
}
1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/maximum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,5 @@ namespace ngraph
} // namespace v1

using v1::Maximum;
NGRAPH_SUPPRESS_DEPRECATED_END
}
}
1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/minimum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,5 @@ namespace ngraph
} // namespace v1

using v1::Minimum;
NGRAPH_SUPPRESS_DEPRECATED_END
}
}
1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/not_equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,5 @@ namespace ngraph
} // namespace v1

using v1::NotEqual;
NGRAPH_SUPPRESS_DEPRECATED_END
}
}
1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/power.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,5 @@ namespace ngraph
} // namespace v1

using v1::Power;
NGRAPH_SUPPRESS_DEPRECATED_END
}
}
1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/select.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,5 @@ namespace ngraph
};
}
using v1::Select;
NGRAPH_SUPPRESS_DEPRECATED_END
}
}
1 change: 0 additions & 1 deletion ngraph/core/include/ngraph/op/subtract.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ namespace ngraph
} // namespace v1

using v1::Subtract;
NGRAPH_SUPPRESS_DEPRECATED_END
} // namespace op

NGRAPH_DEPRECATED("This operator was deprecated and will be removed with v0 operation.")
Expand Down
32 changes: 0 additions & 32 deletions ngraph/core/src/op/squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,38 +152,6 @@ namespace
const HostTensorPtr& out)
{
auto element_type = arg0->get_element_type();
out->set_element_type(element_type);

auto data_shape = arg0->get_shape();
int64_t data_rank = static_cast<int64_t>(data_shape.size());
auto axes_shape = arg1->get_shape();
NGRAPH_CHECK(axes_shape.size() <= 1, "Axes to remove must be a vector or empty.");

auto out_shape = data_shape;
// Empty axes vector
if (axes_shape.size() == 0 || axes_shape[0] == 0)
{
out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), out_shape.end());
}
else
{
// Get axes
vector<int64_t> axes = read_index_vector(arg1);
// Normalize axes
std::transform(axes.begin(),
axes.end(),
axes.begin(),
[data_rank](int64_t i) -> int64_t { return i < 0 ? data_rank + i : i; });
// Sort in decreasing order
std::set<int64_t, greater<int64_t>> axes_set(axes.begin(), axes.end());
for (int64_t axis : axes_set)
{
NGRAPH_CHECK(axis >= 0 && axis < data_rank, "Axis is out of bounds: ", axis);
NGRAPH_CHECK(out_shape[axis] == 1, "Only axis of size 1 can be removed.");
out_shape.erase(out_shape.begin() + axis);
}
}
out->set_shape(out_shape);

bool rc = true;
switch (element_type)
Expand Down
136 changes: 87 additions & 49 deletions ngraph/test/runtime/interpreter/evaluates_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,79 +496,119 @@ namespace {
input[0]->get_shape(),\
op->get_batch_axis(),\
op->get_origin_sequence_axis(),\
input[1]->get_data_ptr<U>());\

input[1]->get_data_ptr<U>()); \
break;
switch (input[1]->get_element_type()) {
case element::Type_t::boolean:
REF_CALL(element::Type_t::boolean)
REF_CALL(element::Type_t::boolean)
case element::Type_t::i8:
REF_CALL(element::Type_t::i8);
REF_CALL(element::Type_t::i8);
case element::Type_t::i16:
REF_CALL(element::Type_t::i16);
REF_CALL(element::Type_t::i16);
case element::Type_t::i32:
REF_CALL(element::Type_t::i32);
REF_CALL(element::Type_t::i32);
case element::Type_t::i64:
REF_CALL(element::Type_t::i64);
REF_CALL(element::Type_t::i64);
case element::Type_t::u8:
REF_CALL(element::Type_t::u8);
REF_CALL(element::Type_t::u8);
case element::Type_t::u16:
REF_CALL(element::Type_t::u16);
REF_CALL(element::Type_t::u16);
case element::Type_t::u32:
REF_CALL(element::Type_t::u32);
REF_CALL(element::Type_t::u32);
case element::Type_t::u64:
REF_CALL(element::Type_t::u64);
REF_CALL(element::Type_t::u64);
case element::Type_t::f16:
REF_CALL(element::Type_t::f16);
REF_CALL(element::Type_t::f16);
case element::Type_t::f32:
REF_CALL(element::Type_t::f32);
REF_CALL(element::Type_t::f32);
case element::Type_t::f64:
REF_CALL(element::Type_t::f64);
REF_CALL(element::Type_t::f64);
default:
return false;
}
#undef REF_CALL
return true;
}

template<element::Type_t ET>
template<element::Type_t OUT_ET>
bool evaluate(const shared_ptr<op::v0::Convert> &op, const HostTensorVector &outputs,
const HostTensorVector &input) {
using T = typename element_type_traits<ET>::value_type;
#define REF_CALL(U) \
runtime::reference::convert<T, typename element_type_traits<U>::value_type>(\
input[0]->get_data_ptr<T>(),\
outputs[0]->get_data_ptr<U>(),\
shape_size(input[0]->get_shape()));


switch (input[0]->get_element_type()) {
case element::Type_t::boolean:
using TO = typename element_type_traits<OUT_ET>::value_type;
if (OUT_ET == element::Type_t::boolean) {
#define REF_CALL_BOOL(TI) \
runtime::reference::convert_to_bool<typename element_type_traits<TI>::value_type>(\
input[0]->get_data_ptr<TI>(),\
outputs[0]->get_data_ptr<char>(),\
shape_size(input[0]->get_shape())); \
break;
switch (input[0]->get_element_type()) {
case element::Type_t::boolean:
REF_CALL_BOOL(element::Type_t::boolean);
case element::Type_t::i8:
REF_CALL_BOOL(element::Type_t::i8);
case element::Type_t::i16:
REF_CALL_BOOL(element::Type_t::i16);
case element::Type_t::i32:
REF_CALL_BOOL(element::Type_t::i32);
case element::Type_t::i64:
REF_CALL_BOOL(element::Type_t::i64);
case element::Type_t::u8:
REF_CALL_BOOL(element::Type_t::u8);
case element::Type_t::u16:
REF_CALL_BOOL(element::Type_t::u16);
case element::Type_t::u32:
REF_CALL_BOOL(element::Type_t::u32);
case element::Type_t::u64:
REF_CALL_BOOL(element::Type_t::u64);
case element::Type_t::f16:
REF_CALL_BOOL(element::Type_t::f16);
case element::Type_t::f32:
REF_CALL_BOOL(element::Type_t::f32);
case element::Type_t::f64:
REF_CALL_BOOL(element::Type_t::f64);
default:
return false;
}
#undef REF_CALL_BOOL
} else {
#define REF_CALL(TI) \
runtime::reference::convert<typename element_type_traits<TI>::value_type, TO>(\
input[0]->get_data_ptr<TI>(),\
outputs[0]->get_data_ptr<TO>(),\
shape_size(input[0]->get_shape())); \
break;

switch (input[0]->get_element_type()) {
case element::Type_t::boolean:
REF_CALL(element::Type_t::boolean);
case element::Type_t::i8:
case element::Type_t::i8:
REF_CALL(element::Type_t::i8);
case element::Type_t::i16:
case element::Type_t::i16:
REF_CALL(element::Type_t::i16);
case element::Type_t::i32:
case element::Type_t::i32:
REF_CALL(element::Type_t::i32);
case element::Type_t::i64:
case element::Type_t::i64:
REF_CALL(element::Type_t::i64);
case element::Type_t::u8:
case element::Type_t::u8:
REF_CALL(element::Type_t::u8);
case element::Type_t::u16:
case element::Type_t::u16:
REF_CALL(element::Type_t::u16);
case element::Type_t::u32:
case element::Type_t::u32:
REF_CALL(element::Type_t::u32);
case element::Type_t::u64:
case element::Type_t::u64:
REF_CALL(element::Type_t::u64);
case element::Type_t::f16:
case element::Type_t::f16:
REF_CALL(element::Type_t::f16);
case element::Type_t::f32:
case element::Type_t::f32:
REF_CALL(element::Type_t::f32);
case element::Type_t::f64:
case element::Type_t::f64:
REF_CALL(element::Type_t::f64);
default:
return false;
}
default:
return false;
}
#undef REF_CALL
}
return true;
}

// TODO: Rewrite to v1
Expand Down Expand Up @@ -600,20 +640,18 @@ namespace {
const HostTensorVector &inputs) {
using T = typename element_type_traits<ET>::value_type;
runtime::reference::pad(inputs[0]->get_data_ptr<char>(),
inputs[1]->get_data_ptr<char>(),
outputs[0]->get_data_ptr<char>(),
shape_size(inputs[0]->get_shape()),
inputs[1]->get_shape(),
outputs[0]->get_shape(),
op->get_pads_end(),
op->get_pads_begin(),
op->get_pad_mode());
inputs[1]->get_data_ptr<char>(),
outputs[0]->get_data_ptr<char>(),
shape_size(inputs[0]->get_shape()),
inputs[1]->get_shape(),
outputs[0]->get_shape(),
op->get_pads_end(),
op->get_pads_begin(),
op->get_pad_mode());
return true;
}




template<typename T>
bool evaluate_node(std::shared_ptr<Node> node, const HostTensorVector &outputs, const HostTensorVector &inputs) {
switch (node->get_element_type()) {
Expand Down
Loading

0 comments on commit 1f42b8e

Please sign in to comment.