Skip to content

Commit

Permalink
[IR] Handle NaN in StructuralEqual and StructuralHash (#17249)
Browse files Browse the repository at this point in the history
* [IR] Handle NaN in StructuralEqual and StructuralHash

Prior to this commit, `NaN` values did not have any special handling
in either `StructuralEqual` or `StructuralHash`.

`StructuralEqual` checked whether the LHS and RHS were within some
tolerance of each other.  If the LHS and RHS are both `NaN`, this
would evaluate to false.  The updated `StructuralEqual` now checks for
this case, and returns true if both sides are `NaN`.

`StructuralHash` used the bit-pattern of a floating-point number to
compute the hash.  A `NaN` value may have any non-zero value in its
mantissa, and so this could produce distinct hashes for ASTs that
differ only by the choice of non-zero value.  The updated
`StructuralHash` uses the same
`std::numeric_limits<double::quiet_NaN()` value for all `NaN` values.

With these changes, `StructuralEqual` and `StructuralHash` can now
compare two IR functions, even if they contain `NaN`.

Closes #17247

* lint fix
  • Loading branch information
Lunderberg authored Aug 19, 2024
1 parent 6f4ac23 commit 1ca9833
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 6 deletions.
21 changes: 16 additions & 5 deletions include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/data_type.h>

#include <cmath>
#include <string>

namespace tvm {
Expand All @@ -38,11 +39,21 @@ namespace tvm {
class BaseValueEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
if (lhs == rhs) return true;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
if (std::isnan(lhs) && std::isnan(rhs)) {
// IEEE floats do not compare as equivalent to each other.
// However, for the purpose of comparing IR representation, two
// NaN values are equivalent.
return true;
} else if (std::isnan(lhs) || std::isnan(rhs)) {
return false;
} else if (lhs == rhs) {
return true;
} else {
// fuzzy float pt comparison
constexpr double atol = 1e-9;
double diff = lhs - rhs;
return diff > -atol && diff < atol;
}
}

bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; }
Expand Down
13 changes: 12 additions & 1 deletion include/tvm/node/structural_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/ndarray.h>

#include <cmath>
#include <functional>
#include <limits>
#include <string>

namespace tvm {
Expand All @@ -52,7 +54,16 @@ class BaseValueHash {

public:
uint64_t operator()(const float& key) const { return Reinterpret<float, uint32_t>(key); }
uint64_t operator()(const double& key) const { return Reinterpret<double, uint64_t>(key); }
uint64_t operator()(const double& key) const {
if (std::isnan(key)) {
// The IEEE format defines more than one bit-pattern that
// represents NaN. For the purpose of comparing IR
// representations, all NaN values are considered equivalent.
return Reinterpret<double, uint64_t>(std::numeric_limits<double>::quiet_NaN());
} else {
return Reinterpret<double, uint64_t>(key);
}
}
uint64_t operator()(const int64_t& key) const { return Reinterpret<int64_t, uint64_t>(key); }
uint64_t operator()(const uint64_t& key) const { return key; }
uint64_t operator()(const int& key) const { return Reinterpret<int, uint32_t>(key); }
Expand Down
43 changes: 43 additions & 0 deletions tests/python/tir-base/test_tir_structural_equal_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,5 +419,48 @@ def func(A: T.Buffer(1, "int32")):
assert '<root>.functions[I.GlobalVar("func")].body.extent.value' in err.value.args[0]


def test_nan_values_are_equivalent():
"""Structural equality treats two NaN values as equivalent.
By IEEE, a check of `NaN == NaN` returns false, as does
`abs(NaN - NaN) < tolerance`. However, for the purpose of
comparing IR representations, both NaN values are equivalent.
"""

@T.prim_func(private=True)
def func_1():
return T.float32("nan")

@T.prim_func(private=True)
def func_2():
return T.float32("nan")

tvm.ir.assert_structural_equal(func_1, func_2)
assert tvm.ir.structural_hash(func_1) == tvm.ir.structural_hash(func_2)


def test_all_nan_values_are_equivalent():
"""Structural equality treats two NaN values as equivalent.
IEEE defines NaN as any value that has all exponent bits set,
and has a non-zero mantissa. For the purposes of comparing IR
representations, all NaN values are considered equivalent.
"""

# A NaN with the first payload bit set.
nan_all_zeros = np.int32(0x7FC00000).view("float32")

# A NaN with the last payload bit set.
nan_with_payload = np.int32(0x7F800001).view("float32")

float_1 = T.float32(nan_all_zeros)
float_2 = T.float32(nan_with_payload)

tvm.ir.assert_structural_equal(float_1, float_2)
assert tvm.ir.structural_hash(float_1) == tvm.ir.structural_hash(float_2)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 1ca9833

Please sign in to comment.