diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 5400328fe219..126a6e58273f 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -126,6 +126,14 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array< /*args=*/args); } +ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, + const IRDocsifier& d) { + Map attrs = BufferAttrs(buffer, p, frame, d); + ExprDoc shape = attrs.Get("shape").value(); + ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype)); + return TIR("Buffer")->Call({shape, dtype}, {}, {}); +} + Array BufferIndices(const Array& indices, const ObjectPath& p, const IRDocsifier& d) { int n = indices.size(); diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index f0f84e81d57c..6094eefb65b1 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -16,6 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#include +#include + #include "./utils.h" namespace tvm { @@ -34,16 +37,115 @@ String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) { return "main"; } +bool IsSimpleBuffer(const tir::Buffer& buf) { + if (!buf->strides.empty()) { + return false; + } + for (const PrimExpr& shp_i : buf->shape) { + if (!tir::UndefinedVars(shp_i).empty()) { + return false; + } + } + for (const PrimExpr& stride_i : buf->strides) { + if (!tir::UndefinedVars(stride_i).empty()) { + return false; + } + } + if (!tir::UndefinedVars(buf->elem_offset).empty()) { + return false; + } else if (buf->elem_offset->IsInstance()) { + IntImm elem_offset = Downcast(buf->elem_offset); + if (elem_offset->value != 0) { + return false; + } + } + return buf.scope() == "global" && buf->data_alignment == runtime::kAllocAlignment && + buf->offset_factor == 1 && buf->buffer_type == tir::BufferType::kDefault && + !buf->axis_separators.size(); +} + +int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) { + class OccurrenceCounter : public tir::StmtExprVisitor { + public: + int count = 0; + const tir::VarNode* v = nullptr; + + void VisitExpr_(const tir::VarNode* op) final { + if (op == v) { + ++count; + } + tir::StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const tir::BufferStoreNode* op) final { + VisitBuffer(op->buffer.get()); + tir::StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const tir::BufferLoadNode* op) final { + VisitBuffer(op->buffer.get()); + tir::StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const tir::DeclBufferNode* op) final { + VisitBuffer(op->buffer.get()); + tir::StmtExprVisitor::VisitStmt_(op); + } + + void VisitBuffer(const tir::BufferNode* buffer) { + VisitExpr(buffer->data); + for (const PrimExpr& shape_i : buffer->shape) { + VisitExpr(shape_i); + } + for (const PrimExpr& stride_i : buffer->strides) { + VisitExpr(stride_i); + } + VisitExpr(buffer->elem_offset); + } + }; + + OccurrenceCounter counter; + counter.v = v.get(); + counter(f->body); + for (const tir::Var& v : f->params) { + counter(v); + } + for (const auto& pair : f->buffer_map) { + counter(pair.first); + counter.VisitBuffer(pair.second.get()); + } + return counter.count; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { With frame(MakeDispatchFrame(d, func, func)); int n_args = func->params.size(); + std::unordered_map buffer_data_counter; + for (const auto& pair : func->buffer_map) { + const tir::VarNode* data_var = pair.second->data.get(); + if (!buffer_data_counter.count(data_var)) { + buffer_data_counter.insert({data_var, 0}); + } + ++buffer_data_counter.at(data_var); + } // Step 1. Handle `func->params` Array args; args.reserve(n_args); + std::unordered_set buffer_inlined; for (int i = 0; i < n_args; ++i) { tir::Var var = func->params[i]; ObjectPath var_p = p->Attr("params")->ArrayIndex(i); + if (CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) { + tir::Buffer buffer = func->buffer_map[var]; + if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { + ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); + args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt, + BufferAttn(buffer, buffer_p, *frame, d))); + buffer_inlined.insert(buffer.get()); + continue; + } + } ExprDoc a = d->AsDoc(var->type_annotation, var_p->Attr("type_annotation")); args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a)); } @@ -58,6 +160,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) tir::Var param = func->params[i]; if (func->buffer_map.count(param)) { tir::Buffer buffer = func->buffer_map[param]; + if (buffer_inlined.count(buffer.get())) { + continue; + } ExprDoc param = args[i]->lhs; ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param); ExprDoc lhs = diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 047513dcb316..183400d974ca 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -209,6 +209,17 @@ inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) { ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, const ObjectPath& p, const Frame& frame, const IRDocsifier& d); +/*! + * \brief Declare and define a buffer as annotation + * \param buffer The buffer to be defined + * \param p The object path + * \param f The frame + * \param d The IRDocsifier + * \return The ExprDoc corresponding to the buffer declaration + */ +ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, + const IRDocsifier& d); + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index d62a1cd12c28..201428b74c66 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -57,10 +57,56 @@ def test_prim_func(): func, expected=""" @T.prim_func +def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): + T.evaluate(0)""", + ) + + +def test_prim_func_no_sugar_inlined_buffer(): + a = tir.Var("a", "handle") + b = tir.Var("b", "handle") + func = tir.PrimFunc( + params=[a, b], + ret_type=None, + buffer_map={ + a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + }, + body=tir.Evaluate(a), + ) + _assert_print( + func, + expected=""" +@T.prim_func +def main(a: T.handle, B: T.Buffer((256, 256), "float32")): + A = T.match_buffer(a, (128, 128)) + T.evaluate(a) +""", + ) + + +def test_prim_func_no_sugar_shared_buffer_data(): + a = tir.Var("a", "handle") + b = tir.Var("b", "handle") + buffer_data = tir.decl_buffer(shape=[128, 128], dtype="float32", name="A").data + func = tir.PrimFunc( + params=[a, b], + ret_type=None, + buffer_map={ + a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A", data=buffer_data), + b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B", data=buffer_data), + }, + body=tir.Evaluate(0), + ) + _assert_print( + func, + expected=""" +@T.prim_func def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (256, 256)) - T.evaluate(0)""", + B = T.match_buffer(b, (256, 256), data=A.data) + T.evaluate(0) +""", ) @@ -641,6 +687,8 @@ def main(): if __name__ == "__main__": test_prim_func() + test_prim_func_no_sugar_inlined_buffer() + test_prim_func_no_sugar_shared_buffer_data() test_block_realize() test_block() test_buffer()