From 1ca9833db2289923c4a557385be05307afb2e9ca Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 19 Aug 2024 08:33:54 -0500 Subject: [PATCH] [IR] Handle NaN in StructuralEqual and StructuralHash (#17249) * [IR] Handle NaN in StructuralEqual and StructuralHash Prior to this commit, `NaN` values did not have any special handling in either `StructuralEqual` or `StructuralHash`. `StructuralEqual` checked whether the LHS and RHS were within some tolerance of each other. If the LHS and RHS are both `NaN`, this would evaluate to false. The updated `StructuralEqual` now checks for this case, and returns true if both sides are `NaN`. `StructuralHash` used the bit-pattern of a floating-point number to compute the hash. A `NaN` value may have any non-zero value in its mantissa, and so this could produce distinct hashes for ASTs that differ only by the choice of non-zero value. The updated `StructuralHash` uses the same `std::numeric_limits #include +#include #include namespace tvm { @@ -38,11 +39,21 @@ namespace tvm { class BaseValueEqual { public: bool operator()(const double& lhs, const double& rhs) const { - // fuzzy float pt comparison - constexpr double atol = 1e-9; - if (lhs == rhs) return true; - double diff = lhs - rhs; - return diff > -atol && diff < atol; + if (std::isnan(lhs) && std::isnan(rhs)) { + // IEEE floats do not compare as equivalent to each other. + // However, for the purpose of comparing IR representation, two + // NaN values are equivalent. + return true; + } else if (std::isnan(lhs) || std::isnan(rhs)) { + return false; + } else if (lhs == rhs) { + return true; + } else { + // fuzzy float pt comparison + constexpr double atol = 1e-9; + double diff = lhs - rhs; + return diff > -atol && diff < atol; + } } bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index 774021ad1564..553f284b8c5a 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -27,7 +27,9 @@ #include #include +#include #include +#include #include namespace tvm { @@ -52,7 +54,16 @@ class BaseValueHash { public: uint64_t operator()(const float& key) const { return Reinterpret(key); } - uint64_t operator()(const double& key) const { return Reinterpret(key); } + uint64_t operator()(const double& key) const { + if (std::isnan(key)) { + // The IEEE format defines more than one bit-pattern that + // represents NaN. For the purpose of comparing IR + // representations, all NaN values are considered equivalent. + return Reinterpret(std::numeric_limits::quiet_NaN()); + } else { + return Reinterpret(key); + } + } uint64_t operator()(const int64_t& key) const { return Reinterpret(key); } uint64_t operator()(const uint64_t& key) const { return key; } uint64_t operator()(const int& key) const { return Reinterpret(key); } diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py b/tests/python/tir-base/test_tir_structural_equal_hash.py index eca78d649b85..32099cecf4b2 100644 --- a/tests/python/tir-base/test_tir_structural_equal_hash.py +++ b/tests/python/tir-base/test_tir_structural_equal_hash.py @@ -419,5 +419,48 @@ def func(A: T.Buffer(1, "int32")): assert '.functions[I.GlobalVar("func")].body.extent.value' in err.value.args[0] +def test_nan_values_are_equivalent(): + """Structural equality treats two NaN values as equivalent. + + By IEEE, a check of `NaN == NaN` returns false, as does + `abs(NaN - NaN) < tolerance`. However, for the purpose of + comparing IR representations, both NaN values are equivalent. + + """ + + @T.prim_func(private=True) + def func_1(): + return T.float32("nan") + + @T.prim_func(private=True) + def func_2(): + return T.float32("nan") + + tvm.ir.assert_structural_equal(func_1, func_2) + assert tvm.ir.structural_hash(func_1) == tvm.ir.structural_hash(func_2) + + +def test_all_nan_values_are_equivalent(): + """Structural equality treats two NaN values as equivalent. + + IEEE defines NaN as any value that has all exponent bits set, + and has a non-zero mantissa. For the purposes of comparing IR + representations, all NaN values are considered equivalent. + + """ + + # A NaN with the first payload bit set. + nan_all_zeros = np.int32(0x7FC00000).view("float32") + + # A NaN with the last payload bit set. + nan_with_payload = np.int32(0x7F800001).view("float32") + + float_1 = T.float32(nan_all_zeros) + float_2 = T.float32(nan_with_payload) + + tvm.ir.assert_structural_equal(float_1, float_2) + assert tvm.ir.structural_hash(float_1) == tvm.ir.structural_hash(float_2) + + if __name__ == "__main__": tvm.testing.main()