diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index cbf3ba0c691e0..c5883fd072c57 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -115,7 +115,12 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"): @T.prim_func def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: shared = T.match_buffer( - shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope=shared_scope, + shared_handle, + shmem_shape, + dtype, + align=128, + offset_factor=16, + scope=shared_scope, ) warp = T.match_buffer( warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp" @@ -413,10 +418,14 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False)) LDMATRIX_16x16_A_DYN_INTRIN = "mma.ldmatrix_16x16_a_dyn" -TensorIntrin.register(LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn")) +TensorIntrin.register( + LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn") +) LDMATRIX_16x16_B_DYN_INTRIN = "mma.ldmatrix_16x16_b_dyn" -TensorIntrin.register(LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn")) +TensorIntrin.register( + LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn") +) LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans" TensorIntrin.register( diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index cb7efd9438084..f89dbacee33d8 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1085,5 +1085,4 @@ def index_map(i, j): if __name__ == "__main__": - # tvm.testing.main() - test_three_stage_gemm() + tvm.testing.main()