diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 5ba602bc3bc3..94027c36b004 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -363,7 +363,7 @@ class Canonical::Internal : public IRMutator { return Binary_(op, e, a.value, b.value); } if (is_const(a.value) && is_const(b.value)) { - return ComputeExpr(a.value, b.value); + return ComputeExpr(a.value, b.value); } else if (is_const(b.value)) { return SumModConst(a.AsSum(), b.value); } else { diff --git a/tests/cpp/ir_simplify_test.cc b/tests/cpp/ir_simplify_test.cc index 8114bb51b771..3d762c67e7fd 100644 --- a/tests/cpp/ir_simplify_test.cc +++ b/tests/cpp/ir_simplify_test.cc @@ -27,6 +27,16 @@ TEST(IRSIMPLIFY, Mul) { CHECK(is_zero(es)); } +TEST(IRSIMPLIFY, Mod) { + auto x = tvm::Integer(10); + auto y = tvm::Integer(12); + // Mod::make is used instead of % to avoid constant folding during + // calling operator%(x,y). Mod::make doesn't try constant folding, + // and therefore, the constant folding will be attempted in CanonicalSimplify + auto mod = tvm::ir::CanonicalSimplify(tvm::ir::Mod::make(x, y)); + auto es = tvm::ir::CanonicalSimplify(mod - x); + CHECK(is_zero(es)); +} int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";