Skip to content

Commit

Permalink
seems to work
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent dd8ccf9 commit bcf212d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 27 deletions.
21 changes: 2 additions & 19 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,14 +554,12 @@ def transform_Assign(self, node):
4.1 var = T.allocate()
"""

print("parsing ", node.rhs.func_name)
if isinstance(node.rhs, ast.Call):
# Pattern 1 & Pattern 4
if isinstance(node.rhs.func_name, ast.Op):
func = None
else:
func = self.transform(node.rhs.func_name)
print(func)

if isinstance(func, WithScopeHandler):
if not func.concise_scope or not func.def_symbol:
Expand All @@ -582,27 +580,12 @@ def transform_Assign(self, node):
elif callable(func):
args = [self.transform(arg) for arg in node.rhs.params]
out = func(*args)
print(out)
print(node.lhs)
assert len(out) == len(node.lhs)

lhs_vars = []
for ast_var, value in zip(node.lhs, out):
var = tvm.te.var(
ast_var.id.name,
"int32",
span=tvm_span_from_synr(ast_var.span),
)
self.context.update_symbol(var.name, var, node)
lhs_vars.append(var)

body = self.parse_body(node)
self.context.update_symbol(ast_var.id.name, value, node)

for var, value in reversed(list(zip(lhs_vars, out))):
self.context.remove_symbol(var.name)
body = tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))

return body
return self.parse_body(node)

if isinstance(node.rhs, (ast.Call, ast.Constant)):
# Pattern 4 of let binding
Expand Down
11 changes: 3 additions & 8 deletions tests/python/unittest/test_mma_16x8x16_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None:
T.writes(A_warp[thread_id, y])
A_warp[thread_id, y] = A_shared[v0, v1]

# T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
# A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[
# v0, v1
# ]

@T.prim_func
def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
Expand Down Expand Up @@ -83,10 +79,9 @@ def ldmatrix_b_desc(a: T.handle, c: T.handle) -> None:
with T.block("B_shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(B_shared[v0, v1])
T.writes(B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2])
B_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = B_shared[
v0, v1
]
thread_id, y = shared_16x16_to_ldmatrix_32x8_layout(v0, v1)
T.writes(B_warp[thread_id, y])
B_warp[thread_id, y] = B_shared[v0, v1]


@T.prim_func
Expand Down

0 comments on commit bcf212d

Please sign in to comment.