Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Additional Stmt/Expr simplication rules #11373

Merged
merged 5 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes));
}

// cancelation rules
TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - x, y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF(x - (y + x), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF(x - (x + y), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be removed since you've added the else branch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the catch, and they are now removed.

if (IsIndexType(op->dtype)) {
// Index rules
// cancelation rules
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integer simplifcation should be fast path (and clean, hopefully without Nans), it would be better to put new rules in else branch even if it means some duplications, since (recursively) checking side effects is expensive.
Doing floating point simplification is probably fine, but given the possibility of introducing additional types, it might be safer to say something like

if (IsIndexType(op->dtype)) {
  old rules
} else if (op->dtype.is_float())
  new rules
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable, and updated as requested.

Expand Down Expand Up @@ -411,6 +419,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
} else if (op->dtype.is_float()) {
// Cancellation rules. Deliberately off of the integer path, to
// avoid introducing checks on the side effects for the fast path.
TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - x, y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF(x - (y + x), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF(x - (x + y), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
}

// condition rules.
Expand Down
12 changes: 6 additions & 6 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
// eliminate useless stores
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
if (const BufferLoadNode* load = op->value.as<BufferLoadNode>()) {
if (load->buffer->data.same_as(op->buffer->data) &&
ArrayDeepEqual(load->indices, op->indices) &&
tir::ExprDeepEqual()(load->buffer->elem_offset, op->buffer->elem_offset) &&
ArrayDeepEqual(load->buffer->shape, op->buffer->shape) &&
ArrayDeepEqual(load->buffer->strides, op->buffer->strides)) {
if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
if (load->buffer->data.same_as(store->buffer->data) &&
ArrayDeepEqual(load->indices, store->indices) &&
tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) &&
ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
return Evaluate(0);
}
}
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,5 +972,13 @@ def test_div_zero_simplify():
assert "division by zero" in str(cm.execption)


def test_sub_bufferload():
ck = RewriteChecker()
buf = tvm.tir.decl_buffer([1], dtype="float32")
load = tvm.tir.BufferLoad(buf, [0])
expr = load - load
ck.verify(expr, 0.0)


if __name__ == "__main__":
pytest.main([__file__])
45 changes: 40 additions & 5 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing

from tvm import te
from tvm.script import tir as T


def test_stmt_simplify():
Expand Down Expand Up @@ -133,9 +136,41 @@ def sls(n, d):
assert "if" not in str(stmt)


def test_load_store_noop():
"""Store of a value that was just read from the same location is a no-op."""

@T.prim_func
def before(A: T.Buffer[(1,), "float32"]):
A[0] = A[0]

@T.prim_func
def expected(A: T.Buffer[(1,), "float32"]):
T.evaluate(0)

after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"]
tvm.ir.assert_structural_equal(after, expected)


def test_load_store_noop_after_simplify():
"""As test_load_store_noop, but requiring simplification to identify.

Previously, a bug caused the self-assignment of a buffer to
checked based on the pre-simplification assignment, not the
post-simplification. This test is to identify any similar
regression.
"""

@T.prim_func
def before(A: T.Buffer[(1,), "float32"]):
A[0] = A[0] + (5.0 - 5.0)

@T.prim_func
def expected(A: T.Buffer[(1,), "float32"]):
T.evaluate(0)

after = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(before))["main"]
tvm.ir.assert_structural_equal(after, expected)


if __name__ == "__main__":
test_stmt_simplify()
test_thread_extent_simplify()
test_if_likely()
test_basic_likely_elimination()
test_complex_likely_elimination()
tvm.testing.main()