From 7132faa449279d4211d34378f9ae855fb4091a53 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 17 Jan 2023 23:10:17 -0800 Subject: [PATCH 1/3] remap syntax sugar for tvmscript printer fix outdated code fix outdated unittest `match_buffer` syntax sugar fix fix bugs --- src/script/printer/tir/buffer.cc | 8 ++++ src/script/printer/tir/function.cc | 43 +++++++++++++++++++ src/script/printer/tir/utils.h | 11 +++++ .../unittest/test_tvmscript_printer_tir.py | 4 +- 4 files changed, 63 insertions(+), 3 deletions(-) diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 5400328fe219..126a6e58273f 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -126,6 +126,14 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array< /*args=*/args); } +ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, + const IRDocsifier& d) { + Map attrs = BufferAttrs(buffer, p, frame, d); + ExprDoc shape = attrs.Get("shape").value(); + ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype)); + return TIR("Buffer")->Call({shape, dtype}, {}, {}); +} + Array BufferIndices(const Array& indices, const ObjectPath& p, const IRDocsifier& d) { int n = indices.size(); diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index f0f84e81d57c..36e2d08e9417 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "./utils.h" namespace tvm { @@ -34,6 +36,33 @@ String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) { return "main"; } +bool IsSimpleBuffer(const tir::Buffer& buf) { + if (!buf->strides.empty()) { + return false; + } + for (const PrimExpr& shp_i : buf->shape) { + if (!tir::UndefinedVars(shp_i).empty()) { + return false; + } + } + for (const PrimExpr& stride_i : buf->strides) { + if (!tir::UndefinedVars(stride_i).empty()) { + return false; + } + } + if (!tir::UndefinedVars(buf->elem_offset).empty()) { + return false; + } else if (buf->elem_offset->IsInstance()) { + IntImm elem_offset = Downcast(buf->elem_offset); + if (elem_offset->value != 0) { + return false; + } + } + return buf.scope() == "global" && buf->data_alignment == runtime::kAllocAlignment && + buf->offset_factor == 1 && buf->buffer_type == tir::BufferType::kDefault && + !buf->axis_separators.size(); +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { With frame(MakeDispatchFrame(d, func, func)); @@ -41,9 +70,20 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 1. Handle `func->params` Array args; args.reserve(n_args); + std::unordered_set buffer_inlined; for (int i = 0; i < n_args; ++i) { tir::Var var = func->params[i]; ObjectPath var_p = p->Attr("params")->ArrayIndex(i); + if (func->buffer_map.count(var)) { + tir::Buffer buffer = func->buffer_map[var]; + ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); + if (IsSimpleBuffer(buffer)) { + args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt, + BufferAttn(buffer, buffer_p, *frame, d))); + buffer_inlined.insert(buffer.get()); + continue; + } + } ExprDoc a = d->AsDoc(var->type_annotation, var_p->Attr("type_annotation")); args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a)); } @@ -58,6 +98,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) tir::Var param = func->params[i]; if (func->buffer_map.count(param)) { tir::Buffer buffer = func->buffer_map[param]; + if (buffer_inlined.count(buffer.get())) { + continue; + } ExprDoc param = args[i]->lhs; ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param); ExprDoc lhs = diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 047513dcb316..183400d974ca 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -209,6 +209,17 @@ inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) { ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, const ObjectPath& p, const Frame& frame, const IRDocsifier& d); +/*! + * \brief Declare and define a buffer as annotation + * \param buffer The buffer to be defined + * \param p The object path + * \param f The frame + * \param d The IRDocsifier + * \return The ExprDoc corresponding to the buffer declaration + */ +ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame, + const IRDocsifier& d); + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index d62a1cd12c28..807401fb92a0 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -57,9 +57,7 @@ def test_prim_func(): func, expected=""" @T.prim_func -def main(a: T.handle, b: T.handle): - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (256, 256)) +def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): T.evaluate(0)""", ) From f031737bf68dd5e24f7e5d3af379fa73a1ae27d4 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 19 Jan 2023 12:01:26 -0800 Subject: [PATCH 2/3] ... --- src/script/printer/tir/function.cc | 58 ++++++++++++++++++- .../unittest/test_tvmscript_printer_tir.py | 52 +++++++++++++++++ 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 36e2d08e9417..86f41a388fb8 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include "./utils.h" @@ -63,6 +64,59 @@ bool IsSimpleBuffer(const tir::Buffer& buf) { !buf->axis_separators.size(); } +int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) { + class OccurrenceCounter : public tir::StmtExprVisitor { + public: + int count = 0; + const tir::VarNode* v = nullptr; + + void VisitExpr_(const tir::VarNode* op) final { + if (op == v) { + ++count; + } + tir::StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const tir::BufferStoreNode* op) final { + VisitBuffer(op->buffer.get()); + tir::StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const tir::BufferLoadNode* op) final { + VisitBuffer(op->buffer.get()); + tir::StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const tir::DeclBufferNode* op) final { + VisitBuffer(op->buffer.get()); + tir::StmtExprVisitor::VisitStmt_(op); + } + + void VisitBuffer(const tir::BufferNode* buffer) { + VisitExpr(buffer->data); + for (const PrimExpr& shape_i : buffer->shape) { + VisitExpr(shape_i); + } + for (const PrimExpr& stride_i : buffer->strides) { + VisitExpr(stride_i); + } + VisitExpr(buffer->elem_offset); + } + }; + + OccurrenceCounter counter; + counter.v = v.get(); + counter(f->body); + for (const tir::Var& v : f->params) { + counter(v); + } + for (const auto& pair : f->buffer_map) { + counter(pair.first); + counter.VisitBuffer(pair.second.get()); + } + return counter.count; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { With frame(MakeDispatchFrame(d, func, func)); @@ -74,10 +128,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0; i < n_args; ++i) { tir::Var var = func->params[i]; ObjectPath var_p = p->Attr("params")->ArrayIndex(i); - if (func->buffer_map.count(var)) { + if (CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) { tir::Buffer buffer = func->buffer_map[var]; - ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); if (IsSimpleBuffer(buffer)) { + ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt, BufferAttn(buffer, buffer_p, *frame, d))); buffer_inlined.insert(buffer.get()); diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 807401fb92a0..6eb09f823db0 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -62,6 +62,56 @@ def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")) ) +def test_prim_func_no_sugar_inlined_buffer(): + a = tir.Var("a", "handle") + b = tir.Var("b", "handle") + func = tir.PrimFunc( + params=[a, b], + ret_type=None, + buffer_map={ + a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + }, + body=tir.Evaluate(a), + ) + _assert_print( + func, + expected=""" +@T.prim_func +def main(a: T.handle, B: T.Buffer((256, 256), "float32")): + A = T.match_buffer(a, (128, 128)) + T.evaluate(a) +""", + ) + + +def test_prim_func_no_sugar_shared_buffer_data(): + a = tir.Var("a", "handle") + b = tir.Var("b", "handle") + buffer_data = tir.decl_buffer(shape=[128, 128], dtype="float32", name="A").data + func = tir.PrimFunc( + params=[a, b], + ret_type=None, + buffer_map={ + a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A", data=buffer_data), + b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B", data=buffer_data), + }, + body=tir.Evaluate(0), + ) + print(func) + + +# _assert_print( +# func, +# expected=""" +# @T.prim_func +# def main(a: T.handle, B: T.Buffer((256, 256), "float32")): +# A = T.match_buffer(a, (128, 128)) +# T.evaluate(a) +# """, +# ) + + def test_block_realize(): i = tir.Var("i", "int32") j = tir.Var("j", "int32") @@ -639,6 +689,8 @@ def main(): if __name__ == "__main__": test_prim_func() + test_prim_func_no_sugar_inlined_buffer() + test_prim_func_no_sugar_shared_buffer_data() test_block_realize() test_block() test_buffer() From 2f446daed25bcd53ac84b8a2f6cdd54035025a0e Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Thu, 19 Jan 2023 14:27:16 -0800 Subject: [PATCH 3/3] add check for shared buffer data --- src/script/printer/tir/function.cc | 10 ++++++++- .../unittest/test_tvmscript_printer_tir.py | 22 +++++++++---------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 86f41a388fb8..6094eefb65b1 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -121,6 +121,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { With frame(MakeDispatchFrame(d, func, func)); int n_args = func->params.size(); + std::unordered_map buffer_data_counter; + for (const auto& pair : func->buffer_map) { + const tir::VarNode* data_var = pair.second->data.get(); + if (!buffer_data_counter.count(data_var)) { + buffer_data_counter.insert({data_var, 0}); + } + ++buffer_data_counter.at(data_var); + } // Step 1. Handle `func->params` Array args; args.reserve(n_args); @@ -130,7 +138,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ObjectPath var_p = p->Attr("params")->ArrayIndex(i); if (CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) { tir::Buffer buffer = func->buffer_map[var]; - if (IsSimpleBuffer(buffer)) { + if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt, BufferAttn(buffer, buffer_p, *frame, d))); diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 6eb09f823db0..201428b74c66 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -98,18 +98,16 @@ def test_prim_func_no_sugar_shared_buffer_data(): }, body=tir.Evaluate(0), ) - print(func) - - -# _assert_print( -# func, -# expected=""" -# @T.prim_func -# def main(a: T.handle, B: T.Buffer((256, 256), "float32")): -# A = T.match_buffer(a, (128, 128)) -# T.evaluate(a) -# """, -# ) + _assert_print( + func, + expected=""" +@T.prim_func +def main(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (256, 256), data=A.data) + T.evaluate(0) +""", + ) def test_block_realize():