diff --git a/3rdparty/tvm b/3rdparty/tvm index 0a24d6597..511057718 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 0a24d6597641a389349b8985ff346150bdaf54e5 +Subproject commit 51105771898a7f40617547e928353536db336722 diff --git a/benchmark/dsl/convolution.py b/benchmark/dsl/convolution.py index 9bb9f4e48..519481c9a 100644 --- a/benchmark/dsl/convolution.py +++ b/benchmark/dsl/convolution.py @@ -43,8 +43,12 @@ def conv2d_nhwc_hwio(n, f, h, w, c, kh, kw, s, d, p, in_dtype="float16", out_dty C = te.compute( out_shape, lambda n, h, w, f: te.sum( - pad[n, h * stride_h + kh * dilation_h, w * stride_w + kw * dilation_w, c,] * B[kh, kw, - c, f], + pad[ + n, + h * stride_h + kh * dilation_h, + w * stride_w + kw * dilation_w, + c, + ] * B[kh, kw, c, f], axis=[kh, kw, c], ), name="C", diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 937a3c1c7..3074e3fcb 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -10,6 +10,7 @@ if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include" + os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl") sys.path.insert(0, install_tvm_path + "/python") develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") @@ -18,6 +19,7 @@ if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include" + os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl") sys.path.insert(0, develop_tvm_path + "/python") import tvm as tvm # noqa: E402 diff --git a/bitblas/base/roller/policy/default.py b/bitblas/base/roller/policy/default.py index e9f7b809f..9bda7fe6c 100644 --- a/bitblas/base/roller/policy/default.py +++ b/bitblas/base/roller/policy/default.py @@ -285,9 +285,7 @@ def _optimize(node, rstep): all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) def _score(rstep_id): - rstep = { - k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis - } + rstep = {k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis} score = 0 shape = node.propagate_inputs(td.get_tile(node), rstep=rstep) for i, input_buffer in enumerate(node.input_buffers): @@ -325,9 +323,7 @@ def _enlarge(rstep_id): break else: cur_rstep_id = new_rstep_id - rstep = { - k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis - } + rstep = {k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis} return rstep for node in self.ordered_nodes: diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index f5a0d1f24..722095657 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -211,9 +211,9 @@ def get_node_reduce_step_candidates(self, node): else: # must be a a multiple of wmma_k return { - k.var.name: - [x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)] - for k in node.raxis + k.var.name: [ + x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k) + ] for k in node.raxis } def check_tile_shape_isvalid(self, td: TileDict): diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index 46336e0c2..1a9ababd2 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -5,6 +5,7 @@ import ctypes import os import os.path as osp +import sys import tempfile import subprocess import logging @@ -47,8 +48,21 @@ def compile_lib(self, timeout: float = None, with_tl: bool = False): "-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}", ] + if with_tl: - tvm_root = osp.join(osp.dirname(__file__), "../../../3rdparty/tvm") + install_tvm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tvm") + develop_tvm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tvm") + + tvm_root = next((path for path in [install_tvm_path, develop_tvm_path] + if os.path.exists(path) and path not in sys.path), None) + + if "TL_TEMPLATE_PATH " in os.environ: + tl_template_path = os.environ["TL_TEMPLATE_PATH"] + else: + tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) + tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) if "TL_CUTLASS_PATH" in os.environ: cutlass_path = os.environ["TL_CUTLASS_PATH"] diff --git a/bitblas/common.py b/bitblas/common.py index 2a4576bc8..b2023f7b8 100644 --- a/bitblas/common.py +++ b/bitblas/common.py @@ -5,4 +5,4 @@ BITBLAS_DEFAULT_CACHE_PATH = os.path.expanduser("~/.cache/bitblas") -MAX_ERROR_MESSAGE_LENGTH = 100 +MAX_ERROR_MESSAGE_LENGTH = 200 diff --git a/bitblas/gpu/element_wise.py b/bitblas/gpu/element_wise.py index 3d67937e8..8e0e78545 100644 --- a/bitblas/gpu/element_wise.py +++ b/bitblas/gpu/element_wise.py @@ -45,15 +45,9 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring vector_factors = [1] * len(block_factors) vector_factors[-1] = vec_len - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block) - ] - ) - or len(sch.get_loops(block)) == 0 - ): + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): continue for loop, iter_type in zip(sch.get_loops(block), dom_kind): @@ -68,15 +62,11 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring thread_loops = [] inner_loops = [] for s_loop, block_factor, step_factor, thread_factor in zip( - s_loops, block_factors, step_factors, thread_factors - ): + s_loops, block_factors, step_factors, thread_factors): block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor]) vthread_loop, inner_loop = sch.split( - inner_loop, factors=[None, thread_factor * step_factor] - ) - thread_loop, inner_loop = sch.split( - inner_loop, factors=[None, step_factor] - ) + inner_loop, factors=[None, thread_factor * step_factor]) + thread_loop, inner_loop = sch.split(inner_loop, factors=[None, step_factor]) block_loops.append(block_loop) vthread_loops.append(vthread_loop) thread_loops.append(thread_loop) @@ -84,14 +74,8 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring # inner virtual thread first vthread_loops = list(reversed(vthread_loops)) - sch.reorder( - *block_loops, - *vthread_loops, - *thread_loops, - *inner_loops, - *r_loops, - *o_loops - ) + sch.reorder(*block_loops, *vthread_loops, *thread_loops, *inner_loops, *r_loops, + *o_loops) sch.bind(sch.fuse(*block_loops), "blockIdx.x") sch.bind(sch.fuse(*thread_loops), "threadIdx.x") if len(vthread_loops) > 3: @@ -99,10 +83,10 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring for i, ax in enumerate(vthread_loops): sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) - + # vectorize the last axis ax = inner_loops[-1] if sch.get(ax).extent.value > 1: sch.vectorize(ax) - + return sch diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index 466466ed9..280731657 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -1677,7 +1677,17 @@ def get_lop3_intrin_group( if is_ladder_stage3: key += "_offset" + if out_dtype == "float16": + d4f = "f16" + elif out_dtype == "int8": + d4f = "i8s" + else: + raise ValueError("Unsupported target dtype: {}".format(target_dtype)) + source_symbol = "u" if source_format == "uint" else "s" + func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f) + return { + "func_name": func_name, "c_source": import_c_map[key], "compute": _intrin, } diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 9ddc4500b..351dd3739 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -680,9 +680,8 @@ def check_last_trait(region: List[Range]): minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32 # the batch dimension is not taken into consideration. extent = block_stmt.iter_vars[1].dom.extent - if isinstance(extent, - tir.expr.IntImm) and (extent.value < - (1 if allow_gemv else minimal_tensorize_threshold)): + if isinstance(extent, tir.expr.IntImm) and (extent.value < (1 if allow_gemv else + minimal_tensorize_threshold)): return func, None for item_var in block_stmt.iter_vars[2:]: extent = item_var.dom.extent diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 9932e69fc..f5796f589 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -152,6 +152,20 @@ def get_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741 return get_index_map_3d(index_map, l, r) +def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append( + weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e4m3"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["bfloat16", "float16", "int8"]) + return all(conditions) + + class MatmulTensorizationMMAWithDequantizeInfo(GPUScheduleRule): """ The schedule rule for float16 tensor core matmul computation. @@ -212,19 +226,6 @@ def check_dequantize_info(dequantize_info): (weight_decode_info,) = list(dequantize_info.values()) - def check_weight_decode_info(weight_decode_info): - conditions = [] - # check source format in ["int", "fp", "nf"] - conditions.append("source_format" in weight_decode_info) - conditions.append(weight_decode_info["source_format"]["format"] in - ["uint", "int", "fp", "nf", "fp_e4m3"]) - # check source bits in [1, 2, 4, 8] - conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) - # check target format in ["float16", "int8"] - conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) - return all(conditions) - assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" # Start Schedule @@ -727,19 +728,6 @@ def check_dequantize_info(dequantize_info): (weight_decode_info,) = list(dequantize_info.values()) - def check_weight_decode_info(weight_decode_info): - conditions = [] - # check source format in ["int", "fp", "nf"] - conditions.append("source_format" in weight_decode_info) - conditions.append(weight_decode_info["source_format"]["format"] in - ["uint", "int", "fp", "nf", "fp_e4m3"]) - # check source bits in [1, 2, 4, 8] - conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) - # check target format in ["float16", "int8"] - conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) - return all(conditions) - assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" # Start Schedule @@ -1225,20 +1213,6 @@ def check_dequantize_info(dequantize_info): (weight_decode_info,) = list(dequantize_info.values()) - def check_weight_decode_info(weight_decode_info): - conditions = [] - # check source format in ["int", "fp", "nf"] - conditions.append("source_format" in weight_decode_info) - conditions.append(weight_decode_info["source_format"]["format"] in - ["uint", "int", "fp", "nf", "fp_e4m3"]) - # check source bits in [1, 2, 4, 8] - conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) - # check target format in ["float16", "int8"] - conditions.append("target_format" in weight_decode_info) - conditions.append( - weight_decode_info["target_format"] in ["bfloat16", "float16", "int8"]) - return all(conditions) - assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" # Start Schedule @@ -1820,19 +1794,6 @@ def check_dequantize_info(dequantize_info): (weight_decode_info,) = list(dequantize_info.values()) - def check_weight_decode_info(weight_decode_info): - conditions = [] - # check source format in ["int", "fp", "nf"] - conditions.append("source_format" in weight_decode_info) - conditions.append(weight_decode_info["source_format"]["format"] in - ["uint", "int", "fp", "nf", "fp_e4m3"]) - # check source bits in [1, 2, 4, 8] - conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) - # check target format in ["float16", "int8"] - conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) - return all(conditions) - assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" # Start Schedule diff --git a/bitblas/ops/base_scheduler.py b/bitblas/ops/base_scheduler.py index 19112486c..b296d1dde 100644 --- a/bitblas/ops/base_scheduler.py +++ b/bitblas/ops/base_scheduler.py @@ -1,12 +1,24 @@ from tvm import IRModule from tvm.tir import PrimFunc -from typing import Union +from typing import Union, Callable from dataclasses import dataclass, field from tvm.tir.transform import Simplify from abc import ABC, abstractmethod from bitblas.base.arch import TileDevice +# Decorator to simplify the output of a function +def maybe_simplify(self, func: Callable): + + def wrapper(*args, **kwargs): + stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs) + if self._enable_simplify: + return self.Simplify(stmt) + return stmt + + return wrapper + + @dataclass class BaseScheduler(ABC): diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 0c7d5be0f..e71b18971 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -12,6 +12,7 @@ from .tirscript.matmul_dequantize_impl import select_implementation as weight_dequantize_implementation from .tirscript.matmul_impl import select_implementation as consistent_implementation from .tilelang.dense import select_scheduler as consistent_scheduler +from .tilelang.dequantize import select_scheduler as weight_dequantize_scheduler from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from bitblas.utils.target_detector import auto_detect_nvidia_target from dataclasses import dataclass @@ -591,7 +592,26 @@ def _select_scheduler(self): propagate_b=self.propagate_b, ) else: - raise ValueError("Currently only support native compute for scheduler") + return weight_dequantize_scheduler( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.A_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + bit=self.bit, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + layout=self.layout, + zeros_mode=self.zeros_mode, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) def post_process(self, code: str) -> str: code = tensor_replace_dp4a(code) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 1a75ef54d..227de7ad3 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -317,6 +317,8 @@ def __repr__(self): return ("{" f"block_M={self.block_row_warps * self.warp_row_tiles}," f"block_N={self.block_col_warps * self.warp_col_tiles}," + f"warp_M={self.warp_row_tiles}," + f"warp_N={self.warp_col_tiles}," f"block_K={self.chunk}," f"threads={self.block_row_warps * self.block_col_warps * warp_size}," f"num_stages={self.num_stages}," @@ -512,9 +514,12 @@ def main( # Store results from shared memory to global memory for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, - i % micro_size_x, j % micro_size_y,] + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] return self.maybe_simplify(main) @@ -649,8 +654,12 @@ def main( micro_size_y, micro_size_k, ): - B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, - ko * (block_K // micro_size_k) + k, jj, kk,] + B_shared[j, k, jj, kk] = B[ + bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, + jj, + kk, + ] # Perform the matrix multiplication on tensor core fragments for ki in T.serial(0, (block_K // micro_size_k)): @@ -683,9 +692,12 @@ def main( # Store results from shared memory to global memory for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, - i % micro_size_x, j % micro_size_y,] + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] return self.maybe_simplify(main) @@ -864,9 +876,12 @@ def main( ) for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, - i % micro_size_x, j % micro_size_y,] + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] return main @@ -976,8 +991,12 @@ def main( micro_size_y, micro_size_k, ): - B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, - ko * (block_K // micro_size_k) + k, jj, kk,] + B_shared[j, k, jj, kk] = B[ + bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, + jj, + kk, + ] for ki in T.serial(0, (block_K // micro_size_k)): @@ -1006,8 +1025,11 @@ def main( ) for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, - i % micro_size_x, j % micro_size_y,] + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] return main diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py index 59e481eb9..bc13c9d4c 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py @@ -1,2 +1,102 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + +from .block_primitive_tensorcore import ( + MatmulDequantizeScheduler, # noqa: F401 +) + +from bitblas.ops.common import TransformKind +from typing import Union + + +def parse_layout(layout: str): + if len(layout) != 2 or layout[0] not in "nt" or layout[1] not in "nt": + raise ValueError(f"Invalid layout: {layout}") + + trans_A = layout[0] == 't' + trans_B = layout[1] == 't' + + return trans_A, trans_B + + +def is_non_transform_kind(kind) -> bool: + return kind == TransformKind.NonTransform + + +def select_scheduler( + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + zeros_mode="original", + propagate_a: Union[int, TransformKind] = TransformKind.NonTransform, + propagate_b: Union[int, TransformKind] = TransformKind.NonTransform, +): + ''' + Fine-grained Interface is preferred as it provides more flexibility + and can be used to implement high performance kernel. + ''' + if isinstance(propagate_a, int): + propagate_a = TransformKind(propagate_a) + if isinstance(propagate_b, int): + propagate_b = TransformKind(propagate_b) + if with_bias: + raise NotImplementedError + + trans_A, trans_B = parse_layout(layout) + + def can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b): + conditions = [] + conditions.append(trans_A is False) + conditions.append(trans_B is True) + conditions.append(propagate_a == TransformKind.NonTransform) + conditions.append(propagate_b == TransformKind.NonTransform) + return all(conditions) + + def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): + conditions = [] + conditions.append(trans_A is False) + conditions.append(trans_B is True) + conditions.append(propagate_a == TransformKind.NonTransform) + conditions.append(propagate_b == TransformKind.LDMatrixTransform) + return all(conditions) + + def can_apply_block_scheduler(propagate_a, propagate_b): + conditions = [] + conditions.append(propagate_a == TransformKind.NonTransform) + conditions.append(propagate_b == TransformKind.NonTransform) + return all(conditions) + + if can_apply_block_scheduler(propagate_a, propagate_b): + return MatmulDequantizeScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + with_bias=with_bias, + zeros_mode=zeros_mode, + ) + else: + raise ValueError(f"Unsupported configuration: {layout}, {propagate_a}, {propagate_b}") diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py new file mode 100644 index 000000000..7a06d6959 --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py @@ -0,0 +1,429 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T +from typing import Optional, List, Literal +from bitblas.ops.base_scheduler import BaseScheduler +from bitblas.base.arch import TileDevice +from bitblas.base.roller.hint import Hint +from bitblas.base.roller.rasterization import NoRasterization +from bitblas.base.utils import get_roller_hints_from_func +from dataclasses import dataclass +from bitblas.ops.general_matmul.tirscript import ( + matmul_dequantize_select_implementation,) +from bitblas.tl.base_hint import BaseTLHint +from bitblas.quantization import ( + _tir_packed_int_to_int_convert, + _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, + _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16, + _tir_packed_to_unsigned_convert_with_zeros, +) +from bitblas.gpu.intrin.lop3 import get_lop3_intrin_group + +# GPU warp configuration for NVIDIA GPUs +warp_size = 32 + + +@dataclass +class MatmulDequantizeScheduler(BaseScheduler): + + # OP Related Config + M: Optional[int] = None + N: Optional[int] = None + K: Optional[int] = None + trans_A: bool = False + trans_B: bool = False + in_dtype: str = "float16" + out_dtype: str = "float16" + accum_dtype: str = "float16" + + # Dequantize Config + num_bits: int = 4 + storage_dtype: str = "int8" + source_format: str = "uint" + with_scaling: bool = False + with_zeros: bool = False + group_size: int = -1 + fast_decoding: bool = False + with_bias: bool = False + zeros_mode: Literal["original", "rescale", "quantized"] = "original", + + # Default Tile Related Params + block_M: int = 128 + block_N: int = 128 + block_K: int = 32 + num_stages: int = 2 + threads: int = 128 + enable_rasterization: bool = False # Enhance L2 Locality + + class TLHint(BaseTLHint): + + def __init__(self): + super().__init__() + + @classmethod + def from_roller_hint(cls, hint: Hint): + tl_hint = cls() + for key, value in hint.__dict__.items(): + setattr(tl_hint, key, value) + + block = hint.block + warp = hint.warp + rstep = hint.rstep + num_stages = hint.pipeline_stage + rasterization_plan = hint.rasterization_plan + enable_rasterization = not isinstance(rasterization_plan, NoRasterization) + + block_row_warps = block[0] // warp[0] + block_col_warps = block[1] // warp[1] + warp_size = 32 # NVIDIA GPU warp size is 32 + if num_stages == 1: + num_stages = 0 # disable pipelining + + tl_hint.block_M = block[0] + tl_hint.block_N = block[1] + tl_hint.block_K = rstep[0] + tl_hint.num_stages = num_stages + tl_hint.threads = warp_size * block_row_warps * block_col_warps + tl_hint.enable_rasterization = enable_rasterization + + return tl_hint + + def get_config_params(self): + return { + "block_M": self.block_M, + "block_N": self.block_N, + "block_K": self.block_K, + "num_stages": self.num_stages, + "threads": self.threads, + "enable_rasterization": self.enable_rasterization, + } + + def __repr__(self): + return ("{" + f"block_M={self.block_M}," + f"block_N={self.block_N}," + f"block_K={self.block_K}," + f"num_stages={self.num_stages}," + f"threads={self.threads}," + f"enable_rasterization={self.enable_rasterization}" + "}") + + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + + # Simple TIR Compute Expression + ir_module = matmul_dequantize_select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + bit=self.num_bits, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + zeros_mode=self.zeros_mode) + + roller_hints = get_roller_hints_from_func( + ir_module["main"], + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + def serialze_hints_to_configs(hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs + + return serialze_hints_to_configs(roller_hints) + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + return self.get_roller_configs(arch, topk) + + def with_default_config(self): + block_M = getattr(self, "block_M", 64) + block_N = getattr(self, "block_N", 64) + block_K = getattr(self, "block_K", 32) + num_stages = getattr(self, "num_stages", 2) + threads = getattr(self, "threads", 128) + enable_rasterization = getattr(self, "enable_rasterization", False) + + return self.apply_config( + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + enable_rasterization=enable_rasterization, + ) + + def _apply_config_dequant_only( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + assert block_M is not None, "block_M is required" + assert block_N is not None, "block_N is required" + assert block_K is not None, "block_K is required" + assert num_stages is not None, "num_stages is required" + assert threads is not None, "threads is required" + M, N, K = self.M, self.N, self.K + trans_A, trans_B = self.trans_A, self.trans_B + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + + # check is dequantize only + + def check_is_dequantize_only(): + return not self.with_scaling + + if not check_is_dequantize_only(): + raise ValueError("Not a Dequantize Only Configuration") + + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + fast_decoding = self.fast_decoding + + num_bits = self.num_bits + storage_dtype = self.storage_dtype + source_format = self.source_format + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = 8 // num_bits + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + group_size = self.group_size + if group_size == -1: + group_size = K + + A_shape = (M, K) + B_shape = (N, K // storage_nbit * num_bits) + + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + import_source: Optional[str] = None + func_name: str = "" + if fast_decoding is True: + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=storage_dtype, + source_format=source_format, + source_bit=num_bits, + ) + import_source = lop3_intrin_info["c_source"] + func_name = lop3_intrin_info["func_name"] + assert import_source is not None, "lop3_intrin_info is not found" + assert func_name is not None, "lop3_intrin_info is not found" + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + tx = T.thread_binding(0, threads, thread="threadIdx.x") + + T.use_swizzle(10, enable=enable_rasterization) + + T.import_source(import_source) + + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = i * threads * local_size_compressed + tx * local_size_compressed + v + vi = index // (block_K // num_elems_per_byte) + vj = index % (block_K // num_elems_per_byte) + B_local[v] = B_shared[vi, vj] + + if fast_decoding is True: + T.call_extern( + func_name, + T.address_of(B_local[0]), + T.address_of(B_dequantize_local[0]), + dtype=in_dtype) + else: + for v in T.serial(0, local_size): + B_dequantize_local[v] = self._decode_func( + num_bits, + B_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + vi = index // block_K + vj = index % block_K + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + def _apply_config_with_scaling( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + raise NotImplementedError("Scaling Configuration is not implemented") + + def _apply_config_with_scaling_zeros_original_or_rescale( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + raise NotImplementedError("Scaling and Zeros Original Configuration is not implemented") + + def _apply_config_with_scaling_zeros_quantized( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + raise NotImplementedError("Scaling and Zeros Rescale Configuration is not implemented") + + def apply_config( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + assert block_M is not None, "block_M is required" + assert block_N is not None, "block_N is required" + assert block_K is not None, "block_K is required" + assert num_stages is not None, "num_stages is required" + assert threads is not None, "threads is required" + trans_A, trans_B = self.trans_A, self.trans_B + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + + with_scaling = self.with_scaling + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + + args = [block_M, block_N, block_K, num_stages, threads, enable_rasterization] + + dequant_prim_func = None + + if not with_scaling: + dequant_prim_func = self._apply_config_dequant_only(*args) + elif not with_zeros: + dequant_prim_func = self._apply_config_with_scaling(*args) + elif zeros_mode in ["original", "rescale"]: + dequant_prim_func = self._apply_config_with_scaling_zeros_original_or_rescale(*args) + elif zeros_mode == "quantized": + dequant_prim_func = self._apply_config_with_scaling_zeros_quantized(*args) + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + if dequant_prim_func is None: + raise ValueError("Unsupported Configuration") + + return self.maybe_simplify(dequant_prim_func) + + @property + def _decode_func(self): + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + storage_dtype = self.storage_dtype + + in_dtype = self.in_dtype + source_format = self.source_format + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + num_bits = self.num_bits + + dequant_func = None + + def naive_cast_dequant(x): + return x.astype(in_dtype) + + if with_zeros and zeros_mode == "quantized": + dequant_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit) + elif source_format == "uint": + if num_bits == 8: + # 8 num_bits does not need to be compressed + dequant_func = naive_cast_dequant + else: + dequant_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) + elif source_format == "int": + if num_bits == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + dequant_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit) + elif num_bits == 8: + # 8 num_bits does not need to be compressed + dequant_func = naive_cast_dequant + else: + dequant_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) + elif source_format == "fp": + dequant_func = _tir_u32_to_f4_to_f16 + elif source_format == "fp_e4m3": + dequant_func = _tir_u8_to_f8_e4m3_to_f16 + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + return dequant_func + + def __post_init__(self): + # Add Config Validation + return diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py new file mode 100644 index 000000000..c98474ec0 --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -0,0 +1,413 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T +from typing import Optional, List, Literal +from bitblas.tl.utils import ( + get_mma_micro_size, # noqa: F401 + make_swizzle_layout, # noqa: F401 +) + +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitter, # noqa: F401 + TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 +) +from bitblas.ops.common import TransformKind # noqa: F401 +from bitblas.ops.base_scheduler import BaseScheduler +from bitblas.base.arch import TileDevice +from bitblas.base.roller.hint import Hint +from bitblas.base.roller.rasterization import NoRasterization +from bitblas.base.utils import get_roller_hints_from_func +from dataclasses import dataclass +from bitblas.ops.general_matmul.tirscript import ( + matmul_dequantize_select_implementation,) +from bitblas.tl.base_hint import BaseTLHint +from bitblas.quantization import ( + _tir_packed_int_to_int_convert, + _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, + _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16, + _tir_packed_to_unsigned_convert_with_zeros, +) + +# GPU warp configuration for NVIDIA GPUs +warp_size = 32 + + +@dataclass +class MatmulDequantizeScheduler(BaseScheduler): + + # OP Related Config + M: Optional[int] = None + N: Optional[int] = None + K: Optional[int] = None + trans_A: bool = False + trans_B: bool = False + in_dtype: str = "float16" + out_dtype: str = "float16" + accum_dtype: str = "float16" + + # Dequantize Config + num_bits: int = 4 + storage_dtype: str = "int8" + source_format: str = "uint" + with_scaling: bool = False + with_zeros: bool = False + group_size: int = -1 + fast_decoding: bool = False + with_bias: bool = False + zeros_mode: Literal["original", "rescale", "quantized"] = "original", + + # Default Tile Related Params + block_M: int = 64 + block_N: int = 64 + block_K: int = 32 + num_stages: int = 2 + threads: int = 128 + enable_rasterization: bool = False # Enhance L2 Locality + + class TLHint(BaseTLHint): + + def __init__(self): + super().__init__() + + @classmethod + def from_roller_hint(cls, hint: Hint): + tl_hint = cls() + for key, value in hint.__dict__.items(): + setattr(tl_hint, key, value) + + block = hint.block + warp = hint.warp + rstep = hint.rstep + num_stages = hint.pipeline_stage + rasterization_plan = hint.rasterization_plan + enable_rasterization = not isinstance(rasterization_plan, NoRasterization) + + block_row_warps = block[0] // warp[0] + block_col_warps = block[1] // warp[1] + warp_size = 32 # NVIDIA GPU warp size is 32 + if num_stages == 1: + num_stages = 0 # disable pipelining + + tl_hint.block_M = block[0] + tl_hint.block_N = block[1] + tl_hint.block_K = rstep[0] + tl_hint.num_stages = num_stages + tl_hint.threads = warp_size * block_row_warps * block_col_warps + tl_hint.enable_rasterization = enable_rasterization + + return tl_hint + + def get_config_params(self): + return { + "block_M": self.block_M, + "block_N": self.block_N, + "block_K": self.block_K, + "num_stages": self.num_stages, + "threads": self.threads, + "enable_rasterization": self.enable_rasterization, + } + + def __repr__(self): + return ("{" + f"block_M={self.block_M}," + f"block_N={self.block_N}," + f"block_K={self.block_K}," + f"num_stages={self.num_stages}," + f"threads={self.threads}," + f"enable_rasterization={self.enable_rasterization}" + "}") + + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + + # Simple TIR Compute Expression + ir_module = matmul_dequantize_select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + bit=self.num_bits, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + zeros_mode=self.zeros_mode) + + roller_hints = get_roller_hints_from_func( + ir_module["main"], + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + def serialze_hints_to_configs(hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs + + return serialze_hints_to_configs(roller_hints) + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + return self.get_roller_configs(arch, topk) + + def with_default_config(self): + block_M = getattr(self, "block_M", 64) + block_N = getattr(self, "block_N", 64) + block_K = getattr(self, "block_K", 32) + num_stages = getattr(self, "num_stages", 2) + threads = getattr(self, "threads", 128) + enable_rasterization = getattr(self, "enable_rasterization", False) + + return self.apply_config( + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + enable_rasterization=enable_rasterization, + ) + + def _apply_config_dequant_only( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + assert block_M is not None, "block_M is required" + assert block_N is not None, "block_N is required" + assert block_K is not None, "block_K is required" + assert num_stages is not None, "num_stages is required" + assert threads is not None, "threads is required" + M, N, K = self.M, self.N, self.K + trans_A, trans_B = self.trans_A, self.trans_B + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + + # check is dequantize only + + def check_is_dequantize_only(): + return not self.with_scaling + + if not check_is_dequantize_only(): + raise ValueError("Not a Dequantize Only Configuration") + + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + + num_bits = self.num_bits + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = 8 // num_bits + + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + group_size = self.group_size + if group_size == -1: + group_size = K + + A_shape = (M, K) + B_shape = (N, K // storage_nbit * num_bits) + + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype) + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + tx = T.thread_binding(0, threads, thread="threadIdx.x") + + T.use_swizzle(10, enable=enable_rasterization) + + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = i * threads * local_size_compressed + tx * local_size_compressed + v + vi = index // (block_K // num_elems_per_byte) + vj = index % (block_K // num_elems_per_byte) + B_local[v] = B_shared[vi, vj] + for v in T.serial(0, local_size): + B_dequantize_local[v] = self._decode_func( + num_bits, + B_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + vi = index // block_K + vj = index % block_K + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + def _apply_config_with_scaling( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + raise NotImplementedError("Scaling Configuration is not implemented") + + def _apply_config_with_scaling_zeros_original_or_rescale( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + raise NotImplementedError("Scaling and Zeros Original Configuration is not implemented") + + def _apply_config_with_scaling_zeros_quantized( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + raise NotImplementedError("Scaling and Zeros Rescale Configuration is not implemented") + + def apply_config( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + assert block_M is not None, "block_M is required" + assert block_N is not None, "block_N is required" + assert block_K is not None, "block_K is required" + assert num_stages is not None, "num_stages is required" + assert threads is not None, "threads is required" + trans_A, trans_B = self.trans_A, self.trans_B + + assert trans_A is False, "Dequantize only implement for trans_A=False currently" + assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" + + with_scaling = self.with_scaling + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + + args = [block_M, block_N, block_K, num_stages, threads, enable_rasterization] + + dequant_prim_func = None + if not with_scaling: + dequant_prim_func = self._apply_config_dequant_only(*args) + + if not with_zeros: + dequant_prim_func = self._apply_config_with_scaling(*args) + + if zeros_mode in ["original", "rescale"]: + dequant_prim_func = self._apply_config_with_scaling_zeros_original_or_rescale(*args) + elif zeros_mode == "quantized": + dequant_prim_func = self._apply_config_with_scaling_zeros_quantized(*args) + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + if dequant_prim_func is None: + raise ValueError("Unsupported Configuration") + + return self.maybe_simplify(dequant_prim_func) + + @property + def _decode_func(self): + with_zeros = self.with_zeros + zeros_mode = self.zeros_mode + storage_dtype = self.storage_dtype + + in_dtype = self.in_dtype + source_format = self.source_format + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + bit = self.bit + + dequant_func = None + + def naive_cast_dequant(x): + return x.astype(in_dtype) + + if with_zeros and zeros_mode == "quantized": + dequant_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit) + elif source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + dequant_func = naive_cast_dequant + else: + dequant_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + dequant_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit) + elif bit == 8: + # 8 bit does not need to be compressed + dequant_func = naive_cast_dequant + else: + dequant_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) + elif source_format == "fp": + dequant_func = _tir_u32_to_f4_to_f16 + elif source_format == "fp_e4m3": + dequant_func = _tir_u8_to_f8_e4m3_to_f16 + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + return dequant_func + + def __post_init__(self): + # Add Config Validation + return diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index 938a821ce..d928c451d 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -196,7 +196,7 @@ def tvm_callback_cuda_postproc(code, _): "tir.disable_cse_tir": True, **(self.pass_context if self.pass_context else {}) }): - rt_mod, _ = tl.lower(tl_prim_func, target=target) + rt_mod = tl.lower(tl_prim_func, target=target, runtime_only=True) else: raise ValueError(f"Unsupported backend: {self.backend}") except Exception as build_runtime_error: # noqa: F841 diff --git a/bitblas/quantization/__init__.py b/bitblas/quantization/__init__.py index d29cb679a..48059c8bd 100644 --- a/bitblas/quantization/__init__.py +++ b/bitblas/quantization/__init__.py @@ -9,4 +9,8 @@ _tir_packed_to_unsigned_convert_with_zeros, # noqa: F401 ) -from .utils import gen_quant4, general_compress # noqa: F401 +from .utils import ( + gen_quant4, # noqa: F401 + general_compress, # noqa: F401 + interleave_weight, # noqa: F401 +) diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index f3db7d88a..0f7adb791 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -117,8 +117,10 @@ def _warp_ldmatrix_a( ".b16", A_local_buf.data, i * inst.local_size_a, - T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x, - rk * inst.chunk + ki * inst.micro_size_k,]), + T.address_of(A_shared_buf[ + ty * inst.warp_row_tiles + i * inst.micro_size_x, + rk * inst.chunk + ki * inst.micro_size_k, + ]), get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, inst.a_transposed), ) diff --git a/integration/BitNet/vllm_workspace/conftest.py b/integration/BitNet/vllm_workspace/conftest.py index fd5e162af..c99f334cb 100644 --- a/integration/BitNet/vllm_workspace/conftest.py +++ b/integration/BitNet/vllm_workspace/conftest.py @@ -61,12 +61,10 @@ class _ImageAssetsBase(UserList[ImageAsset]): class _ImageAssets(_ImageAssetsBase): def __init__(self) -> None: - super().__init__( - [ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), - ] - ) + super().__init__([ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ]) def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: """ @@ -173,8 +171,7 @@ def __init__( SentenceTransformer( model_name, device="cpu", - ).to(dtype=torch_dtype) - ) + ).to(dtype=torch_dtype)) else: if is_vision_model: auto_cls = AutoModelForVision2Seq @@ -192,8 +189,7 @@ def __init__( torch_dtype=torch_dtype, trust_remote_code=True, **model_kwargs, - ) - ) + )) self.tokenizer = AutoTokenizer.from_pretrained( model_name, @@ -268,9 +264,7 @@ def generate_greedy( **kwargs, ) - return [ - (output_ids[0], output_str[0]) for output_ids, output_str in outputs - ] + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] def generate_beam_search( self, @@ -288,9 +282,7 @@ def generate_beam_search( for i in range(len(outputs)): output_ids, output_str = outputs[i] for j in range(len(output_ids)): - output_ids[j] = [ - x for x in output_ids[j] if x != self.tokenizer.pad_token_id - ] + output_ids[j] = [x for x in output_ids[j] if x != self.tokenizer.pad_token_id] outputs[i] = (output_ids, output_str) return outputs @@ -329,9 +321,7 @@ def generate_greedy_logprobs( self.model.get_output_embeddings().weight.t(), ) if self.model.get_output_embeddings().bias is not None: - logits += self.model.get_output_embeddings().bias.unsqueeze( - 0 - ) + logits += self.model.get_output_embeddings().bias.unsqueeze(0) logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) all_logprobs.append(seq_logprobs) @@ -377,13 +367,8 @@ def generate_greedy_logprobs_limit( last_hidden_states, self.model.get_output_embeddings().weight.t(), ) - if ( - getattr(self.model.get_output_embeddings(), "bias", None) - is not None - ): - logits += self.model.get_output_embeddings().bias.unsqueeze( - 0 - ) + if (getattr(self.model.get_output_embeddings(), "bias", None) is not None): + logits += self.model.get_output_embeddings().bias.unsqueeze(0) logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) @@ -409,10 +394,8 @@ def generate_greedy_logprobs_limit( all_output_strs.append(self.tokenizer.decode(output_ids)) outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [ - (output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs - ] + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) @@ -477,9 +460,7 @@ def generate( for i, image in enumerate(images): inputs[i]["multi_modal_data"] = {"image": image} - req_outputs = self.model.generate( - inputs, sampling_params=sampling_params - ) + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) outputs: List[Tuple[List[List[int]], List[str]]] = [] for req_output in req_outputs: @@ -511,9 +492,7 @@ def generate_w_logprobs( for i, image in enumerate(images): inputs[i]["multi_modal_data"] = {"image": image} - req_outputs = self.model.generate( - inputs, sampling_params=sampling_params - ) + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] for req_output in req_outputs: for sample in req_output.outputs: @@ -531,9 +510,7 @@ def generate_greedy( ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images) - return [ - (output_ids[0], output_str[0]) for output_ids, output_str in outputs - ] + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] def generate_greedy_logprobs( self, @@ -543,16 +520,11 @@ def generate_greedy_logprobs( images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams( - temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs - ) - outputs = self.generate_w_logprobs( - prompts, greedy_logprobs_params, images=images - ) + temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) + outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images) - return [ - (output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs - ] + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] def generate_beam_search( self, @@ -594,9 +566,7 @@ def get_tokenizer_pool_config(tokenizer_group_type): if tokenizer_group_type is None: return None if tokenizer_group_type == "ray": - return TokenizerPoolConfig( - pool_size=1, pool_type="ray", extra_config={} - ) + return TokenizerPoolConfig(pool_size=1, pool_type="ray", extra_config={}) raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") diff --git a/integration/BitNet/vllm_workspace/utils.py b/integration/BitNet/vllm_workspace/utils.py index 0d5e304d8..32877113a 100644 --- a/integration/BitNet/vllm_workspace/utils.py +++ b/integration/BitNet/vllm_workspace/utils.py @@ -3,18 +3,15 @@ TokensText = Tuple[List[int], str] -def check_outputs_equal(outputs_0_lst: List[TokensText], - outputs_1_lst: List[TokensText], name_0: str, - name_1: str): +def check_outputs_equal(outputs_0_lst: List[TokensText], outputs_1_lst: List[TokensText], + name_0: str, name_1: str): """ Compare the two sequences generated by different models, which should be equal. """ assert len(outputs_0_lst) == len(outputs_1_lst) - for prompt_idx, (outputs_0, - outputs_1) in enumerate(zip(outputs_0_lst, - outputs_1_lst)): + for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): output_ids_0, output_str_0 = outputs_0 output_ids_1, output_str_1 = outputs_1 @@ -30,8 +27,7 @@ def check_outputs_equal(outputs_0_lst: List[TokensText], def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], - outputs_1_lst: List[TokensTextLogprobs], name_0: str, - name_1: str): + outputs_1_lst: List[TokensTextLogprobs], name_0: str, name_1: str): """ Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. @@ -39,27 +35,22 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs], assert len(outputs_0_lst) == len(outputs_1_lst) # Loop through responses to each prompt. - for prompt_idx, (outputs_0, - outputs_1) in enumerate(zip(outputs_0_lst, - outputs_1_lst)): + for prompt_idx, (outputs_0, outputs_1) in enumerate(zip(outputs_0_lst, outputs_1_lst)): output_ids_0, output_str_0, logprobs_0 = outputs_0 output_ids_1, output_str_1, logprobs_1 = outputs_1 # Loop through generated tokens. - for idx, (output_id_0, - output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): + for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): # If generated tokens don't match, then if output_id_0 != output_id_1: # Each predicted token must be in top N logprobs of the other - assert output_id_0 in logprobs_1[idx], ( - f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") - assert output_id_1 in logprobs_0[idx], ( - f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + assert output_id_0 in logprobs_1[idx], (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + assert output_id_1 in logprobs_0[idx], (f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") # Break out since sequences will now diverge. break diff --git a/requirements-dev.txt b/requirements-dev.txt index 99c101afb..0b09c0856 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ # formatting -yapf==0.32.0 +yapf==0.40.2 toml==0.10.2 tomli==2.0.1 ruff==0.1.5 diff --git a/requirements-test.txt b/requirements-test.txt index 194cb1ba8..13fd3d1af 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,5 +1,5 @@ # formatting -yapf==0.32.0 +yapf==0.40.2 toml==0.10.2 tomli==2.0.1 ruff==0.1.5 diff --git a/testing/python/operators/test_general_matmul_ops_backend_tl.py b/testing/python/operators/test_general_matmul_ops_backend_tl.py index f9b20c5ef..3e9d55530 100644 --- a/testing/python/operators/test_general_matmul_ops_backend_tl.py +++ b/testing/python/operators/test_general_matmul_ops_backend_tl.py @@ -75,6 +75,156 @@ def matmul_finetune(M, assert get_codegen_result(matmul) +def matmul_torch_forward(M, + N, + K, + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + propagate_b=None): + import torch + torch.random.manual_seed(0) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + propagate_a=False, + propagate_b=propagate_b, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") + + assert layout == "nt", "Only support nt layout" + + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, A_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, W_dtype)) + + LB = matmul.transform_weight(B) + bitblas_output = matmul(A, LB) + ref_output = torch.matmul(A, B.T).to(getattr(torch, out_dtype)) + torch.testing.assert_close(bitblas_output, ref_output, rtol=1e-1, atol=1e-1) + + +def matmul_torch_forward_dequant(M, + N, + K, + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, + propagate_b=None): + import torch + torch.random.manual_seed(0) + import numpy as np + from bitblas.quantization import general_compress + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + propagate_a=False, + propagate_b=propagate_b, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") + + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + source_format, bit = matmul.BITBLAS_TRICK_DTYPE_MAP[W_dtype] + maxq = 2**(bit - 1) + zeros = maxq + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + + intweight = inputs[1] + intweight = intweight.cpu().to(torch.int8) + if source_format == "int": + intweight = intweight + maxq + if with_zeros: + inputs[1] = inputs[1] - zeros + bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() + ref_result = torch.matmul(inputs[0], + (inputs[1].t() if layout == "nt" else inputs[1]).to(torch.float16)) + if with_bias: + ref_result = ref_result + bias + permuted_inputs = [] + permuted_inputs.append(inputs[0]) + if matmul.weight_transform is not None: + permuted_inputs.append(matmul.weight_transform(intweight.cpu()).cuda()) + else: + permuted_inputs.append(intweight) + if with_scaling: + if group_size == -1: + group_size = K + permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + if with_zeros: + if zeros_mode == "original": + permuted_inputs.append( + torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) + elif zeros_mode == "rescale": + original_zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros + scaled_zeros = original_zeros * permuted_inputs[-1] + permuted_inputs.append(scaled_zeros) + elif zeros_mode == "quantized": + original_zeros = torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros + qzeros = general_compress( + original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + permuted_inputs.append(torch.from_numpy(qzeros).cuda()) + else: + raise NotImplementedError + if with_bias: + permuted_inputs.append(bias) + permuted_inputs.append(inputs[2]) + matmul(*permuted_inputs[:-1], output=permuted_inputs[-1]) + print(permuted_inputs[-1]) + print(ref_result) + if zeros_mode == "rescale": + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) + else: + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) + + def test_matmul_codegen_default(): matmul_codegen_default(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None), @@ -95,6 +245,18 @@ def test_matmul_finetune(): False, False, None, False) +def test_matmul_torch_forward(): + matmul_torch_forward(1024, 1024, 1024, "float16", "float16", "float16", "float16", "nt", None, + None, None, None, None, False) + matmul_torch_forward(1024, 1024, 1024, "float16", "float16", "float16", "float16", "nt", None, + None, None, None, None, True) + + +def test_matmul_torch_dequant_forward(): + matmul_torch_forward_dequant(1024, 1024, 1024, "float16", "int4", "float16", "float16", "nt", + None, None, None, None, None, False) + + # fmt: on if __name__ == "__main__": bitblas.testing.main() diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 9308a9428..349a69752 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -10,6 +10,9 @@ MatmulWeightPropagationScheduler, ) +from bitblas.ops.general_matmul.tilelang.dequantize import ( + MatmulDequantizeScheduler,) + import torch import torch.backends @@ -416,6 +419,116 @@ def assert_matmul_weight_propagation_apply_config_correctness( torch.testing.assert_close(C, ref_c, rtol=1e0, atol=1e0) +def assert_matmul_blocked_dequant_with_default_correctness( + M, + N, + K, + trans_A=False, + trans_B=True, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + zeros_mode="original", +): + import numpy as np + from bitblas.quantization import general_compress, interleave_weight + matmul = MatmulDequantizeScheduler( + M=M, + N=N, + K=K, + trans_A=trans_A, + trans_B=trans_B, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + num_bits=bit, + storage_dtype=storage_dtype, + source_format=source_format, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + fast_decoding=fast_decoding, + zeros_mode=zeros_mode, + ).with_default_config() + + mod, params = tl.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + + input_shape = (M, K) + weight_shape = (N, K) + output_shape = (M, N) + inputs = [] + inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) + maxq = 2**(bit - 1) + zeros = maxq + if source_format == "uint": + inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) + elif source_format == "int": + inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) + else: + raise NotImplementedError + + inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) + + intweight = inputs[1] + intweight = intweight.cpu().to(torch.int8) + if source_format == "int": + intweight = intweight + maxq + if with_zeros: + inputs[1] = inputs[1] - zeros + + ref_result = torch.matmul(inputs[0], inputs[1].t().to(torch.float16)) + + permuted_inputs = [] + permuted_inputs.append(inputs[0]) + qw = general_compress(intweight.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + # lop3 transformation + if fast_decoding: + qw = interleave_weight(qw, bit, target_dtype=in_dtype) + permuted_inputs.append(torch.from_numpy(qw).cuda()) + if with_scaling: + if group_size == -1: + group_size = K + permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) + if with_zeros: + if zeros_mode == "original": + permuted_inputs.append( + torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) + elif zeros_mode == "rescale": + original_zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros + scaled_zeros = original_zeros * permuted_inputs[-1] + permuted_inputs.append(scaled_zeros) + elif zeros_mode == "quantized": + original_zeros = torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros + qzeros = general_compress( + original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) + permuted_inputs.append(torch.from_numpy(qzeros).cuda()) + else: + raise NotImplementedError + + permuted_inputs.append(inputs[2]) + + mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + + mod(*permuted_inputs) + + print(permuted_inputs[-1]) + print(ref_result) + if zeros_mode == "rescale": + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) + else: + torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) + + def test_matmul_blocked(): # Default assert_matmul_blocked_with_default_correctness(1024, 1024, 1024) @@ -447,5 +560,12 @@ def test_matmul_weight_propagation(): 1024, 1024, 1024, enable_rasterization=True) +def test_matmul_blocked_dequant_with_default(): + assert_matmul_blocked_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=4) + assert_matmul_blocked_dequant_with_default_correctness( + 1024, 1024, 1024, source_format="uint", bit=2) + + if __name__ == "__main__": bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index 1f9f44ab5..006b0665a 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -8,7 +8,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.quantization import _tir_packed_to_unsigned_convert -from bitblas.tl.utils import get_swizzle_layout +from bitblas.tl.utils import (make_swizzle_layout) from bitblas.tl.macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform,) @@ -17,21 +17,6 @@ torch.manual_seed(0) -def make_swizzle_layout(shared_buf): - dtype = shared_buf.dtype - shape = shared_buf.shape - - can_swizzle = shape[-1] * DataType(dtype).bits == 512 - if not can_swizzle: - return T.Layout(shape, lambda *args: args) - - def transform_func(i, j): - new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) - return [new_warp_i, new_warp_j] - - return T.Layout(shape, transform_func) - - def matmul( M, N, @@ -48,11 +33,16 @@ def matmul( ): num_elems_per_byte = 8 // num_bits storage_dtype = "int8" + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) A_shape = (M, K) B_shape = (N, K // num_elems_per_byte) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte import tvm.tl.language as T @@ -65,8 +55,8 @@ def main( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) - B_local = T.alloc_local([8], storage_dtype) - B_dequantize_local = T.alloc_local([16], in_dtype) + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) @@ -75,27 +65,31 @@ def main( T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) - - for i, j in T.Parallel(block_N, block_K // num_elems_per_byte): - B_shared[i, j] = B[bx * block_N + i, k * block_K // num_elems_per_byte + j] - - for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)): - for v in T.vectorized(0, 4): - vi = (i * threads * 4 + tx * 4 + v) // (block_K // num_elems_per_byte) - vj = (i * threads * 4 + tx * 4 + v) % (block_K // num_elems_per_byte) + T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = i * threads * local_size_compressed + tx * local_size_compressed + v + vi = index // (block_K // num_elems_per_byte) + vj = index % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] - for v in T.serial(0, 8): - B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)( - num_bits, - B_local[v // 2], - v % 2, - dtype=in_dtype, - ) - for v in T.vectorized(0, 8): - vi = (i * threads * 8 + tx * 8 + v) // (block_K) - vj = (i * threads * 8 + tx * 8 + v) % (block_K) + for v in T.serial(0, local_size): + B_dequantize_local[v] = _tir_packed_to_unsigned_convert( + storage_type, storage_nbit)( + num_bits, + B_local[v // num_elems_per_byte], + v % num_elems_per_byte, + dtype=in_dtype, + ) + for v in T.vectorized(0, local_size): + index = i * threads * local_size + tx * local_size + v + vi = index // block_K + vj = index % block_K B_dequantize_shared[vi, vj] = B_dequantize_local[v] + T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) + T.copy(C_local, C[by * block_M, bx * block_N]) return main @@ -433,6 +427,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct def test_run_dequantize_gemm(): + run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128) run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128) diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index 4d7be551b..d0587ebef 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -165,9 +165,12 @@ def main( # Store shared into global for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, - i % micro_size_x, j % micro_size_y,] + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] return main diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_macro_gemm.py index 9ef592d2d..4d1318960 100644 --- a/testing/python/tilelang/test_tilelang_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_macro_gemm.py @@ -169,9 +169,12 @@ def main( # Store shared into global for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, - bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y, - i % micro_size_x, j % micro_size_y,] + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] return main