Skip to content

Commit

Permalink
fix: Considering rtol and atol in threshold comparison for floating p…
Browse files Browse the repository at this point in the history
…oint numbers

Signed-off-by: Anurag Dixit <[email protected]>
  • Loading branch information
andi4191 committed Mar 1, 2022
1 parent ef62f6b commit 0b0ba8d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
20 changes: 9 additions & 11 deletions tests/util/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,19 @@ namespace torch_tensorrt {
namespace tests {
namespace util {

bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold) {
double maxValue = 0.0;
for (auto& tensor : inputs) {
maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
}
std::cout << "Max Difference: " << diff.abs().max().item<float>() << std::endl;
std::cout << "Acceptable Threshold: " << threshold << std::endl;
return diff.abs().max().item<float>() <= threshold * maxValue;
}

bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold) {
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol=1e-8, float rtol=1e-5) {
LOG_GRAPH(a << std::endl << b << std::endl);
auto a_float = a.toType(at::kFloat);
auto b_float = b.toType(at::kFloat);
return checkRtol(a_float - b_float, {a_float, b_float}, threshold);

auto diff = a_float - b_float;
auto result = diff.abs().max().item<float>() - (atol + rtol * b.abs().max().item<float>());

std::cout << "Max Difference: " << result << std::endl;
std::cout << "Acceptable Threshold: " << threshold << std::endl;

return result <= threshold;
}

bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {
Expand Down
2 changes: 1 addition & 1 deletion tests/util/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace torch_tensorrt {
namespace tests {
namespace util {

bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold);
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol=1e-8, float rtol=1e-5);

bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);

Expand Down

0 comments on commit 0b0ba8d

Please sign in to comment.