Skip to content

Commit

Permalink
[REFACTOR][IR] Introduce SeqStmt to replace ir::Block (apache#4627)
Browse files Browse the repository at this point in the history
* [REFACTOR][IR] Introduce SeqStmt to replace Block

ir::Block was used to represent a sequence of Stmts in the original low-level IR.
The nested ir::Block structure is not really friendly for recursive visits,
especially when the statements are unrolled.

This PR introduce a SeqStmt that directly stores a sequence of statements in an Array container.
The new SeqStmt will be used as a replacement of the original Block structure.

* [REFACTOR] Migrate use of Block to SeqStmt.

* [REFACTOR] Remove Block

* Add more comments per yizhi's comment
  • Loading branch information
tqchen authored and alexwong committed Feb 26, 2020
1 parent 6b74a72 commit 13ef863
Show file tree
Hide file tree
Showing 55 changed files with 612 additions and 458 deletions.
3 changes: 3 additions & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,9 @@ inline bool is_no_op(const Stmt& stmt) {
if (const auto* op = stmt.as<ir::Evaluate>()) {
return is_const(op->value);
}
if (const auto* op = stmt.as<ir::SeqStmtNode>()) {
return op->seq.size() == 0;
}
return false;
}

Expand Down
111 changes: 99 additions & 12 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1022,25 +1022,112 @@ class Realize : public StmtNode {
};

/*!
* \brief A sequence of statements.
* \brief The container of seq statement.
* Represent a sequence of statements.
*/
class Block : public StmtNode {
class SeqStmtNode : public StmtNode {
public:
/*! \brief The first statement. */
Stmt first;
/*! \brief The restof statments. */
Stmt rest;
/*! \brief internal sequence content. */
Array<Stmt> seq;

/*! \return get the size of the sequence */
size_t size() const {
return seq.size();
}
/*!
* \brief Get the index-th element in the sequence.
*/
Stmt operator[](size_t index) const {
return seq[index];
}

void VisitAttrs(AttrVisitor* v) {
v->Visit("first", &first);
v->Visit("rest", &rest);
v->Visit("seq", &seq);
}

TVM_DLL static Stmt make(Stmt first, Stmt rest);
TVM_DLL static Stmt make(const std::vector<Stmt> &stmts);
static constexpr const char* _type_key = "SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};

/*! \brief Sequence statement. */
class SeqStmt : public Stmt {
public:
/*!
* \brief Construct SeqStmt.
* \param seq The sequence.
*/
TVM_DLL explicit SeqStmt(Array<Stmt> seq);

/*! \return get the size of the sequence */
size_t size() const {
return operator->()->size();
}
/*!
* \brief Get the index-th element in the sequence.
*/
Stmt operator[](size_t index) const {
return (*(operator->()))[index];
}
/*!
* \brief Construct a sequence statement by flattening
* all the arrays and sequences in the arguments
* recursively.
*
* - When an argument is nullptr, it will be ignored.
* - When an argument is an array or a SeqStmt, it will be flattened recursively.
* - When an argument is a consumer block in ProducerConsumer, the consumer
* tag will be dropped as such information is not useful in lowering.
* - A normal Stmt will be appended to the end of the sequence.
*
* \note This function can directly return an element
* if it is the only element in the sequence.
*
* \param seq_args The list of arguments to be flattened.
* \tparam Args arguments
* \return The constructed statement
*/
template<typename ...Args>
static Stmt Flatten(Args&&... seq_args) {
Array<Stmt> seq;
runtime::detail::for_each(
Flattener(&seq), std::forward<Args>(seq_args)...);
if (seq.size() == 1) return seq[0];
return SeqStmt(seq);
}
/*! \brief Helper class to flatten sequence of arguments into Array. */
class Flattener {
public:
explicit Flattener(Array<Stmt>* seq)
: seq_(seq) {}

void operator()(size_t i, const Stmt& stmt) const {
if (!stmt.defined()) return;
if (auto* op = stmt.as<SeqStmtNode>()) {
operator()(0, op->seq);
} else if (auto* op = stmt.as<ProducerConsumer>()) {
// NOTE: The consumer block annotation was not as useful and can be safely dropped.
if (!op->is_producer) {
operator()(0, op->body);
} else {
seq_->push_back(stmt);
}
} else {
seq_->push_back(stmt);
}
}

template<typename T>
void operator()(size_t i, const T& seq) const {
for (auto v : seq) {
this->operator()(0, v);
}
}

private:
Array<Stmt>* seq_;
};

static constexpr const char* _type_key = "Block";
TVM_DECLARE_FINAL_OBJECT_INFO(Block, StmtNode);
TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode);
};

/*!
Expand Down
23 changes: 19 additions & 4 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Object* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
Expand All @@ -276,7 +276,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Prefetch);
IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable;
}
Expand Down Expand Up @@ -408,7 +408,7 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const Provide* op) override;
void VisitStmt_(const Realize* op) override;
void VisitStmt_(const Prefetch* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override;
};

Expand Down Expand Up @@ -502,8 +502,23 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const Provide* op) override;
Stmt VisitStmt_(const Realize* op) override;
Stmt VisitStmt_(const Prefetch* op) override;
Stmt VisitStmt_(const Block* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
Stmt VisitStmt_(const Evaluate* op) override;
/*!
* \brief Alternative advance method for SeqStmtNode.
*
* This function can be called when a child class override
* VisitStmt_(const SeqStmtNode*) to introduce
* the special behavior to visit
*
* \param op The sequence.
* \param flatten_before_visit Whether to flatten the sequence before visit.
* \param fmutate The mutate function, can be nullptr, which defaults to Visit.
* \return The mutated result.
*/
Stmt VisitSeqStmt_(const SeqStmtNode* op,
bool flatten_before_visit,
std::function<Stmt(const Stmt&)> fmutate = nullptr);
// internal helper.
class Internal;
};
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@ class Array : public ObjectRef {
ArrayNode* n = this->CopyOnWrite();
n->data.push_back(item);
}
/*!
* \brief Resize the array.
* \param size The new size.
*/
inline void resize(size_t size) {
ArrayNode* n = this->CopyOnWrite();
n->data.resize(size);
}
/*!
* \brief set i-th element of the array.
* \param i The index
Expand Down
8 changes: 3 additions & 5 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from .. import _api_internal as _tvm_internal
from .. import expr as _expr
from .. import make as _make
from .. import stmt as _stmt

from .. import api as _api
from .. import ir_pass as _ir_pass

Expand All @@ -48,11 +50,7 @@ def concat_list_to_block(lst):
n = len(lst)
if n == 1:
return lst[0]
body = lst[n - 1]
for i in range(1, n):
stmt = lst[n - 1 - i]
body = _make.Block(stmt, body)
return body
return _stmt.SeqStmt(lst)


def visit_list_to_block(visit, lst):
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,16 @@ def _pop_seq(self):
seq = self._seq_stack.pop()
if not seq or callable(seq[-1]):
seq.append(_make.Evaluate(0))
stmt = seq[-1]
seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x)))
ret_seq = [seq[-1]]

for s in reversed(seq[:-1]):
if callable(s):
stmt = s(stmt)
ret_seq = [s(seqwrap(ret_seq))]
else:
assert isinstance(s, _stmt.Stmt)
stmt = _make.Block(s, stmt)
return stmt
ret_seq.append(s)
return seqwrap(ret_seq)

def emit(self, stmt):
"""Emit a statement to the end of current scope.
Expand Down
37 changes: 22 additions & 15 deletions python/tvm/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,20 +289,23 @@ def __init__(self,


@register_node
class Block(Stmt):
"""Block node.
class SeqStmt(Stmt):
"""Sequence of statements.
Parameters
----------
first : Stmt
The first statement.
rest : Stmt
The following statement.
seq : List[Stmt]
The statements
"""
def __init__(self, first, rest):
def __init__(self, seq):
self.__init_handle_by_constructor__(
_make.Block, first, rest)
_make.SeqStmt, seq)

def __getitem__(self, i):
return self.seq[i]

def __len__(self):
return len(self.seq)


@register_node
Expand Down Expand Up @@ -375,12 +378,14 @@ def stmt_seq(*args):
stmt : Stmt
The combined statement.
"""
ret = None
ret = []
for value in args:
if not isinstance(value, Stmt):
value = Evaluate(value)
ret = value if ret is None else Block(ret, value)
return ret if ret else Evaluate(0)
ret.append(value)
if len(ret) == 1:
return ret[0]
return SeqStmt(ret)


def stmt_list(stmt):
Expand All @@ -395,12 +400,14 @@ def stmt_list(stmt):
stmt_list : list of Stmt
The unpacked list of statements
"""
if isinstance(stmt, Block):
return stmt_list(stmt.first) + stmt_list(stmt.rest)
if isinstance(stmt, SeqStmt):
res = []
for x in stmt:
res += stmt_list(x)
return res
if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]


_make.stmt_list = stmt_list
_make.stmt_seq = stmt_seq
9 changes: 6 additions & 3 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ TVM_REGISTER_GLOBAL("make._cast")
TVM_REGISTER_GLOBAL("make._range_by_min_extent")
.set_body_typed(Range::make_by_min_extent);


TVM_REGISTER_GLOBAL("make.SeqStmt")
.set_body_typed([](Array<Stmt> seq) {
return SeqStmt(std::move(seq));
});

TVM_REGISTER_GLOBAL("make.For")
.set_body_typed([](
VarExpr loop_var, Expr min, Expr extent,
Expand Down Expand Up @@ -163,9 +169,6 @@ REGISTER_MAKE(IfThenElse);
REGISTER_MAKE(Evaluate);

// overloaded, needs special handling
TVM_REGISTER_GLOBAL("make.Block")
.set_body_typed(static_cast<Stmt (*)(Stmt, Stmt)>(Block::make));

// has default args
TVM_REGISTER_GLOBAL("make.Allocate")
.set_body_typed([](
Expand Down
Loading

0 comments on commit 13ef863

Please sign in to comment.