Skip to content

Commit

Permalink
Refactor quantization module to support new float8 formats
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang199 committed May 20, 2024
1 parent aa658de commit b165b58
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 112 deletions.
8 changes: 4 additions & 4 deletions python/bitblas/gpu/gemv_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down Expand Up @@ -213,8 +213,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down
12 changes: 6 additions & 6 deletions python/bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e4m3"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down Expand Up @@ -633,8 +633,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e4m3"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down Expand Up @@ -1123,8 +1123,8 @@ def check_weight_decode_info(weight_decode_info):
conditions = []
# check source format in ["int", "fp", "nf"]
conditions.append("source_format" in weight_decode_info)
conditions.append(
weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e4m3"])
conditions.append(weight_decode_info["source_format"]["format"] in
["uint", "int", "fp", "nf", "fp_e4m3"])
# check source bits in [1, 2, 4, 8]
conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8])
# check target format in ["float16", "int8"]
Expand Down
81 changes: 25 additions & 56 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from typing import Any, List, Literal, Optional, Tuple, Union
from .operator import Operator, TransformKind
from .impl.matmul_dequantize_impl import (
select_implementation as weight_dequantize_implementation,
)
select_implementation as weight_dequantize_implementation,)
from .impl.matmul_impl import select_implementation as consistent_implementation
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4
from bitblas.utils.target_detector import auto_detect_nvidia_target
Expand Down Expand Up @@ -110,36 +109,23 @@ def __legalize_dynamic_symbolic(self, M):

def __legalize_propagate(self, propagate):
if isinstance(propagate, bool):
return (
TransformKind.IntraWarpTransform
if propagate
else TransformKind.NonTransform
)
return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform)
elif isinstance(propagate, int):
return TransformKind(propagate)

return propagate

def __initialize_propagate(
self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]
):
def __initialize_propagate(self, propagate_a: Optional[TransformKind],
propagate_b: Optional[TransformKind]):
MICRO_KERNEL_SIZE = 16
if (
isinstance(self.M, int)
and (self.M % MICRO_KERNEL_SIZE) == 0
and (self.K % MICRO_KERNEL_SIZE) == 0
):
if (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and
(self.K % MICRO_KERNEL_SIZE) == 0):
object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform)
else:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)

if (
self.M == 1
or (self.N % MICRO_KERNEL_SIZE) != 0
or (self.K % MICRO_KERNEL_SIZE) != 0
or isinstance(self.M, Tuple)
or (self.with_zeros and self.zeros_mode == "quantized")
):
if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or
isinstance(self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized")):
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)
else:
Expand All @@ -164,10 +150,7 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):
if fast_decoding is not None:
object.__setattr__(self, "fast_decoding", fast_decoding)
elif (
"int" not in self.W_dtype
or self.W_dtype == self.A_dtype
):
elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype):
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", True)
Expand All @@ -186,12 +169,8 @@ def __post_init__(self):
object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M))

# set propagate_a and propagate_b to default value if it is None
object.__setattr__(
self, "propagate_a", self.__legalize_propagate(self.propagate_a)
)
object.__setattr__(
self, "propagate_b", self.__legalize_propagate(self.propagate_b)
)
object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a))
object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b))

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
Expand All @@ -214,10 +193,10 @@ def __post_init__(self):
object.__setattr__(self, "with_zeros", False)

if self.A_dtype == self.W_dtype and self.W_dtype in [
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
]:
object.__setattr__(self, "storage_dtype", self.W_dtype)

Expand Down Expand Up @@ -260,9 +239,8 @@ def __init__(
if target is None:
target = auto_detect_nvidia_target()
logger.info(f"Auto detected target: {target}")
assert (
config.A_dtype in self.BITBLAS_TRICK_DTYPE_MAP
), f"Unsupported input dtype {config.A_dtype}"
assert (config.A_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}"
source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype]

self.source_format = source_format
Expand All @@ -283,8 +261,7 @@ def __init__(
if isinstance(self.M, Tuple):
self.dynamic_range = {"m": self.M}
self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs(
{"opt_shapes": self.dynamic_range}
)
{"opt_shapes": self.dynamic_range})
else:
self.dynamic_range = None

Expand Down Expand Up @@ -393,9 +370,7 @@ def __init__(

def _build_default_module(self, target: Target):
try:
self.optimized_func = self.apply_default_schedule(
self.prim_func_mod, target
)
self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target)
except Exception:
self.optimized_func = None
logger.warning(
Expand Down Expand Up @@ -446,9 +421,7 @@ def post_process(self, code: str) -> str:
return code

def retrieve_weight_shape(self):
return [
int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape
]
return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape]

def transform_weight(self, weight, scale=None, zeros=None, bias=None):
"""
Expand Down Expand Up @@ -480,7 +453,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
if source_format == "int":
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2 ** (bit - 1)
maxq = 2**(bit - 1)
# Clamp weight values to be within the quantizable range and adjust
weight = torch.clamp(weight, -maxq, maxq).int() + maxq
elif source_format in ["fp_e5m2", "fp_e4m3"]:
Expand All @@ -493,8 +466,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
np_storage_dtype = getattr(np, self.storage_dtype)

weight = general_compress(
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype
)
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype)

weight = torch.from_numpy(weight).cuda().contiguous()

Expand All @@ -520,24 +492,21 @@ def transform_input(self, input_tensor):
raise ValueError(
f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size."
)
self.ladder_permutate_a._forward_from_prebuild_lib(
input_tensor, self.workspace
)
self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace)
return self.workspace
return input_tensor

def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args = []
args.append(self.transform_input(A))
args.append(W)

if self.lut is not None:
args.append(self.lut)

if output is None:
output = torch.empty(
A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device
)
A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device)
if scale is not None:
args.append(scale)
if zeros is not None:
Expand Down
9 changes: 3 additions & 6 deletions python/bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def decode_func(n, k):
w = _tir_u32_to_f4_to_f16(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(
bit, B[n, k], dtype=in_dtype)
w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype)
elif source_format == "nf":
w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
Expand Down Expand Up @@ -268,8 +267,7 @@ def decode_func(n, k):
dtype=in_dtype,
)
elif source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(
bit, B_reindex[n, k], dtype=in_dtype)
w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype)
elif source_format == "nf":
w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
Expand Down Expand Up @@ -457,8 +455,7 @@ def decode_func(n, k):
dtype=in_dtype,
)
elif source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(
bit, B_reindex[n, k], dtype=in_dtype)
w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype)
elif source_format == "nf":
w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
Expand Down
11 changes: 6 additions & 5 deletions python/bitblas/quantization/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,22 @@ def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype
(e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32"))
return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)


def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "int16")) << tir.const(15, "int16")
offset = tir.Select(
s_f16 == 0,
tir.const(8192, "int16"),
tir.const(-8192, "int16")
)
offset = tir.Select(s_f16 == 0, tir.const(8192, "int16"), tir.const(-8192, "int16"))
e_f16 = ((val << tir.const(7, "int16")) + offset)
return tir.reinterpret("float16", s_f16 | e_f16)


def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert nbit == 8
assert dtype == "float16"
return tir.reinterpret("e5m2_float8", val).astype("float16")


def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)

Expand Down Expand Up @@ -189,6 +188,7 @@ def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm

return f_convert


def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8):
storage_dtype = storage_type + str(storage_nbit)

Expand All @@ -201,4 +201,5 @@ def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):

return f_convert


# fmt: on
Loading

0 comments on commit b165b58

Please sign in to comment.