Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2022
1 parent b1b2ab0 commit 9fad599
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
15 changes: 12 additions & 3 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"):
@T.prim_func
def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
shared = T.match_buffer(
shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope=shared_scope,
shared_handle,
shmem_shape,
dtype,
align=128,
offset_factor=16,
scope=shared_scope,
)
warp = T.match_buffer(
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
Expand Down Expand Up @@ -413,10 +418,14 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None:
TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False))

LDMATRIX_16x16_A_DYN_INTRIN = "mma.ldmatrix_16x16_a_dyn"
TensorIntrin.register(LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn"))
TensorIntrin.register(
LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn")
)

LDMATRIX_16x16_B_DYN_INTRIN = "mma.ldmatrix_16x16_b_dyn"
TensorIntrin.register(LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn"))
TensorIntrin.register(
LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn")
)

LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans"
TensorIntrin.register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1085,5 +1085,4 @@ def index_map(i, j):


if __name__ == "__main__":
# tvm.testing.main()
test_three_stage_gemm()
tvm.testing.main()

0 comments on commit 9fad599

Please sign in to comment.