diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py index 2b2ef53a5462..678896310ba7 100644 --- a/python/tvm/dlight/__init__.py +++ b/python/tvm/dlight/__init__.py @@ -17,6 +17,7 @@ """DLight package provides efficient schedules out-of-box for deep learning workloads.""" from . import gpu from .base import ( + fast_tune, ApplyDefaultSchedule, ApplyFastTuning, BlockInfo, diff --git a/python/tvm/dlight/base/__init__.py b/python/tvm/dlight/base/__init__.py index 872a8c2f2ff3..d595d0dcdf25 100644 --- a/python/tvm/dlight/base/__init__.py +++ b/python/tvm/dlight/base/__init__.py @@ -27,3 +27,4 @@ from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial from .schedule_rule import ScheduleRule from .transform import ApplyDefaultSchedule, ApplyFastTuning +from .utils import fast_tune diff --git a/python/tvm/dlight/base/roller/node.py b/python/tvm/dlight/base/roller/node.py index 52508cb244fc..6ff3f09fdf42 100644 --- a/python/tvm/dlight/base/roller/node.py +++ b/python/tvm/dlight/base/roller/node.py @@ -113,32 +113,6 @@ def get_tag(self, k: str) -> Any: if k not in self._tag: return None return self._tag[k] -class BufferNode(Node): - """BufferNode is a wrapper of tir.Buffer, which is used to store the buffer information.""" - - def __init__(self, buffer: tir.Buffer, tags: Dict = {}) -> None: - super().__init__() - self.buffer = buffer - self._tag: Dict = {} - for tag in tags: - self.add_tag(tag, tags[tag]) - self.set_dtype(tvm.DataType(self.buffer.dtype)) - - def set_dtype(self, dtype: tvm.DataType, id=0) -> None: - assert isinstance(dtype, tvm.DataType), type(dtype) - if dtype == tvm.DataType("bool"): - dtype = tvm.DataType("int8") - if len(self._dtypes) <= id: - self._dtypes.extend([None for _ in range(id - len(self._dtypes) + 1)]) - elif self._dtypes[id] is not None: - assert self._dtypes[id] == dtype, (self._dtypes, dtype) - self._dtypes[id] = dtype - - def get_dtype(self, id=0) -> tvm.DataType: - return self._dtypes[id] - - def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType: - return tvm.DataType(buffer.dtype) class PrimFuncNode(Node): @@ -161,6 +135,9 @@ def __init__(self, prim_func: PrimFunc, tags: Dict = {}) -> None: def _specialize_func(self, func: PrimFunc): # Specialize the function to make it more friendly for analysis. + # set attrs + for k, v in func.attrs.items(): + self.set_tag(k, v) opt_shapes = self.get_tag("opt_shapes") if opt_shapes: for name, shape in opt_shapes.items(): @@ -277,7 +254,8 @@ def propogate_inputs(self, tile, rstep={}) -> List[List[int]]: continue # should not exceed original shape trimmed_shape = [ - self.extent_warpper(i) for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) + self.extent_warpper(i) + for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) ] results.append(trimmed_shape) return results diff --git a/python/tvm/dlight/base/transform.py b/python/tvm/dlight/base/transform.py index 186a9914e3a9..a6073eb42c87 100644 --- a/python/tvm/dlight/base/transform.py +++ b/python/tvm/dlight/base/transform.py @@ -34,7 +34,7 @@ from .roller.policy import DefaultPolicy, TensorCorePolicy from .roller.arch import CUDA from .schedule_rule import ScheduleRule -from .analysis import get_tensorized_func_and_tags +from ..gpu.matmul_analysis import get_tensorized_func_and_tags from .utils import apply_and_build diff --git a/python/tvm/dlight/base/utils.py b/python/tvm/dlight/base/utils.py index e5c0715445c1..71689bfb7e81 100644 --- a/python/tvm/dlight/base/utils.py +++ b/python/tvm/dlight/base/utils.py @@ -19,24 +19,31 @@ from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind, MapResult from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Dict from tvm import tir, IRModule from tvm.runtime import Module from tvm.tir import Schedule from tvm import dlight as dl from .analysis import get_root_block, get_reduction_blocks from .roller.arch import Arch +from tvm.dlight.base.roller.arch import CUDA +from tvm.dlight.base.roller.policy import TensorCorePolicy, DefaultPolicy +from tvm.dlight.gpu.matmul_analysis import get_tensorized_func_and_tags from ..base.roller.rasterization import NoRasterization import tempfile import re +import itertools +from tvm.ir.supply import GlobalVarSupply + def match_global_kernel(source: str) -> int: pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+" matched = re.findall(pattern, source) - assert len(matched) > 1 # may have statement before kernel + assert len(matched) > 1 # may have statement before kernel return source.index(matched[0]) -def get_rasterization_code(pannel_width:int = 8) -> str: + +def get_rasterization_code(pannel_width: int = 8) -> str: return f""" const int MAX_BLOCK_N = {pannel_width}; const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; @@ -49,9 +56,8 @@ def get_rasterization_code(pannel_width:int = 8) -> str: const auto bz = blockIdx.z; const dim3 blockIdx(bx, by, bz); """ - - ... - + + class CompileResult: """ Class to store the result of compilation @@ -123,8 +129,9 @@ def apply_and_build_parallel(func, configs, arch, num_repeats=5, max_workers=10) def var_warpper(v): if isinstance(v, tvm.tir.Var): - assert v.name in config.opt_shapes - return config.opt_shapes[v.name] + assert "opt_shapes" in func.attrs + assert v.name in func.attrs["opt_shapes"] + return func.attrs["opt_shapes"][v.name].value elif isinstance(v, tvm.tir.IntImm): return v.value else: @@ -132,6 +139,9 @@ def var_warpper(v): profile_tensors = [] for param in func.params: + if param not in func.buffer_map: + # in case of dynamic symbolic may in params + continue arg = func.buffer_map[param] if arg.dtype == "int8": profile_tensors.append( @@ -166,21 +176,20 @@ def var_warpper(v): # build in process parallel def _build(context) -> str: idx, mod, arch = context - config = configs[idx] + # TODO(lei): # this is a trick to implement rasteration, will be removed in the future - @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) - def tvm_callback_cuda_postproc(code, _): - index = code.index("{", match_global_kernel(code)) - if not isinstance(config.rasterization_plan, NoRasterization): - factor = config.rasterization_plan.panel_width_ - rasterization_code = get_rasterization_code(factor) - code = code[: index + 2] + rasterization_code + code[index + 2 :] - return code - - with tvm.transform.PassContext( - config={"tir.use_async_copy": True, "tir.merge_static_smem": False} - ): + # config = configs[idx] + # @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) + # def tvm_callback_cuda_postproc(code, _): + # index = code.index("{", match_global_kernel(code)) + # if not isinstance(config.rasterization_plan, NoRasterization): + # factor = config.rasterization_plan.panel_width_ + # rasterization_code = get_rasterization_code(factor) + # code = code[: index + 2] + rasterization_code + code[index + 2 :] + # return code + + with tvm.transform.PassContext(config={"tir.use_async_copy": True}): rt_mod = tvm.build(mod["main"], target=arch.target) from tvm.contrib.tar import tar # pylint: disable=import-outside-toplevel @@ -197,7 +206,8 @@ def tvm_callback_cuda_postproc(code, _): if map_result.status == StatusKind.TIMEOUT: print("[FastDlight] LocalBuilder: Timeout") elif map_result.status == StatusKind.EXCEPTION: - print("[FastDlight] LocalBuilder: An exception occurred ", map_result.value) + # TODO(lei): redirect the exception to file if needed + print("[FastDlight] LocalBuilder: An exception occurred ") continue elif map_result.status == StatusKind.COMPLETE: idx, code, artifact_path = map_result.value @@ -247,3 +257,201 @@ def apply_and_build( ) -> Tuple[List[CompileResult], CompileResult]: max_workers = 10 if parallel_build else 1 return apply_and_build_parallel(func, configs, arch, max_workers) + + +def fast_tune( + func: tir.PrimFunc, + target: tvm.target.Target, + topk: int = 10, + parallel_build: bool = True, +): + if target.kind.name != "cuda": + print("[FastDlight] Only support CUDA target") + return func + if "opt_shapes" in func.attrs: + # should be int value + if not all([isinstance(v.value, int) for v in func.attrs["opt_shapes"].values()]): + print("[FastDlight] The opt_shapes should be int value") + return func + + arch = CUDA(target) + + policy = DefaultPolicy(func=func, arch=arch) + try: + func, tags = get_tensorized_func_and_tags(func, arch.target) + except: + tags = None + if tags: + policy = TensorCorePolicy(func=func, arch=arch, tags=tags) + + configs = policy.emit_config(topk) + cpresults, best = apply_and_build(func, configs, arch, parallel_build=parallel_build) + + return cpresults, best + + +# always use the first function as the base +def collect_buffers_to_declare(func): + params = [] + # collect dynamic symbolic + dyn_symbolic: List[tvm.tir.Var] = [] + buffers_to_declare = [] + for param in func.params: + if param not in func.buffer_map: + continue + buffer = func.buffer_map[param] + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic: + dyn_symbolic.append(axis) + buffers_to_declare.append(buffer) + params.append(buffer.data) + + # the args should be buffers + dynamic symbolic + params += list(dyn_symbolic) + + return params, buffers_to_declare + + +# always use the first function as the base +def collect_buffers_to_declare(func): + params = [] + # collect dynamic symbolic + dyn_symbolic: List[tvm.tir.Var] = [] + buffers_to_declare = [] + for param in func.params: + if param not in func.buffer_map: + continue + buffer = func.buffer_map[param] + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic: + dyn_symbolic.append(axis) + buffers_to_declare.append(buffer) + params.append(buffer.data) + + # the args should be buffers + dynamic symbolic + params += list(dyn_symbolic) + + return params, buffers_to_declare + + +def refactor_specialized_func(func, params, buffers_to_declare): + body = func.body + attrs = func.attrs + global_symbol = func.attrs["global_symbol"] + if "opt_shapes" in func.attrs: + opt_shapes = func.attrs["opt_shapes"] + + def serialize_name(opt_shapes: Dict): + return "_opt_" + "_".join([f"{k}_{v}" for k, v in opt_shapes.items()]) + + global_symbol += serialize_name(opt_shapes) + ret_type = func.ret_type + for buf in buffers_to_declare: + body = tvm.tir.DeclBuffer(buf, body=body) + + device_func = tvm.tir.PrimFunc(params, body, ret_type, attrs=attrs).without_attr( + "global_symbol" + ) + return global_symbol, device_func + + +def create_dispatch_func(func: tir.PrimFunc, refactored_funcs: List[str]): + global_symbol = func.attrs["global_symbol"] + attrs = func.attrs + buffer_map = func.buffer_map + params = func.params + ret_type = func.ret_type + + # collect dynamic symbolic + dyn_symbolic: List[tvm.tir.Var] = [] + _invoke_params = [] + for param in func.params: + if param not in func.buffer_map: + continue + buffer = func.buffer_map[param] + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic: + dyn_symbolic.append(axis) + _invoke_params.append(buffer.data) + _invoke_params += list(dyn_symbolic) + + func_range: List[int] = [] + global_symbols = [] + for g_var, refactor_func in refactored_funcs: + opt_shapes = refactor_func.attrs["opt_shapes"] + func_range.append(list(opt_shapes.values())[0]) + global_symbols.append(g_var) + + # TODO(lei): general the dispatch function to support multiple dynamic symbolics + assert len(dyn_symbolic) == 1, "Only support one dyanmic symbolics currently" + + ib = tvm.tir.ir_builder.create() + syb = list(dyn_symbolic)[-1] + last_range = 0 + for i, (_range, g_var) in enumerate(zip(func_range, global_symbols)): + if i == 0: + with ib.if_scope(syb <= _range): + ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) + else: + with ib.if_scope(tvm.tir.all(syb > last_range, syb <= _range)): + ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) + last_range = _range + with ib.if_scope(syb > last_range): + ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) + stmt = ib.get() + dispatch_func = tvm.tir.PrimFunc(params, stmt, ret_type, buffer_map, attrs).with_attrs( + {"tir.is_global_func": True, "global_symbol": global_symbol} + ) + return dispatch_func + + +def create_dispatch_mod( + original_func: tir.PrimFunc, specialized_funcs: List[tir.PrimFunc] +) -> IRModule: + dispatch_mod: IRModule = tvm.IRModule() + g_var_supply = GlobalVarSupply(dispatch_mod) + refactored_funcs = [] + for func in specialized_funcs: + params, buffers_to_declare = collect_buffers_to_declare(func) + global_symbol, device_func = refactor_specialized_func(func, params, buffers_to_declare) + global_symbol = g_var_supply.fresh_global(global_symbol, add_prefix=False) + dispatch_mod[global_symbol] = device_func + refactored_funcs.append((global_symbol, device_func)) + dispatch_func = create_dispatch_func(original_func, refactored_funcs=refactored_funcs) + print(dispatch_func) + dispatch_mod.update(tvm.IRModule.from_expr(dispatch_func)) + return dispatch_mod + + +def fast_tune_with_dynamic_range( + func: tir.PrimFunc, target: tvm.target.Target, topk: int = 10, parallel_build: bool = True +) -> IRModule: + if target.kind.name != "cuda": + print("[FastDlight] Only support CUDA target") + return func + + if "opt_shapes" not in func.attrs: + print("[FastDlight] The primfunc has no opt_shapes, please set opt_shapes for the primfunc") + return func + else: + # should be list value + if not all([isinstance(v, tvm.ir.Array) for v in func.attrs["opt_shapes"].values()]): + print("[FastDlight] The opt_shapes should be list value") + return func + + print("[FastDlight] Start fast tuning with dynamic range") + opt_shapes = func.attrs["opt_shapes"] + + # Step 1.Calculate the Cartesian product using itertools.product + product_list = list(itertools.product(*(opt_shapes[key] for key in opt_shapes))) + + # Convert the Cartesian product to a list of dictionaries + specialize_items: List[Dict] = [dict(zip(opt_shapes.keys(), values)) for values in product_list] + + specilized_tuned_funcs: List[tir.PrimFunc] = [] + for item in specialize_items: + func = func.with_attr("opt_shapes", item) + _, best = fast_tune(func, target, topk, parallel_build) + specilized_tuned_funcs.append(best.sch.mod["main"]) + + return create_dispatch_mod(func, specilized_tuned_funcs) diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index bf2a51bd8598..ce4881816a57 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -122,14 +122,16 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if target.kind.name == "cuda" and utils.get_sm_version(target) >= 70: apply_tensorization: bool = True # the batch dimension is not taken into consideration. + # Analyze read/write buffers and choose correct tensorizer: int8 or fp16. + in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + if in_dtype not in ["int8", "float16"]: + apply_tensorization = False for item_var in block_stmt.iter_vars[1:]: extent = item_var.dom.extent if isinstance(extent, tir.expr.IntImm): if extent.value <= minimal_tensorize_threshold: apply_tensorization = False if apply_tensorization: - # Analyze read/write buffers and choose correct tensorizer: int8 or fp16. - in_dtype, out_dtype = get_in_out_dtypes(block_stmt) if in_dtype == "int8" and out_dtype == "int32": tensorize_sch = MatmulInt8Tensorization().apply(func, target, _) elif utils.get_sm_version(target) >= 80: diff --git a/python/tvm/dlight/gpu/matmul_analysis.py b/python/tvm/dlight/gpu/matmul_analysis.py index e406770a9178..d673403bbf5f 100644 --- a/python/tvm/dlight/gpu/matmul_analysis.py +++ b/python/tvm/dlight/gpu/matmul_analysis.py @@ -25,8 +25,13 @@ from tvm.tir import IterVar, PrimExpr, Var from tvm.tir.analysis import undefined_vars from tvm.tir.schedule.schedule import BlockRV -from ..base.analysis import collect_block_iter_vars_used_in_access_region +from ..base.analysis import ( + collect_block_iter_vars_used_in_access_region, + get_root_block, + get_reduction_blocks, +) from tvm.target.target import Target +from tvm.tir import IndexMap def _is_one(x: PrimExpr) -> bool: @@ -259,7 +264,6 @@ def get_index_map( """ traits = detect_iter_traits(block) if traits is None: - print("[WARNING] traits is None, the block is", block) return None A_traits, B_traits, C_traits, block_traits = traits @@ -339,31 +343,6 @@ def infer_layout(layout: str, region: List[Range], kind: str = "A"): ) -def get_reduction_blocks(sch, blocks) -> Optional[List[BlockRV]]: - # Get the main computation block - def is_reduction(block: BlockRV) -> bool: - block_stmt = sch.get(block) - iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} - return iter_types == {IterVar.CommReduce, IterVar.DataPar} - - def is_spatial(block: BlockRV) -> bool: - block_stmt = sch.get(block) - iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} - return iter_types == {IterVar.DataPar} - - # NOTE: We assume there is only one reduction block in the function - # all blocks are required to be spatial or reduction - if not all([is_reduction(block) or is_spatial(block) for block in blocks]): - return None - - # There is only one reduction block - reduction_blocks = [block for block in blocks if is_reduction(block)] - if len(reduction_blocks) != 1: - return None - - return reduction_blocks - - def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: """ Detect In/Out data types for the given block based on the analysis if read/write buffers. @@ -515,15 +494,6 @@ def check_sm_version(arch: str) -> int: sm_version = arch.replace("sm_", "") return int(sm_version) if sm_version.isdigit() else -1 - def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: - """ - Detect In/Out data types for the given block based on the analysis if read/write buffers. - """ - assert len(block.reads) > 0 and len(block.writes) > 0 - in_dtype = block.reads[0].buffer.dtype - out_dtype = block.writes[0].buffer.dtype - return (in_dtype, out_dtype) - def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool: tags: Dict[str, Union[List[int], int]] = {} block_stmt = sch.get(block) @@ -573,7 +543,10 @@ def check_last_trait(region: List[Range]): intrin_info["out_dtype"] = out_dtype # if the last dimension is reduce axis, the B is transposed intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region) - + if "smooth_a" in func.attrs: + intrin_info["smooth_a"] = func.attrs["smooth_a"] + if "smooth_b" in func.attrs: + intrin_info["smooth_b"] = func.attrs["smooth_b"] tags["intrin_info"] = intrin_info return tags @@ -613,3 +586,45 @@ def check_last_trait(region: List[Range]): return sch.mod["main"], tags return func, None + + +def get_propagate_map(trans: bool = True, dtype="float16"): + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + ldmatrix_32x8_to_shared_16x16_layout, + ldmatrix_trans_32x8_to_shared_16x16_layout, + ) + + assert dtype in ["float16"], "Only support float16 for now" + + ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout + ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout + + # IntraWarp memory layout was occurred by ldmatrix, we should lift the ld_matrix out + def ldmatrix_permutation_16x16_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + return ldmatrix_layout(thread_id, local_id) + + def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + return ldmatrix_layout_trans(thread_id, local_id) + + ldmatrix_index_map = ( + ldmatrix_trans_permutation_16x16_32x8_16x16 + if trans + else ldmatrix_permutation_16x16_32x8_16x16 + ) + + def permutation(i, j, kernel_i, kernel_j): + return ( + i, + j, + *ldmatrix_index_map(kernel_i, kernel_j), + ) + + # TODO(lei): index_dtype should be analyzed from the schedule + inversed_index_map = IndexMap.from_func( + ldmatrix_index_map, index_dtype="int32" + ).inverse([16, 16]) + return permutation, inversed_index_map diff --git a/python/tvm/dlight/gpu/matmul_mma.py b/python/tvm/dlight/gpu/matmul_mma.py index 5da2d1f36bd4..54ed2a1ac5b8 100644 --- a/python/tvm/dlight/gpu/matmul_mma.py +++ b/python/tvm/dlight/gpu/matmul_mma.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=missing-docstring, invalid-name """A GEMM schedule rule for GPU operators.""" -from typing import Literal, Optional +from typing import Literal, Optional, List from tvm import tir from tvm.target import Target @@ -24,19 +24,56 @@ from ..base.roller.rasterization import NoRasterization from ..base import analysis from .base import GPUScheduleRule +from .matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo from .matmul_analysis import ( auto_inline_consumer_chain, is_transpose_block, is_identity_block, + _collect_producers, inline_transpose_block, auto_inline_producers, get_index_map, get_reduction_blocks, get_dequantize_block, normalize_to_matmul, + get_propagate_map, ) +def get_index_map_3d(index_map, l=16, r=16): + def index_map_3d(b, i, j): + return ( + b, + i // l, + j // r, + *index_map(i % l, j % r), + ) + + return index_map_3d + + +def get_index_map_5d(index_map): + """ + for layout transformed gemm, the index map should be 5d + """ + + def index_map_5d(b, i, j, ii, jj): + return ( + b, + i, + j, + *index_map(ii, jj), + ) + + return index_map_5d + + +def get_warp_index_map(index_map, l=16, r=16, is_5d=False): + if is_5d: + return get_index_map_5d(index_map) + return get_index_map_3d(index_map, l, r) + + class MatmulTensorizationMMA(GPUScheduleRule): """ The schedule rule for float16 tensor core matmul computation. @@ -121,8 +158,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring swizzle_factor_for_l2_m = [1, None] swizzle_factor_for_l2_n = [1, None] - # swizzle_factor_for_l2_m = [4, None] - # swizzle_factor_for_l2_n = [4, None] # Step 2. Padding for dynamic shape kernels sch.pad_einsum( @@ -161,7 +196,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring j0, j1, j2, j3 = sch.split(j, factors=j_factors) k0, k1 = sch.split(k, factors=k_factors) - sch.reorder(i0, j0, i1, j1, k0, i2, j2, k1, i3, j3) + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) block_axis = sch.fuse(batch, i0, j0, i1, j1) sch.bind(block_axis, "blockIdx.x") @@ -242,14 +277,9 @@ def store_output(block_outer, write_buffer_idx): # bind loops fused = sch.fuse(*sch.get_loops(block_write_smem)[-2:]) - f0, f1, f2, f3, f4 = sch.split(fused, [None, thread_z, thread_y, thread_x, vector_size]) - sch.bind(f1, "threadIdx.z") - sch.bind(f2, "threadIdx.y") - sch.bind(f3, "threadIdx.x") - sch.vectorize(f4) - - # swizzling - sch.annotate(block_write_smem, ann_key="permuted_layout", ann_val=1) + f0, f1, f2 = sch.split(fused, [None, thread_x, vector_size]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) # 2) Write to register block_write_reg = sch.cache_write(block_outer, write_buffer_idx, "warp") @@ -274,10 +304,6 @@ def store_output(block_outer, write_buffer_idx): ), ) - # swizzling - mma_read_block = sch.blockize(sch.get_loops(block_write_reg)[-2]) - sch.annotate(mma_read_block, ann_key="permuted_layout", ann_val=1) - return block_write_smem, block_write_reg block_write_smem, block_write_reg = store_output(block_outer, 0) @@ -286,12 +312,6 @@ def store_output(block_outer, write_buffer_idx): block_init = sch.decompose_reduction(block_outer, k0) block_init_inner = sch.get_child_blocks(block_init)[0] - # unroll k - # Profiling result shows unrolling k0 is not helpful on A100 - # sch.unroll(k0) - # k00, k01 = sch.split(k0, factors=[None, 8]) - # sch.unroll(k01) - intrin_group = get_mma_intrin_group( load_scope="shared.dyn", store_scope="shared.dyn", @@ -299,7 +319,7 @@ def store_output(block_outer, write_buffer_idx): out_dtype=str(dtype_c), trans_a=is_transpose_a, trans_b=is_transpose_b, - not_use_mma_store_intrinic=False + not_use_mma_store_intrinic=False, ) sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) @@ -331,18 +351,16 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring func: tir.PrimFunc, config, ) -> Optional[tir.Schedule]: - if "dequantize_info" in func.attrs: dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() return dequantize_rule.apply_config(func, config) - + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel get_mma_intrin_group, ) - + sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) - output_blocks = sch.get_output_blocks(root_block) blocks = sch.get_child_blocks(root_block) if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): @@ -353,14 +371,45 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring return None main_block = reduction_blocks[0] - cache_write_required = True # main_block not in output_blocks, it should skip the analysis of reindex + output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)] + def check_require_cache(func:tir.PrimFunc): + conditions:List[bool] = [] + # check if has dynamic symbolic + def check_has_dynamic(func:tir.PrimFunc): + for param in func.params: + if param not in func.buffer_map: + continue + arg = func.buffer_map[param] + for i in arg.shape: + if isinstance(i, tir.Var): + return True + return False + conditions.append(check_has_dynamic(func)) + # check if has post process + conditions.append(sch.get(main_block) not in output_blocks) + return any(conditions) + cache_write_required = check_require_cache(func) + + shared_scope = "shared" + + intrin_info = config.intrin_info + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + in_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) # Start Schedule # Step 0. Get schedule config. # NOTE: we can analyze the config by the hardware spec in the future # tensor core intrinsic size - intrin_info = config.intrin_info warp_row_tiles = config.warp[0] warp_col_tiles = config.warp[1] block_row_warps = config.block[0] // warp_row_tiles @@ -368,12 +417,25 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring stage = config.pipeline_stage use_async = config.use_async chunk = config.rstep[0] - - shared_scope = "shared.dyn" + + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): + return (r, l) if trans else (l, r) + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) - micro_size_x = 16 - micro_size_y = 16 - micro_size_k = 16 + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16": + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_b) warp_size = 32 @@ -423,7 +485,7 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring j0, j1, j2, j3 = sch.split(j, factors=j_factors) k0, k1 = sch.split(k, k_factors) - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) block_idy = sch.fuse(i0, j0) block_idx = sch.fuse(i1, j1) @@ -439,13 +501,12 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring factor = config.rasterization_plan.panel_width_ # TODO(lei): this is a trick for rasterization implementation - # is not optimal. - # wait for https://github.com/apache/tvm/pull/16113 to be merged + # is not optimal. (5% performance loss) # require a solution for general block rasterization - # factor = 4 # should be divisible by block_idy - # if sch.get(block_idx).extent.value % factor == 0: - # block_k, block_idx = sch.split(block_idx, factors=[None, factor]) - # sch.bind(block_k, "blockIdx.z") + factor = 4 # should be divisible by block_idx + if sch.get(block_idx).extent.value % factor == 0: + block_k, block_idx = sch.split(block_idx, factors=[None, factor]) + sch.bind(block_k, "blockIdx.z") else: sch.bind(batch, "blockIdx.z") @@ -454,50 +515,89 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring sch.bind(thread_idy, "threadIdx.y") sch.bind(thread_idz, "threadIdx.z") - if intrin_info.smooth_b: + # rewrite smooth layout of shared memory + def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): + if not enable: + return sch.transform_layout( - block_outer, - ("read", 1), + block, + scope, lambda b, i, j: ( b, - i // micro_size_y, - j // micro_size_k, - i % micro_size_y, - j % micro_size_k, + i // l, + j // r, + i % l, + j % r, ), ) - def fetch_to_shared(block, idx, vec_len): + smooth_smem_layout_rewrite(block_outer, ("read", 0), *a_lr, enable=intrin_info.smooth_a) + smooth_smem_layout_rewrite(block_outer, ("read", 1), *b_lr, enable=intrin_info.smooth_b) + smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False): block_read = sch.cache_read(block, idx, shared_scope) sch.compute_at(block_read, k0, preserve_unit_loops=True) ndim = len(sch.get(block_read).iter_vars) fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - f_1, f_2, _, f_3, f_4 = sch.split( + f_0, f_1, f_2, f_3, f_4 = sch.split( fused, factors=[num_ty, num_tz, None, warp_size, vec_len] ) sch.bind(f_3, "threadIdx.x") - sch.bind(f_2, "threadIdx.z") - sch.bind(f_1, "threadIdx.y") + sch.bind(f_1, "threadIdx.z") + sch.bind(f_0, "threadIdx.y") sch.vectorize(f_4) + sch.unroll(f_2) + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + sch.annotate(f_2, "pragma_unroll_explicit", False) return block_read a_g2s = fetch_to_shared( block_outer, 0, vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + trans=intrin_info.trans_a, ) b_g2s = fetch_to_shared( block_outer, 1, vec_len=list(config.vectorize.values())[1], + can_swizzle=can_swizzle_b, + is_smooth=intrin_info.smooth_b, + trans=intrin_info.trans_b, ) - # Apply Swizzling - sch.annotate(a_g2s, ann_key="permuted_layout", ann_val=True) - sch.annotate(b_g2s, ann_key="permuted_layout", ann_val=(not intrin_info.smooth_b)) + # rewrite global smooth layout + def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False): + if not enable: + return + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + producers = _collect_producers(sch, block) + propagate_block: tir.Block = producers[-1] + + # step2: transform the layout with inverse permutation + _, inverse_indexmap = get_propagate_map(trans=trans, dtype=intrin_info.in_dtype) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *inverse_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + smooth_gmem_layout_rewrite(sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a) + smooth_gmem_layout_rewrite(sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b) auto_inline_producers(sch, a_g2s) auto_inline_producers(sch, b_g2s) @@ -512,12 +612,7 @@ def fetch_to_shared(block, idx, vec_len): accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, thread_idy) - - if cache_write_required: - sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) - else: - auto_inline_consumer_chain(sch, store) + sch.reverse_compute_at(store, j2) # split the store loop to match hardware intrinsic pattern i, j = sch.get_loops(store)[-2:] @@ -525,448 +620,56 @@ def fetch_to_shared(block, idx, vec_len): j0, j1 = sch.split(j, factors=[None, micro_size_y]) sch.reorder(i0, j0, i1, j1) - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - intrin_group = get_mma_intrin_group( - load_scope=shared_scope, - store_scope=shared_scope if cache_write_required else "global", - in_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_a=False, - trans_b=intrin_info.trans_b, - smooth_b=intrin_info.smooth_b, - not_use_mma_store_intrinic=False, - ) - - index_map_a, index_map_b, index_map_c = intrin_group["index_map"] - - def get_index_map_3d(index_map): - def index_map_3d(b, i, j): - return ( - b, - i // 16, - j // 16, - *index_map(i % 16, j % 16), - ) - - return index_map_3d - - def get_index_map_5d(index_map): - """ - for layout transformed gemm, the index map should be 5d - """ - def index_map_5d(b, i, j, ii, jj): - return ( - b, i, j, - *index_map(ii, jj), - ) - - return index_map_5d - - def get_index_map(index_map, is_smooth=False): - if is_smooth: - return get_index_map_5d(index_map) - return get_index_map_3d(index_map) - - sch.transform_layout(A_mat, ("write", 0), get_index_map(index_map_a)) - sch.transform_layout(B_mat, ("write", 0), get_index_map(index_map_b, intrin_info.smooth_b)) - sch.transform_layout(store, ("read", 0), get_index_map(index_map_c)) - - try: - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - ba = sch.blockize(i1) - sch.annotate(ba, ann_key="permuted_layout", ann_val=True) - sch.tensorize(ba, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val=(not intrin_info.smooth_b)) - sch.tensorize(bb, intrin_group["load_b"]) - except Exception: # pylint: disable=bare-except - return None - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - tensorize_init_store_compute() - if cache_write_required: auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, sch.get_loops(store)[-3], preserve_unit_loops=True + ) - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) - _, f1, f2 = sch.split( + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split( fused, factors=[None, warp_size, max(list(config.vectorize.values()))] ) sch.bind(f1, "threadIdx.x") sch.vectorize(f2) - - if stage > 1: - sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - - return sch - - - -class MatmulTensorizationMMAWithDequantizeInfo(GPUScheduleRule): - """ - The schedule rule for float16 tensor core matmul computation. - func with attr 'dlight.do_not_tensorize' will not be tensorized. - """ - def sch_dequantize_in_register_with_config( - self, - func: tir.PrimFunc, - config, - ): - """ - quantized weight - | - V - dequantized in register - | - V - save into shared memory - | - V - compute - """ - return None - - def sch_shared_memory_prefetch_with_config( - self, - func: tir.PrimFunc, - config, - ): - ''' - For A100 Like devices, the shared memory prefetch(async) is required - to achieve optimal performance. - quantized weight - | - V - shared memory prefetch (with async copy) - | - V - dequantized into shared memory - | - V - compute - ''' - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group, - ) - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - # TODO(leiwang): this is a hack to get the configuaration, should write a pass to analysis - dequantize_info = func.attrs['dequantize_info'] - - def check_dequantize_info(dequantize_info): - # currently only support weight only dequantization - conditions = [] - conditions.append(len(dequantize_info) == 1) - # more conditions, e.g. check the format is in [fp, nf, int] - # check if the dequantize value name is weight - return all(conditions) - - assert check_dequantize_info(dequantize_info) - - B_decode_info, = list(dequantize_info.values()) - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - cache_write_required = True # main_block not in output_blocks, it should skip the analysis of reindex - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # tensor core intrinsic size - intrin_info = config.intrin_info - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - chunk = config.rstep[0] - - shared_scope = "shared.dyn" - - micro_size_x = 16 - micro_size_y = 16 - micro_size_k = 16 - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, chunk // micro_size_k], - ) - - num_ty = i_factors[2] - num_tz = j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) - - block_idy = sch.fuse(i0, j0) - block_idx = sch.fuse(i1, j1) - thread_idy = i2 - thread_idz = j2 - - # plan rasteration - if ( - not isinstance(config.rasterization_plan, NoRasterization) - and sch.get(batch).extent.value == 1 - ): - device_func, invoke_func = config.rasterization_plan.get_code() - factor = config.rasterization_plan.panel_width_ - - # TODO(lei): this is a trick for rasterization implementation - # is not optimal. - # wait for https://github.com/apache/tvm/pull/16113 to be merged - # require a solution for general block rasterization - # factor = 4 # should be divisible by block_idy - # if sch.get(block_idx).extent.value % factor == 0: - # block_k, block_idx = sch.split(block_idx, factors=[None, factor]) - # sch.bind(block_k, "blockIdx.z") - else: - sch.bind(batch, "blockIdx.z") - - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - sch.bind(thread_idz, "threadIdx.z") - - if intrin_info.smooth_b: - sch.transform_layout( - block_outer, - ("read", 1), - lambda b, i, j: ( - b, - i // micro_size_y, - j // micro_size_k, - i % micro_size_y, - j % micro_size_k, - ), - ) - - def fetch_to_shared(block, idx, vec_len): - block_read = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_read, k0, preserve_unit_loops=True) - ndim = len(sch.get(block_read).iter_vars) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - f_1, f_2, _, f_3, f_4 = sch.split( - fused, factors=[num_ty, num_tz, None, warp_size, vec_len] - ) - - sch.bind(f_3, "threadIdx.x") - sch.bind(f_2, "threadIdx.z") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_4) - return block_read - - a_g2s = fetch_to_shared( - block_outer, - 0, - vec_len=list(config.vectorize.values())[0], - ) - - def decode_fetch_to_shared( - block, idx, vec_len - ): - # step1. create memory hierarchy - # global -> local -> shared - block_shared = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_shared, k0) - # TODO(lei): the factor shoule be analyzed more deeper. - B_shared_jj, B_shared_vi, B_shared_vj = sch.split( - sch.get_loops(block_shared)[-1], factors=[None, 1, 8]) - block_shared_local = sch.cache_read(block_shared, 0, "local") - # global -> dequantzed_local -> shared - # step2. inline to local block - auto_inline_producers(sch, block_shared_local) - # global -> prefetch_local -> dequantzed_local -> shared - block_shared_local_local = sch.cache_read(block_shared_local, 0, "local") - # global -> prefetch_shared -> vector load -> dequantzed_local -> shared - block_shared_local_local_shared = sch.cache_read( - block_shared_local_local, - 0, - shared_scope - ) - sch.compute_at(block_shared_local, B_shared_vi) - sch.compute_at(block_shared_local_local, B_shared_vi) - - sch.annotate(block_shared, ann_key="permuted_layout", ann_val=(not intrin_info.smooth_b)) - union_len = (2 + 2) # if smooth_b, the value should be -(2 + 4) - B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) - B_shared_inner, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( - B_shared_fused, factors=[None, block_row_warps, block_col_warps, warp_size]) - sch.compute_at(block_shared_local_local_shared, k0, preserve_unit_loops=True) - - sch.bind(B_shared_tx, "threadIdx.x") - sch.bind(B_shared_ty, "threadIdx.y") - sch.bind(B_shared_tz, "threadIdx.z") - sch.vectorize(sch.get_loops(block_shared)[-1]) - sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) - - decode_fetch_to_shared( - block_outer, - 1, - vec_len=list(config.vectorize.values())[1] - ) - - # Apply Swizzling - sch.annotate(a_g2s, ann_key="permuted_layout", ann_val=True) - - auto_inline_producers(sch, a_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "warp") - B_mat = sch.cache_read(block_outer, 1, "warp") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - if cache_write_required: - accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) - - store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, thread_idy) - - if cache_write_required: - sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) else: auto_inline_consumer_chain(sch, store) - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - block_init_c = sch.decompose_reduction(block_outer, k0) block_init_c_inner = sch.get_child_blocks(block_init_c)[0] # Tensorization by hardware intrinsics - intrin_group = get_mma_intrin_group( - load_scope=shared_scope, - store_scope=shared_scope if cache_write_required else "global", - in_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_a=False, - trans_b=intrin_info.trans_b, - smooth_b=intrin_info.smooth_b, - not_use_mma_store_intrinic=False, - ) - index_map_a, index_map_b, index_map_c = intrin_group["index_map"] - def get_index_map_3d(index_map): - def index_map_3d(b, i, j): - return ( - b, - i // 16, - j // 16, - *index_map(i % 16, j % 16), - ) - - return index_map_3d - - def get_index_map_5d(index_map): - def index_map_5d(b, i, j, ii, jj): - return ( - b, i, j, - *index_map(ii, jj), - ) + sch.transform_layout( + A_mat, ("write", 0), get_warp_index_map(index_map_a, *a_lr, intrin_info.smooth_a) + ) + sch.transform_layout( + B_mat, ("write", 0), get_warp_index_map(index_map_b, *b_lr, intrin_info.smooth_b) + ) + sch.transform_layout( + store, + ("read", 0), + get_warp_index_map(index_map_c, is_5d=True), + ) - return index_map_5d - - def get_index_map(index_map, is_smooth=False): - if is_smooth: - return get_index_map_5d(index_map) - return get_index_map_3d(index_map) - - sch.transform_layout(A_mat, ("write", 0), get_index_map(index_map_a)) - sch.transform_layout(B_mat, ("write", 0), get_index_map(index_map_b, intrin_info.smooth_b)) - sch.transform_layout(store, ("read", 0), get_index_map(index_map_c)) - - try: - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - ba = sch.blockize(i1) - sch.annotate(ba, ann_key="permuted_layout", ann_val=True) - sch.tensorize(ba, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val=(not intrin_info.smooth_b)) - sch.tensorize(bb, intrin_group["load_b"]) - except Exception: # pylint: disable=bare-except - return None + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) def tensorize_init_store_compute(): sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) @@ -975,37 +678,10 @@ def tensorize_init_store_compute(): tensorize_init_store_compute() - if cache_write_required: - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) - _, f1, f2 = sch.split( - fused, factors=[None, warp_size, max(list(config.vectorize.values()))] - ) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - if stage > 1: - sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1, stage - 1]) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2, 3]) + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) if use_async: sch.annotate(k0, "software_pipeline_async_stages", [0]) - return sch - def apply_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config, - ) -> Optional[tir.Schedule]: - def check_sm_version(arch: str) -> int: - sm_version = arch.replace("sm_", "") - return int(sm_version) if sm_version.isdigit() else -1 - - if check_sm_version(config.arch.target.arch) < 80: - """MMA Template only support sm_80 and above""" - return None - - if config.arch.target.kind.name == "cuda" and check_sm_version(config.arch.target.arch) == 80: - return self.sch_shared_memory_prefetch_with_config(func, config) - else: - return self.sch_with_config(func, config) \ No newline at end of file + return sch diff --git a/python/tvm/dlight/gpu/matmul_mma_dequantize.py b/python/tvm/dlight/gpu/matmul_mma_dequantize.py new file mode 100644 index 000000000000..3f902a7ac67c --- /dev/null +++ b/python/tvm/dlight/gpu/matmul_mma_dequantize.py @@ -0,0 +1,540 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from typing import Literal, Optional + +from tvm import tir +from tvm.target import Target + +from ..base.roller.rasterization import NoRasterization +from ..base import analysis +from .base import GPUScheduleRule +from .matmul_analysis import ( + auto_inline_consumer_chain, + auto_inline_producers, + get_reduction_blocks, + get_dequantize_block, + normalize_to_matmul, +) + + +def get_index_map_3d(index_map, l=16, r=16): + def index_map_3d(b, i, j): + return ( + b, + i // l, + j // r, + *index_map(i % l, j % r), + ) + + return index_map_3d + + +def get_index_map_5d(index_map): + """ + for layout transformed gemm, the index map should be 5d + """ + + def index_map_5d(b, i, j, ii, jj): + return ( + b, + i, + j, + *index_map(ii, jj), + ) + + return index_map_5d + + +def get_index_map(index_map, l=16, r=16, is_5d=False): + if is_5d: + return get_index_map_5d(index_map) + return get_index_map_3d(index_map, l, r) + + +class MatmulTensorizationMMAWithDequantizeInfo(GPUScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def sch_dequantize_in_register_with_config( + self, + func: tir.PrimFunc, + config, + ): + """ + Simple dequantize schedule without shared memory prefetch. + quantized weight + | + V + dequantized in register + | + V + save into shared memory + | + V + compute + """ + + return None + + def sch_shared_memory_prefetch_with_config( + self, + func: tir.PrimFunc, + config, + ): + """ + For A100 Like devices, the shared memory prefetch(async) is required + to achieve optimal performance. + quantized weight + | + V + shared memory prefetch (with async copy) + | + V + dequantized into shared memory + | + V + compute + """ + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group, + ) + from .intrin.lop3 import get_lop3_intrin_group + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + # always enable shared memory rewrite + cache_write_required = True + + # Check Dequantize Info + # TODO(leiwang): this is a hack to get the configuaration, can be improved by writing a pass to analysis the dequantize block. + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + assert check_dequantize_info(dequantize_info) + + (B_decode_info,) = list(dequantize_info.values()) + + def check_b_decode_info(B_decode_info): + conditions = [] + # check source format in ["int", "fp", "af"] + conditions.append("source_format" in B_decode_info) + conditions.append(B_decode_info["source_format"]["format"] in ["int", "fp", "af"]) + # check source bits in [1, 2, 4, 8] + conditions.append(B_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in B_decode_info) + conditions.append(B_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + assert check_b_decode_info(B_decode_info) + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + intrin_info = config.intrin_info + shared_scope = "shared" + + intrin_info = config.intrin_info + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + in_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + chunk = config.rstep[0] + + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): + return (r, l) if trans else (l, r) + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16": + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.smooth_b) + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + # plan rasteration + if ( + not isinstance(config.rasterization_plan, NoRasterization) + and sch.get(batch).extent.value == 1 + ): + device_func, invoke_func = config.rasterization_plan.get_code() + factor = config.rasterization_plan.panel_width_ + + # TODO(lei): this is a trick for rasterization implementation + # is not optimal. + # require a solution for general block rasterization + factor = 8 # should be divisible by block_idy + if sch.get(block_idx).extent.value % factor == 0: + block_k, block_idx = sch.split(block_idx, factors=[None, factor]) + sch.bind(block_k, "blockIdx.z") + else: + sch.bind(batch, "blockIdx.z") + + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + sch.bind(thread_idz, "threadIdx.z") + + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.smooth_a) + smooth_layout_recover( + block_outer, + ("read", 1), + *b_lr, + enable=intrin_info.smooth_b, + ) + smooth_layout_recover(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[num_ty, num_tz, None, warp_size, vec_len] + ) + + sch.bind(f_3, "threadIdx.x") + sch.bind(f_1, "threadIdx.z") + sch.bind(f_0, "threadIdx.y") + sch.vectorize(f_4) + sch.unroll(f_2) + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + sch.annotate(f_2, "pragma_unroll_explicit", False) + return block_read + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + ) + + auto_inline_producers(sch, a_g2s) + + def decode_fetch_to_shared(block, idx): + # step1. create memory hierarchy + # global -> local -> shared + block_shared = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_shared, k0, preserve_unit_loops=True) + + # TODO(lei): the factor shoule be analyzed more deeper. + _, B_shared_vi, _ = sch.split(sch.get_loops(block_shared)[-1], factors=[None, 1, 8]) + block_shared_local = sch.cache_read(block_shared, 0, "local") + # global -> dequantzed_local -> shared + # step2. inline to local block + auto_inline_producers(sch, block_shared_local) + + # get target dequantize buffer's idx + def get_idx(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structual based way + # to analysis the idx + if B_decode_info["source_format"]["format"] == "af": + return 1 + return 0 + + b_idx = get_idx() + # global -> prefetch_local -> dequantzed_local -> shared + block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") + # global -> prefetch_shared -> vector load -> dequantzed_local -> shared + block_shared_local_local_shared = sch.cache_read( + block_shared_local_local, 0, shared_scope + ) + sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) + sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) + + dequantize_block = block_shared_local + # fast type conversion + if "fast_decoding" in B_decode_info and B_decode_info["fast_decoding"]: + intrin_group = get_lop3_intrin_group( + in_dtype="int8", out_dtype="float16", storage_nbit=4, with_scale=False + ) + sch.tensorize(sch.get_loops(dequantize_block)[-1], intrin_group["compute"]) + sch.annotate( + thread_idz, ann_key="pragma_import_c", ann_val=intrin_group["c_source"] + ) + + sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) + union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) + B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) + _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( + B_shared_fused, factors=[None, num_ty, num_tz, warp_size] + ) + if not (can_swizzle_b or intrin_info.smooth_b): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) + sch.bind(B_shared_tx, "threadIdx.x") + sch.bind(B_shared_ty, "threadIdx.y") + sch.bind(B_shared_tz, "threadIdx.z") + sch.vectorize(sch.get_loops(block_shared)[-1]) + sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) + + sch.compute_at(block_shared_local_local_shared, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_shared_local_local_shared).iter_vars) + fused = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) + + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[None, num_tz, num_ty, warp_size, 16] # int8x16 = 128bits + ) + + sch.bind(f_3, "threadIdx.x") + sch.bind(f_2, "threadIdx.y") + sch.bind(f_1, "threadIdx.z") + sch.vectorize(f_4) + sch.unroll(f_0) + sch.annotate(f_0, "pragma_unroll_explicit", False) + + # cache small tensors, e.g. LUT + if b_idx: + block_shared_lut = sch.cache_read(dequantize_block, 0, shared_scope) + sch.reverse_compute_at(block_shared_lut, j2) + _, B_shared_tx = sch.split( + sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size] + ) + sch.bind(B_shared_tx, "threadIdx.x") + return block_shared_local + + dequantize_block = decode_fetch_to_shared(block_outer, 1) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1, preserve_unit_loops=True) + sch.compute_at(B_mat, k1, preserve_unit_loops=True) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x]) + j0, j1 = sch.split(j, factors=[None, micro_size_y]) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, sch.get_loops(store)[-3], preserve_unit_loops=True + ) + + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split( + fused, factors=[None, warp_size, max(list(config.vectorize.values()))] + ) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, ("write", 0), get_index_map(index_map_a, *a_lr, intrin_info.smooth_a) + ) + sch.transform_layout( + B_mat, ("write", 0), get_index_map(index_map_b, *b_lr, intrin_info.smooth_b) + ) + sch.transform_layout( + store, + ("read", 0), + get_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate( + k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1, stage - 1] + ) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2, 3]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> Optional[tir.Schedule]: + def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + + if check_sm_version(config.arch.target.arch) < 80: + """MMA Template only support sm_80 and above""" + return None + + if ( + config.arch.target.kind.name == "cuda" + and check_sm_version(config.arch.target.arch) == 80 + ): + return self.sch_shared_memory_prefetch_with_config(func, config) + else: + return self.sch_with_config(func, config) diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index ef7fb13adca9..76fb0e07f2a1 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -90,7 +90,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- block_info, arith.normalize_to_iter_sum( detect_dominant_read(block_stmt), - input_iters={i.var: i.dom.extent for i in block_stmt.iter_vars}, + input_iters={i.var: i.dom for i in block_stmt.iter_vars}, ), ) if is_inner_reduction is None and c_factor is None: @@ -113,7 +113,7 @@ def _normalize( # pylint: disable=too-many-branches access: arith.IterSumExpr, ) -> Tuple[Optional[bool], Optional[int]]: if access.base != 0: - return None, None + return None, None, None, None iter_to_info = {i.var: i for i in block_info.iters} s_loops, r_loops, c_loops, c_factor = [], [], [], None s_split_loop, s_split_index = None, None @@ -124,7 +124,7 @@ def _normalize( # pylint: disable=too-many-branches is_inner_reduction = info.kind == "R" if split_expr.lower_factor > 1: if c_loops: - return None, None + return None, None, None, None s_split_loop = loop s_split_index = len(s_loops) loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) @@ -141,7 +141,7 @@ def _normalize( # pylint: disable=too-many-branches if info.kind == "S" and info.dom.extent == 1: s_loops.append(info.loop_rv) else: - return None, None + return None, None, None, None loop_order = {} s_block_var_loops = [] @@ -281,7 +281,7 @@ def _sch_inner_spatial( # Schedule epilogue if epilogue_info is not None: epilogue = epilogue_info.block_rv - sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) + sch.reverse_compute_at(epilogue, bx) if is_broadcast_epilogue(sch, block, epilogue): sch.set_scope(block, 0, "shared") _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 65a2d081b58c..8f7069ef6282 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -613,7 +613,6 @@ class ReverseComputeInliner : public BaseInliner { bool BodyPatternAllowInline(const BlockRealize& consumer_block_realize) { const Block& consumer_block = consumer_block_realize->block; - LOG(INFO) << "BodyPatternAllowInline"; if (!is_one(consumer_block_realize->predicate)) { // Failure: Predicate is the consumer block is not supported