diff --git a/python/tvm/script/parser/core/doc.py b/python/tvm/script/parser/core/doc.py index 5ea83749eadf..1c5241dc8d90 100644 --- a/python/tvm/script/parser/core/doc.py +++ b/python/tvm/script/parser/core/doc.py @@ -414,13 +414,20 @@ def subscript_from_doc(x: doc.Subscript) -> ast.Subscript: ctx=from_doc(x.ctx), ) elif isinstance(x.slice, doc.Tuple): - result = ast.Subscript( - value=from_doc(x.value), - slice=ast.ExtSlice( - dims=[from_doc(i) for i in x.slice.elts], - ), - ctx=from_doc(x.ctx), - ) + + def remap_dim(doc_item: doc.Expr) -> ast.Expr: + ast_item = from_doc(doc_item) + if isinstance(ast_item, (ast.Index, ast.Slice)): + return ast_item + return ast.Index(value=ast_item) + + # ast.ExtSlice requires a non-empty list of dims, and each dim must be either + # a Slice or an Index. + if x.slice.elts: + ast_slice = ast.ExtSlice(dims=[*map(remap_dim, x.slice.elts)]) + else: + ast_slice = ast.Index(value=ast.Tuple(elts=[], ctx=from_doc(x.ctx))) + result = ast.Subscript(value=from_doc(x.value), slice=ast_slice, ctx=from_doc(x.ctx)) else: result = ast.Subscript( value=from_doc(x.value), diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index 210c173141c5..ef02df497b7b 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -292,5 +292,21 @@ def non_starred(a: T.handle, b: T.handle): tvm.ir.assert_structural_equal(starred, non_starred) +def test_tir_empty_tuple_index(): + @T.macro + def bar(val): + T.evaluate(val) + + @T.prim_func(private=True) + def func_with_empty_tuple(A: T.Buffer((), "int32"), B: T.Buffer((), "int32")): + bar(val=A[()]) + + @T.prim_func(private=True) + def expected(A: T.Buffer((), "int32"), B: T.Buffer((), "int32")): + T.evaluate(A[()]) + + tvm.ir.assert_structural_equal(func_with_empty_tuple, expected) + + if __name__ == "__main__": tvm.testing.main()