diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 7c699c42aecb..fdccabcd235d 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -186,10 +186,11 @@ def exist(self, value: Any) -> bool: res : bool The existence of the value. """ - for v in self.name2value.values(): - if v is value: - return True - return False + return any( + value.same_as(known_value) + for known_value_stack in self.name2value.values() + for known_value in known_value_stack + ) def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod: diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index da194f885d1c..3bc16f234fba 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -245,7 +245,7 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: for i, j, k in T.grid(M_DIM, N_DIM, k_dim): with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) - b_row_ind, b_col_ind = maybe_swap(k, j) + b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j)) thread_id_C, local_id_C = T.meta_var(index_map_C(i, j)) thread_id_A, local_id_A = T.meta_var(index_map_A(i, k)) @@ -719,7 +719,7 @@ def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: for i, j, k in T.grid(m_dim, n_dim, k_dim): with T.block(""): vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) - B_index_0, B_index_1 = maybe_swap(vkk, vjj) + B_index_0, B_index_1 = T.meta_var(maybe_swap(vkk, vjj)) C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) * maybe_cast( B[B_index_0, B_index_1] ) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 184722cd36bc..ac1262b9b517 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -399,6 +399,31 @@ def implicit(A: T.Buffer(1, "int32")): assert_structural_equal(implicit, explicit) +def test_preserve_trivial_let_binding(): + @T.prim_func + def explicit(i: T.int32): + j = T.int32() + T.LetStmt(i, var=j) + T.evaluate(j) + + @T.prim_func + def implicit(i: T.int32): + j = i + T.evaluate(j) + + assert_structural_equal(implicit, explicit) + + +def test_preserve_parameter_name(): + @T.prim_func + def func(i: T.int32): + j = i + T.evaluate(j) + + param_name = func.params[0].name + assert param_name == "i" + + def test_preserve_variable_name(): """Use variable name when generating tir::LetStmt"""