From b165b58dd7678d9fa7a29415f946cd31a9af06fb Mon Sep 17 00:00:00 2001 From: LeiWang199 Date: Mon, 20 May 2024 14:48:07 +0000 Subject: [PATCH] Refactor quantization module to support new float8 formats --- python/bitblas/gpu/gemv_dequantize.py | 8 +- python/bitblas/gpu/matmul_mma_dequantize.py | 12 +-- python/bitblas/ops/general_matmul.py | 81 ++++++------------- .../ops/impl/matmul_dequantize_impl.py | 9 +-- python/bitblas/quantization/quantization.py | 11 +-- .../operators/test_general_matmul_fp8.py | 81 +++++++++++-------- 6 files changed, 90 insertions(+), 112 deletions(-) diff --git a/python/bitblas/gpu/gemv_dequantize.py b/python/bitblas/gpu/gemv_dequantize.py index db1898f9ea17..47e4bf42cd0a 100644 --- a/python/bitblas/gpu/gemv_dequantize.py +++ b/python/bitblas/gpu/gemv_dequantize.py @@ -49,8 +49,8 @@ def check_weight_decode_info(weight_decode_info): conditions = [] # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) - conditions.append( - weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"]) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"]) # check source bits in [1, 2, 4, 8] conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] @@ -213,8 +213,8 @@ def check_weight_decode_info(weight_decode_info): conditions = [] # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) - conditions.append( - weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"]) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"]) # check source bits in [1, 2, 4, 8] conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] diff --git a/python/bitblas/gpu/matmul_mma_dequantize.py b/python/bitblas/gpu/matmul_mma_dequantize.py index db2f78334bf5..96461db453e0 100644 --- a/python/bitblas/gpu/matmul_mma_dequantize.py +++ b/python/bitblas/gpu/matmul_mma_dequantize.py @@ -126,8 +126,8 @@ def check_weight_decode_info(weight_decode_info): conditions = [] # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) - conditions.append( - weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e4m3"]) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) # check source bits in [1, 2, 4, 8] conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] @@ -633,8 +633,8 @@ def check_weight_decode_info(weight_decode_info): conditions = [] # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) - conditions.append( - weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e4m3"]) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) # check source bits in [1, 2, 4, 8] conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] @@ -1123,8 +1123,8 @@ def check_weight_decode_info(weight_decode_info): conditions = [] # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) - conditions.append( - weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e4m3"]) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) # check source bits in [1, 2, 4, 8] conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index 6212bf483ead..35eee1fb6e6b 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -8,8 +8,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union from .operator import Operator, TransformKind from .impl.matmul_dequantize_impl import ( - select_implementation as weight_dequantize_implementation, -) + select_implementation as weight_dequantize_implementation,) from .impl.matmul_impl import select_implementation as consistent_implementation from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4 from bitblas.utils.target_detector import auto_detect_nvidia_target @@ -110,36 +109,23 @@ def __legalize_dynamic_symbolic(self, M): def __legalize_propagate(self, propagate): if isinstance(propagate, bool): - return ( - TransformKind.IntraWarpTransform - if propagate - else TransformKind.NonTransform - ) + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) elif isinstance(propagate, int): return TransformKind(propagate) return propagate - def __initialize_propagate( - self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind] - ): + def __initialize_propagate(self, propagate_a: Optional[TransformKind], + propagate_b: Optional[TransformKind]): MICRO_KERNEL_SIZE = 16 - if ( - isinstance(self.M, int) - and (self.M % MICRO_KERNEL_SIZE) == 0 - and (self.K % MICRO_KERNEL_SIZE) == 0 - ): + if (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and + (self.K % MICRO_KERNEL_SIZE) == 0): object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform) else: object.__setattr__(self, "propagate_a", TransformKind.NonTransform) - if ( - self.M == 1 - or (self.N % MICRO_KERNEL_SIZE) != 0 - or (self.K % MICRO_KERNEL_SIZE) != 0 - or isinstance(self.M, Tuple) - or (self.with_zeros and self.zeros_mode == "quantized") - ): + if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or + isinstance(self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized")): object.__setattr__(self, "propagate_a", TransformKind.NonTransform) object.__setattr__(self, "propagate_b", TransformKind.NonTransform) else: @@ -164,10 +150,7 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]): def __initialize_fast_decoding(self, fast_decoding: Optional[bool]): if fast_decoding is not None: object.__setattr__(self, "fast_decoding", fast_decoding) - elif ( - "int" not in self.W_dtype - or self.W_dtype == self.A_dtype - ): + elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype): object.__setattr__(self, "fast_decoding", False) else: object.__setattr__(self, "fast_decoding", True) @@ -186,12 +169,8 @@ def __post_init__(self): object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M)) # set propagate_a and propagate_b to default value if it is None - object.__setattr__( - self, "propagate_a", self.__legalize_propagate(self.propagate_a) - ) - object.__setattr__( - self, "propagate_b", self.__legalize_propagate(self.propagate_b) - ) + object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a)) + object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b)) # This is hack to legalize propagate_a and b # TODO(lei): should be removed in the future when tc+br template is ready. @@ -214,10 +193,10 @@ def __post_init__(self): object.__setattr__(self, "with_zeros", False) if self.A_dtype == self.W_dtype and self.W_dtype in [ - "float16", - "int8", - "e4m3_float8", - "e5m2_float8", + "float16", + "int8", + "e4m3_float8", + "e5m2_float8", ]: object.__setattr__(self, "storage_dtype", self.W_dtype) @@ -260,9 +239,8 @@ def __init__( if target is None: target = auto_detect_nvidia_target() logger.info(f"Auto detected target: {target}") - assert ( - config.A_dtype in self.BITBLAS_TRICK_DTYPE_MAP - ), f"Unsupported input dtype {config.A_dtype}" + assert (config.A_dtype + in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}" source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype] self.source_format = source_format @@ -283,8 +261,7 @@ def __init__( if isinstance(self.M, Tuple): self.dynamic_range = {"m": self.M} self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( - {"opt_shapes": self.dynamic_range} - ) + {"opt_shapes": self.dynamic_range}) else: self.dynamic_range = None @@ -393,9 +370,7 @@ def __init__( def _build_default_module(self, target: Target): try: - self.optimized_func = self.apply_default_schedule( - self.prim_func_mod, target - ) + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) except Exception: self.optimized_func = None logger.warning( @@ -446,9 +421,7 @@ def post_process(self, code: str) -> str: return code def retrieve_weight_shape(self): - return [ - int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape - ] + return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] def transform_weight(self, weight, scale=None, zeros=None, bias=None): """ @@ -480,7 +453,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): if source_format == "int": assert not self.with_scaling, "scale should be False for int source format" assert not self.with_zeros, "zeros should be False for int source format" - maxq = 2 ** (bit - 1) + maxq = 2**(bit - 1) # Clamp weight values to be within the quantizable range and adjust weight = torch.clamp(weight, -maxq, maxq).int() + maxq elif source_format in ["fp_e5m2", "fp_e4m3"]: @@ -493,8 +466,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): np_storage_dtype = getattr(np, self.storage_dtype) weight = general_compress( - weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype - ) + weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) weight = torch.from_numpy(weight).cuda().contiguous() @@ -520,9 +492,7 @@ def transform_input(self, input_tensor): raise ValueError( f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size." ) - self.ladder_permutate_a._forward_from_prebuild_lib( - input_tensor, self.workspace - ) + self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace) return self.workspace return input_tensor @@ -530,14 +500,13 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args = [] args.append(self.transform_input(A)) args.append(W) - + if self.lut is not None: args.append(self.lut) if output is None: output = torch.empty( - A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device - ) + A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) if scale is not None: args.append(scale) if zeros is not None: diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py index 66573e68abe4..6e6b098c19a1 100644 --- a/python/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py @@ -92,8 +92,7 @@ def decode_func(n, k): w = _tir_u32_to_f4_to_f16( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) elif source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16( - bit, B[n, k], dtype=in_dtype) + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) elif source_format == "nf": w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( bit, @@ -268,8 +267,7 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16( - bit, B_reindex[n, k], dtype=in_dtype) + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) elif source_format == "nf": w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( bit, @@ -457,8 +455,7 @@ def decode_func(n, k): dtype=in_dtype, ) elif source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16( - bit, B_reindex[n, k], dtype=in_dtype) + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) elif source_format == "nf": w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( bit, diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py index 00ddb7af409e..d9f36094794d 100644 --- a/python/bitblas/quantization/quantization.py +++ b/python/bitblas/quantization/quantization.py @@ -138,23 +138,22 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32")) return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) + def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): assert nbit == 8 assert dtype == "float16" s_f16 = (val >> tir.const(7, "int16")) << tir.const(15, "int16") - offset = tir.Select( - s_f16 == 0, - tir.const(8192, "int16"), - tir.const(-8192, "int16") - ) + offset = tir.Select(s_f16 == 0, tir.const(8192, "int16"), tir.const(-8192, "int16")) e_f16 = ((val << tir.const(7, "int16")) + offset) return tir.reinterpret("float16", s_f16 | e_f16) + def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): assert nbit == 8 assert dtype == "float16" return tir.reinterpret("e5m2_float8", val).astype("float16") + def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): storage_dtype = storage_type + str(storage_nbit) @@ -189,6 +188,7 @@ def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm return f_convert + def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8): storage_dtype = storage_type + str(storage_nbit) @@ -201,4 +201,5 @@ def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): return f_convert + # fmt: on diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 3b9dd6bfa78a..3d0a7be2f583 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -5,21 +5,22 @@ from bitblas import MatmulConfig, Matmul import logging from bitblas import set_log_level - + set_log_level(logging.DEBUG) + # TODO(lei): should add requirements for cuda and sm version @pytest.mark.parametrize( "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", [ - (1, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, None, - None), - (1024, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, None, - None), - (1, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, None, - None), - (1024, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, None, - None), + (1, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, + None, None), + (1024, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, + None, None, None), + (1, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, + None, None), + (1024, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, + None, None, None), ], ) def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, @@ -46,9 +47,9 @@ def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) - + def map_torch_type(intype): - + typemap = { 'e4m3_float8': torch.float8_e4m3fn, 'e5m2_float8': torch.float8_e5m2, @@ -61,29 +62,37 @@ def map_torch_type(intype): numpytype_a = map_torch_type(A_dtype) numpytype_b = map_torch_type(W_dtype) numpytype_c = map_torch_type(out_dtype) - - torch_a = torch.rand(M*K).uniform_(-5, 5).reshape(input_shape).type(numpytype_a).cuda() - torch_b = torch.rand(N*K).uniform_(-5, 5).reshape(weight_shape).type(numpytype_b).cuda() - ref_out = torch.matmul(torch_a.to(torch.float32), torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul(torch_a.to(torch.float32), torch_b.to(torch.float32)) + + torch_a = torch.rand(M * K).uniform_(-5, 5).reshape(input_shape).type(numpytype_a).cuda() + torch_b = torch.rand(N * K).uniform_(-5, 5).reshape(weight_shape).type(numpytype_b).cuda() + ref_out = torch.matmul(torch_a.to(torch.float32), + torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul( + torch_a.to(torch.float32), torch_b.to(torch.float32)) ref_out = ref_out.to(numpytype_c) - + print("torch_ref_out", ref_out) new_torch_b = matmul.transform_weight(torch_b) bitblas_out = matmul(torch_a, new_torch_b) print("bitblas_out", bitblas_out) - + + # TODO(lei): should add requirements for cuda and sm version @pytest.mark.parametrize( "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", [ - (1, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, None, None), - (1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, None, None), - (1, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, 32, True, None, None), - (1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, 32, True, None, None), + (1, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, + None, None), + (1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, + None, None), + (1, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, 32, True, None, + None), + (1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, 32, True, + None, None), ], ) -def test_matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, - group_size, with_scaling, with_zeros, zeros_mode): +def test_matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, + layout, with_bias, group_size, with_scaling, + with_zeros, zeros_mode): import torch torch.random.manual_seed(0) @@ -108,9 +117,9 @@ def test_matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum print(matmul.src_name) input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) - + def map_torch_type(intype): - + typemap = { 'e4m3_float8': torch.float8_e4m3fn, 'e5m2_float8': torch.float8_e5m2, @@ -123,10 +132,10 @@ def map_torch_type(intype): numpytype_a = map_torch_type(A_dtype) numpytype_b = map_torch_type(W_dtype) numpytype_c = map_torch_type(out_dtype) - - torch_a = torch.rand(M*K).uniform_(-1, 1).reshape(input_shape).type(numpytype_a).cuda() - torch_b = torch.rand(N*K).uniform_(-1, 1).reshape(weight_shape).type(numpytype_b).cuda() - + + torch_a = torch.rand(M * K).uniform_(-1, 1).reshape(input_shape).type(numpytype_a).cuda() + torch_b = torch.rand(N * K).uniform_(-1, 1).reshape(weight_shape).type(numpytype_b).cuda() + torch_fp16_a = torch_a.to(torch.float16) torch_fp16_b = torch_b.t().to(torch.float16) if layout == "nt" else torch_b.to(torch.float16) scale_tensor = None @@ -135,29 +144,31 @@ def map_torch_type(intype): group_size = -1 if group_size == -1: group_size = K - scale_tensor = torch.rand(N * K // group_size).uniform_(-4, 4).reshape([N, K // group_size]).type(torch.float16).cuda() + scale_tensor = torch.rand(N * K // group_size).uniform_(-4, 4).reshape( + [N, K // group_size]).type(torch.float16).cuda() # scale_tensor = torch.ones([N, K // group_size]).type(torch.float16).cuda() rescale_b = torch.zeros_like(torch_b).type(torch.float16) for i in range(K): rescale_b[:, i] = torch_b.to(torch.float16)[:, i] * scale_tensor[:, i // group_size] - torch_fp16_b = rescale_b.t().to(torch.float16) if layout == "nt" else rescale_b.to(torch.float16) - + torch_fp16_b = rescale_b.t().to(torch.float16) if layout == "nt" else rescale_b.to( + torch.float16) + ref_out = torch.matmul(torch_fp16_a, torch_fp16_b) ref_out = ref_out.to(numpytype_c) - - + permuted_inputs = [] permuted_inputs.append(torch_a) permuted_inputs.append(matmul.transform_weight(torch_b)) if with_scaling: permuted_inputs.append(scale_tensor) bitblas_out = matmul(*permuted_inputs) - + print("torch_ref_out", ref_out) print("bitblas_out", bitblas_out) torch.testing.assert_allclose(ref_out, bitblas_out, rtol=1e-2, atol=1e-2) + # fmt: on if __name__ == "__main__": bitblas.testing.main()