Skip to content

Commit

Permalink
[TVMScript] T.match_buffer syntax sugar in arguments for TVMScript …
Browse files Browse the repository at this point in the history
…printer (#13801)

This PR implements the syntax sugar of `T.match_buffer` for new TVMScript printer. This syntax sugar will replace the `T.handle` in `T.prim_func` arguments, with matched simple buffer. For example, it will change
```python
@T.prim_func
def func(a: T.handle, b: T.handle, c: T.handle):
  A = T.match_buffer(a, [128], dtype="float32")
  B = T.match_buffer(b, [128, 128], dtype="int32")
  C = T.match_buffer(c, [128, 128, 128], dtype="uint8")
```
into
```python
@T.prim_func
def main(A: T.Buffer[(128,)], B: T.Buffer[(128, 128), "int32"], C: T.Buffer[(128, 128, 128), "uint8"]):
  T.evaluate(0)
```

Co-authored-by: Junru Shao <[email protected]>
  • Loading branch information
cyx-6 and junrushao authored Jan 20, 2023
1 parent cfa65b2 commit 6c2d485
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/script/printer/tir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<
/*args=*/args);
}

ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
const IRDocsifier& d) {
Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
ExprDoc shape = attrs.Get("shape").value();
ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype));
return TIR("Buffer")->Call({shape, dtype}, {}, {});
}

Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p,
const IRDocsifier& d) {
int n = indices.size();
Expand Down
105 changes: 105 additions & 0 deletions src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/tir/stmt_functor.h>

#include "./utils.h"

namespace tvm {
Expand All @@ -34,16 +37,115 @@ String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) {
return "main";
}

bool IsSimpleBuffer(const tir::Buffer& buf) {
if (!buf->strides.empty()) {
return false;
}
for (const PrimExpr& shp_i : buf->shape) {
if (!tir::UndefinedVars(shp_i).empty()) {
return false;
}
}
for (const PrimExpr& stride_i : buf->strides) {
if (!tir::UndefinedVars(stride_i).empty()) {
return false;
}
}
if (!tir::UndefinedVars(buf->elem_offset).empty()) {
return false;
} else if (buf->elem_offset->IsInstance<IntImmNode>()) {
IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
if (elem_offset->value != 0) {
return false;
}
}
return buf.scope() == "global" && buf->data_alignment == runtime::kAllocAlignment &&
buf->offset_factor == 1 && buf->buffer_type == tir::BufferType::kDefault &&
!buf->axis_separators.size();
}

int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) {
class OccurrenceCounter : public tir::StmtExprVisitor {
public:
int count = 0;
const tir::VarNode* v = nullptr;

void VisitExpr_(const tir::VarNode* op) final {
if (op == v) {
++count;
}
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::BufferStoreNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const tir::BufferLoadNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::DeclBufferNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitBuffer(const tir::BufferNode* buffer) {
VisitExpr(buffer->data);
for (const PrimExpr& shape_i : buffer->shape) {
VisitExpr(shape_i);
}
for (const PrimExpr& stride_i : buffer->strides) {
VisitExpr(stride_i);
}
VisitExpr(buffer->elem_offset);
}
};

OccurrenceCounter counter;
counter.v = v.get();
counter(f->body);
for (const tir::Var& v : f->params) {
counter(v);
}
for (const auto& pair : f->buffer_map) {
counter(pair.first);
counter.VisitBuffer(pair.second.get());
}
return counter.count;
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc {
With<TIRFrame> frame(MakeDispatchFrame(d, func, func));
int n_args = func->params.size();
std::unordered_map<const tir::VarNode*, int> buffer_data_counter;
for (const auto& pair : func->buffer_map) {
const tir::VarNode* data_var = pair.second->data.get();
if (!buffer_data_counter.count(data_var)) {
buffer_data_counter.insert({data_var, 0});
}
++buffer_data_counter.at(data_var);
}
// Step 1. Handle `func->params`
Array<AssignDoc> args;
args.reserve(n_args);
std::unordered_set<const tir::BufferNode*> buffer_inlined;
for (int i = 0; i < n_args; ++i) {
tir::Var var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
if (CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) {
tir::Buffer buffer = func->buffer_map[var];
if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) {
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt,
BufferAttn(buffer, buffer_p, *frame, d)));
buffer_inlined.insert(buffer.get());
continue;
}
}
ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation, var_p->Attr("type_annotation"));
args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a));
}
Expand All @@ -58,6 +160,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
tir::Var param = func->params[i];
if (func->buffer_map.count(param)) {
tir::Buffer buffer = func->buffer_map[param];
if (buffer_inlined.count(buffer.get())) {
continue;
}
ExprDoc param = args[i]->lhs;
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param);
ExprDoc lhs =
Expand Down
11 changes: 11 additions & 0 deletions src/script/printer/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) {
ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args,
const ObjectPath& p, const Frame& frame, const IRDocsifier& d);

/*!
* \brief Declare and define a buffer as annotation
* \param buffer The buffer to be defined
* \param p The object path
* \param f The frame
* \param d The IRDocsifier
* \return The ExprDoc corresponding to the buffer declaration
*/
ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
const IRDocsifier& d);

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
52 changes: 50 additions & 2 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,56 @@ def test_prim_func():
func,
expected="""
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
T.evaluate(0)""",
)


def test_prim_func_no_sugar_inlined_buffer():
a = tir.Var("a", "handle")
b = tir.Var("b", "handle")
func = tir.PrimFunc(
params=[a, b],
ret_type=None,
buffer_map={
a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
},
body=tir.Evaluate(a),
)
_assert_print(
func,
expected="""
@T.prim_func
def main(a: T.handle, B: T.Buffer((256, 256), "float32")):
A = T.match_buffer(a, (128, 128))
T.evaluate(a)
""",
)


def test_prim_func_no_sugar_shared_buffer_data():
a = tir.Var("a", "handle")
b = tir.Var("b", "handle")
buffer_data = tir.decl_buffer(shape=[128, 128], dtype="float32", name="A").data
func = tir.PrimFunc(
params=[a, b],
ret_type=None,
buffer_map={
a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A", data=buffer_data),
b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B", data=buffer_data),
},
body=tir.Evaluate(0),
)
_assert_print(
func,
expected="""
@T.prim_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (256, 256))
T.evaluate(0)""",
B = T.match_buffer(b, (256, 256), data=A.data)
T.evaluate(0)
""",
)


Expand Down Expand Up @@ -641,6 +687,8 @@ def main():

if __name__ == "__main__":
test_prim_func()
test_prim_func_no_sugar_inlined_buffer()
test_prim_func_no_sugar_shared_buffer_data()
test_block_realize()
test_block()
test_buffer()
Expand Down

0 comments on commit 6c2d485

Please sign in to comment.