Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][IR] Introduce SeqStmt to replace ir::Block #4627

Merged
merged 4 commits into from
Jan 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,9 @@ inline bool is_no_op(const Stmt& stmt) {
if (const auto* op = stmt.as<ir::Evaluate>()) {
return is_const(op->value);
}
if (const auto* op = stmt.as<ir::SeqStmtNode>()) {
return op->seq.size() == 0;
}
return false;
}

Expand Down
111 changes: 99 additions & 12 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1022,25 +1022,112 @@ class Realize : public StmtNode {
};

/*!
* \brief A sequence of statements.
* \brief The container of seq statement.
* Represent a sequence of statements.
*/
class Block : public StmtNode {
class SeqStmtNode : public StmtNode {
public:
/*! \brief The first statement. */
Stmt first;
/*! \brief The restof statments. */
Stmt rest;
/*! \brief internal sequence content. */
Array<Stmt> seq;

/*! \return get the size of the sequence */
size_t size() const {
return seq.size();
}
/*!
* \brief Get the index-th element in the sequence.
*/
Stmt operator[](size_t index) const {
return seq[index];
}

void VisitAttrs(AttrVisitor* v) {
v->Visit("first", &first);
v->Visit("rest", &rest);
v->Visit("seq", &seq);
}

TVM_DLL static Stmt make(Stmt first, Stmt rest);
TVM_DLL static Stmt make(const std::vector<Stmt> &stmts);
static constexpr const char* _type_key = "SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};

/*! \brief Sequence statement. */
class SeqStmt : public Stmt {
public:
/*!
* \brief Construct SeqStmt.
* \param seq The sequence.
*/
TVM_DLL explicit SeqStmt(Array<Stmt> seq);

/*! \return get the size of the sequence */
size_t size() const {
return operator->()->size();
}
/*!
* \brief Get the index-th element in the sequence.
*/
Stmt operator[](size_t index) const {
return (*(operator->()))[index];
}
/*!
* \brief Construct a sequence statement by flattening
* all the arrays and sequences in the arguments
* recursively.
*
* - When an argument is nullptr, it will be ignored.
* - When an argument is an array or a SeqStmt, it will be flattened recursively.
* - When an argument is a consumer block in ProducerConsumer, the consumer
* tag will be dropped as such information is not useful in lowering.
* - A normal Stmt will be appended to the end of the sequence.
*
* \note This function can directly return an element
* if it is the only element in the sequence.
*
* \param seq_args The list of arguments to be flattened.
* \tparam Args arguments
* \return The constructed statement
*/
template<typename ...Args>
static Stmt Flatten(Args&&... seq_args) {
Array<Stmt> seq;
runtime::detail::for_each(
Flattener(&seq), std::forward<Args>(seq_args)...);
if (seq.size() == 1) return seq[0];
return SeqStmt(seq);
}
/*! \brief Helper class to flatten sequence of arguments into Array. */
class Flattener {
public:
explicit Flattener(Array<Stmt>* seq)
: seq_(seq) {}

void operator()(size_t i, const Stmt& stmt) const {
if (!stmt.defined()) return;
if (auto* op = stmt.as<SeqStmtNode>()) {
operator()(0, op->seq);
} else if (auto* op = stmt.as<ProducerConsumer>()) {
// NOTE: The consumer block annotation was not as useful and can be safely dropped.
if (!op->is_producer) {
operator()(0, op->body);
} else {
seq_->push_back(stmt);
}
} else {
seq_->push_back(stmt);
}
}

template<typename T>
void operator()(size_t i, const T& seq) const {
for (auto v : seq) {
this->operator()(0, v);
}
}

private:
Array<Stmt>* seq_;
};

static constexpr const char* _type_key = "Block";
TVM_DECLARE_FINAL_OBJECT_INFO(Block, StmtNode);
TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode);
};

/*!
Expand Down
23 changes: 19 additions & 4 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Prefetch* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Object* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
Expand All @@ -276,7 +276,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Prefetch);
IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable;
}
Expand Down Expand Up @@ -408,7 +408,7 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const Provide* op) override;
void VisitStmt_(const Realize* op) override;
void VisitStmt_(const Prefetch* op) override;
void VisitStmt_(const Block* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const Evaluate* op) override;
};

Expand Down Expand Up @@ -502,8 +502,23 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const Provide* op) override;
Stmt VisitStmt_(const Realize* op) override;
Stmt VisitStmt_(const Prefetch* op) override;
Stmt VisitStmt_(const Block* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
Stmt VisitStmt_(const Evaluate* op) override;
/*!
* \brief Alternative advance method for SeqStmtNode.
*
* This function can be called when a child class override
* VisitStmt_(const SeqStmtNode*) to introduce
* the special behavior to visit
*
* \param op The sequence.
* \param flatten_before_visit Whether to flatten the sequence before visit.
* \param fmutate The mutate function, can be nullptr, which defaults to Visit.
* \return The mutated result.
*/
Stmt VisitSeqStmt_(const SeqStmtNode* op,
bool flatten_before_visit,
std::function<Stmt(const Stmt&)> fmutate = nullptr);
// internal helper.
class Internal;
};
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@ class Array : public ObjectRef {
ArrayNode* n = this->CopyOnWrite();
n->data.push_back(item);
}
/*!
* \brief Resize the array.
* \param size The new size.
*/
inline void resize(size_t size) {
ArrayNode* n = this->CopyOnWrite();
n->data.resize(size);
}
/*!
* \brief set i-th element of the array.
* \param i The index
Expand Down
8 changes: 3 additions & 5 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from .. import _api_internal as _tvm_internal
from .. import expr as _expr
from .. import make as _make
from .. import stmt as _stmt

from .. import api as _api
from .. import ir_pass as _ir_pass

Expand All @@ -48,11 +50,7 @@ def concat_list_to_block(lst):
n = len(lst)
if n == 1:
return lst[0]
body = lst[n - 1]
for i in range(1, n):
stmt = lst[n - 1 - i]
body = _make.Block(stmt, body)
return body
return _stmt.SeqStmt(lst)


def visit_list_to_block(visit, lst):
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,16 @@ def _pop_seq(self):
seq = self._seq_stack.pop()
if not seq or callable(seq[-1]):
seq.append(_make.Evaluate(0))
stmt = seq[-1]
seqwrap = lambda x: x[0] if len(x) == 1 else _stmt.SeqStmt(list(reversed(x)))
ret_seq = [seq[-1]]

for s in reversed(seq[:-1]):
if callable(s):
stmt = s(stmt)
ret_seq = [s(seqwrap(ret_seq))]
else:
assert isinstance(s, _stmt.Stmt)
stmt = _make.Block(s, stmt)
return stmt
ret_seq.append(s)
return seqwrap(ret_seq)

def emit(self, stmt):
"""Emit a statement to the end of current scope.
Expand Down
37 changes: 22 additions & 15 deletions python/tvm/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,20 +289,23 @@ def __init__(self,


@register_node
class Block(Stmt):
"""Block node.
class SeqStmt(Stmt):
"""Sequence of statements.
Parameters
----------
first : Stmt
The first statement.
rest : Stmt
The following statement.
seq : List[Stmt]
The statements
"""
def __init__(self, first, rest):
def __init__(self, seq):
self.__init_handle_by_constructor__(
_make.Block, first, rest)
_make.SeqStmt, seq)

def __getitem__(self, i):
return self.seq[i]

def __len__(self):
return len(self.seq)


@register_node
Expand Down Expand Up @@ -375,12 +378,14 @@ def stmt_seq(*args):
stmt : Stmt
The combined statement.
"""
ret = None
ret = []
for value in args:
if not isinstance(value, Stmt):
value = Evaluate(value)
ret = value if ret is None else Block(ret, value)
return ret if ret else Evaluate(0)
ret.append(value)
if len(ret) == 1:
return ret[0]
return SeqStmt(ret)


def stmt_list(stmt):
Expand All @@ -395,12 +400,14 @@ def stmt_list(stmt):
stmt_list : list of Stmt
The unpacked list of statements
"""
if isinstance(stmt, Block):
return stmt_list(stmt.first) + stmt_list(stmt.rest)
if isinstance(stmt, SeqStmt):
res = []
for x in stmt:
res += stmt_list(x)
return res
if isinstance(stmt, ProducerConsumer):
return stmt_list(stmt.body)
return [stmt]


_make.stmt_list = stmt_list
_make.stmt_seq = stmt_seq
9 changes: 6 additions & 3 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ TVM_REGISTER_GLOBAL("make._cast")
TVM_REGISTER_GLOBAL("make._range_by_min_extent")
.set_body_typed(Range::make_by_min_extent);


TVM_REGISTER_GLOBAL("make.SeqStmt")
.set_body_typed([](Array<Stmt> seq) {
return SeqStmt(std::move(seq));
});

TVM_REGISTER_GLOBAL("make.For")
.set_body_typed([](
VarExpr loop_var, Expr min, Expr extent,
Expand Down Expand Up @@ -163,9 +169,6 @@ REGISTER_MAKE(IfThenElse);
REGISTER_MAKE(Evaluate);

// overloaded, needs special handling
TVM_REGISTER_GLOBAL("make.Block")
.set_body_typed(static_cast<Stmt (*)(Stmt, Stmt)>(Block::make));

// has default args
TVM_REGISTER_GLOBAL("make.Allocate")
.set_body_typed([](
Expand Down
Loading