From 42f379c3c218193e05223b19ec7981b10eb4b53f Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 20 May 2024 22:50:57 +0800 Subject: [PATCH] [FP8] Support Weight Dequantize FP16xFP8_E4M3 (#42) * Refactor quantization module to support new float8 formats * Refactor quantization module to support new float8 formats * update readme --------- Co-authored-by: LeiWang199 --- README.md | 1 + python/bitblas/gpu/gemv_dequantize.py | 8 +- python/bitblas/gpu/matmul_mma_dequantize.py | 12 +- python/bitblas/ops/general_matmul.py | 89 +++++------- .../ops/impl/matmul_dequantize_impl.py | 20 ++- python/bitblas/quantization/__init__.py | 1 + python/bitblas/quantization/quantization.py | 17 +++ .../operators/test_general_matmul_fp8.py | 127 +++++++++++++++--- 8 files changed, 186 insertions(+), 89 deletions(-) diff --git a/README.md b/README.md index 66ca98f2dab5..7a979c8e387c 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and |:-----------:|:-----------:|:---------------:|:---------------:|:----------------------:|:----------------------:| | FP16 | FP16 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP16 | FP4_E2M1 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | FP8_E4M3 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP16 | INT8 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP16 | UINT4/INT4 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | FP16 | UINT2/INT2 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | diff --git a/python/bitblas/gpu/gemv_dequantize.py b/python/bitblas/gpu/gemv_dequantize.py index a9157750a226..47e4bf42cd0a 100644 --- a/python/bitblas/gpu/gemv_dequantize.py +++ b/python/bitblas/gpu/gemv_dequantize.py @@ -49,8 +49,8 @@ def check_weight_decode_info(weight_decode_info): conditions = [] # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) - conditions.append( - weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"]) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"]) # check source bits in [1, 2, 4, 8] conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] @@ -213,8 +213,8 @@ def check_weight_decode_info(weight_decode_info): conditions = [] # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) - conditions.append( - weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"]) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"]) # check source bits in [1, 2, 4, 8] conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] diff --git a/python/bitblas/gpu/matmul_mma_dequantize.py b/python/bitblas/gpu/matmul_mma_dequantize.py index e4c5a272a857..96461db453e0 100644 --- a/python/bitblas/gpu/matmul_mma_dequantize.py +++ b/python/bitblas/gpu/matmul_mma_dequantize.py @@ -126,8 +126,8 @@ def check_weight_decode_info(weight_decode_info): conditions = [] # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) - conditions.append( - weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"]) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) # check source bits in [1, 2, 4, 8] conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] @@ -633,8 +633,8 @@ def check_weight_decode_info(weight_decode_info): conditions = [] # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) - conditions.append( - weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"]) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) # check source bits in [1, 2, 4, 8] conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] @@ -1123,8 +1123,8 @@ def check_weight_decode_info(weight_decode_info): conditions = [] # check source format in ["int", "fp", "nf"] conditions.append("source_format" in weight_decode_info) - conditions.append( - weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"]) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) # check source bits in [1, 2, 4, 8] conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) # check target format in ["float16", "int8"] diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index 0045934f32f2..35eee1fb6e6b 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -8,8 +8,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union from .operator import Operator, TransformKind from .impl.matmul_dequantize_impl import ( - select_implementation as weight_dequantize_implementation, -) + select_implementation as weight_dequantize_implementation,) from .impl.matmul_impl import select_implementation as consistent_implementation from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4 from bitblas.utils.target_detector import auto_detect_nvidia_target @@ -110,36 +109,23 @@ def __legalize_dynamic_symbolic(self, M): def __legalize_propagate(self, propagate): if isinstance(propagate, bool): - return ( - TransformKind.IntraWarpTransform - if propagate - else TransformKind.NonTransform - ) + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) elif isinstance(propagate, int): return TransformKind(propagate) return propagate - def __initialize_propagate( - self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind] - ): + def __initialize_propagate(self, propagate_a: Optional[TransformKind], + propagate_b: Optional[TransformKind]): MICRO_KERNEL_SIZE = 16 - if ( - isinstance(self.M, int) - and (self.M % MICRO_KERNEL_SIZE) == 0 - and (self.K % MICRO_KERNEL_SIZE) == 0 - ): + if (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and + (self.K % MICRO_KERNEL_SIZE) == 0): object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform) else: object.__setattr__(self, "propagate_a", TransformKind.NonTransform) - if ( - self.M == 1 - or (self.N % MICRO_KERNEL_SIZE) != 0 - or (self.K % MICRO_KERNEL_SIZE) != 0 - or isinstance(self.M, Tuple) - or (self.with_zeros and self.zeros_mode == "quantized") - ): + if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or + isinstance(self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized")): object.__setattr__(self, "propagate_a", TransformKind.NonTransform) object.__setattr__(self, "propagate_b", TransformKind.NonTransform) else: @@ -164,10 +150,7 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]): def __initialize_fast_decoding(self, fast_decoding: Optional[bool]): if fast_decoding is not None: object.__setattr__(self, "fast_decoding", fast_decoding) - elif ( - "int" not in self.W_dtype - or self.W_dtype == self.A_dtype - ): + elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype): object.__setattr__(self, "fast_decoding", False) else: object.__setattr__(self, "fast_decoding", True) @@ -186,12 +169,8 @@ def __post_init__(self): object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M)) # set propagate_a and propagate_b to default value if it is None - object.__setattr__( - self, "propagate_a", self.__legalize_propagate(self.propagate_a) - ) - object.__setattr__( - self, "propagate_b", self.__legalize_propagate(self.propagate_b) - ) + object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a)) + object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b)) # This is hack to legalize propagate_a and b # TODO(lei): should be removed in the future when tc+br template is ready. @@ -214,10 +193,10 @@ def __post_init__(self): object.__setattr__(self, "with_zeros", False) if self.A_dtype == self.W_dtype and self.W_dtype in [ - "float16", - "int8", - "e4m3_float8", - "e5m2_float8", + "float16", + "int8", + "e4m3_float8", + "e5m2_float8", ]: object.__setattr__(self, "storage_dtype", self.W_dtype) @@ -242,10 +221,9 @@ class Matmul(Operator): "int1": ("int", 1), "uint1": ("uint", 1), "nf4": ("nf", 4), - "fp8_e5m2": ("fp", 8), "fp4_e2m1": ("fp", 4), - "e4m3_float8": ("fp", 8), # "e4m3_float8" is a trick for "float8_e4m3fn" - "e5m2_float8": ("fp", 8), + "e4m3_float8": ("fp_e4m3", 8), # "e4m3_float8" is a trick for "float8_e4m3fn" + "e5m2_float8": ("fp_e5m2", 8), } def __init__( @@ -261,9 +239,8 @@ def __init__( if target is None: target = auto_detect_nvidia_target() logger.info(f"Auto detected target: {target}") - assert ( - config.A_dtype in self.BITBLAS_TRICK_DTYPE_MAP - ), f"Unsupported input dtype {config.A_dtype}" + assert (config.A_dtype + in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}" source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype] self.source_format = source_format @@ -284,8 +261,7 @@ def __init__( if isinstance(self.M, Tuple): self.dynamic_range = {"m": self.M} self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( - {"opt_shapes": self.dynamic_range} - ) + {"opt_shapes": self.dynamic_range}) else: self.dynamic_range = None @@ -394,9 +370,7 @@ def __init__( def _build_default_module(self, target: Target): try: - self.optimized_func = self.apply_default_schedule( - self.prim_func_mod, target - ) + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) except Exception: self.optimized_func = None logger.warning( @@ -447,9 +421,7 @@ def post_process(self, code: str) -> str: return code def retrieve_weight_shape(self): - return [ - int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape - ] + return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] def transform_weight(self, weight, scale=None, zeros=None, bias=None): """ @@ -481,9 +453,12 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): if source_format == "int": assert not self.with_scaling, "scale should be False for int source format" assert not self.with_zeros, "zeros should be False for int source format" - maxq = 2 ** (bit - 1) + maxq = 2**(bit - 1) # Clamp weight values to be within the quantizable range and adjust weight = torch.clamp(weight, -maxq, maxq).int() + maxq + elif source_format in ["fp_e5m2", "fp_e4m3"]: + weight = weight.view(torch.int8) + weight = weight.int() else: # For non-integer formats, simply convert weights to integers weight = weight.int() @@ -491,8 +466,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): np_storage_dtype = getattr(np, self.storage_dtype) weight = general_compress( - weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype - ) + weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) weight = torch.from_numpy(weight).cuda().contiguous() @@ -518,9 +492,7 @@ def transform_input(self, input_tensor): raise ValueError( f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size." ) - self.ladder_permutate_a._forward_from_prebuild_lib( - input_tensor, self.workspace - ) + self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace) return self.workspace return input_tensor @@ -528,14 +500,13 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args = [] args.append(self.transform_input(A)) args.append(W) - + if self.lut is not None: args.append(self.lut) if output is None: output = torch.empty( - A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device - ) + A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) if scale is not None: args.append(scale) if zeros is not None: diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py index b7c7c64ee65e..6e6b098c19a1 100644 --- a/python/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py @@ -11,6 +11,7 @@ _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16, _tir_packed_to_unsigned_convert_with_zeros, ) @@ -58,14 +59,17 @@ def qzeros_dequantize(k, n): dtype=storage_dtype, ) - Dequantize_qzeros = te.compute( - (K // group_size, N), - qzeros_dequantize, - name="Dequantize_zeros", - ) + Dequantize_qzeros = None + if with_zeros and zeros_mode == "quantized": + Dequantize_qzeros = te.compute( + (K // group_size, N), + qzeros_dequantize, + name="Dequantize_zeros", + ) def decode_func(n, k): if with_zeros and zeros_mode == "quantized": + assert Dequantize_qzeros is not None, "Dequantize_zeros is None" w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( bit, B[n, k // n_float_per_elem], @@ -87,6 +91,8 @@ def decode_func(n, k): elif source_format == "fp": w = _tir_u32_to_f4_to_f16( bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) elif source_format == "nf": w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( bit, @@ -260,6 +266,8 @@ def decode_func(n, k): k % n_float_per_elem, dtype=in_dtype, ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) elif source_format == "nf": w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( bit, @@ -446,6 +454,8 @@ def decode_func(n, k): k % n_float_per_elem, dtype=in_dtype, ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) elif source_format == "nf": w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( bit, diff --git a/python/bitblas/quantization/__init__.py b/python/bitblas/quantization/__init__.py index 0ca9ab377575..d29cb679a957 100644 --- a/python/bitblas/quantization/__init__.py +++ b/python/bitblas/quantization/__init__.py @@ -5,6 +5,7 @@ _tir_packed_to_signed_convert, # noqa: F401 _tir_packed_to_unsigned_convert, # noqa: F401 _tir_u32_to_f4_to_f16, # noqa: F401 + _tir_u8_to_f8_e4m3_to_f16, # noqa: F401 _tir_packed_to_unsigned_convert_with_zeros, # noqa: F401 ) diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py index e390fa6406e9..d9f36094794d 100644 --- a/python/bitblas/quantization/quantization.py +++ b/python/bitblas/quantization/quantization.py @@ -139,6 +139,21 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) +def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == "float16" + s_f16 = (val >> tir.const(7, "int16")) << tir.const(15, "int16") + offset = tir.Select(s_f16 == 0, tir.const(8192, "int16"), tir.const(-8192, "int16")) + e_f16 = ((val << tir.const(7, "int16")) + offset) + return tir.reinterpret("float16", s_f16 | e_f16) + + +def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == "float16" + return tir.reinterpret("e5m2_float8", val).astype("float16") + + def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): storage_dtype = storage_type + str(storage_nbit) @@ -173,6 +188,7 @@ def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm return f_convert + def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8): storage_dtype = storage_type + str(storage_nbit) @@ -185,4 +201,5 @@ def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): return f_convert + # fmt: on diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 720e8fb321c8..3d0a7be2f583 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -9,17 +9,18 @@ set_log_level(logging.DEBUG) +# TODO(lei): should add requirements for cuda and sm version @pytest.mark.parametrize( "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", [ - (1, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, None, - None), - (1024, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, None, - None), - (1, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, None, - None), - (1024, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, None, - None), + (1, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, + None, None), + (1024, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, + None, None, None), + (1, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, + None, None), + (1024, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, + None, None, None), ], ) def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, @@ -46,9 +47,9 @@ def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) - + def map_torch_type(intype): - + typemap = { 'e4m3_float8': torch.float8_e4m3fn, 'e5m2_float8': torch.float8_e5m2, @@ -61,17 +62,113 @@ def map_torch_type(intype): numpytype_a = map_torch_type(A_dtype) numpytype_b = map_torch_type(W_dtype) numpytype_c = map_torch_type(out_dtype) - - torch_a = torch.rand(M*K).uniform_(-5, 5).reshape(input_shape).type(numpytype_a).cuda() - torch_b = torch.rand(N*K).uniform_(-5, 5).reshape(weight_shape).type(numpytype_b).cuda() - ref_out = torch.matmul(torch_a.to(torch.float32), torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul(torch_a.to(torch.float32), torch_b.to(torch.float32)) + + torch_a = torch.rand(M * K).uniform_(-5, 5).reshape(input_shape).type(numpytype_a).cuda() + torch_b = torch.rand(N * K).uniform_(-5, 5).reshape(weight_shape).type(numpytype_b).cuda() + ref_out = torch.matmul(torch_a.to(torch.float32), + torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul( + torch_a.to(torch.float32), torch_b.to(torch.float32)) ref_out = ref_out.to(numpytype_c) - + print("torch_ref_out", ref_out) new_torch_b = matmul.transform_weight(torch_b) bitblas_out = matmul(torch_a, new_torch_b) print("bitblas_out", bitblas_out) + +# TODO(lei): should add requirements for cuda and sm version +@pytest.mark.parametrize( + "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + [ + (1, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, + None, None), + (1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None, + None, None), + (1, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, 32, True, None, + None), + (1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, 32, True, + None, None), + ], +) +def test_matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, + layout, with_bias, group_size, with_scaling, + with_zeros, zeros_mode): + import torch + torch.random.manual_seed(0) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + propagate_a=False, + propagate_b=False, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False) + print(matmul.src_name) + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + + def map_torch_type(intype): + + typemap = { + 'e4m3_float8': torch.float8_e4m3fn, + 'e5m2_float8': torch.float8_e5m2, + } + if intype in typemap: + return typemap[intype] + else: + return getattr(torch, intype) + + numpytype_a = map_torch_type(A_dtype) + numpytype_b = map_torch_type(W_dtype) + numpytype_c = map_torch_type(out_dtype) + + torch_a = torch.rand(M * K).uniform_(-1, 1).reshape(input_shape).type(numpytype_a).cuda() + torch_b = torch.rand(N * K).uniform_(-1, 1).reshape(weight_shape).type(numpytype_b).cuda() + + torch_fp16_a = torch_a.to(torch.float16) + torch_fp16_b = torch_b.t().to(torch.float16) if layout == "nt" else torch_b.to(torch.float16) + scale_tensor = None + if with_scaling: + if group_size is None: + group_size = -1 + if group_size == -1: + group_size = K + scale_tensor = torch.rand(N * K // group_size).uniform_(-4, 4).reshape( + [N, K // group_size]).type(torch.float16).cuda() + # scale_tensor = torch.ones([N, K // group_size]).type(torch.float16).cuda() + rescale_b = torch.zeros_like(torch_b).type(torch.float16) + for i in range(K): + rescale_b[:, i] = torch_b.to(torch.float16)[:, i] * scale_tensor[:, i // group_size] + torch_fp16_b = rescale_b.t().to(torch.float16) if layout == "nt" else rescale_b.to( + torch.float16) + + ref_out = torch.matmul(torch_fp16_a, torch_fp16_b) + ref_out = ref_out.to(numpytype_c) + + permuted_inputs = [] + permuted_inputs.append(torch_a) + permuted_inputs.append(matmul.transform_weight(torch_b)) + if with_scaling: + permuted_inputs.append(scale_tensor) + bitblas_out = matmul(*permuted_inputs) + + print("torch_ref_out", ref_out) + print("bitblas_out", bitblas_out) + + torch.testing.assert_allclose(ref_out, bitblas_out, rtol=1e-2, atol=1e-2) + + # fmt: on if __name__ == "__main__": bitblas.testing.main()