Skip to content

Commit

Permalink
[FP8] Support Weight Dequantize FP16xFP8_E4M3 (apache#42)
Browse files Browse the repository at this point in the history
* Refactor quantization module to support new float8 formats

* Refactor quantization module to support new float8 formats

* update readme

---------

Co-authored-by: LeiWang199 <leiwang199>
  • Loading branch information
LeiWang1999 authored May 20, 2024
1 parent e08db97 commit 42f379c
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 89 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and
|:-----------:|:-----------:|:---------------:|:---------------:|:----------------------:|:----------------------:|
| FP16 | FP16 | FP16 | FP16 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | FP4_E2M1 | FP16 | FP16 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | FP8_E4M3 | FP16 | FP16 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | INT8 | FP16 | FP16 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | UINT4/INT4 | FP16 | FP16 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP16 | UINT2/INT2 | FP16 | FP16 | **** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
Expand Down
8 changes: 4 additions & 4 deletions python/bitblas/gpu/gemv_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down Expand Up @@ -213,8 +213,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down
12 changes: 6 additions & 6 deletions python/bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down Expand Up @@ -633,8 +633,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down Expand Up @@ -1123,8 +1123,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down
89 changes: 30 additions & 59 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from typing import Any, List, Literal, Optional, Tuple, Union
from .operator import Operator, TransformKind
from .impl.matmul_dequantize_impl import (
select_implementation as weight_dequantize_implementation,
)
select_implementation as weight_dequantize_implementation,)
from .impl.matmul_impl import select_implementation as consistent_implementation
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4
from bitblas.utils.target_detector import auto_detect_nvidia_target
Expand Down Expand Up @@ -110,36 +109,23 @@ def __legalize_dynamic_symbolic(self, M):

def __legalize_propagate(self, propagate):
if isinstance(propagate, bool):
return (
TransformKind.IntraWarpTransform
if propagate
else TransformKind.NonTransform
)
return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform)
elif isinstance(propagate, int):
return TransformKind(propagate)

return propagate

def __initialize_propagate(
self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]
):
def __initialize_propagate(self, propagate_a: Optional[TransformKind],
propagate_b: Optional[TransformKind]):
MICRO_KERNEL_SIZE = 16
if (
isinstance(self.M, int)
and (self.M % MICRO_KERNEL_SIZE) == 0
and (self.K % MICRO_KERNEL_SIZE) == 0
):
if (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and
(self.K % MICRO_KERNEL_SIZE) == 0):
object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform)
else:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)

if (
self.M == 1
or (self.N % MICRO_KERNEL_SIZE) != 0
or (self.K % MICRO_KERNEL_SIZE) != 0
or isinstance(self.M, Tuple)
or (self.with_zeros and self.zeros_mode == "quantized")
):
if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or
isinstance(self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized")):
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)
else:
Expand All @@ -164,10 +150,7 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):
if fast_decoding is not None:
object.__setattr__(self, "fast_decoding", fast_decoding)
elif (
"int" not in self.W_dtype
or self.W_dtype == self.A_dtype
):
elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype):
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", True)
Expand All @@ -186,12 +169,8 @@ def __post_init__(self):
object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M))

# set propagate_a and propagate_b to default value if it is None
object.__setattr__(
self, "propagate_a", self.__legalize_propagate(self.propagate_a)
)
object.__setattr__(
self, "propagate_b", self.__legalize_propagate(self.propagate_b)
)
object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a))
object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b))

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
Expand All @@ -214,10 +193,10 @@ def __post_init__(self):
object.__setattr__(self, "with_zeros", False)

if self.A_dtype == self.W_dtype and self.W_dtype in [
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
]:
object.__setattr__(self, "storage_dtype", self.W_dtype)

Expand All @@ -242,10 +221,9 @@ class Matmul(Operator):
"int1": ("int", 1),
"uint1": ("uint", 1),
"nf4": ("nf", 4),
"fp8_e5m2": ("fp", 8),
"fp4_e2m1": ("fp", 4),
"e4m3_float8": ("fp", 8), # "e4m3_float8" is a trick for "float8_e4m3fn"
"e5m2_float8": ("fp", 8),
"e4m3_float8": ("fp_e4m3", 8), # "e4m3_float8" is a trick for "float8_e4m3fn"
"e5m2_float8": ("fp_e5m2", 8),
}

def __init__(
Expand All @@ -261,9 +239,8 @@ def __init__(
if target is None:
target = auto_detect_nvidia_target()
logger.info(f"Auto detected target: {target}")
assert (
config.A_dtype in self.BITBLAS_TRICK_DTYPE_MAP
), f"Unsupported input dtype {config.A_dtype}"
assert (config.A_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}"
source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype]

self.source_format = source_format
Expand All @@ -284,8 +261,7 @@ def __init__(
if isinstance(self.M, Tuple):
self.dynamic_range = {"m": self.M}
self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs(
{"opt_shapes": self.dynamic_range}
)
{"opt_shapes": self.dynamic_range})
else:
self.dynamic_range = None

Expand Down Expand Up @@ -394,9 +370,7 @@ def __init__(

def _build_default_module(self, target: Target):
try:
self.optimized_func = self.apply_default_schedule(
self.prim_func_mod, target
)
self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target)
except Exception:
self.optimized_func = None
logger.warning(
Expand Down Expand Up @@ -447,9 +421,7 @@ def post_process(self, code: str) -> str:
return code

def retrieve_weight_shape(self):
return [
int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape
]
return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape]

def transform_weight(self, weight, scale=None, zeros=None, bias=None):
"""
Expand Down Expand Up @@ -481,18 +453,20 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
if source_format == "int":
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2 ** (bit - 1)
maxq = 2**(bit - 1)
# Clamp weight values to be within the quantizable range and adjust
weight = torch.clamp(weight, -maxq, maxq).int() + maxq
elif source_format in ["fp_e5m2", "fp_e4m3"]:
weight = weight.view(torch.int8)
weight = weight.int()
else:
# For non-integer formats, simply convert weights to integers
weight = weight.int()

np_storage_dtype = getattr(np, self.storage_dtype)

weight = general_compress(
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype
)
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype)

weight = torch.from_numpy(weight).cuda().contiguous()

Expand All @@ -518,24 +492,21 @@ def transform_input(self, input_tensor):
raise ValueError(
f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size."
)
self.ladder_permutate_a._forward_from_prebuild_lib(
input_tensor, self.workspace
)
self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace)
return self.workspace
return input_tensor

def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args = []
args.append(self.transform_input(A))
args.append(W)

if self.lut is not None:
args.append(self.lut)

if output is None:
output = torch.empty(
A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device
)
A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device)
if scale is not None:
args.append(scale)
if zeros is not None:
Expand Down
20 changes: 15 additions & 5 deletions python/bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert,
_tir_u32_to_f4_to_f16,
_tir_u8_to_f8_e4m3_to_f16,
_tir_packed_to_unsigned_convert_with_zeros,
)

Expand Down Expand Up @@ -58,14 +59,17 @@ def qzeros_dequantize(k, n):
dtype=storage_dtype,
)

Dequantize_qzeros = te.compute(
(K // group_size, N),
qzeros_dequantize,
name="Dequantize_zeros",
)
Dequantize_qzeros = None
if with_zeros and zeros_mode == "quantized":
Dequantize_qzeros = te.compute(
(K // group_size, N),
qzeros_dequantize,
name="Dequantize_zeros",
)

def decode_func(n, k):
if with_zeros and zeros_mode == "quantized":
assert Dequantize_qzeros is not None, "Dequantize_zeros is None"
w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)(
bit,
B[n, k // n_float_per_elem],
Expand All @@ -87,6 +91,8 @@ def decode_func(n, k):
elif source_format == "fp":
w = _tir_u32_to_f4_to_f16(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype)
elif source_format == "nf":
w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
Expand Down Expand Up @@ -260,6 +266,8 @@ def decode_func(n, k):
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype)
elif source_format == "nf":
w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
Expand Down Expand Up @@ -446,6 +454,8 @@ def decode_func(n, k):
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype)
elif source_format == "nf":
w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
Expand Down
1 change: 1 addition & 0 deletions python/bitblas/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
_tir_packed_to_signed_convert, # noqa: F401
_tir_packed_to_unsigned_convert, # noqa: F401
_tir_u32_to_f4_to_f16, # noqa: F401
_tir_u8_to_f8_e4m3_to_f16, # noqa: F401
_tir_packed_to_unsigned_convert_with_zeros, # noqa: F401
)

Expand Down
17 changes: 17 additions & 0 deletions python/bitblas/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype
return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)


def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "int16")) << tir.const(15, "int16")
offset = tir.Select(s_f16 == 0, tir.const(8192, "int16"), tir.const(-8192, "int16"))
e_f16 = ((val << tir.const(7, "int16")) + offset)
return tir.reinterpret("float16", s_f16 | e_f16)


def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
return tir.reinterpret("e5m2_float8", val).astype("float16")


def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)

Expand Down Expand Up @@ -173,6 +188,7 @@ def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm

return f_convert


def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)

Expand All @@ -185,4 +201,5 @@ def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):

return f_convert


# fmt: on
Loading

0 comments on commit 42f379c

Please sign in to comment.