diff --git a/CMakeLists.txt b/CMakeLists.txt index c8cdfc336c6f..3aa281d3d23a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -385,8 +385,8 @@ endif(USE_CUDA) if(USE_ROCM) tvm_file_glob(GLOB TILE_LIBRARY_HIP_SRCS - src/tl/target/codegen_rocm.cc - src/tl/target/rt_mod_rocm.cc + src/tl/target/codegen_hip.cc + src/tl/target/rt_mod_hip.cc ) list(APPEND TILE_LIBRARY_SRCS ${TILE_LIBRARY_HIP_SRCS}) endif(USE_ROCM) diff --git a/python/tvm/tl/utils.py b/python/tvm/tl/utils.py index fd0e83dc74bc..3b1f4cb73107 100644 --- a/python/tvm/tl/utils.py +++ b/python/tvm/tl/utils.py @@ -152,7 +152,7 @@ def assert_allclose(self, reference_program: callable, atol: float = 1e-8, rtol: # percentage_not_close = (num_not_close / total_elements) * 100 # print(f"{percentage_not_close:.2f}% of the elements are not close.") # print(f"Total elements: {total_elements}, Not close elements: {num_not_close}") - assert torch.allclose(lhs, rhs, rtol=rtol, atol=atol), (lhs, rhs) + torch.testing.assert_close(lhs, rhs, rtol=rtol, atol=atol) def assert_consistent(self, repeat=10): # Used to check no race condition inside the kernel diff --git a/src/tl/layout/gemm_layouts.cc b/src/tl/layout/gemm_layouts.cc index 04b3709d69bd..000211e6e634 100644 --- a/src/tl/layout/gemm_layouts.cc +++ b/src/tl/layout/gemm_layouts.cc @@ -101,8 +101,8 @@ Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int w ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m; ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n; auto base_layout = makeGemmFragmentCDNA16x16()->Repeat({1, 1}, false); - auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false); - auto block_layout = warp_layout->Repeat({warp_m / 16, warp_n / 16}, false, false); + auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, true); + auto block_layout = warp_layout->Repeat({warp_m / 16, warp_n / 16}, false, true); return block_layout; } @@ -271,6 +271,13 @@ Layout makeGemmABLayoutF64_Kouter(int stride, int continuous) { return Layout(Array{stride, continuous}, {tc, ts, index}); } +// The Default Layout for Tensor Access +Layout makeGemmLayoutLinear(int stride, int continuous) { + IterVar i = make_itervar("i", stride); + IterVar j = make_itervar("j", continuous); + return Layout(Array{i, j}, {i * continuous + j}); +} + Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size) { IterVar i = make_itervar("i", stride); IterVar j = make_itervar("j", continuous); diff --git a/src/tl/layout/layout.h b/src/tl/layout/layout.h index d90c5cca5253..072b3a065937 100644 --- a/src/tl/layout/layout.h +++ b/src/tl/layout/layout.h @@ -147,6 +147,9 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block const int warp_m, const int warp_n); Fragment makeGemmFragmentB(const int block_m, const int block_n, const int block_k, const int warp_m, const int warp_n); +// Default Memory Layout +Layout makeGemmLayoutLinear(int stride, int continuous); +Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size); Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfactor); Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m, diff --git a/src/tl/op/gemm.cc b/src/tl/op/gemm.cc index 6f702a3e43ee..c07599358606 100644 --- a/src/tl/op/gemm.cc +++ b/src/tl/op/gemm.cc @@ -218,8 +218,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) { results.Set(C, fragment); if (A.scope() == "shared" || A.scope() == "shared.dyn") { - results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]), - A->dtype.bits(), trans_A ? 1 : 2)); + auto shared_layout = + makeGemmLayoutLinear(*as_const_int(A->shape[0]), *as_const_int(A->shape[1])); + // TODO(lei): Handle Pad and CK Tile Swizzle + results.Set(A, shared_layout); } else if (A.scope() == "local.fragment") { ICHECK(trans_A == false); results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n)); @@ -227,8 +229,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) { ICHECK(0); } if (B.scope() == "shared" || B.scope() == "shared.dyn") { - results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]), - B->dtype.bits(), trans_B ? 2 : 1)); + auto shared_layout = + makeGemmLayoutLinear(*as_const_int(B->shape[0]), *as_const_int(B->shape[1])); + // TODO(lei): Handle Pad and CK Tile Swizzle + results.Set(B, shared_layout); } else if (B.scope() == "local.fragment") { ICHECK(trans_B == false); results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));