Skip to content

Commit

Permalink
remap syntax sugar for tvmscript printer
Browse files Browse the repository at this point in the history
fix outdated code

fix outdated unittest

`match_buffer` syntax sugar

fix

fix bugs
  • Loading branch information
cyx-6 committed Jan 18, 2023
1 parent 60358a1 commit 7132faa
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 3 deletions.
8 changes: 8 additions & 0 deletions src/script/printer/tir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,14 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<
/*args=*/args);
}

ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
const IRDocsifier& d) {
Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
ExprDoc shape = attrs.Get("shape").value();
ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype));
return TIR("Buffer")->Call({shape, dtype}, {}, {});
}

Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p,
const IRDocsifier& d) {
int n = indices.size();
Expand Down
43 changes: 43 additions & 0 deletions src/script/printer/tir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/device_api.h>

#include "./utils.h"

namespace tvm {
Expand All @@ -34,16 +36,54 @@ String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) {
return "main";
}

bool IsSimpleBuffer(const tir::Buffer& buf) {
if (!buf->strides.empty()) {
return false;
}
for (const PrimExpr& shp_i : buf->shape) {
if (!tir::UndefinedVars(shp_i).empty()) {
return false;
}
}
for (const PrimExpr& stride_i : buf->strides) {
if (!tir::UndefinedVars(stride_i).empty()) {
return false;
}
}
if (!tir::UndefinedVars(buf->elem_offset).empty()) {
return false;
} else if (buf->elem_offset->IsInstance<IntImmNode>()) {
IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
if (elem_offset->value != 0) {
return false;
}
}
return buf.scope() == "global" && buf->data_alignment == runtime::kAllocAlignment &&
buf->offset_factor == 1 && buf->buffer_type == tir::BufferType::kDefault &&
!buf->axis_separators.size();
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc {
With<TIRFrame> frame(MakeDispatchFrame(d, func, func));
int n_args = func->params.size();
// Step 1. Handle `func->params`
Array<AssignDoc> args;
args.reserve(n_args);
std::unordered_set<const tir::BufferNode*> buffer_inlined;
for (int i = 0; i < n_args; ++i) {
tir::Var var = func->params[i];
ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
if (func->buffer_map.count(var)) {
tir::Buffer buffer = func->buffer_map[var];
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
if (IsSimpleBuffer(buffer)) {
args.push_back(AssignDoc(DefineBuffer(buffer, *frame, d), NullOpt,
BufferAttn(buffer, buffer_p, *frame, d)));
buffer_inlined.insert(buffer.get());
continue;
}
}
ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation, var_p->Attr("type_annotation"));
args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a));
}
Expand All @@ -58,6 +98,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
tir::Var param = func->params[i];
if (func->buffer_map.count(param)) {
tir::Buffer buffer = func->buffer_map[param];
if (buffer_inlined.count(buffer.get())) {
continue;
}
ExprDoc param = args[i]->lhs;
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param);
ExprDoc lhs =
Expand Down
11 changes: 11 additions & 0 deletions src/script/printer/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) {
ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args,
const ObjectPath& p, const Frame& frame, const IRDocsifier& d);

/*!
* \brief Declare and define a buffer as annotation
* \param buffer The buffer to be defined
* \param p The object path
* \param f The frame
* \param d The IRDocsifier
* \return The ExprDoc corresponding to the buffer declaration
*/
ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
const IRDocsifier& d);

} // namespace printer
} // namespace script
} // namespace tvm
Expand Down
4 changes: 1 addition & 3 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def test_prim_func():
func,
expected="""
@T.prim_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (256, 256))
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
T.evaluate(0)""",
)

Expand Down

0 comments on commit 7132faa

Please sign in to comment.