From 4533a7117187959cfaade23cabba6eff1d4bf491 Mon Sep 17 00:00:00 2001 From: sbalandi Date: Thu, 18 Jan 2024 17:35:28 +0100 Subject: [PATCH] [Op Conformance] Update compare accuracy function --- .../common_test_utils/src/ov_tensor_utils.cpp | 54 +++++++++++++++++-- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/src/tests/test_utils/common_test_utils/src/ov_tensor_utils.cpp b/src/tests/test_utils/common_test_utils/src/ov_tensor_utils.cpp index 7743d4b9098455..281310fb4e83c7 100644 --- a/src/tests/test_utils/common_test_utils/src/ov_tensor_utils.cpp +++ b/src/tests/test_utils/common_test_utils/src/ov_tensor_utils.cpp @@ -305,11 +305,24 @@ ov::runtime::Tensor create_and_fill_tensor_consistently(const ov::element::Type constexpr double eps = std::numeric_limits::epsilon(); inline double less(double a, double b) { - return (b - a) > (std::fmax(std::fabs(a), std::fabs(b)) * eps); + return std::fabs(a - b) > eps && a < b; } inline double less_or_equal(double a, double b) { - return (b - a) >= (std::fmax(std::fabs(a), std::fabs(b)) * eps); + bool res = true; + if (std::isnan(a) || std::isnan(b)) { + res = false; + } else if (std::isinf(b) && b > 0) { + // b is grater than any number or eq the +Inf + res = true; + } else if (std::isinf(a) && a > 0) { + res = false; + } else { + res = (std::fabs(b - a) <= (std::fmax(std::fabs(a), std::fabs(b)) * eps) || a < b); + } + double eq_midle_res = std::fabs(b - a); + bool eq_res = (std::fabs(b - a) <= (std::fmax(std::fabs(a), std::fabs(b)) * eps)); + return res; } struct Error { @@ -369,13 +382,33 @@ void compare(const ov::Tensor& expected, if (abs_threshold == std::numeric_limits::max() && rel_threshold == std::numeric_limits::max()) { if (sizeof(ExpectedT) == 1 || sizeof(ActualT) == 1) { abs_threshold = 1.; + rel_threshold = 1.; + if (expected.get_element_type() == ov::element::Type_t::boolean) { + abs_threshold = 0.; + rel_threshold = 0.; + } } else { std::vector abs_values(shape_size_cnt); for (size_t i = 0; i < shape_size_cnt; i++) { abs_values[i] = std::fabs(static_cast(expected_data[i])); } auto abs_median = calculate_median(abs_values); + auto elem_type = expected.get_element_type(); + abs_threshold = abs_median * 0.05 < 1e-5 ? 1e-5 : 0.05 * abs_median; + + if (elem_type == ov::element::Type_t::boolean) { + abs_threshold = 0.; + } else if (elem_type.is_integral_number()) { + abs_threshold = 1.0; + } else if (elem_type == ov::element::Type_t::f32 || elem_type == ov::element::Type_t::f64) { + abs_threshold = abs_median * 0.05 < 1e-5 ? 1e-5 : 0.05 * abs_median; + } else if (elem_type == ov::element::Type_t::bf16 || elem_type == ov::element::Type_t::f16) { + abs_threshold = abs_median * 0.05 < 1e-3 ? 1e-3 : 0.05 * abs_median; + } + + rel_threshold = abs_threshold; + if (std::is_integral::value) { abs_threshold = std::ceil(abs_threshold); } @@ -388,10 +421,21 @@ void compare(const ov::Tensor& expected, std::cout << "[ COMPARATION ] abs_threshold: " << abs_threshold << std::endl; } + auto max_type_expected = std::numeric_limits::max(); + auto max_type_actual = std::numeric_limits::max(); + auto min_type_expected = std::numeric_limits::min(); + auto min_type_actual = std::numeric_limits::min(); Error abs_error(abs_threshold), rel_error(rel_threshold); for (size_t i = 0; i < shape_size_cnt; ++i) { double expected_value = expected_data[i]; double actual_value = actual_data[i]; + if ((std::isinf(expected_value) || expected_value >= max_type_expected) && + (std::isinf(actual_value) || actual_value >= max_type_actual)) { + continue; + } else if ((std::isinf(expected_value) || expected_value <= min_type_expected) && + (std::isinf(actual_value) || actual_value <= min_type_actual)) { + continue; + } if (std::isnan(expected_value) && std::isnan(actual_value)) continue; if (std::isnan(expected_value)) { @@ -406,15 +450,15 @@ void compare(const ov::Tensor& expected, } double abs = std::fabs(expected_value - actual_value); - double rel = expected_value ? (abs / std::fabs(expected_value)) : abs; - + double rel = + expected_value && actual_value && !std::isinf(expected_value) ? (abs / std::fabs(expected_value)) : 0; abs_error.update(abs, i); rel_error.update(rel, i); } abs_error.mean /= shape_size_cnt; rel_error.mean /= shape_size_cnt; - if (!(less_or_equal(abs_error.max, abs_threshold) && less_or_equal(rel_error.max, rel_threshold))) { + if (!(less_or_equal(abs_error.max, abs_threshold) || less_or_equal(rel_error.mean, rel_threshold))) { std::ostringstream out_stream; out_stream << "abs_max < abs_threshold && rel_max < rel_threshold" << "\n\t abs_max: " << abs_error.max << "\n\t\t coordinate " << abs_error.max_coordinate