diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 6e7ed418b17a..cbc7a51d5b25 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -64,11 +65,12 @@ struct ExprDeepEqual { TVM_DLL Array UndefinedVars(const Stmt& stmt, const Array& defs); /*! - * \brief Whether the expression have side effect. + * \brief Analyze the side effect * \param expr The expression to be checked. - * \return whether expression have side effect + * + * \return CallEffectKind, can be kPure, kReadState or kUpdateState */ -TVM_DLL bool HasSideEffect(const PrimExpr& expr); +TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr); /*! * \brief Whether e expression used any var in variable set.. diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 6139f73e5107..726289cebf09 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1018,8 +1018,9 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) // components which have side effects should also be preserved for (size_t i = 0; i < used.size(); ++i) { - if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) || - HasSideEffect(op->combiner->result[i])) { + if (SideEffect(op->source[i]) > CallEffectKind::kReadState || + SideEffect(op->combiner->identity_element[i]) > CallEffectKind::kReadState || + SideEffect(op->combiner->result[i]) > CallEffectKind::kReadState) { mark_used(i); } } diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 2a026611cbad..8fb69b31857a 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -37,7 +37,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) { Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); - if (!tir::HasSideEffect(value)) { + if (SideEffect(value) <= CallEffectKind::kPure) { analyzer_->Bind(op->var, value); } // We keep the let-binding here @@ -154,7 +154,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); - if (!tir::HasSideEffect(value)) { + if (SideEffect(value) <= CallEffectKind::kPure) { analyzer_->Bind(op->var, value); } // We keep the let-binding here diff --git a/src/te/schedule/operation_inline.cc b/src/te/schedule/operation_inline.cc index fd613f47107a..aab30ede5dbf 100644 --- a/src/te/schedule/operation_inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -54,7 +54,7 @@ class OperationInliner final : public StmtExprMutator { bool has_side_effect = false; for (size_t i = 0; i < op->indices.size(); ++i) { - if (HasSideEffect(op->indices[i])) has_side_effect = true; + if (SideEffect(op->indices[i]) > CallEffectKind::kReadState) has_side_effect = true; } if (has_side_effect) { for (size_t i = 0; i < args_.size(); ++i) { diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index f2955f33e225..e5124dfdc965 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -147,7 +147,7 @@ class InjectScanStep : public StmtMutator { class SchedulePostProc : public StmtExprMutator { public: Stmt VisitStmt_(const LetStmtNode* op) final { - if (!HasSideEffect(op->value)) { + if (SideEffect(op->value) <= CallEffectKind::kPure) { var_value_[op->var.get()] = this->VisitExpr(op->value); return this->VisitStmt(op->body); } else { diff --git a/src/tir/analysis/side_effect.cc b/src/tir/analysis/side_effect.cc index 923cda3e41ea..5613961e2b66 100644 --- a/src/tir/analysis/side_effect.cc +++ b/src/tir/analysis/side_effect.cc @@ -33,34 +33,47 @@ namespace tir { class ExprSideEffect : public ExprVisitor { public: void VisitExpr(const PrimExpr& e) final { - if (has_side_effect_) return; + if (kind_ == CallEffectKind::kUpdateState) return; ExprVisitor::VisitExpr(e); } + void VisitExpr_(const LoadNode* op) final { + this->UpdateEffect(CallEffectKind::kReadState); + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const BufferLoadNode* op) final { + this->UpdateEffect(CallEffectKind::kReadState); + ExprVisitor::VisitExpr_(op); + } + void VisitExpr_(const CallNode* op) final { static auto op_call_effect = Op::GetAttrMap("TCallEffectKind"); if (auto* ptr_op = op->op.as()) { - auto effect_kind = op_call_effect[GetRef(ptr_op)]; - if (effect_kind != CallEffectKind::kPure && effect_kind != CallEffectKind::kExprAnnotation) { - has_side_effect_ = true; - return; - } else { - ExprVisitor::VisitExpr_(op); - } + this->UpdateEffect(static_cast(op_call_effect[GetRef(ptr_op)]->value)); } else { - has_side_effect_ = true; - return; + this->UpdateEffect(CallEffectKind::kOpaque); + } + ExprVisitor::VisitExpr_(op); + } + + void UpdateEffect(CallEffectKind effect_kind) { + if (effect_kind > CallEffectKind::kUpdateState) { + effect_kind = CallEffectKind::kUpdateState; + } + if (effect_kind > kind_) { + kind_ = effect_kind; } } - bool has_side_effect_{false}; + CallEffectKind kind_{CallEffectKind::kPure}; }; -bool HasSideEffect(const PrimExpr& e) { - ExprSideEffect v; - v(e); - return v.has_side_effect_; +CallEffectKind SideEffect(const PrimExpr& e) { + ExprSideEffect visitor; + visitor(e); + return visitor.kind_; } } // namespace tir diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index cd3a4b7483cc..baa1c3c368fd 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -90,7 +90,7 @@ class NoOpRemover : public StmtMutator { return is_no_op(op->body) ? op->body : stmt; } Stmt VisitStmt_(const EvaluateNode* op) final { - if (HasSideEffect(op->value)) return GetRef(op); + if (SideEffect(op->value) > CallEffectKind::kReadState) return GetRef(op); return Evaluate(0); } @@ -127,7 +127,7 @@ class NoOpRemover : public StmtMutator { private: Stmt MakeEvaluate(PrimExpr value) { - if (HasSideEffect(value)) { + if (SideEffect(value) > CallEffectKind::kReadState) { return Evaluate(value); } else { return Evaluate(0); @@ -136,7 +136,7 @@ class NoOpRemover : public StmtMutator { Stmt MakeEvaluate(const Array& values) { Stmt stmt; for (PrimExpr e : values) { - if (HasSideEffect(e)) { + if (SideEffect(e) > CallEffectKind::kReadState) { if (stmt.defined()) { stmt = SeqStmt({stmt, Evaluate(e)}); } else { diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 3088b6bbfc97..df8816c8f693 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -60,7 +60,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { // Won't face the deep expression explosion problem as in Let expression. // attempt to inline as much as possible if the value integer type(can be index). if (!op->value.dtype().is_int()) return false; - return !tir::HasSideEffect(op->value); + return SideEffect(op->value) <= CallEffectKind::kPure; } Stmt VisitStmt_(const LetStmtNode* op) { diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 75ae743f79ef..169ac1401445 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -70,7 +70,8 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->var.get()); Stmt body = this->VisitStmt(op->body); // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) { + if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && + simplify_let_) { return body; } else { PrimExpr value = this->VisitExpr(op->value); @@ -101,7 +102,8 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->var.get()); PrimExpr body = this->VisitExpr(op->body); // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value) && simplify_let_) { + if (use_count_.at(op->var.get()) == 0 && SideEffect(op->value) <= CallEffectKind::kReadState && + simplify_let_) { return body; } else { PrimExpr value = this->VisitExpr(op->value); diff --git a/tests/cpp/simple_passes_test.cc b/tests/cpp/tir_analysis_side_effect.cc similarity index 68% rename from tests/cpp/simple_passes_test.cc rename to tests/cpp/tir_analysis_side_effect.cc index 36b36452f4fc..26dedabb9304 100644 --- a/tests/cpp/simple_passes_test.cc +++ b/tests/cpp/tir_analysis_side_effect.cc @@ -21,16 +21,17 @@ #include #include #include +#include -TEST(SimplePasses, HasSideEffect) { +TEST(SimplePasses, SideEffect) { using namespace tvm; - auto n = te::var("n"); - Array shape; - shape.push_back(n); - - auto A = te::placeholder(shape, DataType::Float(32), "A"); - - CHECK(!tvm::tir::HasSideEffect(A[0])); + auto A = tir::Var("A", DataType::Handle()); + auto i = tir::Var("i", DataType::Int(32)); + CHECK(tir::SideEffect(tir::Load(DataType::Float(32), A, i, tir::const_true(1))) == + tir::CallEffectKind::kReadState); + CHECK(tir::SideEffect(exp(tir::Cast(DataType::Float(32), i + 1))) == tir::CallEffectKind::kPure); + CHECK(tir::SideEffect(tir::Call(DataType::Handle(), tir::builtin::tvm_storage_sync(), {})) == + tir::CallEffectKind::kUpdateState); } int main(int argc, char** argv) {