Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Fix the behavior of TVMScript printer #9974

Merged
merged 2 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
void TryDeallocVar(const Var& var);
bool ContainsOptionalInfo(const Stmt& stmt);
/*!
* \brief check if a buffer declaration has only 'shape' and 'dtype' arguments specified
* \brief Check if a buffer declaration satisfies:
* 1. has only 'shape' and 'dtype' arguments specified,
* 2. the shape and strides are not dynamic.
* \param buffer The match buffer to be checked
*/
bool IsSimpleBuffer(const Buffer& buffer);
Expand Down Expand Up @@ -481,14 +483,25 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {

// check if all arguments, except the first two, are specified for T.match_buffer
// if not, then this match buffer is printed out as T.buffer in prim_func arguments
// and check whether there are undefined variables in the shape/strides.
bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
if (memo_var_.find(buf->data) != memo_var_.end()) {
return false;
}
if (!buf->strides.empty()) {
return false;
}
if (buf->elem_offset->IsInstance<VarNode>()) {
for (const PrimExpr& shp_i : buf->shape) {
if (!UndefinedVars(shp_i).empty()) {
return false;
}
}
for (const PrimExpr& stride_i : buf->strides) {
if (!UndefinedVars(stride_i).empty()) {
return false;
}
}
if (!UndefinedVars(buf->elem_offset).empty()) {
return false;
} else if (buf->elem_offset->IsInstance<IntImmNode>()) {
IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
Expand Down Expand Up @@ -1302,6 +1315,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
// check if this param is a T.handle
if (it != op->buffer_map.end()) {
// check if this match_buffer has only the first two arguments specified
// and whether the match_buffer is a dynamic buffer.
const Buffer& buf = (*it).second;
if (IsSimpleBuffer(buf)) {
simple_buf.insert(buf);
Expand Down
23 changes: 23 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest
from tvm.ir import assert_structural_equal
from tvm.script import tir as T
from tvm.script.parser import from_source
from tvm.testing import check_error


Expand Down Expand Up @@ -158,5 +159,27 @@ def elementwise_buffer_no_kwargs_failed(
pass


# dynamic shape gemm
@T.prim_func
def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle):
N = T.var("int32")
M = T.var("int32")
K = T.var("int32")
A = T.match_buffer(a, (N, K), "float32")
B = T.match_buffer(b, (K, M), "float32")
C = T.match_buffer(c, (N, M), "float32")
for i, j, k in T.grid(N, M, K):
with T.block("gemm"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


def test_dynamic_shape_gemm():
gemm_dyn_shape_roundtrip = from_source(gemm_dyn_shape.script())
assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))