Skip to content

Commit

Permalink
TL HIP Codegeneration with Block Primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Oct 27, 2024
1 parent bccac9f commit 072a5a1
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 9 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ endif(USE_CUDA)

if(USE_ROCM)
tvm_file_glob(GLOB TILE_LIBRARY_HIP_SRCS
src/tl/target/codegen_rocm.cc
src/tl/target/rt_mod_rocm.cc
src/tl/target/codegen_hip.cc
src/tl/target/rt_mod_hip.cc
)
list(APPEND TILE_LIBRARY_SRCS ${TILE_LIBRARY_HIP_SRCS})
endif(USE_ROCM)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def assert_allclose(self, reference_program: callable, atol: float = 1e-8, rtol:
# percentage_not_close = (num_not_close / total_elements) * 100
# print(f"{percentage_not_close:.2f}% of the elements are not close.")
# print(f"Total elements: {total_elements}, Not close elements: {num_not_close}")
assert torch.allclose(lhs, rhs, rtol=rtol, atol=atol), (lhs, rhs)
torch.testing.assert_close(lhs, rhs, rtol=rtol, atol=atol)

def assert_consistent(self, repeat=10):
# Used to check no race condition inside the kernel
Expand Down
11 changes: 9 additions & 2 deletions src/tl/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int w
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragmentCDNA16x16()->Repeat({1, 1}, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
auto block_layout = warp_layout->Repeat({warp_m / 16, warp_n / 16}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, true);
auto block_layout = warp_layout->Repeat({warp_m / 16, warp_n / 16}, false, true);
return block_layout;
}

Expand Down Expand Up @@ -271,6 +271,13 @@ Layout makeGemmABLayoutF64_Kouter(int stride, int continuous) {
return Layout(Array<PrimExpr>{stride, continuous}, {tc, ts, index});
}

// The Default Layout for Tensor Access
Layout makeGemmLayoutLinear(int stride, int continuous) {
IterVar i = make_itervar("i", stride);
IterVar j = make_itervar("j", continuous);
return Layout(Array{i, j}, {i * continuous + j});
}

Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size) {
IterVar i = make_itervar("i", stride);
IterVar j = make_itervar("j", continuous);
Expand Down
3 changes: 3 additions & 0 deletions src/tl/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block
const int warp_m, const int warp_n);
Fragment makeGemmFragmentB(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n);
// Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfactor);

Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m,
Expand Down
12 changes: 8 additions & 4 deletions src/tl/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,21 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
results.Set(C, fragment);

if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]),
A->dtype.bits(), trans_A ? 1 : 2));
auto shared_layout =
makeGemmLayoutLinear(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]));
// TODO(lei): Handle Pad and CK Tile Swizzle
results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n));
} else {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]),
B->dtype.bits(), trans_B ? 2 : 1));
auto shared_layout =
makeGemmLayoutLinear(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]));
// TODO(lei): Handle Pad and CK Tile Swizzle
results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
ICHECK(trans_B == false);
results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
Expand Down

0 comments on commit 072a5a1

Please sign in to comment.