From 0b0ba8d4336d508605861fb1b3eec58b6c16e37d Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Mon, 28 Feb 2022 18:24:20 -0800 Subject: [PATCH] fix: Considering rtol and atol in threshold comparison for floating point numbers Signed-off-by: Anurag Dixit --- tests/util/util.cpp | 20 +++++++++----------- tests/util/util.h | 2 +- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/util/util.cpp b/tests/util/util.cpp index 96bff9023d..d91e84ac0a 100644 --- a/tests/util/util.cpp +++ b/tests/util/util.cpp @@ -5,21 +5,19 @@ namespace torch_tensorrt { namespace tests { namespace util { -bool checkRtol(const at::Tensor& diff, const std::vector inputs, float threshold) { - double maxValue = 0.0; - for (auto& tensor : inputs) { - maxValue = fmax(tensor.abs().max().item(), maxValue); - } - std::cout << "Max Difference: " << diff.abs().max().item() << std::endl; - std::cout << "Acceptable Threshold: " << threshold << std::endl; - return diff.abs().max().item() <= threshold * maxValue; -} -bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) { +bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol=1e-8, float rtol=1e-5) { LOG_GRAPH(a << std::endl << b << std::endl); auto a_float = a.toType(at::kFloat); auto b_float = b.toType(at::kFloat); - return checkRtol(a_float - b_float, {a_float, b_float}, threshold); + + auto diff = a_float - b_float; + auto result = diff.abs().max().item() - (atol + rtol * b.abs().max().item()); + + std::cout << "Max Difference: " << result << std::endl; + std::cout << "Acceptable Threshold: " << threshold << std::endl; + + return result <= threshold; } bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) { diff --git a/tests/util/util.h b/tests/util/util.h index fdfc8884bf..c609cb9395 100644 --- a/tests/util/util.h +++ b/tests/util/util.h @@ -11,7 +11,7 @@ namespace torch_tensorrt { namespace tests { namespace util { -bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold); +bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol=1e-8, float rtol=1e-5); bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);