Skip to content

Commit

Permalink
opset1::OneHot Fix shape infer function for dynamic input shape case (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgenya Stepyreva authored Feb 5, 2021
1 parent 6083c7f commit 47127fb
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
4 changes: 2 additions & 2 deletions inference-engine/src/legacy_api/src/ngraph_ops/onehot_ie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ op::OneHotIE::OneHotIE(const Output<ngraph::Node>& input, int axis, int depth, f
void op::OneHotIE::validate_and_infer_types() {
const PartialShape& arg_shape = get_input_partial_shape(0);

if (arg_shape.is_dynamic()) {
if (arg_shape.rank().is_dynamic()) {
set_output_type(0, m_type, PartialShape::dynamic());
} else {
Shape output_shape = arg_shape.to_shape();
vector<Dimension> output_shape{arg_shape};
int normalized_axis = m_axis;
if (m_axis < 0)
normalized_axis = m_axis + static_cast<int>(arg_shape.to_shape().size());
Expand Down
9 changes: 2 additions & 7 deletions ngraph/core/src/op/one_hot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,10 @@ void op::v1::OneHot::validate_and_infer_types()
PartialShape result_shape{PartialShape::dynamic()};
const auto& depth = input_value(1).get_node_shared_ptr();
const auto& depth_constant = get_constant_from_source(input_value(1));
if (indices_shape.is_static() && indices_shape.rank().is_static() && depth_constant)
if (indices_shape.rank().is_static() && depth_constant)
{
std::vector<Dimension> out_dims{indices_shape};
const auto indices_rank = indices_shape.rank().get_length();

std::vector<Dimension> out_dims(indices_rank);
for (auto i = 0; i < indices_rank; i++)
{
out_dims[i] = indices_shape[i];
}
m_axis =
ngraph::normalize_axis(this, m_axis, indices_rank + 1, -indices_rank - 1, indices_rank);

Expand Down
10 changes: 10 additions & 0 deletions ngraph/test/type_prop/one_hot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ TEST(type_prop, one_hot_v1_output_shape)
auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
ASSERT_EQ(ont_hot->get_element_type(), element::u32);
ASSERT_EQ(ont_hot->get_shape(), (Shape{3, 2}));

auto dyn_indices = make_shared<op::Parameter>(element::i64, PartialShape{{1, 3}});
auto dyn_ont_hot = make_shared<op::v1::OneHot>(dyn_indices, depth, on_value, off_value, axis);
ASSERT_EQ(dyn_ont_hot->get_output_element_type(0), element::u32);
ASSERT_EQ(dyn_ont_hot->get_output_partial_shape(0), (PartialShape{{1, 3}, 2}));
}

TEST(type_prop, one_hot_v1_output_shape_2)
Expand All @@ -43,6 +48,11 @@ TEST(type_prop, one_hot_v1_output_shape_2)
auto ont_hot = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
ASSERT_EQ(ont_hot->get_element_type(), element::f32);
ASSERT_EQ(ont_hot->get_shape(), (Shape{1, 3, 2, 4, 3}));

auto dyn_indices = make_shared<op::Parameter>(element::i64, PartialShape{1, {3, 5}, 2, 3});
auto dyn_ont_hot = make_shared<op::v1::OneHot>(dyn_indices, depth, on_value, off_value, axis);
ASSERT_EQ(dyn_ont_hot->get_output_element_type(0), element::f32);
ASSERT_EQ(dyn_ont_hot->get_output_partial_shape(0), (PartialShape{1, {3, 5}, 2, 4, 3}));
}

TEST(type_prop, one_hot_v1_indices_elem_not_integral)
Expand Down

0 comments on commit 47127fb

Please sign in to comment.