Skip to content

Commit

Permalink
Updated to maintain separate int/float simplification paths
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed May 24, 2022
1 parent 6f3c757 commit 23bdc17
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {

if (IsIndexType(op->dtype)) {
// Index rules
// cancelation rules
TVM_TRY_REWRITE((x + y) - y, x);
TVM_TRY_REWRITE((x + y) - x, y);
TVM_TRY_REWRITE(x - (y + x), 0 - y);
TVM_TRY_REWRITE(x - (x + y), 0 - y);

TVM_TRY_REWRITE(min(x, y) - x, min(0, y - x));
TVM_TRY_REWRITE(min(x, y) - y, min(x - y, 0));
Expand All @@ -278,6 +283,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_REWRITE(y - min(x, y), max(y - x, 0));

// mul co-efficient folding
TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x));
TVM_TRY_REWRITE(x * y - x, x * (y - 1));
TVM_TRY_REWRITE(y * x - x, x * (y - 1));
TVM_TRY_REWRITE(x - y * x, x * (1 - y));
Expand Down Expand Up @@ -413,6 +419,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
} else if (op->dtype.is_float()) {
// Cancellation rules. Deliberately off of the integer path, to
// avoid introducing checks on the side effects for the fast path.
TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - x, y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF(x - (y + x), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF(x - (x + y), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
}

// condition rules.
Expand Down

0 comments on commit 23bdc17

Please sign in to comment.