Skip to content

Commit

Permalink
[bugfix] Fix the behavior of TVMScript printer (apache#9974)
Browse files Browse the repository at this point in the history
* upd

* lint
  • Loading branch information
yzh119 authored and crazydemo committed Jan 27, 2022
1 parent cf226c2 commit 2adddee
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
void TryDeallocVar(const Var& var);
bool ContainsOptionalInfo(const Stmt& stmt);
/*!
* \brief check if a buffer declaration has only 'shape' and 'dtype' arguments specified
* \brief Check if a buffer declaration satisfies:
* 1. has only 'shape' and 'dtype' arguments specified,
* 2. the shape and strides are not dynamic.
* \param buffer The match buffer to be checked
*/
bool IsSimpleBuffer(const Buffer& buffer);
Expand Down Expand Up @@ -481,14 +483,25 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {

// check if all arguments, except the first two, are specified for T.match_buffer
// if not, then this match buffer is printed out as T.buffer in prim_func arguments
// and check whether there are undefined variables in the shape/strides.
bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
if (memo_var_.find(buf->data) != memo_var_.end()) {
return false;
}
if (!buf->strides.empty()) {
return false;
}
if (buf->elem_offset->IsInstance<VarNode>()) {
for (const PrimExpr& shp_i : buf->shape) {
if (!UndefinedVars(shp_i).empty()) {
return false;
}
}
for (const PrimExpr& stride_i : buf->strides) {
if (!UndefinedVars(stride_i).empty()) {
return false;
}
}
if (!UndefinedVars(buf->elem_offset).empty()) {
return false;
} else if (buf->elem_offset->IsInstance<IntImmNode>()) {
IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
Expand Down Expand Up @@ -1302,6 +1315,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
// check if this param is a T.handle
if (it != op->buffer_map.end()) {
// check if this match_buffer has only the first two arguments specified
// and whether the match_buffer is a dynamic buffer.
const Buffer& buf = (*it).second;
if (IsSimpleBuffer(buf)) {
simple_buf.insert(buf);
Expand Down
23 changes: 23 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest
from tvm.ir import assert_structural_equal
from tvm.script import tir as T
from tvm.script.parser import from_source
from tvm.testing import check_error


Expand Down Expand Up @@ -158,5 +159,27 @@ def elementwise_buffer_no_kwargs_failed(
pass


# dynamic shape gemm
@T.prim_func
def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle):
N = T.var("int32")
M = T.var("int32")
K = T.var("int32")
A = T.match_buffer(a, (N, K), "float32")
B = T.match_buffer(b, (K, M), "float32")
C = T.match_buffer(c, (N, M), "float32")
for i, j, k in T.grid(N, M, K):
with T.block("gemm"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


def test_dynamic_shape_gemm():
gemm_dyn_shape_roundtrip = from_source(gemm_dyn_shape.script())
assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 2adddee

Please sign in to comment.