Skip to content

Commit

Permalink
Fix broadcast with BroadcastType::NONE and unit dimensions (#5093)
Browse files Browse the repository at this point in the history
* Fix broadcast with BroadcastType::NONE and unit dimensions

* fix issues found by flake

* fix broadcast negative test

* remove xfail_issue_49913
  • Loading branch information
mateusztabaka authored Apr 12, 2021
1 parent 34385eb commit 69e1eeb
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 11 deletions.
5 changes: 3 additions & 2 deletions ngraph/core/src/op/util/broadcast_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ void op::util::BroadcastBase::validate_target_shape_none(const PartialShape& arg
if (arg_shape.rank().get_length() > 0)
{
NODE_VALIDATION_CHECK(this,
target_shape[axes_mapping_val[i]].compatible(arg_shape[i]),
target_shape[axes_mapping_val[i]].compatible(arg_shape[i]) ||
arg_shape[i].compatible(1),
"Broadcast target[axes_mapping[",
i,
"]]",
Expand Down Expand Up @@ -575,4 +576,4 @@ bool op::util::BroadcastBase::evaluate_upper(const HostTensorVector& output_valu
(get_input_size() > 2 && !input_value(2).get_tensor().has_and_set_bound()))
return false;
return default_upper_bound_evaluator(this, output_values);
}
}
4 changes: 2 additions & 2 deletions ngraph/python/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True):
"FakeQuantize_xxx has non const input on 1 port")
xfail_issue_46762 = xfail_test(reason="Incorrect result of Minimum op if uint data type is used")
xfail_issue_46765 = xfail_test(reason="select_last_index attribute is not supported by ArgMin and ArgMax")
xfail_issue_47317 = xfail_test(reason="RuntimeError: While validating ONNX node '<Node(Add): 2>': "
"Check shape_size(axes_shape) == input_rank' failed")
xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64")
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output")
Expand Down Expand Up @@ -173,3 +171,5 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True):
xfail_issue_49752 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::Pad")
xfail_issue_49753 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::StridedSlice")
xfail_issue_49754 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::TopKIE")
xfail_issue_52463 = xfail_test(reason="test_operator_add_size1_singleton_broadcast_cpu - "
"Not equal to tolerance")
9 changes: 4 additions & 5 deletions ngraph/python/tests/test_onnx/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@
xfail_issue_45344,
xfail_issue_46762,
xfail_issue_46765,
xfail_issue_47317,
xfail_issue_47323,
xfail_issue_47337,
xfail_issue_48052,
xfail_issue_49207,
xfail_issue_49750,
xfail_issue_49752,
xfail_issue_49753,
xfail_issue_49754)
xfail_issue_49754,
xfail_issue_52463)


def expect_fail(test_case_path, xfail): # type: (str) -> None
Expand Down Expand Up @@ -189,9 +189,8 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None
"OnnxBackendNodeModelTest.test_argmin_no_keepdims_random_select_last_index_cpu"),
(xfail_issue_38091,
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu"),
(xfail_issue_47317,
"OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_broadcast_cpu",
"OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_singleton_broadcast_cpu",),
(xfail_issue_52463,
"OnnxBackendPyTorchOperatorModelTest.test_operator_add_size1_singleton_broadcast_cpu"),
(xfail_issue_47323,
"OnnxBackendPyTorchOperatorModelTest.test_operator_add_broadcast_cpu",
"OnnxBackendPyTorchOperatorModelTest.test_operator_addconstant_cpu",
Expand Down
16 changes: 16 additions & 0 deletions ngraph/test/backend/broadcast.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,22 @@ NGRAPH_TEST(${BACKEND_NAME}, broadcast_algo_3d_stride_2)
broadcast_test_helper(shape_a, shape_r, axis);
}

NGRAPH_TEST(${BACKEND_NAME}, broadcast_algo_3d_diffrent_rank)
{
Shape shape_a{3, 1};
Shape shape_r{2, 3, 3};
AxisSet axis{1, 2};
broadcast_test_helper(shape_a, shape_r, axis);
}

NGRAPH_TEST(${BACKEND_NAME}, broadcast_algo_4d_same_rank)
{
Shape shape_a{2, 3, 1, 1};
Shape shape_r{2, 3, 4, 5};
AxisSet axis{0, 1, 2, 3};
broadcast_test_helper(shape_a, shape_r, axis);
}

NGRAPH_TEST(${BACKEND_NAME}, broadcast_matrix_0)
{
Shape shape_a{2, 2};
Expand Down
4 changes: 2 additions & 2 deletions ngraph/test/type_prop/broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ TYPED_TEST_P(BroadcastTests, broadcast_fail_axes_map)

TYPED_TEST_P(BroadcastTests, broadcast_fail_axes_map_shape)
{
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 1});
auto param = make_shared<op::Parameter>(element::f32, Shape{3, 2});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{3}, {2, 3, 3});
auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 2});

Expand All @@ -162,7 +162,7 @@ TYPED_TEST_P(BroadcastTests, broadcast_fail_axes_map_shape)
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast target[axes_mapping[1]] Expected 1. Got 3");
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast target[axes_mapping[1]] Expected 2. Got 3");
}
catch (...)
{
Expand Down

0 comments on commit 69e1eeb

Please sign in to comment.