From 47127fb021805bdff866b235ddb39fdf276e97bc Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Fri, 5 Feb 2021 10:10:40 +0300 Subject: [PATCH] opset1::OneHot Fix shape infer function for dynamic input shape case (#4163) --- .../src/legacy_api/src/ngraph_ops/onehot_ie.cpp | 4 ++-- ngraph/core/src/op/one_hot.cpp | 9 ++------- ngraph/test/type_prop/one_hot.cpp | 10 ++++++++++ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/inference-engine/src/legacy_api/src/ngraph_ops/onehot_ie.cpp b/inference-engine/src/legacy_api/src/ngraph_ops/onehot_ie.cpp index 2c964ec21d8098..403b82fcd61ec5 100644 --- a/inference-engine/src/legacy_api/src/ngraph_ops/onehot_ie.cpp +++ b/inference-engine/src/legacy_api/src/ngraph_ops/onehot_ie.cpp @@ -19,10 +19,10 @@ op::OneHotIE::OneHotIE(const Output& input, int axis, int depth, f void op::OneHotIE::validate_and_infer_types() { const PartialShape& arg_shape = get_input_partial_shape(0); - if (arg_shape.is_dynamic()) { + if (arg_shape.rank().is_dynamic()) { set_output_type(0, m_type, PartialShape::dynamic()); } else { - Shape output_shape = arg_shape.to_shape(); + vector output_shape{arg_shape}; int normalized_axis = m_axis; if (m_axis < 0) normalized_axis = m_axis + static_cast(arg_shape.to_shape().size()); diff --git a/ngraph/core/src/op/one_hot.cpp b/ngraph/core/src/op/one_hot.cpp index 1dee5b6b0cc81a..eddbe8ce2965d0 100644 --- a/ngraph/core/src/op/one_hot.cpp +++ b/ngraph/core/src/op/one_hot.cpp @@ -77,15 +77,10 @@ void op::v1::OneHot::validate_and_infer_types() PartialShape result_shape{PartialShape::dynamic()}; const auto& depth = input_value(1).get_node_shared_ptr(); const auto& depth_constant = get_constant_from_source(input_value(1)); - if (indices_shape.is_static() && indices_shape.rank().is_static() && depth_constant) + if (indices_shape.rank().is_static() && depth_constant) { + std::vector out_dims{indices_shape}; const auto indices_rank = indices_shape.rank().get_length(); - - std::vector out_dims(indices_rank); - for (auto i = 0; i < indices_rank; i++) - { - out_dims[i] = indices_shape[i]; - } m_axis = ngraph::normalize_axis(this, m_axis, indices_rank + 1, -indices_rank - 1, indices_rank); diff --git a/ngraph/test/type_prop/one_hot.cpp b/ngraph/test/type_prop/one_hot.cpp index 09886dd18a690b..c55a393afdcf67 100644 --- a/ngraph/test/type_prop/one_hot.cpp +++ b/ngraph/test/type_prop/one_hot.cpp @@ -31,6 +31,11 @@ TEST(type_prop, one_hot_v1_output_shape) auto ont_hot = make_shared(indices, depth, on_value, off_value, axis); ASSERT_EQ(ont_hot->get_element_type(), element::u32); ASSERT_EQ(ont_hot->get_shape(), (Shape{3, 2})); + + auto dyn_indices = make_shared(element::i64, PartialShape{{1, 3}}); + auto dyn_ont_hot = make_shared(dyn_indices, depth, on_value, off_value, axis); + ASSERT_EQ(dyn_ont_hot->get_output_element_type(0), element::u32); + ASSERT_EQ(dyn_ont_hot->get_output_partial_shape(0), (PartialShape{{1, 3}, 2})); } TEST(type_prop, one_hot_v1_output_shape_2) @@ -43,6 +48,11 @@ TEST(type_prop, one_hot_v1_output_shape_2) auto ont_hot = make_shared(indices, depth, on_value, off_value, axis); ASSERT_EQ(ont_hot->get_element_type(), element::f32); ASSERT_EQ(ont_hot->get_shape(), (Shape{1, 3, 2, 4, 3})); + + auto dyn_indices = make_shared(element::i64, PartialShape{1, {3, 5}, 2, 3}); + auto dyn_ont_hot = make_shared(dyn_indices, depth, on_value, off_value, axis); + ASSERT_EQ(dyn_ont_hot->get_output_element_type(0), element::f32); + ASSERT_EQ(dyn_ont_hot->get_output_partial_shape(0), (PartialShape{1, {3, 5}, 2, 4, 3})); } TEST(type_prop, one_hot_v1_indices_elem_not_integral)