Skip to content

Commit

Permalink
[Script] Be more careful when generating ast.ExtSlice for Subscript
Browse files Browse the repository at this point in the history
The ast.ExtSlice expects a non-empty list, otherwise evaluation
fails with "error: empty dims on ExtSlice". Also, each element
in "dims" list of ExtSlice must be either Slice or Index.

In python3.8 an expression A[()] is parsed (by ast) as Subscript
with slice being Index(value=Tuple(elts=[])). When we translate a
subscript from doc.AST to ast, we unconditionally convert every
tuple to ast.ExtSlice, which in this case is incorrect.

The fix is to map empty tuple back to the Index(Tuple[])) instead
of ExtSlice. In other cases, ensure that members of ExtSlice are
of correct types.
  • Loading branch information
Krzysztof Parzyszek committed Aug 3, 2023
1 parent 49fd318 commit 5502502
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
20 changes: 13 additions & 7 deletions python/tvm/script/parser/core/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,19 @@ def subscript_from_doc(x: doc.Subscript) -> ast.Subscript:
ctx=from_doc(x.ctx),
)
elif isinstance(x.slice, doc.Tuple):
result = ast.Subscript(
value=from_doc(x.value),
slice=ast.ExtSlice(
dims=[from_doc(i) for i in x.slice.elts],
),
ctx=from_doc(x.ctx),
)
def remap_dim(doc_item: doc.Expr) -> ast.Expr:
ast_item = from_doc(doc_item)
if isinstance(ast_item, (ast.Index, ast.Slice)):
return ast_item
return ast.Index(value=ast_item)

# ast.ExtSlice requires a non-empty list of dims, and each dim must be either
# a Slice or an Index.
if x.slice.elts:
ast_slice = ast.ExtSlice(dims=[*map(remap_dim, x.slice.elts)])
else:
ast_slice = ast.Index(value=ast.Tuple(elts=[], ctx=from_doc(x.ctx)))
result = ast.Subscript(value=from_doc(x.value), slice=ast_slice, ctx=from_doc(x.ctx))
else:
result = ast.Subscript(
value=from_doc(x.value),
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,21 @@ def non_starred(a: T.handle, b: T.handle):
tvm.ir.assert_structural_equal(starred, non_starred)


def test_tir_empty_tuple_index():
@T.macro
def bar(val):
T.evaluate(val)

@T.prim_func(private=True)
def func_with_empty_tuple(A: T.Buffer((), "int32"), B: T.Buffer((), "int32")):
bar(val=A[()])

@T.prim_func(private=True)
def expected(A: T.Buffer((), "int32"), B: T.Buffer((), "int32")):
T.evaluate(A[()])

tvm.ir.assert_structural_equal(func_with_empty_tuple, expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 5502502

Please sign in to comment.