From 7c8b8e5ddaf476ba83192ca8a3918969506ddfb6 Mon Sep 17 00:00:00 2001 From: Bartosz Lesniewski Date: Mon, 15 Feb 2021 05:22:19 +0100 Subject: [PATCH] Fix comparison of constant with short float NAN values (#4299) * fix comparison of constant with short float NAN values * adjust precision, remove elvises * more templates * add ir serialization test with float16 const * remove unused prototxt --- .../models/add_abc_initializers_nan_const.bin | Bin 0 -> 17 bytes .../models/add_abc_initializers_nan_const.xml | 93 +++++++++++++++++ .../ir_serialization/serialize.cpp | 1 + .../common_test_utils/ngraph_test_utils.cpp | 96 +++++++++++++----- 4 files changed, 163 insertions(+), 27 deletions(-) create mode 100644 inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.bin create mode 100644 inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.xml diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.bin b/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.bin new file mode 100644 index 0000000000000000000000000000000000000000..16a41a06ca8aa97ca1b80cada20efb2d82daf818 GIT binary patch literal 17 XcmZQzU~o9#z~K0P|6T_mW^e=mG6V+M literal 0 HcmV?d00001 diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.xml b/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.xml new file mode 100644 index 00000000000000..c0ecd4025c550c --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.xml @@ -0,0 +1,93 @@ + + + + + + + + 2 + 2 + + + + + + + + 2 + 2 + + + + + + + 2 + 2 + + + 2 + 2 + + + + + 2 + 2 + + + + + + + 2 + 2 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp b/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp index 067ed94bc7dc06..254622157ad7f0 100644 --- a/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp +++ b/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp @@ -64,6 +64,7 @@ INSTANTIATE_TEST_CASE_P(IRSerialization, SerializationTest, std::make_tuple("split_equal_parts_2d.xml", "split_equal_parts_2d.bin"), std::make_tuple("addmul_abc.xml", "addmul_abc.bin"), std::make_tuple("add_abc_initializers.xml", "add_abc_initializers.bin"), + std::make_tuple("add_abc_initializers_nan_const.xml", "add_abc_initializers_nan_const.bin"), std::make_tuple("experimental_detectron_roi_feature_extractor.xml", ""), std::make_tuple("experimental_detectron_roi_feature_extractor_opset6.xml", ""), std::make_tuple("experimental_detectron_detection_output.xml", ""), diff --git a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp index 4045c8fd00e0af..81a207d9c6bc62 100644 --- a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp +++ b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp @@ -325,32 +325,44 @@ struct Equal { }; template <> -struct Equal { - static bool equal_value(float lhs, float rhs) { - return std::abs(lhs - rhs) < 1e-5; +struct Equal { + static bool equal_value(ngraph::bfloat16 lhs, ngraph::bfloat16 rhs) { + if (lhs.to_bits() == rhs.to_bits()) { + return true; } + return std::abs(lhs - rhs) < 1e-3; + } }; template <> -struct Equal { - static bool equal_value(double lhs, double rhs) { - return std::abs(lhs - rhs) < 1e-5; +struct Equal { + static bool equal_value(ngraph::float16 lhs, ngraph::float16 rhs) { + if (lhs.to_bits() == rhs.to_bits()) { + return true; } + return std::abs(lhs - rhs) < 1e-3; + } }; template <> -struct Equal> { - static bool equal_value(const std::vector& lhs, const std::vector& rhs) { - return lhs.size() == rhs.size() && - std::equal(begin(lhs), end(lhs), begin(rhs), Equal::equal_value); +struct Equal { + static bool equal_value(float lhs, float rhs) { + return std::abs(lhs - rhs) < 1e-4; } }; template <> -struct Equal> { - static bool equal_value(const std::vector& lhs, const std::vector& rhs) { +struct Equal { + static bool equal_value(double lhs, double rhs) { + return std::abs(lhs - rhs) < 1e-5; + } +}; + +template +struct Equal> { + static bool equal_value(const std::vector& lhs, const std::vector& rhs) { return lhs.size() == rhs.size() && - std::equal(begin(lhs), end(lhs), begin(rhs), Equal::equal_value); + std::equal(begin(lhs), end(lhs), begin(rhs), Equal::equal_value); } }; @@ -439,6 +451,45 @@ struct Equal { } }; +using Constant = ngraph::opset1::Constant; +template <> struct Equal> { + static bool equal_value(const std::shared_ptr& lhs, + const std::shared_ptr& rhs) { + const auto lhs_t = lhs->get_element_type(); + const auto rhs_t = rhs->get_element_type(); + if (lhs_t != rhs_t) { + return false; + } + + switch (lhs_t) { + case ngraph::element::Type_t::bf16: { + auto lhs_v = lhs->cast_vector(); + auto rhs_v = rhs->cast_vector(); + return Equal>::equal_value(lhs_v, rhs_v); + break; + } + case ngraph::element::Type_t::f16: { + const auto &lhs_v = lhs->cast_vector(); + const auto &rhs_v = rhs->cast_vector(); + return Equal>::equal_value(lhs_v, rhs_v); + break; + } + case ngraph::element::Type_t::f32: { + const auto &lhs_v = lhs->cast_vector(); + const auto &rhs_v = rhs->cast_vector(); + return Equal>::equal_value(lhs_v, rhs_v); + break; + } + default: { + const auto &lhs_v = lhs->cast_vector(); + const auto &rhs_v = rhs->cast_vector(); + return Equal>::equal_value(lhs_v, rhs_v); + break; + } + } + return false; + } +}; } // namespace equal namespace str { @@ -741,22 +792,13 @@ FunctionsComparator::Result FunctionsComparator::compare( using Constant = ngraph::opset1::Constant; auto const1 = ngraph::as_type_ptr(node1->get_input_node_shared_ptr(i)); auto const2 = ngraph::as_type_ptr(node2->get_input_node_shared_ptr(i)); - - const auto equal = [](std::shared_ptr c1, std::shared_ptr c2) { - const auto& c1v = c1->cast_vector(); - const auto& c2v = c2->cast_vector(); - - return c1v.size() == c2v.size() && std::equal( - begin(c1v), end(c1v), begin(c2v), - [](const double& s1, const double& s2) { - return std::abs(s1 - s2) < 0.001; - }); - }; - - if (const1 && const2 && !equal(const1, const2)) { + using namespace ::attr_comparison::equal; + if (const1 && const2 && + !Equal>::equal_value(const1, const2)) { err_log << "Different Constant values detected\n" << node1->description() << " Input(" << i << ") and " - << node2->description() << " Input(" << i << ")" << std::endl; + << node2->description() << " Input(" << i << ")" + << std::endl; } }