Skip to content

Commit

Permalink
[BUGFIX][IR] Fix String SEqual (apache#5275)
Browse files Browse the repository at this point in the history
* fix String SEqual

* retrigger ci
  • Loading branch information
zhiics authored and Trevor Morris committed Apr 16, 2020
1 parent 719a0c9 commit 36df966
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/node/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ struct StringObjTrait {
SEqualReducer equal) {
if (lhs == rhs) return true;
if (lhs->size != rhs->size) return false;
if (lhs->data != rhs->data) return true;
return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
if (lhs->data == rhs->data) return true;
return std::memcmp(lhs->data, rhs->data, lhs->size) == 0;
}
};

Expand Down
6 changes: 3 additions & 3 deletions tests/python/relay/test_ir_structural_equal_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def test_function_attr():
p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00)
func0 = func0.with_attr("FuncName", tvm.tir.StringImm("a"))
func0 = func0.with_attr("FuncName", tvm.runtime.container.String("a"))

x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10))
Expand All @@ -366,7 +366,7 @@ def test_function_attr():
p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.with_attr("FuncName", tvm.tir.StringImm("b"))
func1 = func1.with_attr("FuncName", tvm.runtime.container.String("b"))
assert not consistent_equal(func0, func1)


Expand Down Expand Up @@ -698,7 +698,7 @@ def test_fn_attribute():
d = relay.var('d', shape=(10, 10))
add_1 = relay.add(c, d)
add_1_fn = relay.Function([c, d], add_1)
add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.tir.StringImm("test"))
add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.runtime.container.String("test"))
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())

assert not consistent_equal(add_1_fn, add_fn)
Expand Down

0 comments on commit 36df966

Please sign in to comment.