Skip to content

Commit

Permalink
[Bugfix][TVMScript] Handle LetStmt for var1 = var2 expressions (#14320
Browse files Browse the repository at this point in the history
)

* [Bugfix][TVMScript] Handle LetStmt for `var1 = var2` expressions

Usually, when using TVMScript to represent a `PrimFunc` variable
definition `var_name = expr` defines `LetStmt` with a variable named
`var_name` bound to the expression `expr`.  However, prior to this
commit, if `expr` is a `tir::Var`, the TVMScript parser would instead
silently omit the `LetStmt`, and rename all instances of that variable
to `var_name`.

The root cause was in the `VarTable.exist` check, which erroneously
returned False in all cases.  This was due to a `value is v` check,
which checked if the value was the same as the stack of
maybe-shadowing values that share the same name.  Replacing the
'value is v` check with a `value in v` check resolves this issue.

This bug dates to the initial implementation of the new TVMScript
parser in #12496.

* Avoid implicit `PrimExpr.__bool__` from `if value in value_stack`

* Use T.meta_var where variable renaming is required.
  • Loading branch information
Lunderberg authored Apr 2, 2023
1 parent 49e6695 commit 66e18fb
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
9 changes: 5 additions & 4 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,11 @@ def exist(self, value: Any) -> bool:
res : bool
The existence of the value.
"""
for v in self.name2value.values():
if v is value:
return True
return False
return any(
value.same_as(known_value)
for known_value_stack in self.name2value.values()
for known_value in known_value_stack
)


def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
b_row_ind, b_col_ind = maybe_swap(k, j)
b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j))

thread_id_C, local_id_C = T.meta_var(index_map_C(i, j))
thread_id_A, local_id_A = T.meta_var(index_map_A(i, k))
Expand Down Expand Up @@ -719,7 +719,7 @@ def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
for i, j, k in T.grid(m_dim, n_dim, k_dim):
with T.block(""):
vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
B_index_0, B_index_1 = maybe_swap(vkk, vjj)
B_index_0, B_index_1 = T.meta_var(maybe_swap(vkk, vjj))
C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) * maybe_cast(
B[B_index_0, B_index_1]
)
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,31 @@ def implicit(A: T.Buffer(1, "int32")):
assert_structural_equal(implicit, explicit)


def test_preserve_trivial_let_binding():
@T.prim_func
def explicit(i: T.int32):
j = T.int32()
T.LetStmt(i, var=j)
T.evaluate(j)

@T.prim_func
def implicit(i: T.int32):
j = i
T.evaluate(j)

assert_structural_equal(implicit, explicit)


def test_preserve_parameter_name():
@T.prim_func
def func(i: T.int32):
j = i
T.evaluate(j)

param_name = func.params[0].name
assert param_name == "i"


def test_preserve_variable_name():
"""Use variable name when generating tir::LetStmt"""

Expand Down

0 comments on commit 66e18fb

Please sign in to comment.