From bcf212dda0f94c51f55c48921f61d92fd3b83777 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 14 May 2022 07:16:42 +0900 Subject: [PATCH] seems to work --- python/tvm/script/parser.py | 21 ++----------------- .../unittest/test_mma_16x8x16_4k_tune.py | 11 +++------- 2 files changed, 5 insertions(+), 27 deletions(-) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 1d76815dc3b8..2a19dfc33dc2 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -554,14 +554,12 @@ def transform_Assign(self, node): 4.1 var = T.allocate() """ - print("parsing ", node.rhs.func_name) if isinstance(node.rhs, ast.Call): # Pattern 1 & Pattern 4 if isinstance(node.rhs.func_name, ast.Op): func = None else: func = self.transform(node.rhs.func_name) - print(func) if isinstance(func, WithScopeHandler): if not func.concise_scope or not func.def_symbol: @@ -582,27 +580,12 @@ def transform_Assign(self, node): elif callable(func): args = [self.transform(arg) for arg in node.rhs.params] out = func(*args) - print(out) - print(node.lhs) assert len(out) == len(node.lhs) - lhs_vars = [] for ast_var, value in zip(node.lhs, out): - var = tvm.te.var( - ast_var.id.name, - "int32", - span=tvm_span_from_synr(ast_var.span), - ) - self.context.update_symbol(var.name, var, node) - lhs_vars.append(var) - - body = self.parse_body(node) + self.context.update_symbol(ast_var.id.name, value, node) - for var, value in reversed(list(zip(lhs_vars, out))): - self.context.remove_symbol(var.name) - body = tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) - - return body + return self.parse_body(node) if isinstance(node.rhs, (ast.Call, ast.Constant)): # Pattern 4 of let binding diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune.py b/tests/python/unittest/test_mma_16x8x16_4k_tune.py index ddeb931ff9ed..856eb5d53659 100644 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune.py @@ -31,10 +31,6 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: T.writes(A_warp[thread_id, y]) A_warp[thread_id, y] = A_shared[v0, v1] - # T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) - # A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[ - # v0, v1 - # ] @T.prim_func def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None: @@ -83,10 +79,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None: with T.block("B_shared_warp"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(B_shared[v0, v1]) - T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) - B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[ - v0, v1 - ] + thread_id, y = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) + T.writes(B_warp[thread_id, y]) + B_warp[thread_id, y] = B_shared[v0, v1] @T.prim_func