Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] T.match_buffer syntax sugar in arguments for TVMScript printer #13801

Merged
merged 3 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/script/printer/tir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<
/*args=*/args);
}

ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
cyx-6 marked this conversation as resolved.
Show resolved Hide resolved
const IRDocsifier& d) {
Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
ExprDoc shape = attrs.Get("shape").value();
ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype));
return TIR("Buffer")->Call({shape, dtype}, {}, {});
}

Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p,
const IRDocsifier& d) {
int n = indices.size();
Expand Down
105 changes: 105 additions & 0 deletions src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/tir/stmt_functor.h>

#include "./utils.h"

namespace tvm {
Expand All @@ -34,16 +37,115 @@ String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) {
return "main";
}

bool IsSimpleBuffer(const tir::Buffer& buf) {
if (!buf->strides.empty()) {
return false;
}
for (const PrimExpr& shp_i : buf->shape) {
if (!tir::UndefinedVars(shp_i).empty()) {
return false;
}
}
for (const PrimExpr& stride_i : buf->strides) {
if (!tir::UndefinedVars(stride_i).empty()) {
return false;
}
}
if (!tir::UndefinedVars(buf->elem_offset).empty()) {
return false;
} else if (buf->elem_offset->IsInstance<IntImmNode>()) {
IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
if (elem_offset->value != 0) {
return false;
}
}
return buf.scope() == "global" && buf->data_alignment == runtime::kAllocAlignment &&
buf->offset_factor == 1 && buf->buffer_type == tir::BufferType::kDefault &&
!buf->axis_separators.size();
}

int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) {
class OccurrenceCounter : public tir::StmtExprVisitor {
public:
int count = 0;
const tir::VarNode* v = nullptr;

void VisitExpr_(const tir::VarNode* op) final {
if (op == v) {
++count;
}
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::BufferStoreNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const tir::BufferLoadNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const tir::DeclBufferNode* op) final {
VisitBuffer(op->buffer.get());
tir::StmtExprVisitor::VisitStmt_(op);
}

void VisitBuffer(const tir::BufferNode* buffer) {
VisitExpr(buffer->data);
for (const PrimExpr& shape_i : buffer->shape) {
VisitExpr(shape_i);
}
for (const PrimExpr& stride_i : buffer->strides) {
VisitExpr(stride_i);
}
VisitExpr(buffer->elem_offset);
}
};

OccurrenceCounter counter;
counter.v = v.get();
counter(f->body);
for (const tir::Var& v : f->params) {
counter(v);
}
for (const auto& pair : f->buffer_map) {
counter(pair.first);
counter.VisitBuffer(pair.second.get());
}
return counter.count;
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc {
With<TIRFrame> frame(MakeDispatchFrame(d, func, func));
int n_args = func->params.size();
std::unordered_map<const tir::VarNode*, int> buffer_data_counter;
for (const auto& pair : func->buffer_map) {
const tir::VarNode* data_var = pair.second->data.get();
if (!buffer_data_counter.count(data_var)) {
buffer_data_counter.insert({data_var, 0});
}
++buffer_data_counter.at(data_var);
}
// Step 1. Handle `func->params`
Array<AssignDoc> args;
args.reserve(n_args);
std::unordered_set<const tir::BufferNode*> buffer_inlined;
for (int i = 0; i < n_args; ++i) {
tir::Var var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
if (CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) {
tir::Buffer buffer = func->buffer_map[var];
if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) {
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt,
BufferAttn(buffer, buffer_p, *frame, d)));
buffer_inlined.insert(buffer.get());
continue;
}
}
ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation, var_p->Attr("type_annotation"));
args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a));
}
Expand All @@ -58,6 +160,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
tir::Var param = func->params[i];
if (func->buffer_map.count(param)) {
tir::Buffer buffer = func->buffer_map[param];
if (buffer_inlined.count(buffer.get())) {
continue;
}
ExprDoc param = args[i]->lhs;
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param);
ExprDoc lhs =
Expand Down
11 changes: 11 additions & 0 deletions src/script/printer/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) {
ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args,
const ObjectPath& p, const Frame& frame, const IRDocsifier& d);

/*!
* \brief Declare and define a buffer as annotation
* \param buffer The buffer to be defined
* \param p The object path
* \param f The frame
* \param d The IRDocsifier
* \return The ExprDoc corresponding to the buffer declaration
*/
ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
const IRDocsifier& d);

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
52 changes: 50 additions & 2 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,56 @@ def test_prim_func():
func,
expected="""
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
T.evaluate(0)""",
)


def test_prim_func_no_sugar_inlined_buffer():
a = tir.Var("a", "handle")
b = tir.Var("b", "handle")
func = tir.PrimFunc(
params=[a, b],
ret_type=None,
buffer_map={
a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
},
body=tir.Evaluate(a),
)
_assert_print(
func,
expected="""
@T.prim_func
def main(a: T.handle, B: T.Buffer((256, 256), "float32")):
A = T.match_buffer(a, (128, 128))
T.evaluate(a)
""",
)


def test_prim_func_no_sugar_shared_buffer_data():
a = tir.Var("a", "handle")
b = tir.Var("b", "handle")
buffer_data = tir.decl_buffer(shape=[128, 128], dtype="float32", name="A").data
func = tir.PrimFunc(
params=[a, b],
ret_type=None,
buffer_map={
a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A", data=buffer_data),
b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B", data=buffer_data),
},
body=tir.Evaluate(0),
)
_assert_print(
func,
expected="""
@T.prim_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (256, 256))
T.evaluate(0)""",
B = T.match_buffer(b, (256, 256), data=A.data)
T.evaluate(0)
""",
)


Expand Down Expand Up @@ -641,6 +687,8 @@ def main():

if __name__ == "__main__":
test_prim_func()
test_prim_func_no_sugar_inlined_buffer()
test_prim_func_no_sugar_shared_buffer_data()
test_block_realize()
test_block()
test_buffer()
Expand Down