Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 committed Jan 18, 2023
1 parent 874d41b commit c0df6f7
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 11 deletions.
10 changes: 3 additions & 7 deletions src/script/printer/tir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,9 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<
ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
const IRDocsifier& d) {
Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
Array<Doc> indices_doc;
for (String s : {"shape", "dtype"}) {
if (Optional<ExprDoc> doc = attrs.Get(s)) {
indices_doc.push_back(doc.value());
}
}
return TIR("Buffer")[indices_doc];
ExprDoc shape = attrs.Get("shape").value();
ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype));
return TIR("Buffer")->Call({shape, dtype}, {}, {});
}

Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def func(a: T.handle, b: T.handle, c: T.handle):
C = T.match_buffer(c, [128, 128, 128], dtype="uint8")

expected_output = """@T.prim_func
def main(A: T.Buffer[(128,)], B: T.Buffer[(128, 128), "int32"], C: T.Buffer[(128, 128, 128), "uint8"]):
def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128, 128), "int32"), C: T.Buffer((128, 128, 128), "uint8")):
T.evaluate(0)"""
_test(func, expected_output)

Expand Down
4 changes: 1 addition & 3 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def test_prim_func():
func,
expected="""
@T.prim_func
def main(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (256, 256))
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
T.evaluate(0)""",
)

Expand Down

0 comments on commit c0df6f7

Please sign in to comment.