diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 5400328fe219..126a6e58273f 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -126,6 +126,14 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array< /*args=*/args); } +ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, + const IRDocsifier& d) { + Map attrs = BufferAttrs(buffer, p, frame, d); + ExprDoc shape = attrs.Get("shape").value(); + ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype)); + return TIR("Buffer")->Call({shape, dtype}, {}, {}); +} + Array BufferIndices(const Array& indices, const ObjectPath& p, const IRDocsifier& d) { int n = indices.size(); diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index f0f84e81d57c..36e2d08e9417 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { @@ -34,6 +36,33 @@ String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) { return "main"; } +bool IsSimpleBuffer(const tir::Buffer& buf) { + if (!buf->strides.empty()) { + return false; + } + for (const PrimExpr& shp_i : buf->shape) { + if (!tir::UndefinedVars(shp_i).empty()) { + return false; + } + } + for (const PrimExpr& stride_i : buf->strides) { + if (!tir::UndefinedVars(stride_i).empty()) { + return false; + } + } + if (!tir::UndefinedVars(buf->elem_offset).empty()) { + return false; + } else if (buf->elem_offset->IsInstance()) { + IntImm elem_offset = Downcast(buf->elem_offset); + if (elem_offset->value != 0) { + return false; + } + } + return buf.scope() == "global" && buf->data_alignment == runtime::kAllocAlignment && + buf->offset_factor == 1 && buf->buffer_type == tir::BufferType::kDefault && + !buf->axis_separators.size(); +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { With frame(MakeDispatchFrame(d, func, func)); @@ -41,9 +70,20 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 1. Handle `func->params` Array args; args.reserve(n_args); + std::unordered_set buffer_inlined; for (int i = 0; i < n_args; ++i) { tir::Var var = func->params[i]; ObjectPath var_p = p->Attr("params")->ArrayIndex(i); + if (func->buffer_map.count(var)) { + tir::Buffer buffer = func->buffer_map[var]; + ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); + if (IsSimpleBuffer(buffer)) { + args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt, + BufferAttn(buffer, buffer_p, *frame, d))); + buffer_inlined.insert(buffer.get()); + continue; + } + } ExprDoc a = d->AsDoc(var->type_annotation, var_p->Attr("type_annotation")); args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a)); } @@ -58,6 +98,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) tir::Var param = func->params[i]; if (func->buffer_map.count(param)) { tir::Buffer buffer = func->buffer_map[param]; + if (buffer_inlined.count(buffer.get())) { + continue; + } ExprDoc param = args[i]->lhs; ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param); ExprDoc lhs = diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 047513dcb316..183400d974ca 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -209,6 +209,17 @@ inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) { ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, const ObjectPath& p, const Frame& frame, const IRDocsifier& d); +/*! + * \brief Declare and define a buffer as annotation + * \param buffer The buffer to be defined + * \param p The object path + * \param f The frame + * \param d The IRDocsifier + * \return The ExprDoc corresponding to the buffer declaration + */ +ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, + const IRDocsifier& d); + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index d62a1cd12c28..807401fb92a0 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -57,9 +57,7 @@ def test_prim_func(): func, expected=""" @T.prim_func -def main(a: T.handle, b: T.handle): - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (256, 256)) +def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): T.evaluate(0)""", )