From 7f2183727bd77f7dcb520d9b9424997aa94f7769 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 31 May 2019 01:29:54 -0700 Subject: [PATCH] [Relay][Hashing] Structural hash - incorporate the var type into its hash (#3267) Currently, the BindVar function does not take Var type into account. This causes two same graph structures with different var shapes to have same hash. Structural hash is used for keeping track of which operators we have already compiled. Because of this, two operators with different shapes end up pointing to same compiled code. The failure is encountered at runtime, where the expected input shape asserts are not met. --- src/relay/ir/hash.cc | 3 +++ tests/python/relay/test_pass_alpha_equal.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index c56c4ce17067..c57475476e58 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -219,6 +219,9 @@ class RelayHashHandler: size_t BindVar(const NodeRef& var) { size_t hash = std::hash()(var_counter++); CHECK_EQ(hash_map_.count(var), 0); + if (auto var_node = var.as()) { + hash = Combine(hash, TypeHash(var_node->type_annotation)); + } hash_map_[var] = hash; const auto* ty_param = var.as(); diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 478b433180b9..0e0036565363 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -594,7 +594,24 @@ def test_graph_equal(): # Check the difference in the text format. assert not alpha_equal(z0, z3) +def test_hash_unequal(): + x1 = relay.var("x1", shape=(10, 10), dtype="float32") + y1 = relay.var("y1", shape=(10, 10), dtype="float32") + func1 = relay.Function([x1, y1], relay.add(x1, y1)) + # func2 is exactly same structure with same variables shapes and dtypes + x2 = relay.var("x2", shape=(10, 10), dtype="float32") + y2 = relay.var("y2", shape=(10, 10), dtype="float32") + func2 = relay.Function([x2, y2], relay.add(x2, y2)) + + assert ir_pass.structural_hash(func1) == ir_pass.structural_hash(func2) + + # func3 is same as func1 but with different var shapes + x3 = relay.var("x3", shape=(20, 10), dtype="float32") + y3 = relay.var("y3", shape=(20, 10), dtype="float32") + func3 = relay.Function([x3, y3], relay.add(x3, y3)) + + assert not ir_pass.structural_hash(func1) == ir_pass.structural_hash(func3) if __name__ == "__main__": test_tensor_type_alpha_equal() @@ -617,3 +634,4 @@ def test_graph_equal(): test_op_alpha_equal() test_var_alpha_equal() test_graph_equal() + test_hash_unequal()