Skip to content

Commit

Permalink
Merge pull request #2 from TileLang/bitblas_tl
Browse files Browse the repository at this point in the history
[Dev] Merge the latest bitblas modification to upstream
  • Loading branch information
LeiWang1999 authored Oct 4, 2024
2 parents ef0837f + f1ad5c1 commit cd230c5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions python/tvm/tl/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ def Kernel(*blocks: List[tir.PrimExpr], threads: Union[int, List[int], Tuple] =
return _ffi_api.KernelLaunch(blocks, threads, attrs)


def use_swizzle(panel_size: int, order: str = "row"):
def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
device_func = (
"rasterization2DRow" if order == "row" else "rasterization2DColumn"
)
return T.attr(
None, "threadblock_swizzle_pattern", f"tl::{device_func}<{panel_size}>"
)
) if enable else None


def alloc_shared(shape, dtype, scope="shared.dyn"):
Expand Down
36 changes: 18 additions & 18 deletions src/tl/tl_templates/threadblock_swizzle.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,33 @@ namespace tl {

template <int panel_width>
TL_DEVICE dim3 rasterization2DRow() {
const int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const int grid_size = gridDim.x * gridDim.y;
const int panel_size = panel_width * gridDim.x;
const int panel_offset = block_idx % panel_size;
const int panel_idx = block_idx / panel_size;
const int total_panel = cutlass::ceil_div(grid_size, panel_size);
const int stride =
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.x;
const unsigned int panel_offset = block_idx % panel_size;
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.x;
const int col_idx =
const unsigned int col_idx =
(panel_idx & 1) ? gridDim.x - 1 - panel_offset / stride : panel_offset / stride;
const int row_idx = panel_offset % stride + panel_idx * panel_width;
const unsigned int row_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}

template <int panel_width>
TL_DEVICE dim3 rasterization2DColumn() {
const int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const int grid_size = gridDim.x * gridDim.y;
const int panel_size = panel_width * gridDim.y;
const int panel_offset = block_idx % panel_size;
const int panel_idx = block_idx / panel_size;
const int total_panel = cutlass::ceil_div(grid_size, panel_size);
const int stride =
const unsigned int block_idx = blockIdx.x + blockIdx.y * gridDim.x;
const unsigned int grid_size = gridDim.x * gridDim.y;
const unsigned int panel_size = panel_width * gridDim.y;
const unsigned int panel_offset = block_idx % panel_size;
const unsigned int panel_idx = block_idx / panel_size;
const unsigned int total_panel = cutlass::ceil_div(grid_size, panel_size);
const unsigned int stride =
panel_idx + 1 < total_panel ? panel_width : (grid_size - panel_idx * panel_size) / gridDim.y;
const int row_idx =
const unsigned int row_idx =
(panel_idx & 1) ? gridDim.y - 1 - panel_offset / stride : panel_offset / stride;
const int col_idx = panel_offset % stride + panel_idx * panel_width;
const unsigned int col_idx = panel_offset % stride + panel_idx * panel_width;
return {col_idx, row_idx, blockIdx.z};
}

Expand Down

0 comments on commit cd230c5

Please sign in to comment.