Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][REFACTOR] Deprecate FreeStmt #5890

Merged
merged 1 commit into from
Jun 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -545,35 +545,6 @@ class Allocate : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};

/*! \brief Free the resources in the buffer before the scope ends. */
class FreeNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;

void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); }

bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const {
return equal(buffer_var, other->buffer_var);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); }

static constexpr const char* _type_key = "tir.Free";
TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
};

/*!
* \brief Managed reference to FreeNode.
* \sa FreeNode
*/
class Free : public Stmt {
public:
TVM_DLL Free(Var buffer_var);

TVM_DEFINE_OBJECT_REF_METHODS(Free, Stmt, FreeNode);
};

/*!
* \brief The container of seq statement.
* Represent a sequence of statements.
Expand Down
4 changes: 0 additions & 4 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -112,7 +111,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(ForNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(FreeNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode);
Expand Down Expand Up @@ -154,7 +152,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
void VisitStmt_(const FreeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProducerStoreNode* op) override;
void VisitStmt_(const ProducerRealizeNode* op) override;
Expand Down Expand Up @@ -246,7 +243,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Stmt VisitStmt_(const FreeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProducerStoreNode* op) override;
Stmt VisitStmt_(const ProducerRealizeNode* op) override;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
from .stmt import Free, ProducerRealize, SeqStmt
from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list

from .function import PrimFunc
Expand Down
14 changes: 0 additions & 14 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,20 +258,6 @@ def __init__(self, node, attr_key, value, body):
_ffi_api.AttrStmt, node, attr_key, value, body)


@tvm._ffi.register_object("tir.Free")
class Free(Stmt):
"""Free node.

Parameters
----------
buffer_var : Var
The buffer variable.
"""
def __init__(self, buffer_var):
self.__init_handle_by_constructor__(
_ffi_api.Free, buffer_var)


@tvm._ffi.register_object("tir.ProducerRealize")
class ProducerRealize(Stmt):
"""ProducerRealize node.
Expand Down
1 change: 0 additions & 1 deletion src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const BufferStoreNode* op) override;
Doc VisitStmt_(const BufferRealizeNode* op) override;
Doc VisitStmt_(const AllocateNode* op) override;
Doc VisitStmt_(const FreeNode* op) override;
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const EvaluateNode* op) override;
Expand Down
6 changes: 0 additions & 6 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,12 +438,6 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const FreeNode* op) {
Doc doc;
doc << "free(" << Print(op->buffer_var) << ")";
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << PrintBody(op->then_case);
Expand Down
19 changes: 0 additions & 19 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,25 +374,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "}\n";
});

// Free
Free::Free(Var buffer_var) {
ObjectPtr<FreeNode> node = make_object<FreeNode>();
node->buffer_var = buffer_var;
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.Free").set_body_typed([](Var buffer_var) { return Free(buffer_var); });

TVM_REGISTER_NODE_TYPE(FreeNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FreeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FreeNode*>(node.get());
p->PrintIndent();
p->stream << "free " << op->buffer_var;
p->stream << '\n';
});

// Prefetch
Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
data_ = make_object<PrefetchNode>(buffer, bounds);
Expand Down
4 changes: 0 additions & 4 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
}
}

void StmtVisitor::VisitStmt_(const FreeNode* op) {}

void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
this->VisitExpr(op->condition);
this->VisitExpr(op->message);
Expand Down Expand Up @@ -381,8 +379,6 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
}
}

Stmt StmtMutator::VisitStmt_(const FreeNode* op) { return GetRef<Stmt>(op); }

// Implementations of IRTransform, PostOrderVisit and Substitute
class IRApplyVisit : public StmtExprVisitor {
public:
Expand Down
4 changes: 0 additions & 4 deletions tests/python/unittest/test_tir_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,6 @@ def test_stmt_constructor():
assert x.attr_key == "xyz"
assert x.body == nop

x = tvm.tir.Free(buffer_var)
assert isinstance(x, tvm.tir.Free)
assert x.buffer_var == buffer_var

x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"),
tvm.tir.Evaluate(11),
nop)
Expand Down