diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 41e7aa5b7796f..a73edb428cbab 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -667,6 +667,9 @@ inline bool is_no_op(const Stmt& stmt) { if (const auto* op = stmt.as()) { return is_const(op->value); } + if (const auto* op = stmt.as()) { + return op->seq.size() == 0; + } return false; } diff --git a/include/tvm/ir.h b/include/tvm/ir.h index c55a4695de4d3..b1cefff1e90e8 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -1022,25 +1022,112 @@ class Realize : public StmtNode { }; /*! - * \brief A sequence of statements. + * \brief The container of seq statement. + * Represent a sequence of statements. */ -class Block : public StmtNode { +class SeqStmtNode : public StmtNode { public: - /*! \brief The first statement. */ - Stmt first; - /*! \brief The restof statments. */ - Stmt rest; + /*! \brief internal sequence content. */ + Array seq; + + /*! \return get the size of the sequence */ + size_t size() const { + return seq.size(); + } + /*! + * \brief Get the index-th element in the sequence. + */ + Stmt operator[](size_t index) const { + return seq[index]; + } void VisitAttrs(AttrVisitor* v) { - v->Visit("first", &first); - v->Visit("rest", &rest); + v->Visit("seq", &seq); } - TVM_DLL static Stmt make(Stmt first, Stmt rest); - TVM_DLL static Stmt make(const std::vector &stmts); + static constexpr const char* _type_key = "SeqStmt"; + TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); +}; + +/*! \brief Sequence statement. */ +class SeqStmt : public Stmt { + public: + /*! + * \brief Construct SeqStmt. + * \param seq The sequence. + */ + TVM_DLL explicit SeqStmt(Array seq); + + /*! \return get the size of the sequence */ + size_t size() const { + return operator->()->size(); + } + /*! + * \brief Get the index-th element in the sequence. + */ + Stmt operator[](size_t index) const { + return (*(operator->()))[index]; + } + /*! + * \brief Construct a sequence statement by flattening + * all the arrays and sequences in the arguments + * recursively. + * + * - When an argument is nullptr, it will be ignored. + * - When an argument is an array or a SeqStmt, it will be flattened recursively. + * - When an argument is a consumer block in ProducerConsumer, the consumer + * tag will be dropped as such information is not useful in lowering. + * - A normal Stmt will be appended to the end of the sequence. + * + * \note This function can directly return an element + * if it is the only element in the sequence. + * + * \param seq_args The list of arguments to be flattened. + * \tparam Args arguments + * \return The constructed statement + */ + template + static Stmt Flatten(Args&&... seq_args) { + Array seq; + runtime::detail::for_each( + Flattener(&seq), std::forward(seq_args)...); + if (seq.size() == 1) return seq[0]; + return SeqStmt(seq); + } + /*! \brief Helper class to flatten sequence of arguments into Array. */ + class Flattener { + public: + explicit Flattener(Array* seq) + : seq_(seq) {} + + void operator()(size_t i, const Stmt& stmt) const { + if (!stmt.defined()) return; + if (auto* op = stmt.as()) { + operator()(0, op->seq); + } else if (auto* op = stmt.as()) { + // NOTE: The consumer block annotation was not as useful and can be safely dropped. + if (!op->is_producer) { + operator()(0, op->body); + } else { + seq_->push_back(stmt); + } + } else { + seq_->push_back(stmt); + } + } + + template + void operator()(size_t i, const T& seq) const { + for (auto v : seq) { + this->operator()(0, v); + } + } + + private: + Array* seq_; + }; - static constexpr const char* _type_key = "Block"; - TVM_DECLARE_FINAL_OBJECT_INFO(Block, StmtNode); + TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode); }; /*! diff --git a/include/tvm/ir_functor_ext.h b/include/tvm/ir_functor_ext.h index 4d821c2c4236e..6cc6d702c7cde 100644 --- a/include/tvm/ir_functor_ext.h +++ b/include/tvm/ir_functor_ext.h @@ -253,7 +253,7 @@ class StmtFunctor { virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Object* op, Args ...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); @@ -276,7 +276,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(Provide); IR_STMT_FUNCTOR_DISPATCH(Realize); IR_STMT_FUNCTOR_DISPATCH(Prefetch); - IR_STMT_FUNCTOR_DISPATCH(Block); + IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(Evaluate); return vtable; } @@ -408,7 +408,7 @@ class TVM_DLL StmtVisitor : void VisitStmt_(const Provide* op) override; void VisitStmt_(const Realize* op) override; void VisitStmt_(const Prefetch* op) override; - void VisitStmt_(const Block* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const Evaluate* op) override; }; @@ -502,8 +502,23 @@ class TVM_DLL StmtMutator : Stmt VisitStmt_(const Provide* op) override; Stmt VisitStmt_(const Realize* op) override; Stmt VisitStmt_(const Prefetch* op) override; - Stmt VisitStmt_(const Block* op) override; + Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const Evaluate* op) override; + /*! + * \brief Alternative advance method for SeqStmtNode. + * + * This function can be called when a child class override + * VisitStmt_(const SeqStmtNode*) to introduce + * the special behavior to visit + * + * \param op The sequence. + * \param flatten_before_visit Whether to flatten the sequence before visit. + * \param fmutate The mutate function, can be nullptr, which defaults to Visit. + * \return The mutated result. + */ + Stmt VisitSeqStmt_(const SeqStmtNode* op, + bool flatten_before_visit, + std::function fmutate = nullptr); // internal helper. class Internal; }; diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 0d5ab376d50a3..7686a96c19d3c 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -271,6 +271,14 @@ class Array : public ObjectRef { ArrayNode* n = this->CopyOnWrite(); n->data.push_back(item); } + /*! + * \brief Resize the array. + * \param size The new size. + */ + inline void resize(size_t size) { + ArrayNode* n = this->CopyOnWrite(); + n->data.resize(size); + } /*! * \brief set i-th element of the array. * \param i The index diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 816a0e1d7ad38..7e5659a8e9bbf 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -37,6 +37,8 @@ from .. import _api_internal as _tvm_internal from .. import expr as _expr from .. import make as _make +from .. import stmt as _stmt + from .. import api as _api from .. import ir_pass as _ir_pass @@ -48,11 +50,7 @@ def concat_list_to_block(lst): n = len(lst) if n == 1: return lst[0] - body = lst[n - 1] - for i in range(1, n): - stmt = lst[n - 1 - i] - body = _make.Block(stmt, body) - return body + return _stmt.SeqStmt(lst) def visit_list_to_block(visit, lst): diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index d7f41a2669bfe..bf41c98a7bdd7 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -120,14 +120,16 @@ def _pop_seq(self): seq = self._seq_stack.pop() if not seq or callable(seq[-1]): seq.append(_make.Evaluate(0)) - stmt = seq[-1] + seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x))) + ret_seq = [seq[-1]] + for s in reversed(seq[:-1]): if callable(s): - stmt = s(stmt) + ret_seq = [s(seqwrap(ret_seq))] else: assert isinstance(s, _stmt.Stmt) - stmt = _make.Block(s, stmt) - return stmt + ret_seq.append(s) + return seqwrap(ret_seq) def emit(self, stmt): """Emit a statement to the end of current scope. diff --git a/python/tvm/stmt.py b/python/tvm/stmt.py index fc7f6e2cb173a..64628d1d41987 100644 --- a/python/tvm/stmt.py +++ b/python/tvm/stmt.py @@ -289,20 +289,23 @@ def __init__(self, @register_node -class Block(Stmt): - """Block node. +class SeqStmt(Stmt): + """Sequence of statements. Parameters ---------- - first : Stmt - The first statement. - - rest : Stmt - The following statement. + seq : List[Stmt] + The statements """ - def __init__(self, first, rest): + def __init__(self, seq): self.__init_handle_by_constructor__( - _make.Block, first, rest) + _make.SeqStmt, seq) + + def __getitem__(self, i): + return self.seq[i] + + def __len__(self): + return len(self.seq) @register_node @@ -375,12 +378,14 @@ def stmt_seq(*args): stmt : Stmt The combined statement. """ - ret = None + ret = [] for value in args: if not isinstance(value, Stmt): value = Evaluate(value) - ret = value if ret is None else Block(ret, value) - return ret if ret else Evaluate(0) + ret.append(value) + if len(ret) == 1: + return ret[0] + return SeqStmt(ret) def stmt_list(stmt): @@ -395,12 +400,14 @@ def stmt_list(stmt): stmt_list : list of Stmt The unpacked list of statements """ - if isinstance(stmt, Block): - return stmt_list(stmt.first) + stmt_list(stmt.rest) + if isinstance(stmt, SeqStmt): + res = [] + for x in stmt: + res += stmt_list(x) + return res if isinstance(stmt, ProducerConsumer): return stmt_list(stmt.body) return [stmt] _make.stmt_list = stmt_list -_make.stmt_seq = stmt_seq diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 2987b9ef39c16..034405f1a7f03 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -63,6 +63,12 @@ TVM_REGISTER_GLOBAL("make._cast") TVM_REGISTER_GLOBAL("make._range_by_min_extent") .set_body_typed(Range::make_by_min_extent); + +TVM_REGISTER_GLOBAL("make.SeqStmt") +.set_body_typed([](Array seq) { + return SeqStmt(std::move(seq)); +}); + TVM_REGISTER_GLOBAL("make.For") .set_body_typed([]( VarExpr loop_var, Expr min, Expr extent, @@ -163,9 +169,6 @@ REGISTER_MAKE(IfThenElse); REGISTER_MAKE(Evaluate); // overloaded, needs special handling -TVM_REGISTER_GLOBAL("make.Block") - .set_body_typed(static_cast(Block::make)); - // has default args TVM_REGISTER_GLOBAL("make.Allocate") .set_body_typed([]( diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 4b95e2caf1aae..a3f145994f2ca 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -405,22 +405,22 @@ inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // N } } -void CodeGenC::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const IntImm* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const UIntImm* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const FloatImm* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const StringImm* op, std::ostream& os) { // NOLINT(*) os << "\"" << op->value << "\""; } template inline void PrintBinaryExpr(const T* op, - const char *opstr, + const char* opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { if (op->dtype.lanes() == 1) { @@ -443,7 +443,7 @@ inline void PrintBinaryExpr(const T* op, } inline void PrintBinaryIntrinsic(const Call* op, - const char *opstr, + const char* opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { if (op->dtype.lanes() == 1) { @@ -457,65 +457,65 @@ inline void PrintBinaryIntrinsic(const Call* op, p->PrintVecBinaryOp(opstr, op->dtype, op->args[0], op->args[1], os); } } -void CodeGenC::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*) std::stringstream value; this->PrintExpr(op->value, value); os << CastFromTo(value.str(), op->value.dtype(), op->dtype); } -void CodeGenC::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Variable* op, std::ostream& os) { // NOLINT(*) os << GetVarID(op); } -void CodeGenC::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "+", os, this); } -void CodeGenC::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Sub* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "-", os, this); } -void CodeGenC::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Mul* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "*", os, this); } -void CodeGenC::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Div* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "/", os, this); } -void CodeGenC::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Mod* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } -void CodeGenC::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "min", os, this); } -void CodeGenC::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Max* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "max", os, this); } -void CodeGenC::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const EQ* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "==", os, this); } -void CodeGenC::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const NE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "!=", os, this); } -void CodeGenC::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const LT* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "<", os, this); } -void CodeGenC::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const LE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "<=", os, this); } -void CodeGenC::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const GT* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, ">", os, this); } -void CodeGenC::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const GE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, ">=", os, this); } -void CodeGenC::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const And* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "&&", os, this); } -void CodeGenC::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Or* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "||", os, this); } -void CodeGenC::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Not* op, std::ostream& os) { // NOLINT(*) os << '!'; PrintExpr(op->a, os); } -void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*) if (op->call_type == Call::Extern || op->call_type == Call::PureExtern) { os << op->name << "("; @@ -875,12 +875,13 @@ void CodeGenC::VisitStmt_(const IfThenElse* op) { stream << "}\n"; } -void CodeGenC::VisitStmt_(const Block *op) { - PrintStmt(op->first); - if (op->rest.defined()) PrintStmt(op->rest); +void CodeGenC::VisitStmt_(const SeqStmtNode* op) { + for (Stmt stmt : op->seq) { + PrintStmt(stmt); + } } -void CodeGenC::VisitStmt_(const Evaluate *op) { +void CodeGenC::VisitStmt_(const Evaluate* op) { if (is_const(op->value)) return; const Call* call = op->value.as(); if (call) { @@ -906,7 +907,7 @@ void CodeGenC::VisitStmt_(const Evaluate *op) { } } -void CodeGenC::VisitStmt_(const ProducerConsumer *op) { +void CodeGenC::VisitStmt_(const ProducerConsumer* op) { PrintStmt(op->body); } diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h index b8d3570519986..eae1e4961b773 100644 --- a/src/codegen/codegen_c.h +++ b/src/codegen/codegen_c.h @@ -140,7 +140,7 @@ class CodeGenC : void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const Evaluate* op) override; - void VisitStmt_(const Block* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const ProducerConsumer* op) override; /*! * Print Type represetnation of type t. diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 94ad8b76c9c91..b0d86a9f66ce1 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -1214,10 +1214,9 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) { this->VisitStmt(op->body); } -void CodeGenLLVM::VisitStmt_(const Block* op) { - this->VisitStmt(op->first); - if (op->rest.defined()) { - this->VisitStmt(op->rest); +void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { + for (Stmt stmt : op->seq) { + this->VisitStmt(stmt); } } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 08c836adf9d04..076ffb2af5880 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -140,7 +140,7 @@ class CodeGenLLVM : void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const LetStmt* op) override; - void VisitStmt_(const Block* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const ProducerConsumer* op) override; diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 7800e47319e0b..0709965d0e8b1 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -638,10 +638,9 @@ void CodeGenSPIRV::VisitStmt_(const LetStmt* op) { this->VisitStmt(op->body); } -void CodeGenSPIRV::VisitStmt_(const Block* op) { - VisitStmt(op->first); - if (op->rest.defined()) { - this->VisitStmt(op->rest); +void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { + for (Stmt stmt : op->seq) { + this->VisitStmt(stmt); } } diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 3d16377271c41..5cd88c9f267af 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -98,7 +98,7 @@ class CodeGenSPIRV: void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const LetStmt* op) override; - void VisitStmt_(const Block* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const ProducerConsumer* op) override; diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index 9482b2cd649ed..23bb008a0e7e9 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -268,59 +268,59 @@ void CodeGenStackVM::PushCast(DataType dst, DataType src) { } } -void CodeGenStackVM::VisitExpr_(const StringImm *op) { +void CodeGenStackVM::VisitExpr_(const StringImm* op) { int sid = this->GetStrID(op->value); this->PushOp(StackVM::PUSH_I64, sid); } -void CodeGenStackVM::VisitExpr_(const IntImm *op) { +void CodeGenStackVM::VisitExpr_(const IntImm* op) { CHECK(op->value >= std::numeric_limits::min() && op->value <= std::numeric_limits::max()) << "Int constant exceed bound"; this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } -void CodeGenStackVM::VisitExpr_(const UIntImm *op) { +void CodeGenStackVM::VisitExpr_(const UIntImm* op) { CHECK(op->value <= std::numeric_limits::max()) << "Int constant exceed bound"; this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } -void CodeGenStackVM::VisitExpr_(const FloatImm *op) { +void CodeGenStackVM::VisitExpr_(const FloatImm* op) { LOG(FATAL) << "Float Imm is not supported"; } -void CodeGenStackVM::VisitExpr_(const Variable *op) { +void CodeGenStackVM::VisitExpr_(const Variable* op) { int vid = this->GetVarID(op); this->PushOp(StackVM::LOAD_HEAP, vid); } -void CodeGenStackVM::VisitExpr_(const Cast *op) { +void CodeGenStackVM::VisitExpr_(const Cast* op) { this->Push(op->value); PushCast(op->dtype, op->value.dtype()); } -void CodeGenStackVM::VisitExpr_(const Add *op) { +void CodeGenStackVM::VisitExpr_(const Add* op) { PushBinary(StackVM::ADD_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Sub *op) { +void CodeGenStackVM::VisitExpr_(const Sub* op) { PushBinary(StackVM::SUB_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Mul *op) { +void CodeGenStackVM::VisitExpr_(const Mul* op) { PushBinary(StackVM::MUL_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Div *op) { +void CodeGenStackVM::VisitExpr_(const Div* op) { PushBinary(StackVM::DIV_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Mod *op) { +void CodeGenStackVM::VisitExpr_(const Mod* op) { PushBinary(StackVM::MOD_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const Min *op) { +void CodeGenStackVM::VisitExpr_(const Min* op) { this->Push(op->a); this->Push(op->b); this->PushOp(StackVM::PUSH_VALUE, -1); @@ -329,7 +329,7 @@ void CodeGenStackVM::VisitExpr_(const Min *op) { this->PushOp(StackVM::SELECT); } -void CodeGenStackVM::VisitExpr_(const Max *op) { +void CodeGenStackVM::VisitExpr_(const Max* op) { this->Push(op->a); this->Push(op->b); this->PushOp(StackVM::PUSH_VALUE, 0); @@ -338,34 +338,34 @@ void CodeGenStackVM::VisitExpr_(const Max *op) { this->PushOp(StackVM::SELECT); } -void CodeGenStackVM::VisitExpr_(const EQ *op) { +void CodeGenStackVM::VisitExpr_(const EQ* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const LE *op) { +void CodeGenStackVM::VisitExpr_(const LE* op) { PushBinary(StackVM::LE_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const NE *op) { +void CodeGenStackVM::VisitExpr_(const NE* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitExpr_(const LT *op) { +void CodeGenStackVM::VisitExpr_(const LT* op) { PushBinary(StackVM::LT_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const GE *op) { +void CodeGenStackVM::VisitExpr_(const GE* op) { PushBinary(StackVM::LT_I64, op->a, op->b); this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitExpr_(const GT *op) { +void CodeGenStackVM::VisitExpr_(const GT* op) { PushBinary(StackVM::LE_I64, op->a, op->b); this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitExpr_(const And *op) { +void CodeGenStackVM::VisitExpr_(const And* op) { this->Push(op->a); int64_t pc_jump = this->GetPC(); int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); @@ -375,7 +375,7 @@ void CodeGenStackVM::VisitExpr_(const And *op) { this->SetOperand(opr_index, diff); } -void CodeGenStackVM::VisitExpr_(const Or *op) { +void CodeGenStackVM::VisitExpr_(const Or* op) { this->Push(op->a); int64_t pc_jump = this->GetPC(); int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0); @@ -389,11 +389,11 @@ void CodeGenStackVM::VisitExpr_(const Not* op) { this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitStmt_(const ProducerConsumer *op) { +void CodeGenStackVM::VisitStmt_(const ProducerConsumer* op) { this->Push(op->body); } -void CodeGenStackVM::VisitStmt_(const For *op) { +void CodeGenStackVM::VisitStmt_(const For* op) { CHECK(is_zero(op->min)); int vid = this->AllocVarID(op->loop_var.get()); this->PushOp(StackVM::PUSH_I64, 0); @@ -417,9 +417,10 @@ void CodeGenStackVM::VisitStmt_(const For *op) { this->SetOperand(backward_jump, loop_head - label_bjump); } -void CodeGenStackVM::VisitStmt_(const Block *op) { - this->Push(op->first); - if (op->rest.defined()) this->Push(op->rest); +void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { + for (Stmt stmt : op->seq) { + this->Push(stmt); + } } void CodeGenStackVM::VisitStmt_(const Evaluate *ev) { @@ -444,7 +445,7 @@ void CodeGenStackVM::VisitStmt_(const Evaluate *ev) { } } -void CodeGenStackVM::VisitStmt_(const IfThenElse *op) { +void CodeGenStackVM::VisitStmt_(const IfThenElse* op) { this->Push(op->condition); int64_t label_ejump = this->GetPC(); int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); @@ -466,29 +467,29 @@ void CodeGenStackVM::VisitStmt_(const IfThenElse *op) { } } -void CodeGenStackVM::VisitStmt_(const LetStmt *op) { +void CodeGenStackVM::VisitStmt_(const LetStmt* op) { this->Push(op->value); int64_t vid = this->AllocVarID(op->var.get()); this->PushOp(StackVM::STORE_HEAP, static_cast(vid)); this->Push(op->body); } -void CodeGenStackVM::VisitExpr_(const Ramp *op) { +void CodeGenStackVM::VisitExpr_(const Ramp* op) { LOG(FATAL) << "Ramp is not supported"; } -void CodeGenStackVM::VisitExpr_(const Broadcast *op) { +void CodeGenStackVM::VisitExpr_(const Broadcast* op) { LOG(FATAL) << "Broadcast is not supported"; } -void CodeGenStackVM::VisitExpr_(const Select *op) { +void CodeGenStackVM::VisitExpr_(const Select* op) { this->Push(op->true_value); this->Push(op->false_value); this->Push(op->condition); this->PushOp(StackVM::SELECT); } -void CodeGenStackVM::VisitStmt_(const AssertStmt *op) { +void CodeGenStackVM::VisitStmt_(const AssertStmt* op) { if (const auto* str = op->message.as()) { int sid = this->GetStrID(str->value); this->Push(op->condition); @@ -497,11 +498,11 @@ void CodeGenStackVM::VisitStmt_(const AssertStmt *op) { this->Push(op->body); } -void CodeGenStackVM::VisitStmt_(const AttrStmt *op) { +void CodeGenStackVM::VisitStmt_(const AttrStmt* op) { this->Push(op->body); } -void CodeGenStackVM::VisitExpr_(const Let *op) { +void CodeGenStackVM::VisitExpr_(const Let* op) { this->Push(op->value); int64_t vid = this->AllocVarID(op->var.get()); this->PushOp(StackVM::STORE_HEAP, static_cast(vid)); diff --git a/src/codegen/stackvm/codegen_stackvm.h b/src/codegen/stackvm/codegen_stackvm.h index dcae072c102d4..7a4c0ab797fd5 100644 --- a/src/codegen/stackvm/codegen_stackvm.h +++ b/src/codegen/stackvm/codegen_stackvm.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -148,7 +148,7 @@ class CodeGenStackVM void VisitStmt_(const AttrStmt* op) final; void VisitStmt_(const AssertStmt* op) final; void VisitStmt_(const Evaluate* op) final; - void VisitStmt_(const Block* op) final; + void VisitStmt_(const SeqStmtNode* op) final; void VisitStmt_(const ProducerConsumer* op) final; private: diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 301602fb82382..00b2c230c5bb6 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -76,24 +76,24 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { os << t.bits(); } -void CodeGenHybrid::VisitExpr_(const IntImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const IntImm* op, std::ostream& os) { // NOLINT(*) os << op->value; } -void CodeGenHybrid::VisitExpr_(const UIntImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const UIntImm* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << op->value << ")"; } -void CodeGenHybrid::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloatImm* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; } -void CodeGenHybrid::VisitExpr_(const StringImm *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const StringImm* op, std::ostream& os) { // NOLINT(*) os << "'" << op->value << "'"; } template inline void PrintBinaryExpr(const T* op, - const char *opstr, + const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; @@ -115,7 +115,7 @@ inline void PrintBinaryExpr(const T* op, } inline void PrintBinaryIntrinsitc(const Call* op, - const char *opstr, + const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; @@ -127,7 +127,7 @@ inline void PrintBinaryIntrinsitc(const Call* op, os << ')'; } -void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Cast* op, std::ostream& os) { // NOLINT(*) if (op->dtype == op->value.dtype()) { PrintExpr(op->value, stream); } else { @@ -138,76 +138,76 @@ void CodeGenHybrid::VisitExpr_(const Cast *op, std::ostream& os) { // NOLINT(*) } } -void CodeGenHybrid::VisitExpr_(const Variable *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Variable* op, std::ostream& os) { // NOLINT(*) os << GetVarID(op); } -void CodeGenHybrid::VisitExpr_(const Add *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Add* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "+", os, this); } -void CodeGenHybrid::VisitExpr_(const Sub *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Sub* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "-", os, this); } -void CodeGenHybrid::VisitExpr_(const Mul *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Mul* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "*", os, this); } -void CodeGenHybrid::VisitExpr_(const Div *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Div* op, std::ostream& os) { // NOLINT(*) if (op->dtype.is_int()) PrintBinaryExpr(op, "//", os, this); else PrintBinaryExpr(op, "/", os, this); } -void CodeGenHybrid::VisitExpr_(const FloorDiv *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloorDiv* op, std::ostream& os) { // NOLINT(*) if (op->dtype.is_int()) PrintBinaryExpr(op, "//", os, this); else PrintBinaryExpr(op, "/", os, this); } -void CodeGenHybrid::VisitExpr_(const Mod *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Mod* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } -void CodeGenHybrid::VisitExpr_(const FloorMod *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloorMod* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "%", os, this); } -void CodeGenHybrid::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Min* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "min", os, this); } -void CodeGenHybrid::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Max* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "max", os, this); } -void CodeGenHybrid::VisitExpr_(const EQ *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const EQ* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "==", os, this); } -void CodeGenHybrid::VisitExpr_(const NE *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const NE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "!=", os, this); } -void CodeGenHybrid::VisitExpr_(const LT *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const LT* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "<", os, this); } -void CodeGenHybrid::VisitExpr_(const LE *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const LE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "<=", os, this); } -void CodeGenHybrid::VisitExpr_(const GT *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const GT* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, ">", os, this); } -void CodeGenHybrid::VisitExpr_(const GE *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const GE* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, ">=", os, this); } -void CodeGenHybrid::VisitExpr_(const And *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const And* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "&&", os, this); } -void CodeGenHybrid::VisitExpr_(const Or *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Or* op, std::ostream& os) { // NOLINT(*) PrintBinaryExpr(op, "||", os, this); } -void CodeGenHybrid::VisitExpr_(const Not *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Not* op, std::ostream& os) { // NOLINT(*) os << "not "; PrintExpr(op->a, os); } -void CodeGenHybrid::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*) if (op->call_type == Call::Halide) { os << GetTensorID(op->func, op->value_index); os << "["; @@ -313,7 +313,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmt* op) { } } -void CodeGenHybrid::VisitStmt_(const Realize *op) { +void CodeGenHybrid::VisitStmt_(const Realize* op) { CHECK(alloc_storage_scope_.count(op->func)); if (!alloc_storage_scope_[op->func].empty()) { PrintIndent(); @@ -389,19 +389,20 @@ void CodeGenHybrid::VisitStmt_(const IfThenElse* op) { } } -void CodeGenHybrid::VisitStmt_(const Block *op) { - PrintStmt(op->first); - if (op->rest.defined()) PrintStmt(op->rest); +void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) { + for (Stmt stmt : op->seq) { + PrintStmt(stmt); + } } -void CodeGenHybrid::VisitStmt_(const Evaluate *op) { +void CodeGenHybrid::VisitStmt_(const Evaluate* op) { if (is_const(op->value)) return; std::string str = PrintExpr(op->value); if (!str.empty()) stream << str << "\n"; } -void CodeGenHybrid::VisitStmt_(const ProducerConsumer *op) { +void CodeGenHybrid::VisitStmt_(const ProducerConsumer* op) { PrintStmt(op->body); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 647ef77fc5341..27c97c73e3330 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -131,7 +131,7 @@ class CodeGenHybrid : void VisitStmt_(const AttrStmt* op) override; void VisitStmt_(const AssertStmt* op) override; void VisitStmt_(const Evaluate* op) override; - void VisitStmt_(const Block* op) override; + void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const ProducerConsumer* op) override; /*! * \brief Print Type represetnation of type t. diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 5b410d1e37416..de047f330630b 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -504,31 +504,10 @@ Stmt Prefetch::make(FunctionRef func, int value_index, DataType dtype, Region bo return Stmt(node); } -Stmt Block::make(Stmt first, Stmt rest) { - CHECK(first.defined()); - CHECK(rest.defined()); - ObjectPtr node = make_object(); - - // canonicalize. - if (const Block* b = first.as()) { - node->first = b->first; - node->rest = Block::make(b->rest, rest); - } else { - node->first = std::move(first); - node->rest = std::move(rest); - } - return Stmt(node); -} - -Stmt Block::make(const std::vector& stmts) { - if (stmts.empty()) { - return Stmt(); - } - Stmt result = stmts.back(); - for (size_t i = stmts.size() - 1; i != 0; --i) { - result = Block::make(stmts[i - 1], result); - } - return result; +SeqStmt::SeqStmt(Array seq) { + auto node = make_object(); + node->seq = std::move(seq); + data_ = std::move(node); } Stmt IfThenElse::make(Expr condition, Stmt then_case, Stmt else_case) { @@ -1032,10 +1011,11 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& node, NodePrinter* p) { - auto* op = static_cast(node.get()); - p->Print(op->first); - if (op->rest.defined()) p->Print(op->rest); +.set_dispatch([](const ObjectRef& node, NodePrinter* p) { + auto* op = static_cast(node.get()); + for (Stmt stmt : op->seq) { + p->Print(stmt); + } }); TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) @@ -1212,7 +1192,7 @@ TVM_REGISTER_NODE_TYPE(Provide); TVM_REGISTER_NODE_TYPE(Allocate); TVM_REGISTER_NODE_TYPE(Free); TVM_REGISTER_NODE_TYPE(Realize); -TVM_REGISTER_NODE_TYPE(Block); +TVM_REGISTER_NODE_TYPE(SeqStmtNode); TVM_REGISTER_NODE_TYPE(IfThenElse); TVM_REGISTER_NODE_TYPE(Evaluate); diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 939327890fecb..6146284554b44 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -337,8 +337,8 @@ void MakeReduction(const ComputeOpNode* op, provides.emplace_back(Provide::make( t->op, t->value_index, update_value[i], args)); } - *init = Block::make(inits); - *provide = Block::make(provides); + *init = SeqStmt::Flatten(inits); + *provide = SeqStmt::Flatten(provides); if (!is_one(reduce->condition)) { *provide = IfThenElse::make(reduce->condition, *provide); } @@ -382,7 +382,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, if (debug_keep_trivial_loop) { provide = MergeNest(common, provide); } else { - provide = MergeNest(common, Block::make(init, provide)); + provide = MergeNest(common, SeqStmt::Flatten(init, provide)); } // run substitution in the on the full nest, because loop condition // could depend on outer loops. @@ -392,7 +392,7 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, for (size_t i = 0; i < self->body.size(); ++i) { provides.emplace_back(MakeProvide(self, stage->op.output(i))); } - Stmt provide = Block::make(provides); + Stmt provide = SeqStmt::Flatten(provides); provide = MergeNest(n.main_nest, provide); // run substitution in the on the full nest, because loop condition // could depend on outer loops. diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc index 4a3aa54ccc6d5..ab56fc9657d2e 100644 --- a/src/op/cross_thread_reduction.cc +++ b/src/op/cross_thread_reduction.cc @@ -100,10 +100,10 @@ Stmt MakeCrossThreadReduction( stage->op, idx, Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args); } - Stmt assign_body = Block::make(assigns); + Stmt assign_body = SeqStmt::Flatten(assigns); assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); assign_body = MergeNest(op::MakeIfNest(conds), assign_body); - Stmt body = Block::make(reduce_body, assign_body); + Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { body = Allocate::make( res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc index 76ecf3417d36c..a6252df05246b 100644 --- a/src/op/tensor_compute_op.cc +++ b/src/op/tensor_compute_op.cc @@ -242,7 +242,7 @@ Stmt TensorComputeOpNode::BuildProvide( update = MergeNest(binder.asserts(), update); update = op::Substitute(update, n.main_vmap); update = MergeNest(update_nest, update); - return MergeNest(common, Block::make(init, update)); + return MergeNest(common, SeqStmt::Flatten(init, update)); } else { // When init op is not available, use body op for reset in the first iter. CHECK(this->intrin->body.defined()) diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index f6fa00dad859c..0df8e889efeb6 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -478,7 +478,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, update = MergeNest(binder.asserts(), update); update = Substitute(update, n.main_vmap); update = MergeNest(update_nest, update); - return MergeNest(common, Block::make(init, update)); + return MergeNest(common, SeqStmt::Flatten(init, update)); } else { // When init op is not available, use body op for reset in the first iter. CHECK(intrin->body.defined()) diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index e4ff9cb457a56..a0ddcd98b2609 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -240,7 +240,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, AssertStmt::make(arith::ComputeReduce(conds, Expr()), stride_err_msg.str(), Evaluate::make(0)); check = IfThenElse::make(Not::make(is_null), check, Stmt()); - asserts_.emplace_back(Block::make(check, Evaluate::make(0))); + asserts_.emplace_back(SeqStmt({check, Evaluate::make(0)})); } } else if (buffer->buffer_type == kAutoBroadcast) { DataType stype = buffer->DefaultIndexType(); diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index 9243051996285..33af959f102e4 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -655,24 +655,14 @@ class CoProcSyncInserter : public StmtMutator { } Stmt VisitStmt(const Stmt& stmt) final { - Stmt before, after; - auto it = insert_before_.find(stmt.get()); - if (it != insert_before_.end()) { - before = MergeSeq(std::vector( - it->second.rbegin(), it->second.rend())); - } - it = insert_after_.find(stmt.get()); - if (it != insert_after_.end()) { - after = MergeSeq(it->second); - } + auto it_before = insert_before_.find(stmt.get()); + auto it_after = insert_after_.find(stmt.get()); Stmt new_stmt = StmtMutator::VisitStmt(stmt); - if (before.defined()) { - new_stmt = Block::make(before, new_stmt); - } - if (after.defined()) { - new_stmt = Block::make(new_stmt, after); - } - return new_stmt; + + return SeqStmt::Flatten( + it_before != insert_before_.end() ? it_before->second : std::vector(), + new_stmt, + it_after != insert_after_.end() ? it_after->second : std::vector()); } private: diff --git a/src/pass/inject_double_buffer.cc b/src/pass/inject_double_buffer.cc index 84b2f705e9958..0158a949da533 100644 --- a/src/pass/inject_double_buffer.cc +++ b/src/pass/inject_double_buffer.cc @@ -147,7 +147,7 @@ class DoubleBufferInjector : public StmtExprMutator { } Stmt loop = For::make( outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, - MergeSeq(loop_seq)); + SeqStmt::Flatten(loop_seq)); // tail std::vector tail_seq; Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); @@ -158,9 +158,9 @@ class DoubleBufferInjector : public StmtExprMutator { IfThenElse::make(idx < old_loop->extent, Substitute(tail_body, vmap))); } - stmt = Block::make(loop, MergeSeq(tail_seq)); + stmt = SeqStmt::Flatten(loop, tail_seq); } - stmt = Block::make(MergeSeq(it->second), stmt); + stmt = SeqStmt::Flatten(it->second, stmt); } it = loop_allocs_.find(op); if (it != loop_allocs_.end()) { diff --git a/src/pass/inject_prefetch.cc b/src/pass/inject_prefetch.cc index 73e1dc9a3e302..73725c20583c0 100644 --- a/src/pass/inject_prefetch.cc +++ b/src/pass/inject_prefetch.cc @@ -59,7 +59,7 @@ class PrefetchInjector : public StmtMutator { vectorized_.erase(iter_var); Stmt prefetch = Prefetch::make(ts->op, ts->value_index, ts->dtype, region); - return Block::make(prefetch, op->body); + return SeqStmt({prefetch, op->body}); } return ret; } diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 202a5c27bd8b2..0887a83c1a48d 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -356,20 +356,18 @@ class VTInjector : public StmtExprMutator { return IfThenElse::make(condition, then_case, else_case); } } - // Block - Stmt VisitStmt_(const Block* op) final { + + // Seq + Stmt VisitStmt_(const SeqStmtNode* op) final { CHECK_EQ(max_loop_depth_, 0); - Stmt first = this->VisitStmt(op->first); - int temp = max_loop_depth_; - max_loop_depth_ = 0; - Stmt rest = this->VisitStmt(op->rest); - max_loop_depth_ = std::max(max_loop_depth_, temp); - if (first.same_as(op->first) && - rest.same_as(op->rest)) { - return GetRef(op); - } else { - return Block::make(first, rest); - } + auto fmutate = [this](const Stmt& s) { + int temp = max_loop_depth_; + max_loop_depth_ = 0; + Stmt ret = this->VisitStmt(s); + max_loop_depth_ = std::max(max_loop_depth_, temp); + return ret; + }; + return StmtMutator::VisitSeqStmt_(op, false, fmutate); } // Allocate Stmt VisitStmt_(const Allocate* op) final { @@ -442,12 +440,11 @@ class VTInjector : public StmtExprMutator { // only unroll if number of vthreads are small if (max_loop_depth_ == 0 && num_threads_ < 16) { // do unrolling if it is inside innermost content. - Stmt blk = Substitute(stmt, {{var_, make_zero(var_.dtype())}}); - for (int i = 1; i < num_threads_; ++i) { - blk = Block::make( - blk, Substitute(stmt, {{var_, make_const(var_.dtype(), i)}})); + Array seq; + for (int i = 0; i < num_threads_; ++i) { + seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}})); } - return blk; + return SeqStmt::Flatten(seq); } else { // insert a for loop Var idx(var_->name_hint + ".s", var_->dtype); diff --git a/src/pass/ir_deep_compare.cc b/src/pass/ir_deep_compare.cc index 6a61d5e402f94..bbee9eeb7c8a3 100644 --- a/src/pass/ir_deep_compare.cc +++ b/src/pass/ir_deep_compare.cc @@ -179,10 +179,12 @@ class IRDeepCompare : if (CompareRegion(op->bounds, rhs->bounds) != 0) return; } - void VisitStmt_(const Block* op, const Stmt& other) final { - const Block* rhs = other.as(); - if (CompareStmt(op->first, rhs->first) != 0) return; - if (CompareStmt(op->rest, rhs->rest) != 0) return; + void VisitStmt_(const SeqStmtNode* op, const Stmt& other) final { + const SeqStmtNode* rhs = other.as(); + if (CompareValue(op->size(), rhs->size()) != 0) return; + for (size_t i = 0; i < op->size(); ++i) { + if (CompareStmt(op->seq[i], rhs->seq[i]) != 0) return; + } } void VisitStmt_(const Evaluate* op, const Stmt& other) final { diff --git a/src/pass/ir_functor.cc b/src/pass/ir_functor.cc index efc43a2ffa3db..dddf90eb47aa1 100644 --- a/src/pass/ir_functor.cc +++ b/src/pass/ir_functor.cc @@ -209,9 +209,10 @@ void StmtVisitor::VisitStmt_(const Prefetch* op) { }); } -void StmtVisitor::VisitStmt_(const Block* op) { - this->VisitStmt(op->first); - this->VisitStmt(op->rest); +void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { + VisitArray(op->seq, [this](const Stmt& s) { + this->VisitStmt(s); + }); } void StmtVisitor::VisitStmt_(const Evaluate* op) { @@ -490,20 +491,63 @@ Stmt StmtMutator::VisitStmt_(const Prefetch* op) { } } -Stmt StmtMutator::VisitStmt_(const Block* op) { - Stmt first = this->VisitStmt(op->first); - Stmt rest = this->VisitStmt(op->rest); - if (first.same_as(op->first) && - rest.same_as(op->rest)) { +Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { + Array seq = Internal::Mutate(this, op->seq); + if (seq.same_as(op->seq)) { return GetRef(op); } else { auto n = CopyOnWrite(op); - n->first = std::move(first); - n->rest = std::move(rest); + n->seq = std::move(seq); return Stmt(n); } } +// advanced visit function for seqstmt. +Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, + bool flatten_before_visit, + std::function fmutate) { + if (flatten_before_visit) { + // Pass 1, check if we need to flatten. + bool need_flatten = false; + for (size_t i = 0; i < op->seq.size(); ++i) { + Stmt tmp = (*op)[i]; + if (tmp.as()) need_flatten = true; + } + flatten_before_visit = need_flatten; + } + // function to run the visit. + auto frunvisit = [&](const SeqStmtNode* op) { + Array seq = + fmutate != nullptr ? + MutateArray(op->seq, fmutate, allow_copy_on_write_) : + Internal::Mutate(this, op->seq); + if (seq.same_as(op->seq)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->seq = std::move(seq); + return Stmt(n); + } + }; + if (flatten_before_visit) { + Array seq; + SeqStmt::Flattener flattener(&seq); + flattener(0, op->seq); + // NOTE: If copy on write is allowed + // the assignment to seq below will + // destruct the original seq. + // + // Such destruction removes duplicated reference + // count to children and still enables COW for + // child Stmt. + ObjectPtr n = CopyOnWrite(op); + n->seq = std::move(seq); + return frunvisit(n.operator->()); + } else { + return frunvisit(op); + } +} + Stmt StmtMutator::VisitStmt_(const AssertStmt* op) { Expr condition = this->VisitExpr(op->condition); Expr message = this->VisitExpr(op->message); diff --git a/src/pass/ir_util.cc b/src/pass/ir_util.cc index cdc708ce5faf2..8956a4d11e7cc 100644 --- a/src/pass/ir_util.cc +++ b/src/pass/ir_util.cc @@ -51,10 +51,10 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { CHECK(!n->else_case.defined()); n->then_case = body; body = Stmt(n); - } else if (const auto* block = s.as()) { - auto n = make_object(*block); - CHECK(is_no_op(n->rest)); - n->rest = body; + } else if (const auto* seq = s.as()) { + auto n = make_object(*seq); + CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); + n->seq.Set(n->size() - 1, body); body = Stmt(n); } else if (const auto* assert_ = s.as()) { auto n = make_object(*assert_); @@ -80,14 +80,5 @@ Stmt MergeNest(const std::vector >& nest, Stmt body) { return body; } -Stmt MergeSeq(const std::vector& seq) { - if (seq.size() == 0) return Evaluate::make(0); - Stmt body = seq[0]; - for (size_t i = 1; i < seq.size(); ++i) { - body = Block::make(body, seq[i]); - } - return body; -} - } // namespace ir } // namespace tvm diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 0f8bb990c2d3b..900d6d59853ac 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -47,13 +47,6 @@ Stmt MergeNest(const std::vector& nest, Stmt body); */ Stmt MergeNest(const std::vector >& nest, Stmt body); -/*! - * \brief combine sequence of operations. - * \param seq The sequence. - * \return The combined Stmt - */ -Stmt MergeSeq(const std::vector& seq); - /*! * \brief update array with an unary function * \param arr array diff --git a/src/pass/lift_attr_scope.cc b/src/pass/lift_attr_scope.cc index eeed10ebe3aec..4f2df7b22b097 100644 --- a/src/pass/lift_attr_scope.cc +++ b/src/pass/lift_attr_scope.cc @@ -75,17 +75,56 @@ class AttrScopeLifter : public StmtMutator { } } - Stmt VisitStmt_(const Block* op) final { - std::vector seq; - FlattenSeq(op->first, &seq); - FlattenSeq(op->rest, &seq); - seq = MutateSeq(seq); - if (seq.size() == 2 && - seq[0].same_as(op->first) && - seq[1].same_as(op->rest)) { - return GetRef(op); + Stmt VisitStmt_(const SeqStmtNode* op) final { + // remember the decorations. + std::vector attr_node; + std::vector attr_value; + + auto fmutate = [&](const Stmt& s) { + attr_node_ = ObjectRef(); + attr_value_ = Expr(); + Stmt ret = this->VisitStmt(s); + attr_node.push_back(attr_node_); + attr_value.push_back(attr_value_); + return ret; + }; + Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate); + if (attr_node.size() == 0) return ret; + + op = ret.as(); + CHECK(op != nullptr); + Array reorg; + // check if all decorations are common. + for (size_t begin = 0; begin < attr_node.size();) { + size_t end = begin + 1; + while (end < attr_node.size() && + attr_node[end].same_as(attr_node[begin]) && + ValueSame(attr_value[end], attr_value[begin])) { + ++end; + } + // covers everything + // lift attr to parent. + if (begin == 0 && end == attr_node.size()) { + attr_node_ = attr_node[0]; + attr_value_ = attr_value[0]; + return ret; + } + // construct subsegments. + Array seq; + for (size_t i = begin; i < end; ++i) { + seq.push_back(op->seq[i]); + } + Stmt stmt = SeqStmt::Flatten(seq); + if (attr_node[begin].defined()) { + stmt = AttrStmt::make( + attr_node[begin], attr_key_, attr_value[begin], stmt); + } + reorg.push_back(stmt); + begin = end; } - return MergeSeq(seq); + attr_node_ = ObjectRef(); + attr_value_ = Expr(); + return SeqStmt::Flatten(reorg); } Stmt VisitStmt_(const IfThenElse* op) final { @@ -132,71 +171,10 @@ class AttrScopeLifter : public StmtMutator { } private: - void FlattenSeq(Stmt s, std::vector* res) { - if (const Block* op = s.as()) { - FlattenSeq(op->first, res); - FlattenSeq(op->rest, res); - } else if (const ProducerConsumer* op = s.as()) { - if (!op->is_producer) { - FlattenSeq(op->body, res); - } else { - res->emplace_back(s); - } - } else { - res->emplace_back(s); - } - } - - std::vector MutateSeq(const std::vector& seq) { - std::vector res_seq; - ObjectRef curr_node; - Expr curr_value; - Stmt curr_stmt; - for (const Stmt & stmt : seq) { - attr_node_ = ObjectRef(); - attr_value_ = Expr(); - Stmt rest = this->VisitStmt(stmt); - if (attr_node_.defined() && - attr_value_.defined() && - curr_node.defined() && - curr_value.defined() && - attr_node_.same_as(curr_node) && - ValueSame(attr_value_, curr_value)) { - curr_stmt = Block::make(curr_stmt, rest); - } else { - if (curr_stmt.defined()) { - if (curr_node.defined()) { - curr_stmt = AttrStmt::make( - curr_node, attr_key_, curr_value, curr_stmt); - } - res_seq.push_back(curr_stmt); - } - curr_stmt = rest; - curr_node = attr_node_; - curr_value = attr_value_; - } - } - - if (curr_stmt.defined()) { - // keep attr_node_, attr_node_ - if (res_seq.size() == 0) { - return {curr_stmt}; - } - if (curr_node.defined()) { - curr_stmt = AttrStmt::make( - curr_node, attr_key_, curr_value, curr_stmt); - } - res_seq.push_back(curr_stmt); - // reset - attr_node_ = ObjectRef(); - attr_value_ = Expr(); - } - return res_seq; - } - // value comparison that also compares content of int constant static bool ValueSame(const Expr& a, const Expr& b) { if (a.same_as(b)) return true; + if (!a.defined() || !b.defined()) return false; if (a->type_index() != b->type_index()) return false; if (a.dtype() != b.dtype()) return false; if (const IntImm* op = a.as()) { diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index 11cf57490450c..aa8ebe1eb19bd 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -106,14 +106,16 @@ class CandidateSelector final : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const Block* op) final { - bool temp = no_split_; - this->VisitStmt(op->first); - // erase the no split state of first when visit rest. - std::swap(temp, no_split_); - this->VisitStmt(op->rest); - // restore the no split flag. - no_split_ = no_split_ || temp; + void VisitStmt_(const SeqStmtNode* op) final { + bool init_no_split = no_split_; + for (Stmt stmt : op->seq) { + // erase the no split state of before visiting the next one. + bool temp = init_no_split; + std::swap(temp, no_split_); + this->VisitStmt(stmt); + // restore the no split flag. + no_split_ = no_split_ || temp; + } } void VisitExpr_(const Call* op) final { @@ -402,16 +404,6 @@ LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, return std::make_pair(interval, cond_set); } -Stmt AppendStmts(const Stmt& a, const Stmt& b) { - if (!a.defined()) { - return b; - } else if (!b.defined()) { - return a; - } else { - return Block::make(a, b); - } -} - /* * Tries to recursively partition the range of the variable (given by var) of * the for loop (given by node and stmt) into a @@ -577,8 +569,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, } } } - s = AppendStmts(pre_stmt, mid_stmt); - s = AppendStmts(s, post_stmt); + s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt); } else { Expr cond = const_true(); if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index f77443a1b4c35..4712bccb415a6 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -185,7 +185,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { Var buffer_var = Downcast(call->args[2+size+i]); stores[i] = Store::make(buffer_var, values[i], 0, pred); } - return Block::make(stores); + return SeqStmt::Flatten(stores); } // Whether the threadIdx.x is involved in reduction. if (vred[0].scope.dim_index == 0) { @@ -218,7 +218,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { {Expr(group_extent), Expr(reduce_extent)}, pred, Evaluate::make(0)); } - return MergeSeq(seq); + return SeqStmt::Flatten(seq); } // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode *combiner, @@ -252,7 +252,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (size_t i = 0; i < size; ++i) { stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true()); } - return Block::make(stores); + return SeqStmt::Flatten(stores); }; // Step one, check for if (reduce_align > reduce_extent) { @@ -280,11 +280,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { seq.emplace_back(SyncThread("warp")); } if (in_warp_seq.size() != 0) { - Stmt warp_body = MergeSeq(in_warp_seq); + Stmt warp_body = SeqStmt::Flatten(in_warp_seq); seq.emplace_back(IfThenElse::make(in_warp_cond, warp_body)); seq.emplace_back(SyncThread("shared")); } - return MergeSeq(seq); + return SeqStmt::Flatten(seq); } // Flatten the thread index. // Also return a warp number, diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index 0cc9ea21a61a2..c0b98793c7f91 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -72,11 +72,14 @@ class BuiltinLower : public StmtExprMutator { auto stmt = StmtExprMutator::VisitStmt(s); CHECK_EQ(run_shape_stack_, 0); CHECK_EQ(run_array_stack_, 0); - while (prep_seq_.size() != 0) { - stmt = Block::make(prep_seq_.back(), stmt); - prep_seq_.pop_back(); + + if (prep_seq_.size() != 0) { + Stmt ret = SeqStmt::Flatten(prep_seq_, stmt); + prep_seq_.clear(); + return ret; + } else { + return stmt; } - return stmt; } Stmt VisitStmt_(const Allocate* op) { @@ -107,12 +110,12 @@ class BuiltinLower : public StmtExprMutator { intrinsic::tvm_throw_last_error, {}, Call::Intrinsic)); - Stmt body = Block::make( + Stmt body = SeqStmt({ IfThenElse::make(Call::make(DataType::Bool(1), intrinsic::tvm_handle_is_null, {op->buffer_var}, Call::PureIntrinsic), throw_last_error), - op->body); + op->body}); Stmt alloca = LetStmt::make( op->buffer_var, @@ -133,7 +136,7 @@ class BuiltinLower : public StmtExprMutator { op->buffer_var}, Call::Extern); Stmt free_stmt = IfThenElse::make(free_op != make_zero(DataType::Int(32)), throw_last_error); - body = Block::make(alloca, free_stmt); + body = SeqStmt({alloca, free_stmt}); body = AttrStmt::make( op->buffer_var, attr::storage_alignment, make_const(DataType::Int(32), runtime::kTempAllocaAlignment), diff --git a/src/pass/make_api.cc b/src/pass/make_api.cc index 274383f866916..f065502db6b4d 100644 --- a/src/pass/make_api.cc +++ b/src/pass/make_api.cc @@ -189,7 +189,7 @@ LoweredFunc MakeAPI(Stmt body, DataType::Int(32), intrinsic::tvm_call_packed, {StringImm::make(runtime::symbol::tvm_set_device), device_type, device_id}, Call::Intrinsic))); - body = Block::make(set_device, body); + body = SeqStmt({set_device, body}); } n->body = MergeNest( {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); diff --git a/src/pass/remove_no_op.cc b/src/pass/remove_no_op.cc index 8b0b57c109842..68918708b9fb0 100644 --- a/src/pass/remove_no_op.cc +++ b/src/pass/remove_no_op.cc @@ -93,15 +93,35 @@ class NoOpRemover : public StmtMutator { if (HasSideEffect(op->value)) return GetRef(op); return Evaluate::make(0); } - Stmt VisitStmt_(const Block* op) final { - Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); - if (is_no_op(op->first)) { - return op->rest; - } else if (is_no_op(op->rest)) { - return op->first; + + Stmt VisitStmt_(const SeqStmtNode* op) final { + Stmt ret = StmtMutator::VisitSeqStmt_(op, true); + op = ret.as(); + CHECK(op != nullptr); + bool need_compact = false; + for (size_t i = 0; i < op->size(); ++i) { + if (is_no_op(op->seq[i])) need_compact = true; + } + if (need_compact) { + auto n = CopyOnWrite(op); + size_t top = 0; + for (size_t i = 0; i < n->seq.size(); ++i) { + if (!is_no_op(n->seq[i])) { + n->seq.Set(top++, n->seq[i]); + } + } + if (top == 1) { + return n->seq[0]; + } else { + n->seq.resize(top); + return Stmt(n); + } } else { - return stmt; + if (op->size() == 1) { + return op->seq[0]; + } else { + return ret; + } } } @@ -118,7 +138,7 @@ class NoOpRemover : public StmtMutator { for (Expr e : values) { if (HasSideEffect(e)) { if (stmt.defined()) { - stmt = Block::make(stmt, Evaluate::make(e)); + stmt = SeqStmt({stmt, Evaluate::make(e)}); } else { stmt = Evaluate::make(e); } diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index 6ace4f7f85b45..85cf2b92f9e4f 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -216,7 +216,7 @@ class ThreadSyncInserter : public StmtExprMutator { } // Mutate after query, to avoid stmt change. auto ret = StmtExprMutator::VisitStmt(stmt); - ret = Block::make(barrier, ret); + ret = SeqStmt({barrier, ret}); return ret; } else { return StmtExprMutator::VisitStmt(stmt); @@ -313,10 +313,10 @@ class ThreadSyncInserter : public StmtExprMutator { rw_stats_.clear(); Stmt kinit = Evaluate::make( Call::make(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, Call::Intrinsic)); - body = Block::make(kinit, body); + body = SeqStmt({kinit, body}); body = AttrStmt::make( op->node, op->attr_key, op->value, body); - return Block::make(prep, body); + return SeqStmt({prep, body}); } Stmt MakeGlobalBarrier() { CHECK(sync_scope_.rank == StorageRank::kGlobal); diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc index 9fc87f3a0d6bf..7826a9b3a666f 100644 --- a/src/pass/unroll_loop.cc +++ b/src/pass/unroll_loop.cc @@ -120,26 +120,21 @@ class LoopUnroller : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } - Stmt VisitStmt_(const Block* op) final { - Stmt first = this->VisitStmt(op->first); - // cleanup state - int step_count = step_count_; - int unroll_depth = unroll_depth_; - int normal_loop_depth = normal_loop_depth_; - step_count_ = 0; - unroll_depth_ = 0; - normal_loop_depth_ = 0; - // work on rest part - Stmt rest = this->VisitStmt(op->rest); - step_count_ += step_count; - normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); - unroll_depth_ = std::max(unroll_depth_, unroll_depth); - if (first.same_as(op->first) && - rest.same_as(op->rest)) { - return GetRef(op); - } else { - return Block::make(first, rest); - } + Stmt VisitStmt_(const SeqStmtNode* op) final { + auto fmutate = [this](const Stmt& s) { + int step_count = step_count_; + int unroll_depth = unroll_depth_; + int normal_loop_depth = normal_loop_depth_; + step_count_ = 0; + unroll_depth_ = 0; + normal_loop_depth_ = 0; + Stmt ret = this->VisitStmt(s); + step_count_ += step_count; + normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); + unroll_depth_ = std::max(unroll_depth_, unroll_depth); + return ret; + }; + return StmtMutator::VisitSeqStmt_(op, false, fmutate); } Stmt Unroll(const For* op) { @@ -149,17 +144,13 @@ class LoopUnroller : public StmtExprMutator { if (value == 0) return Evaluate::make(0); Stmt body = op->body; Map vmap; - Stmt unrolled; + Array unrolled; for (int i = 0; i < value; ++i) { vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i)); Stmt step = Substitute(body, vmap); - if (unrolled.defined()) { - unrolled = Block::make(unrolled, step); - } else { - unrolled = step; - } + unrolled.push_back(step); } - return unrolled; + return SeqStmt::Flatten(unrolled); } private: diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index b177d6f8d22fe..2d494522b2119 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -53,7 +53,7 @@ Stmt MakePipeline(const Stage& s, if (consumer.defined() && !is_no_op(consumer)) { consumer = ProducerConsumer::make(s->op, false, consumer); - pipeline = Block::make(producer, consumer); + pipeline = SeqStmt({producer, consumer}); } pipeline = s->op->BuildRealize(s, dom_map, pipeline); // use attribute to mark scope of the operation. diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index a10ddd413c205..a37f6f97d920e 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -155,6 +155,9 @@ TEST(IRF, StmtMutator) { Expr VisitExpr_(const Add* op) final { return op->a; } + Stmt VisitStmt_(const SeqStmtNode* op) final { + return StmtMutator::VisitSeqStmt_(op, true); + } Expr VisitExpr(const Expr& expr) final { return ExprMutator::VisitExpr(expr); } @@ -219,6 +222,35 @@ TEST(IRF, StmtMutator) { auto res = v(std::move(body)); CHECK(res.as()->value.as()->args[0].same_as(x)); } + { + auto body = fmakealloc(); + Stmt body2 = Evaluate::make(1); + auto* ref2 = body2.get(); + auto* extentptr = body.as()->extents.get(); + // construct a recursive SeqStmt. + body = SeqStmt({body}); + body = SeqStmt({body, body2}); + body = SeqStmt({body, body2}); + body = v(std::move(body)); + // the seq get flattened + CHECK(body.as()->size() == 3); + CHECK(body.as()->seq[0].as()->extents.get() == extentptr); + CHECK(body.as()->seq[1].get() == ref2); + } + + { + // Cannot cow because of bref + auto body = fmakealloc(); + Stmt body2 = Evaluate::make(1); + auto* extentptr = body.as()->extents.get(); + // construct a recursive SeqStmt. + body = SeqStmt({body}); + auto bref = body; + body = SeqStmt({body, body2}); + body = v(std::move(body)); + // the seq get flattened + CHECK(body.as()->seq[0].as()->extents.get() != extentptr); + } } int main(int argc, char ** argv) { diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 1f101a1e92e8b..c3c40cf740ad2 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -123,12 +123,12 @@ def test_outer_product(): assert ibody.extent.name == 'm' #Check loop body jblock = ibody.body - assert isinstance(jblock, tvm.stmt.Block) - jbody = jblock.first + assert isinstance(jblock, tvm.stmt.SeqStmt) + jbody = jblock[0] assert isinstance(jbody, tvm.stmt.AssertStmt) assert isinstance(jbody.message, tvm.expr.StringImm) assert jbody.message.value == "index out of range!" - jbody = jblock.rest + jbody = jblock[1] assert isinstance(jbody, tvm.stmt.Provide) assert jbody.func.name == 'c' assert len(jbody.args) == 2 @@ -191,12 +191,12 @@ def fanout(n, a): assert abody.func.name == 'sigma' #Check i loop body rbody = abody.body - assert isinstance(rbody.first, tvm.stmt.Provide) - assert rbody.first.func.name == 'sigma' - assert len(rbody.first.args) == 1 - assert rbody.first.args[0].value == 0 + assert isinstance(rbody[0], tvm.stmt.Provide) + assert rbody[0].func.name == 'sigma' + assert len(rbody[0].args) == 1 + assert rbody[0].args[0].value == 0 #Check fanout loop - jloop = rbody.rest.first + jloop = rbody[1] assert jloop.loop_var.name == 'j' assert jloop.min.value == 0 assert jloop.extent.value == 3 @@ -214,7 +214,7 @@ def fanout(n, a): assert value.b.name == 'a' assert len(value.b.args) == 1 assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var) - divide= rbody.rest.rest.first + divide= rbody[2] assert isinstance(divide, tvm.stmt.Provide) assert len(divide.args) == 1 assert divide.args[0].value == 0 @@ -224,7 +224,7 @@ def fanout(n, a): assert len(value.a.args) == 1 assert value.a.args[0].value == 0 assert abs(value.b.value - (1 / 3.0)) < 1e-5 - write = rbody.rest.rest.rest + write = rbody[3] assert isinstance(write, tvm.stmt.Provide) assert write.func.name == 'b' assert write.value.name == 'sigma' @@ -257,9 +257,9 @@ def looptype(a, b, c): ir = d.op.body except: return - iloop = ir.first - jloop = ir.rest.first - kloop = ir.rest.rest + iloop = ir[0] + jloop = ir[1] + kloop = ir[2] assert iloop.for_type == tvm.stmt.For.Parallel assert jloop.for_type == tvm.stmt.For.Vectorized assert kloop.for_type == tvm.stmt.For.Unrolled @@ -802,7 +802,7 @@ def sum_array(inputs): inputs = [] for i in range(n): inputs.append(tvm.placeholder((10,), name='t%s' % i, dtype='float32')) - + out = sum_array(tvm.convert(inputs)) assert len(out.op.inputs) == n diff --git a/tests/python/unittest/test_ir_builder.py b/tests/python/unittest/test_ir_builder.py index c910c62424f0c..8b9da90c914cc 100644 --- a/tests/python/unittest/test_ir_builder.py +++ b/tests/python/unittest/test_ir_builder.py @@ -34,8 +34,8 @@ def test_for(): body = body.body assert isinstance(body, tvm.stmt.For) body = body.body - assert isinstance(body, tvm.stmt.Block) - assert isinstance(body.rest, tvm.stmt.For) + assert isinstance(body, tvm.stmt.SeqStmt) + assert isinstance(body[1], tvm.stmt.For) def test_if(): ib = tvm.ir_builder.create() diff --git a/tests/python/unittest/test_lang_constructor.py b/tests/python/unittest/test_lang_constructor.py index a0d39f2daffe4..fe329494e24e5 100644 --- a/tests/python/unittest/test_lang_constructor.py +++ b/tests/python/unittest/test_lang_constructor.py @@ -146,12 +146,6 @@ def test_stmt_constructor(): assert isinstance(x, tvm.stmt.AttrStmt) assert x.value.value == 1 - x = tvm.stmt.Block(tvm.stmt.Evaluate(11), - nop) - assert isinstance(x, tvm.stmt.Block) - assert x.first.value.value == 11 - assert x.rest == nop - x = tvm.stmt.AssertStmt(tvm.const(1, "uint1"), tvm.convert("hellow"), nop) diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index 44aca3b324bb5..7e9f59bf348da 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -171,8 +171,8 @@ def intrin_func(ins, outs): s = tvm.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) - assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate) - assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate) + assert isinstance(stmt.body.body.body[0], tvm.stmt.Evaluate) + assert isinstance(stmt.body.body.body[1].body, tvm.stmt.Evaluate) def test_tensor_scan(): m = tvm.var("m") diff --git a/tests/python/unittest/test_pass_equal.py b/tests/python/unittest/test_pass_equal.py index 8bd491bb5c8e3..1f5bb9cba9a9d 100644 --- a/tests/python/unittest/test_pass_equal.py +++ b/tests/python/unittest/test_pass_equal.py @@ -53,6 +53,7 @@ def func2(): A[i] = A[i] + 1 with ib.for_range(0, 10, name="j") as j: A[j] = A[j] + 2 + A[j] = A[j] + 2 return ib.get() assert tvm.ir_pass.Equal(func1(), func1()) diff --git a/tests/python/unittest/test_pass_inject_vthread.py b/tests/python/unittest/test_pass_inject_vthread.py index cb2fa201aff43..a3d059787ab8a 100644 --- a/tests/python/unittest/test_pass_inject_vthread.py +++ b/tests/python/unittest/test_pass_inject_vthread.py @@ -92,8 +92,8 @@ def test_vthread_if_then_else(): B[i] = A[i * nthread + tx] + 2 stmt = ib.get() stmt = tvm.ir_pass.InjectVirtualThread(stmt) - assert stmt.body.body.body.first.else_case != None - assert stmt.body.body.body.rest.else_case == None + assert stmt.body.body.body[0].else_case != None + assert stmt.body.body.body[1].else_case == None if __name__ == "__main__": test_vthread_extern() diff --git a/tests/python/unittest/test_pass_lift_attr_scope.py b/tests/python/unittest/test_pass_lift_attr_scope.py index d786ca8c8108c..b281e17bc6335 100644 --- a/tests/python/unittest/test_pass_lift_attr_scope.py +++ b/tests/python/unittest/test_pass_lift_attr_scope.py @@ -31,10 +31,29 @@ def test_coproc_lift(): with ib.for_range(0, 10, name="j") as j: ib.scope_attr(cp, "coproc_uop_scope", value) A[j] = A[j] + 2 + A[j] = A[j] + 3 + A[j] = A[j] + 3 body = ib.get() body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope") assert body.body.body.node == cp + # only able to lift to the common pattern of the last two fors. + ib = tvm.ir_builder.create() + A = ib.allocate("float32", n, name="A", scope="global") + with ib.for_range(0, n, name="i") as i: + with ib.for_range(0, 10, name="j") as j: + A[j] = A[j] + 1 + with ib.for_range(0, 10, name="j") as j: + ib.scope_attr(cp, "coproc_uop_scope", value) + A[i] = A[i] + 1 + with ib.for_range(0, 10, name="j") as j: + ib.scope_attr(cp, "coproc_uop_scope", value) + A[i] = A[i] + 2 + + body = ib.get() + body = tvm.ir_pass.LiftAttrScope(body, "coproc_uop_scope") + assert body.body.body.body[1].node == cp + assert len(body.body.body.body) == 2 if __name__ == "__main__": test_coproc_lift() diff --git a/tests/python/unittest/test_pass_loop_partition.py b/tests/python/unittest/test_pass_loop_partition.py index 0217095067542..c58b2f6dd2988 100644 --- a/tests/python/unittest/test_pass_loop_partition.py +++ b/tests/python/unittest/test_pass_loop_partition.py @@ -64,7 +64,7 @@ def test_basic(): stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body.body.first)) + assert('if' not in str(stmt.body.body.body[0])) def test_const_loop(): n = 21 @@ -79,7 +79,7 @@ def test_const_loop(): stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.LoopPartition(stmt, True) stmt = tvm.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body.body.first)) + assert('if' not in str(stmt.body.body.body[0])) def test_multi_loop(): ib = tvm.ir_builder.create() @@ -95,7 +95,7 @@ def test_multi_loop(): stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert(not any(collect_visit(stmt.body.first, lambda x: isinstance(x, tvm.stmt.IfThenElse)))) + assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.stmt.IfThenElse)))) def test_multi_if(): ib = tvm.ir_builder.create() @@ -115,7 +115,7 @@ def test_multi_if(): stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.first)) + assert('if' not in str(stmt.body[0])) def test_thread_axis(): m = tvm.var('m') @@ -134,7 +134,7 @@ def test_thread_axis(): stmt = tvm.schedule.ScheduleOps(s, bounds) stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body.body.first)) + assert('if' not in str(stmt.body.body.body[0])) def test_vectorize(): n = tvm.var('n') @@ -169,7 +169,7 @@ def test_condition(): stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, False) stmt = tvm.ir_pass.Simplify(stmt) - assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select)))) + assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.expr.Select)))) def test_condition_EQ(): ib = tvm.ir_builder.create() @@ -181,7 +181,7 @@ def test_condition_EQ(): stmt = ib.get() stmt = tvm.ir_pass.LoopPartition(stmt, True) stmt = tvm.ir_pass.Simplify(stmt) - assert(not any(collect_visit(stmt.first, lambda x: isinstance(x, tvm.expr.Select)))) + assert(not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.expr.Select)))) def test_thread_axis2(): n = tvm.convert(4096) @@ -197,7 +197,7 @@ def test_thread_axis2(): s[C].bind(bx, tvm.thread_axis("blockIdx.x")) s[C].bind(tx, tvm.thread_axis("threadIdx.x")) stmt = lower(s, [A, B]) - for_body = stmt.body.body.body.body.body.first + for_body = stmt.body.body.body.body.body[0] assert('threadIdx' not in str(for_body.extent)) def test_everything_during_deduction(): diff --git a/tests/python/unittest/test_pass_remove_no_op.py b/tests/python/unittest/test_pass_remove_no_op.py index d1ea450b53de7..d287b8591fb32 100644 --- a/tests/python/unittest/test_pass_remove_no_op.py +++ b/tests/python/unittest/test_pass_remove_no_op.py @@ -16,6 +16,9 @@ # under the License. import tvm +def nop(): + return tvm.stmt.Evaluate(0) + def test_remove_no_op(): i = tvm.var('i') j = tvm.var('j') @@ -37,12 +40,13 @@ def test_remove_no_op(): store = tvm.make.Store(Ab.data, tvm.make.Load(dtype, Ab.data, i) + 1, i + 1) - stmt2 = tvm.make.Block(stmt, store) + stmt2 = tvm.stmt.SeqStmt([nop(), tvm.stmt.SeqStmt([store, nop()])]) assert(tvm.ir_pass.RemoveNoOp(stmt2) == store) # remove zero extent loop stmt3 = tvm.make.For(i, 0, 0, 0, 0, store) ret = tvm.ir_pass.RemoveNoOp(stmt3) assert(isinstance(ret, tvm.stmt.Evaluate)) + if __name__ == "__main__": test_remove_no_op() diff --git a/tests/python/unittest/test_pass_storage_sync.py b/tests/python/unittest/test_pass_storage_sync.py index 5c7024abe518a..3202d7b7d3a82 100644 --- a/tests/python/unittest/test_pass_storage_sync.py +++ b/tests/python/unittest/test_pass_storage_sync.py @@ -119,10 +119,10 @@ def __check_list(tvm_array, py_list): stmt = ib.get() stmt = tvm.ir_pass.CoProcSync(stmt) - slist = tvm.make.stmt_list(stmt.first.body.body) + slist = tvm.make.stmt_list(stmt[0].body.body) push_st = slist[2] slist = tvm.make.stmt_list(slist[-1]) - pop_st = slist[0].body.first + pop_st = slist[0].body[0] assert(push_st.value.name == "cop.coproc_dep_push") assert(__check_list(push_st.value.args, [2,3])) diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py index fc8d0dc5f5c82..c94ffe0bde14b 100644 --- a/tests/python/unittest/test_pass_unroll.py +++ b/tests/python/unittest/test_pass_unroll.py @@ -43,13 +43,13 @@ def test_unroll_loop(): ib.scope_attr(tvm.const(0, "int32"), "pragma_auto_unroll_max_step", 16) ib.emit(stmt) wrapped = ib.get() - wrapped = tvm.make.Block(wrapped, stmt) + wrapped = tvm.stmt.SeqStmt([wrapped, stmt]) assert isinstance(ret, tvm.stmt.For) ret = tvm.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False) - assert isinstance(ret.first, tvm.stmt.For) - assert ret.first.for_type == tvm.stmt.For.Unrolled - assert isinstance(ret.rest, tvm.stmt.For) - assert ret.rest.for_type != tvm.stmt.For.Unrolled + assert isinstance(ret[0], tvm.stmt.For) + assert ret[0].for_type == tvm.stmt.For.Unrolled + assert isinstance(ret[1], tvm.stmt.For) + assert ret[1].for_type != tvm.stmt.For.Unrolled def test_unroll_fake_loop(): ib = tvm.ir_builder.create() @@ -65,7 +65,7 @@ def test_unroll_fake_loop(): stmt = ib.get() ret = tvm.ir_pass.UnrollLoop(stmt, 8, 0, 1, True) - assert isinstance(ret.first, tvm.stmt.Store) + assert isinstance(ret[0], tvm.stmt.Store) def test_unroll_single_count_loops(): n = tvm.var('n') diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 5275aec4db90c..b10224376dfac 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -71,7 +71,7 @@ def test_schedule_scan(): s = tvm.create_schedule(res.op) s = s.normalize() ir = tvm.lower(s, [s_state], simple_mode=True) - assert not hasattr(ir.body.body.body.body.rest.body.body.rest.body, "condition") + assert not hasattr(ir.body.body.body.body[1].body.body[1].body, "condition") bounds = tvm.schedule.InferBound(s) assert(bounds[res.op.scan_axis].min.value == 1) stmt = tvm.schedule.ScheduleOps(s, bounds) diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index 12ef7daac731f..dbce9a7b91022 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -135,7 +135,7 @@ def _do_fold(stmt): if body == stmt.body: return stmt ends = list(reversed(ends)) - body = tvm.make.stmt_seq(*(begins + [body] + ends)) + body = tvm.stmt.stmt_seq(*(begins + [body] + ends)) return tvm.make.AttrStmt( stmt.node, stmt.attr_key, stmt.value, body) return None @@ -307,7 +307,7 @@ def _do_fold(stmt): success[0] = True sync = tvm.make.Call( "int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0) - return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync)) + return tvm.stmt.SeqStmt([stmt.body, tvm.make.Evaluate(sync)]) if _match_pragma(stmt, "trim_loop"): op = stmt.body assert isinstance(op, tvm.stmt.For)