Skip to content

Commit

Permalink
Add special case for canonical simplify and fix test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Mar 29, 2019
1 parent 2231fb8 commit 6527adc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/arithmetic/canonical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,9 @@ class Canonical::Internal : public IRMutator {
if (pair.size() == 0) {
int64_t value = GetConstIntValue(v);
auto n = make_node<ComExprNode>();
if (value == 1 || value == -1) {
return make_zero(v.type());
}
// Because TVM mod is non-Euclidean mod, we apply the following simplication
// only when both a and v are positive.
bool can_simplify = true;
Expand Down
7 changes: 6 additions & 1 deletion tests/python/unittest/test_arith_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ def test_simplify_mod():
with ib.for_range(0, 16, name="i") as i:
A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16]
body = ib.get()
stmt = tvm.ir_pass.CanonicalSimplify(body)
stmt = tvm.ir_pass.CanonicalSimplify(body, {j: tvm.Range(0, 6), n: tvm.Range(0, 10)})
diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16)
assert diff.value == 0
# if we can't prove that j is non-negative, we can't prove that (j+16) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16)
assert index != j
index = tvm.ir_pass.CanonicalSimplify((j + 16) % 16, {j: tvm.Range(0, 6)})
assert index == j
# if we can't prove that j+n*32 is non-negative, we can't prove that (j+n*32) % 16 is j%16
index = tvm.ir_pass.CanonicalSimplify(
(j + n * 32) % 16, {j: tvm.Range(0, 6)})
Expand Down

0 comments on commit 6527adc

Please sign in to comment.