diff --git a/ngraph/src/ngraph/op/broadcast.cpp b/ngraph/src/ngraph/op/broadcast.cpp index db81524db71956..4cc2c2d8feae89 100644 --- a/ngraph/src/ngraph/op/broadcast.cpp +++ b/ngraph/src/ngraph/op/broadcast.cpp @@ -88,15 +88,21 @@ std::pair op::v3::Broadcast::get_broadcast_axes() const namespace { - PartialShape - get_result_shape_bidirectional(const Node* this_ptr, Shape& arg_shape, Shape& target_shape) + PartialShape get_result_shape_bidirectional(const Node* this_ptr, + const PartialShape& arg_shape, + Shape& target_shape) { + if (arg_shape.rank().is_dynamic()) + { + return PartialShape::dynamic(); + } + auto arg_shape_vec = static_cast>(arg_shape); PartialShape result_shape; // Add left padding to shorter target or argument shape - const auto target_padded_rank = std::max(arg_shape.size(), target_shape.size()); - while (arg_shape.size() < target_padded_rank) + const auto target_padded_rank = std::max(arg_shape_vec.size(), target_shape.size()); + while (arg_shape_vec.size() < target_padded_rank) { - arg_shape.insert(arg_shape.begin(), 1); + arg_shape_vec.insert(arg_shape_vec.begin(), 1); } while (target_shape.size() < target_padded_rank) { @@ -106,15 +112,28 @@ namespace result_shape = target_shape; for (auto i = 0; i < target_shape.size(); ++i) { + if (arg_shape_vec[i].is_dynamic()) + { + if (target_shape[i] == 1) + { + result_shape[i] = Dimension::dynamic(); + } + else + { + result_shape[i] = target_shape[i]; + } + continue; + } + const size_t arg_shape_dim = arg_shape_vec[i].get_length(); NODE_VALIDATION_CHECK(this_ptr, - arg_shape[i] == 1 || target_shape[i] == 1 || - arg_shape[i] == target_shape[i], + arg_shape_dim == 1 || target_shape[i] == 1 || + arg_shape_dim == target_shape[i], "Broadcast incorrect target shape. Expecting either 1 or ", - arg_shape[i], + arg_shape_dim, ". Got ", target_shape[i]); - result_shape[i] = std::max(arg_shape[i], target_shape[i]); + result_shape[i] = std::max(arg_shape_dim, target_shape[i]); } return result_shape; } @@ -141,9 +160,9 @@ void op::v3::Broadcast::validate_and_infer_types() auto result_shape = get_output_partial_shape(0); if (m_mode.m_type == BroadcastType::BIDIRECTIONAL) { - if (get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static()) + if (get_input_partial_shape(0).rank().is_static() && get_input_partial_shape(1).is_static()) { - auto arg_shape = get_input_shape(0); + auto arg_shape = get_input_partial_shape(0); const auto shape_constant = as_type_ptr(input_value(1).get_node_shared_ptr()); @@ -193,7 +212,8 @@ bool op::v3::Broadcast::evaluate(const HostTensorVector& outputs, { auto arg_shape = inputs[0]->get_shape(); Shape target_shape = op::util::BroadcastBase::get_target_shape(inputs[1]); - PartialShape result_shape = get_result_shape_bidirectional(this, arg_shape, target_shape); + PartialShape result_shape = + get_result_shape_bidirectional(this, PartialShape{arg_shape}, target_shape); auto pair_broadcast_axes = get_broadcast_axes_bidirectional(arg_shape, result_shape.to_shape()); return op::util::BroadcastBase::evaluate_broadcast( diff --git a/ngraph/src/ngraph/op/util/broadcast_base.cpp b/ngraph/src/ngraph/op/util/broadcast_base.cpp index 0cde52548abce7..c9ea94e24ade34 100644 --- a/ngraph/src/ngraph/op/util/broadcast_base.cpp +++ b/ngraph/src/ngraph/op/util/broadcast_base.cpp @@ -46,35 +46,79 @@ op::util::BroadcastBase::BroadcastBase(const Output& arg, { } -PartialShape op::util::BroadcastBase::get_result_shape_numpy_pdpd( - const Shape& arg0_shape, +PartialShape op::util::BroadcastBase::get_result_shape_pdpd( + const PartialShape& arg0_shape, const Shape& target_shape, const op::BroadcastModeSpec& broadcast_spec) const { + if (arg0_shape.rank().is_dynamic()) + { + return PartialShape::dynamic(target_shape.size()); + } + const auto arg_rank_length = arg0_shape.rank().get_length(); PartialShape result_shape = target_shape; - auto start_axis = (broadcast_spec.m_type == op::BroadcastType::PDPD) - ? broadcast_spec.m_axis - : target_shape.size() - arg0_shape.size(); + auto start_axis = broadcast_spec.m_axis; + NODE_VALIDATION_CHECK(this, start_axis >= 0, "Broadcast target_shape has smaller rank ", target_shape.size(), " than arg shape ", - arg0_shape.size()); + arg_rank_length); for (auto i = start_axis; i < target_shape.size(); i++) { + if (arg0_shape[i - start_axis].is_dynamic()) + { + result_shape[i] = Dimension::dynamic(); + continue; + } + const size_t arg_dim = arg0_shape[i - start_axis].get_length(); NODE_VALIDATION_CHECK(this, - arg0_shape[i - start_axis] == 1 || target_shape[i] == 1 || - arg0_shape[i - start_axis] == target_shape[i], + arg_dim == 1 || target_shape[i] == 1 || arg_dim == target_shape[i], "Broadcast incorrect target shape. Expecting either 1 or ", - arg0_shape[i - start_axis], + arg_dim, " . Got ", target_shape[i]); - result_shape[i] = std::max(arg0_shape[i - start_axis], target_shape[i]); + result_shape[i] = std::max(arg_dim, target_shape[i]); } return result_shape; } +void op::util::BroadcastBase::validate_target_shape_numpy(const PartialShape& arg_shape, + const Shape& target_shape) const +{ + if (arg_shape.rank().is_dynamic()) + { + return; + } + const auto arg_rank_length = arg_shape.rank().get_length(); + auto start_axis = target_shape.size() - arg_rank_length; + NODE_VALIDATION_CHECK(this, + start_axis >= 0, + "Broadcast target_shape has smaller rank ", + target_shape.size(), + " than arg shape ", + arg_rank_length); + for (auto i = start_axis; i < target_shape.size(); i++) + { + if (arg_shape[i - start_axis].is_dynamic()) + { + continue; + } + const size_t arg_dim = arg_shape[i - start_axis].get_length(); + NODE_VALIDATION_CHECK(this, + arg_dim == 1 || arg_dim == target_shape[i], + "Input shape dimension equal ", + arg_dim, + " cannot be broadcasted (numpy mode) to ", + target_shape[i], + ". Allowed input dimension value would be 1", + target_shape[i] != 1 + ? (std::string(" or ") + std::to_string(target_shape[i])).c_str() + : ""); + } +} + void op::util::BroadcastBase::validate_target_shape_none(const Shape& arg_shape, const AxisVector& axes_mapping_val, const Shape& target_shape) const @@ -141,13 +185,28 @@ void op::util::BroadcastBase::validate_and_infer_types() } PartialShape result_shape{PartialShape::dynamic()}; - auto input_rank = input_value(0).get_partial_shape().rank(); - auto output_rank = input_value(1).get_partial_shape(); - if (input_rank.is_static() && output_rank.is_static() && output_rank[0].is_static()) + const auto& input_shape = get_input_partial_shape(0); + const auto input_rank = input_shape.rank(); + const auto& target_shape = input_value(1).get_partial_shape(); + const bool is_target_shape_known = + target_shape.rank().is_static() && target_shape[0].is_static(); + + if (m_mode.m_type == BroadcastType::BIDIRECTIONAL) { - result_shape = - PartialShape::dynamic(std::max(input_rank.get_length(), output_rank[0].get_length())); + if (input_rank.is_static() && is_target_shape_known) + { + result_shape = PartialShape::dynamic( + std::max(input_rank.get_length(), target_shape[0].get_length())); + } + } + else + { + if (is_target_shape_known) + { + result_shape = PartialShape::dynamic(target_shape[0].get_length()); + } } + const auto shape_constant = as_type_ptr(input_value(1).get_node_shared_ptr()); if (auto concat = as_type_ptr(input_value(1).get_node_shared_ptr())) @@ -205,17 +264,21 @@ void op::util::BroadcastBase::validate_and_infer_types() } } } - else if (m_mode.m_type == BroadcastType::NUMPY || m_mode.m_type == BroadcastType::PDPD) + else if (m_mode.m_type == BroadcastType::NUMPY) { - if (get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static()) + if (shape_constant) { - auto arg_shape = get_input_shape(0); - - if (shape_constant) - { - const auto target_shape = shape_constant->get_shape_val(); - result_shape = get_result_shape_numpy_pdpd(arg_shape, target_shape, m_mode); - } + const auto target_shape = shape_constant->get_shape_val(); + result_shape = target_shape; + validate_target_shape_numpy(input_shape, target_shape); + } + } + else if (m_mode.m_type == BroadcastType::PDPD) + { + if (shape_constant) + { + const auto target_shape = shape_constant->get_shape_val(); + result_shape = get_result_shape_pdpd(input_shape, target_shape, m_mode); } } set_output_type(0, get_input_element_type(0), result_shape); @@ -486,9 +549,16 @@ bool op::util::BroadcastBase::evaluate(const HostTensorVector& outputs, validate_target_shape_none(inputs[0]->get_shape(), axes_mapping_val, target_shape); result_shape = target_shape; } - else if (m_mode.m_type == BroadcastType::NUMPY || m_mode.m_type == BroadcastType::PDPD) + else if (m_mode.m_type == BroadcastType::PDPD) + { + result_shape = get_result_shape_pdpd(arg_shape, target_shape, m_mode); + pair_broadcast_axes = + get_broadcast_axes_numpy_pdpd(arg_shape, result_shape.to_shape(), m_mode); + } + else if (m_mode.m_type == BroadcastType::NUMPY) { - result_shape = get_result_shape_numpy_pdpd(arg_shape, target_shape, m_mode); + result_shape = target_shape; + validate_target_shape_numpy(arg_shape, target_shape); pair_broadcast_axes = get_broadcast_axes_numpy_pdpd(arg_shape, result_shape.to_shape(), m_mode); } diff --git a/ngraph/src/ngraph/op/util/broadcast_base.hpp b/ngraph/src/ngraph/op/util/broadcast_base.hpp index ebf6802b55d48c..28d01ed39c8719 100644 --- a/ngraph/src/ngraph/op/util/broadcast_base.hpp +++ b/ngraph/src/ngraph/op/util/broadcast_base.hpp @@ -77,9 +77,13 @@ namespace ngraph const AxisSet& broadcast_axes) const; PartialShape - get_result_shape_numpy_pdpd(const Shape& arg0_shape, - const Shape& target_shape, - const op::BroadcastModeSpec& broadcast_spec) const; + get_result_shape_pdpd(const PartialShape& arg0_shape, + const Shape& target_shape, + const op::BroadcastModeSpec& broadcast_spec) const; + + void validate_target_shape_numpy(const PartialShape& arg_shape, + const Shape& target_shape) const; + static std::pair get_broadcast_axes_numpy_pdpd(const Shape& arg_shape, const Shape& result_shape, diff --git a/ngraph/test/eval.cpp b/ngraph/test/eval.cpp index a65829e784ab2c..f9c7a6bdf47568 100644 --- a/ngraph/test/eval.cpp +++ b/ngraph/test/eval.cpp @@ -315,7 +315,7 @@ TEST(eval, evaluate_broadcast_v3_numpy_vs_bidi) Shape in_shape{1, 4, 1}; auto A = make_shared(element::f32, in_shape); - auto target_shape = op::Constant::create(element::i64, Shape{3}, {1, 1, 4}); + auto target_shape = op::Constant::create(element::i64, Shape{3}, {1, 4, 4}); auto bcast_v3_num = make_shared(A, target_shape, op::BroadcastType::NUMPY); auto fun_num = make_shared(OutputVector{bcast_v3_num}, ParameterVector{A}); @@ -343,6 +343,26 @@ TEST(eval, evaluate_broadcast_v3_numpy_vs_bidi) ASSERT_EQ(expec2, result_val2); } +TEST(eval, evaluate_broadcast_v3_bidi_3d) +{ + Shape in_shape{1, 4, 1}; + + auto A = make_shared(element::f32, in_shape); + auto target_shape = op::Constant::create(element::i64, Shape{3}, {1, 1, 3}); + auto bcast_v3_num = + make_shared(A, target_shape, op::BroadcastType::BIDIRECTIONAL); + auto fun_num = make_shared(OutputVector{bcast_v3_num}, ParameterVector{A}); + + auto result = make_shared(); + ASSERT_TRUE(fun_num->evaluate( + {result}, {make_host_tensor(in_shape, {1.0f, 2.0f, 3.0f, 4.0f})})); + EXPECT_EQ(result->get_element_type(), element::f32); + EXPECT_EQ(result->get_partial_shape(), (PartialShape{1, 4, 3})); + auto result_val = read_vector(result); + vector expec{1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f}; + ASSERT_EQ(expec, result_val); +} + TEST(eval, evaluate_broadcast_v3_bidi_4d) { Shape in_shape{4, 1, 1}; diff --git a/ngraph/test/type_prop/broadcast.cpp b/ngraph/test/type_prop/broadcast.cpp index 70ec3ef4ccb3a2..e93b7eb067fee0 100644 --- a/ngraph/test/type_prop/broadcast.cpp +++ b/ngraph/test/type_prop/broadcast.cpp @@ -439,6 +439,359 @@ TYPED_TEST_P(BroadcastTests, broadcast_axes_et_wrong) } } +// EXPLICIT MODE + +TYPED_TEST_P(BroadcastTests, broadcast_explicit_all_inputs_dynamic) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic()); + const auto target_shape = make_shared(element::i64, PartialShape::dynamic()); + const auto axes_mapping = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); + + // const axes mapping + const auto axes_mapping_const = + op::Constant::create(element::i64, Shape{3}, vector{0, 1, 2}); + bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); +} + +TYPED_TEST_P(BroadcastTests, broadcast_explicit_target_shape_static_rank) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic()); + const auto target_shape = make_shared(element::i64, PartialShape::dynamic(1)); + const auto axes_mapping = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); + + // const axes mapping + const auto axes_mapping_const = + op::Constant::create(element::i64, Shape{3}, vector{0, 1, 2}); + bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); +} + +TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic()); + const auto target_shape = + op::Constant::create(element::i64, Shape{3}, vector{1, 2, 3}); + const auto axes_mapping = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + + ASSERT_TRUE(bc->get_output_partial_shape(0).is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3); + ASSERT_EQ(bc->get_shape(), (Shape{1, 2, 3})); + + // const axes mapping + const auto axes_mapping_const = + op::Constant::create(element::i64, Shape{3}, vector{0, 2, 1}); + bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); + + ASSERT_TRUE(bc->get_output_partial_shape(0).is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3); + ASSERT_EQ(bc->get_shape(), (Shape{1, 2, 3})); +} + +TYPED_TEST_P(BroadcastTests, broadcast_explicit_input_rank_static) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic(3)); + const auto target_shape = make_shared(element::i64, PartialShape::dynamic()); + const auto axes_mapping = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); + + // const axes mapping + const auto axes_mapping_const = + op::Constant::create(element::i64, Shape{3}, vector{0, 2, 1}); + bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); +} + +TYPED_TEST_P(BroadcastTests, broadcast_explicit_target_shape_and_input_data_rank_static) +{ + // static rank data + const auto data = make_shared(element::f32, PartialShape::dynamic(3)); + const auto target_shape = make_shared(element::i64, PartialShape::dynamic(1)); + auto axes_mapping = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); + + // const axes mapping + const auto axes_mapping_const = + op::Constant::create(element::i64, Shape{3}, vector{0, 2, 1}); + bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); +} + +TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape_static_rank_input) +{ + const auto target_shape = + op::Constant::create(element::i64, Shape{4}, vector{1, 1, 5, 10}); + // static rank data + const auto data = make_shared(element::f32, PartialShape::dynamic(3)); + auto axes_mapping = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10})); + + // const axes mapping + const auto axes_mapping_const = + op::Constant::create(element::i64, Shape{4}, vector{0, 2, 1, 3}); + bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10})); +} + +TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape) +{ + const auto data = make_shared(element::f32, PartialShape{1, 2, 3, 4}); + // dynamic target shape and axes mapping + auto target_shape = make_shared(element::i64, PartialShape::dynamic()); + auto axes_mapping = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); + + // const axes mapping + const auto axes_mapping_const = + op::Constant::create(element::i64, Shape{4}, vector{0, 2, 1, 3}); + bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); + + // static rank target shape + target_shape = make_shared(element::i64, PartialShape::dynamic(1)); + bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); + + // static rank target shape and const axes mapping + target_shape = make_shared(element::i64, PartialShape::dynamic(1)); + bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); +} + +TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape_const_target_shape) +{ + const auto data = make_shared(element::f32, PartialShape{4}); + auto target_shape = op::Constant::create(element::i64, Shape{4}, vector{1, 4, 2, 3}); + // dynamic axes mapping + const auto axes_mapping = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_shape(), (Shape{1, 4, 2, 3})); + + // const axes mapping + const auto axes_mapping_const = + op::Constant::create(element::i64, Shape{1}, vector{1}); + bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_shape(), (Shape{1, 4, 2, 3})); +} + +TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_target_shape) +{ + // dynamic input + auto data = make_shared(element::f32, PartialShape::dynamic()); + const auto target_shape = make_shared(element::i64, PartialShape{4}); + const auto axes_mapping = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic()); + + // static rank input + data = make_shared(element::f32, PartialShape::dynamic(2)); + bc = make_shared(data, target_shape, axes_mapping, "EXPLICIT"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_TRUE(bc->get_output_partial_shape(0).is_dynamic()); +} + +// NUMPY MODE + +TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_shape_dynamic) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic()); + // dynamic output shape + auto target_shape = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); + + // static rank target shape + target_shape = make_shared(element::i64, PartialShape::dynamic(1)); + bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); +} + +TYPED_TEST_P(BroadcastTests, broadcast_numpy_target_shape_constant) +{ + // dynamic data + auto data = make_shared(element::f32, PartialShape::dynamic()); + const auto target_shape = + op::Constant::create(element::i64, Shape{3}, vector{1, 2, 3}); + + auto bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3); + + // static rank data + data = make_shared(element::f32, PartialShape::dynamic(2)); + bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3); +} + +TYPED_TEST_P(BroadcastTests, broadcast_numpy_target_shape_dynamic) +{ + // static rank data + auto data = make_shared(element::f32, PartialShape::dynamic(3)); + const auto target_shape = make_shared(element::i64, PartialShape::dynamic()); + + auto bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); + + // static shape data + data = make_shared(element::f32, PartialShape{3, 4, 5, 6}); + bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); +} + +TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_target_shape_static_rank) +{ + const auto data = make_shared(element::f32, PartialShape::dynamic(3)); + const auto target_shape = make_shared(element::i64, PartialShape::dynamic(1)); + + const auto bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); +} + +TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_static_shape) +{ + const auto data = make_shared(element::f32, PartialShape{1, 2, 3}); + // static rank target_shape + auto target_shape = make_shared(element::i64, PartialShape::dynamic(1)); + + auto bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_dynamic()); + + // constant target_shape + const auto target_shape_const = + op::Constant::create(element::i64, Shape{3}, vector{3, 2, 3}); + bc = make_shared(data, target_shape_const, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 3); + ASSERT_TRUE(bc->get_output_partial_shape(0).is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{3, 2, 3})); +} + +TYPED_TEST_P(BroadcastTests, broadcast_numpy_input_partially_dynamic) +{ + const Shape expected_target_shape{1, 2, 3, 4}; + const auto target_shape = op::Constant::create( + element::i64, + {expected_target_shape.size()}, + std::vector(expected_target_shape.begin(), expected_target_shape.end())); + + auto data = make_shared(element::f32, PartialShape{2, 3, Dimension::dynamic()}); + auto bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape); + + data = make_shared(element::f32, + PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}); + bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape); + + data = make_shared(element::f32, + PartialShape{2, Dimension::dynamic(), Dimension::dynamic()}); + bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape); + + data = make_shared( + element::f32, + PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}); + bc = make_shared(data, target_shape, "NUMPY"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_output_partial_shape(0), expected_target_shape); +} + +TYPED_TEST_P(BroadcastTests, broadcast_numpy_static_dims_incorrect) +{ + const auto target_shape = op::Constant::create(element::i64, Shape{4}, {1, 2, 3, 4}); + + auto data = + make_shared(element::f32, PartialShape{Dimension::dynamic(), 999, 3, 4}); + try + { + auto bc = make_shared(data, target_shape, "NUMPY"); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + "Input shape dimension equal 999 cannot be broadcasted (numpy mode) " + "to 2. Allowed input dimension value would be 1 or 2"); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } + + data = make_shared( + element::f32, + PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 888}); + try + { + auto bc = make_shared(data, target_shape, "NUMPY"); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + "Input shape dimension equal 888 cannot be broadcasted (numpy mode) " + "to 4. Allowed input dimension value would be 1 or 4"); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } + + data = make_shared( + element::f32, + PartialShape{5, Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}); + try + { + auto bc = make_shared(data, target_shape, "NUMPY"); + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + "Input shape dimension equal 5 cannot be broadcasted (numpy mode) to " + "1. Allowed input dimension value would be 1"); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + REGISTER_TYPED_TEST_CASE_P(BroadcastTests, broadcast_numpy, broadcast_axes_mapping, @@ -451,7 +804,23 @@ REGISTER_TYPED_TEST_CASE_P(BroadcastTests, broadcast_axes_wrong_rank, broadcast_fully_dynamic_target_shape, broadcast_broadcast_shape_et_wrong, - broadcast_axes_et_wrong); + broadcast_axes_et_wrong, + broadcast_explicit_all_inputs_dynamic, + broadcast_explicit_target_shape_static_rank, + broadcast_explicit_const_target_shape, + broadcast_explicit_input_rank_static, + broadcast_explicit_target_shape_and_input_data_rank_static, + broadcast_explicit_const_target_shape_static_rank_input, + broadcast_explicit_static_input_shape, + broadcast_explicit_static_input_shape_const_target_shape, + broadcast_explicit_static_target_shape, + broadcast_numpy_input_shape_dynamic, + broadcast_numpy_target_shape_constant, + broadcast_numpy_target_shape_dynamic, + broadcast_numpy_input_target_shape_static_rank, + broadcast_numpy_input_static_shape, + broadcast_numpy_input_partially_dynamic, + broadcast_numpy_static_dims_incorrect); typedef ::testing::Types BroadcastTypes; // the last empty argument resolves compiler warning on MAC: @@ -696,7 +1065,8 @@ TEST(type_prop, broadcast_v3_output_rank_deduced_from_arg) const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL; const auto broadcast_v3 = make_shared(arg, shape, broadcast_spec); - ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4))); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme( + PartialShape{Dimension::dynamic(), 8, 6, 4})); } TEST(type_prop, broadcast_v3_output_rank_deduced_from_new_shape_input) @@ -706,5 +1076,114 @@ TEST(type_prop, broadcast_v3_output_rank_deduced_from_new_shape_input) const auto broadcast_spec = op::BroadcastType::BIDIRECTIONAL; const auto broadcast_v3 = make_shared(arg, shape, broadcast_spec); - ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(5))); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 5); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).same_scheme( + PartialShape{8, 6, Dimension::dynamic(), 5, Dimension::dynamic()})); +} + +TEST(type_prop, broadcast_v3_bidirectional_dynamic_input) +{ + const auto arg = make_shared(element::f32, PartialShape::dynamic()); + + // dynamic target shape + auto target_shape = make_shared(element::i64, PartialShape::dynamic()); + auto broadcast_v3 = make_shared(arg, target_shape, "BIDIRECTIONAL"); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic()); + + // static rank target shape + target_shape = make_shared(element::i64, PartialShape::dynamic(1)); + broadcast_v3 = make_shared(arg, target_shape, "BIDIRECTIONAL"); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic()); + + // constant target shape + const auto target_shape_const = op::Constant::create(element::i64, {3}, {2, 4, 6}); + broadcast_v3 = make_shared(arg, target_shape_const, "BIDIRECTIONAL"); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic()); +} + +TEST(type_prop, broadcast_v3_bidirectional_static_rank_input) +{ + const auto arg = make_shared(element::f32, PartialShape::dynamic(4)); + + // dynamic target shape + auto target_shape = make_shared(element::i64, PartialShape::dynamic()); + auto broadcast_v3 = make_shared(arg, target_shape, "BIDIRECTIONAL"); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic()); + + // static rank target shape + target_shape = make_shared(element::i64, PartialShape::dynamic(1)); + broadcast_v3 = make_shared(arg, target_shape, "BIDIRECTIONAL"); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic()); + + // constant target shape + const auto target_shape_const = op::Constant::create(element::i64, {3}, {2, 4, 6}); + broadcast_v3 = make_shared(arg, target_shape_const, "BIDIRECTIONAL"); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_dynamic()); +} + +TEST(type_prop, broadcast_v3_bidirectional_static_shape_input) +{ + const auto arg = make_shared(element::f32, PartialShape{1, 2, 3, 1}); + + // dynamic target shape + auto target_shape = make_shared(element::i64, PartialShape::dynamic()); + auto broadcast_v3 = make_shared(arg, target_shape, "BIDIRECTIONAL"); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic()); + + // static rank target shape + target_shape = make_shared(element::i64, PartialShape::dynamic(1)); + broadcast_v3 = make_shared(arg, target_shape, "BIDIRECTIONAL"); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_dynamic()); + + // constant target shape + auto target_shape_const = op::Constant::create(element::i64, {4}, {2, 2, 3, 2}); + broadcast_v3 = make_shared(arg, target_shape_const, "BIDIRECTIONAL"); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_static()); + ASSERT_EQ(broadcast_v3->get_output_partial_shape(0), (PartialShape{2, 2, 3, 2})); + + target_shape_const = op::Constant::create(element::i64, {4}, {5, 2, 3, 7}); + broadcast_v3 = make_shared(arg, target_shape_const, "BIDIRECTIONAL"); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(broadcast_v3->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_TRUE(broadcast_v3->get_output_partial_shape(0).is_static()); + ASSERT_EQ(broadcast_v3->get_output_partial_shape(0), (PartialShape{5, 2, 3, 7})); +} + +TEST(type_prop, broadcast_v3_bidirectional_partially_dynamic_input) +{ + const auto target_shape = + op::Constant::create(element::i64, Shape{4}, vector{1, 1, 50, 50}); + + auto data = make_shared(element::f32, PartialShape{16, 1, Dimension::dynamic()}); + auto bc = make_shared(data, target_shape, "BIDIRECTIONAL"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, 16, 50, 50})); + + data = make_shared(element::f32, + PartialShape{Dimension::dynamic(), 1, Dimension::dynamic()}); + bc = make_shared(data, target_shape, "BIDIRECTIONAL"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 50, 50})); + + data = make_shared(element::f32, + PartialShape{16, Dimension::dynamic(), Dimension::dynamic()}); + bc = make_shared(data, target_shape, "BIDIRECTIONAL"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, 16, 50, 50})); + + data = make_shared( + element::f32, + PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}); + bc = make_shared(data, target_shape, "BIDIRECTIONAL"); + ASSERT_TRUE(bc->get_output_partial_shape(0).rank().is_static()); + ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); + ASSERT_EQ(bc->get_output_partial_shape(0), (PartialShape{1, Dimension::dynamic(), 50, 50})); }