diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 277a6042cbeb..023bb0d3ef00 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -222,7 +222,9 @@ class TVMScriptPrinter : public StmtFunctor, void TryDeallocVar(const Var& var); bool ContainsOptionalInfo(const Stmt& stmt); /*! - * \brief check if a buffer declaration has only 'shape' and 'dtype' arguments specified + * \brief Check if a buffer declaration satisfies: + * 1. has only 'shape' and 'dtype' arguments specified, + * 2. the shape and strides are not dynamic. * \param buffer The match buffer to be checked */ bool IsSimpleBuffer(const Buffer& buffer); @@ -481,6 +483,7 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { // check if all arguments, except the first two, are specified for T.match_buffer // if not, then this match buffer is printed out as T.buffer in prim_func arguments +// and check whether there are undefined variables in the shape/strides. bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) { if (memo_var_.find(buf->data) != memo_var_.end()) { return false; @@ -488,7 +491,17 @@ bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) { if (!buf->strides.empty()) { return false; } - if (buf->elem_offset->IsInstance()) { + for (const PrimExpr& shp_i : buf->shape) { + if (!UndefinedVars(shp_i).empty()) { + return false; + } + } + for (const PrimExpr& stride_i : buf->strides) { + if (!UndefinedVars(stride_i).empty()) { + return false; + } + } + if (!UndefinedVars(buf->elem_offset).empty()) { return false; } else if (buf->elem_offset->IsInstance()) { IntImm elem_offset = Downcast(buf->elem_offset); @@ -1302,6 +1315,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { // check if this param is a T.handle if (it != op->buffer_map.end()) { // check if this match_buffer has only the first two arguments specified + // and whether the match_buffer is a dynamic buffer. const Buffer& buf = (*it).second; if (IsSimpleBuffer(buf)) { simple_buf.insert(buf); diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 58458b38d7f3..0e77b2a49454 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -20,6 +20,7 @@ import pytest from tvm.ir import assert_structural_equal from tvm.script import tir as T +from tvm.script.parser import from_source from tvm.testing import check_error @@ -158,5 +159,27 @@ def elementwise_buffer_no_kwargs_failed( pass +# dynamic shape gemm +@T.prim_func +def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle): + N = T.var("int32") + M = T.var("int32") + K = T.var("int32") + A = T.match_buffer(a, (N, K), "float32") + B = T.match_buffer(b, (K, M), "float32") + C = T.match_buffer(c, (N, M), "float32") + for i, j, k in T.grid(N, M, K): + with T.block("gemm"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +def test_dynamic_shape_gemm(): + gemm_dyn_shape_roundtrip = from_source(gemm_dyn_shape.script()) + assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))