From 31743feac263dc7f354a9c6b4206f370943cc238 Mon Sep 17 00:00:00 2001 From: xinhe Date: Thu, 25 Jan 2024 17:43:07 +0800 Subject: [PATCH] use inplace=True mode for WOQ (#1557) * use inplace mode for WOQ Signed-off-by: xin3he --- .../torch_utils/autoround/model_wrapper.py | 43 +- .../adaptor/torch_utils/model_wrapper.py | 42 +- .../adaptor/torch_utils/weight_only.py | 40 +- .../torch/algorithms/weight_only/__init__.py | 4 + .../torch/algorithms/weight_only/gptq.py | 52 +- .../torch/algorithms/weight_only/rtn.py | 449 +----------------- .../torch/algorithms/weight_only/utility.py | 435 +++++++++++++++++ .../torch/quantization/layers.py | 62 +-- neural_compressor/torch/utils/__init__.py | 2 + 9 files changed, 594 insertions(+), 535 deletions(-) create mode 100644 neural_compressor/torch/algorithms/weight_only/utility.py diff --git a/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py b/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py index bd73fddd94d..3c47c6f1bbb 100644 --- a/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py +++ b/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py @@ -127,7 +127,6 @@ def __init__( dtype=self.float_type, ).to(device), ) - self.scales = self.scales.T self.register_buffer( "qweight", torch.zeros( @@ -135,7 +134,6 @@ def __init__( dtype=self.compression_dtype, ).to(device), ) - self.qweight = self.qweight.T self.register_buffer( "qzeros", torch.zeros( @@ -143,7 +141,6 @@ def __init__( dtype=self.compression_dtype, ).to(device), ) - self.qzeros = self.qzeros.T self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) else: self.compression_dtype = compression_dtype @@ -193,6 +190,10 @@ def __init__( self.bias = None def pack(self, int_weight, scale, zp, bias): + if self.use_optimum_format: + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() int_weight = int_weight.to(self.device) if self.use_optimum_format and zp is None: # to avoid overflow @@ -206,8 +207,8 @@ def pack(self, int_weight, scale, zp, bias): assert scale.shape == self.scales.shape, "Scale shape is mismatched." self.scales = scale.type(self.float_type).to(self.device) if not self.use_optimum_format and self.compression_dim == 0: - int_weight = int_weight.T - self.qweight = self.qweight.T + int_weight = int_weight.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() origin_shape = int_weight.shape target_shape = self.qweight.shape assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." @@ -223,15 +224,15 @@ def pack(self, int_weight, scale, zp, bias): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qweight[:, j] |= tmp[:, e] if not self.use_optimum_format and self.compression_dim == 0: - self.qweight = self.qweight.T + self.qweight = self.qweight.t_().contiguous() if zp is not None: zp = zp.to(self.device) if self.use_optimum_format: zp -= 1 if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - self.qzeros = self.qzeros.T + zp = zp.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() assert hasattr(self, "qzeros"), "zp is not set when initializing." target_shape = self.qzeros.shape for j in range(target_shape[1]): @@ -243,16 +244,16 @@ def pack(self, int_weight, scale, zp, bias): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qzeros[:, j] |= tmp[:, e] if self.use_optimum_format or self.compression_dim == 0: - self.qzeros = self.qzeros.T + self.qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format: - self.scales = self.scales.T - self.qweight = self.qweight.T - self.qzeros = self.qzeros.T + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() def recover(self): logger.debug(f"Recovering {self} weight") - scales = self.scales.T if self.use_optimum_format else self.scales - qweight = self.qweight.T if self.use_optimum_format else self.qweight + scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales + qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight device = scales.device fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) @@ -264,8 +265,8 @@ def recover(self): # unpack weight weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T - qweight = qweight.T + weight = weight.t_().contiguous() + qweight = qweight.t_().contiguous() origin_shape = weight.shape target_shape = qweight.shape for j in range(target_shape[1]): @@ -280,7 +281,7 @@ def recover(self): tmp &= mask # remove sign bit weight[:, index] = tmp.type(weight_dtype) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T + weight = weight.t_().contiguous() if "int" not in self.dtype: new_weight = torch.zeros(self.out_features, self.in_features).to(device) for k, v in self.int2float_mapping.items(): @@ -290,10 +291,10 @@ def recover(self): if hasattr(self, "qzeros"): zp_dtype = self.compression_dtype # to avoid overflow when weight-zp zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) - qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros + qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - qzeros = qzeros.T + zp = zp.t_().contiguous() + qzeros = qzeros.t_().contiguous() origin_shape = zp.shape target_shape = qzeros.shape for j in range(target_shape[1]): @@ -307,7 +308,7 @@ def recover(self): tmp &= mask zp[:, index] = tmp.type(zp_dtype) if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T + zp = zp.t_().contiguous() if self.use_optimum_format: # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 zp += 1 diff --git a/neural_compressor/adaptor/torch_utils/model_wrapper.py b/neural_compressor/adaptor/torch_utils/model_wrapper.py index 7376992b9fb..9a3cd6361f2 100644 --- a/neural_compressor/adaptor/torch_utils/model_wrapper.py +++ b/neural_compressor/adaptor/torch_utils/model_wrapper.py @@ -327,9 +327,9 @@ def __init__( def pack(self, int_weight, scale, zp, bias, g_idx=None): if self.use_optimum_format: - self.scales = self.scales.T - self.qweight = self.qweight.T - self.qzeros = self.qzeros.T + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() int_weight = int_weight.to(self.device) if self.use_optimum_format and zp is None: # to avoid overflow @@ -350,8 +350,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): assert scale.shape == self.scales.shape, "Scale shape is mismatched." self.scales = scale.type(self.float_type).to(self.device) if not self.use_optimum_format and self.compression_dim == 0: - int_weight = int_weight.T - self.qweight = self.qweight.T + int_weight = int_weight.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() origin_shape = int_weight.shape target_shape = self.qweight.shape assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." @@ -367,15 +367,15 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qweight[:, j] |= tmp[:, e] if not self.use_optimum_format and self.compression_dim == 0: - self.qweight = self.qweight.T + self.qweight = self.qweight.t_().contiguous() if zp is not None: zp = zp.to(self.device) if self.use_optimum_format: zp -= 1 if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - self.qzeros = self.qzeros.T + zp = zp.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() assert hasattr(self, "qzeros"), "zp is not set when initializing." target_shape = self.qzeros.shape for j in range(target_shape[1]): @@ -387,16 +387,16 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qzeros[:, j] |= tmp[:, e] if self.use_optimum_format or self.compression_dim == 0: - self.qzeros = self.qzeros.T + self.qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format: - self.scales = self.scales.T - self.qweight = self.qweight.T - self.qzeros = self.qzeros.T + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() def recover(self): logger.debug(f"Recovering {self} weight") - scales = self.scales.T if self.use_optimum_format else self.scales - qweight = self.qweight.T if self.use_optimum_format else self.qweight + scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales + qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight device = scales.device fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) @@ -411,8 +411,8 @@ def recover(self): # unpack weight weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T - qweight = qweight.T + weight = weight.t_().contiguous() + qweight = qweight.t_().contiguous() origin_shape = weight.shape target_shape = qweight.shape for j in range(target_shape[1]): @@ -427,7 +427,7 @@ def recover(self): tmp &= mask # remove sign bit weight[:, index] = tmp.type(weight_dtype) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T + weight = weight.t_().contiguous() if "int" not in self.dtype: new_weight = torch.zeros(self.out_features, self.in_features).to(device) for k, v in self.int2float_mapping.items(): @@ -437,10 +437,10 @@ def recover(self): if hasattr(self, "qzeros"): zp_dtype = self.compression_dtype # to avoid overflow when weight-zp zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) - qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros + qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - qzeros = qzeros.T + zp = zp.t_().contiguous() + qzeros = qzeros.t_().contiguous() origin_shape = zp.shape target_shape = qzeros.shape for j in range(target_shape[1]): @@ -454,7 +454,7 @@ def recover(self): tmp &= mask zp[:, index] = tmp.type(zp_dtype) if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T + zp = zp.t_().contiguous() if self.use_optimum_format: # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 zp += 1 diff --git a/neural_compressor/adaptor/torch_utils/weight_only.py b/neural_compressor/adaptor/torch_utils/weight_only.py index b029d137e41..4a4fcf19d95 100644 --- a/neural_compressor/adaptor/torch_utils/weight_only.py +++ b/neural_compressor/adaptor/torch_utils/weight_only.py @@ -330,8 +330,8 @@ def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", en history = [] for i_s in range(int(max_shrink * n_grid)): ratio = 1 - i_s / n_grid # 1, 0.805-1.0 - cur_weight = quant_weight( - m.weight.data, + quant_weight( + m.weight.data, # in-place mode num_bits=num_bits, group_size=group_size, scheme=scheme, @@ -339,7 +339,8 @@ def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", en full_range=enable_full_range, quantile=ratio, ) - loss = (org_weight - cur_weight).float().pow(2).mean().item() + loss = (org_weight - m.weight.data).float().pow(2).mean().item() + m.weight.data.copy_(org_weight) history.append(loss) is_best = loss < best_error if is_best: @@ -429,14 +430,17 @@ def rtn_quantize( if num_bits <= 0: logger.info(f"Skip {name}") continue - weight = m.weight.T if group_dim == 0 else m.weight + # contiguous is not an in-place op and returns Tensor instead of Parameter, so set it back to m.weight.data. + # transpose should be executed on Parameter level because Param.data.t_() is not an in-place op. + # Parameter.T is an in-place op while Tensor.T is not. + m.weight.data = m.weight.t_().data.contiguous() if group_dim == 0 else m.weight.data if enable_mse_search: quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range) if return_int: from .model_wrapper import WeightOnlyLinear _, scale, zp = quant_weight( - weight, + m.weight.data, num_bits, group_size, scheme, @@ -446,9 +450,9 @@ def rtn_quantize( full_range=enable_full_range, ) if group_dim == 0: - weight.transpose_(0, 1) - scale = scale.T if group_dim == 0 else scale - zp = zp.T if group_dim == 0 and zp is not None else zp + m.weight.t_() + scale = scale.t_().contiguous() if group_dim == 0 else scale + zp = zp.t_().contiguous() if group_dim == 0 and zp is not None else zp new_module = WeightOnlyLinear( m.in_features, m.out_features, @@ -463,14 +467,14 @@ def rtn_quantize( device=device, use_optimum_format=use_optimum_format, ) - new_module.pack(weight, scale, zp, m.bias) + new_module.pack(m.weight.data, scale, zp, m.bias) if name == "": return new_module else: set_module(model, name, new_module) else: quant_weight( - weight, + m.weight.data, num_bits, group_size, scheme, @@ -479,7 +483,7 @@ def rtn_quantize( full_range=enable_full_range, ) if group_dim == 0: - weight.transpose_(0, 1) + m.weight.t_() if orig_dtype != torch.float: m = m.to(orig_dtype) return model @@ -651,18 +655,18 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1): if zp is not None: zp = zp.to(device) if group_size == -1: - return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp) + return weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_() int_weight = torch.zeros(weight.shape).to(device) leng = weight.shape[1] // group_size tail_flag = False if weight.shape[1] % group_size == 0 else True for i in range(leng): - int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1) + int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1)) if zp is not None: - int_weight_tmp += zp[:, i].unsqueeze(1) - int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp) + int_weight_tmp.add_(zp[:, i].unsqueeze(1)) + int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_()) if tail_flag: - int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1) + int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1)) if zp is not None: - int_weight_tmp += zp[:, -1].unsqueeze(1) - int_weight[:, leng * group_size :] = torch.round(int_weight_tmp) + int_weight_tmp.add_(zp[:, -1].unsqueeze(1)) + int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_()) return int_weight diff --git a/neural_compressor/torch/algorithms/weight_only/__init__.py b/neural_compressor/torch/algorithms/weight_only/__init__.py index 8989ae9d722..ac8feca4f40 100644 --- a/neural_compressor/torch/algorithms/weight_only/__init__.py +++ b/neural_compressor/torch/algorithms/weight_only/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .utility import * +from .rtn import rtn_quantize +from .gptq import gptq_quantize diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index 9c7d3453e8e..1eb1722beec 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -30,10 +30,7 @@ import transformers from tqdm import tqdm -from neural_compressor.common import Logger - -logger = Logger().get_logger() - +from neural_compressor.torch.utils import logger DEBUG = False @@ -758,14 +755,15 @@ def find_params(self, x, weight=False): self.maxq = self.maxq.to(dev) # NF4 FP4 if self.wdtype != "int": - from .rtn import quant_weight + from .utility import quant_tensor - _, scale, zero = quant_weight( - x, + tmp = x.clone() # make sure x is not replaced + _, scale, zero = quant_tensor( + tmp, self.wbits, self.group_size, scheme=self.scheme, - data_type=self.wdtype, + dtype=self.wdtype, quantile=1.0, return_int=True, full_range=False, @@ -850,16 +848,16 @@ def find_params(self, x, weight=False): self.zero = self.zero.reshape(shape) if self.double_quant: - from .rtn import quant_weight + from .utility import quant_tensor orig_scale_shape = self.scale.shape self.scale = self.scale.reshape(1, -1) - self.scale = quant_weight( + quant_tensor( self.scale, self.double_quant_bits, self.double_quant_group_size, scheme=self.double_quant_scheme, - data_type=self.double_quant_dtype, + dtype=self.double_quant_dtype, quantile=1.0, return_int=False, full_range=False, @@ -879,9 +877,11 @@ def find_params(self, x, weight=False): def quantize(self, x, scale, zero, maxq): """Do quantization.""" if self.wdtype != "int": - from .rtn import quantize_4bit + from .utility import quantize_4bit - return quantize_4bit(x, data_type=self.wdtype, scale=scale) + tmp = x.clone() + + return quantize_4bit(tmp, dtype=self.wdtype, scale=scale) else: if maxq < 0: return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero @@ -950,6 +950,32 @@ def gptq_config_mapping(configs_mapping: Dict[Tuple[str, Callable], GPTQConfig]) return weight_config, nsamples, use_max_length, pad_max_length, device, dataloader_len +def gptq_quantize( + model, + weight_config={}, + dataloader=None, + nsamples=128, + use_max_length=True, + pad_max_length=2048, + device=None, + layer_wise=False, + model_path=None, +): + """Run weight-only quantization with.""" + # TODO: unify weight_config keys, add docstring, and support default config + assert isinstance(model, torch.nn.Module), "only support torch module" + if layer_wise: + assert model_path is not None, "model_path should not be None when use layer_wise mode" + from .gptq import GPTQuantizer + + gptq_quantizer = GPTQuantizer( + model, weight_config, dataloader, nsamples, use_max_length, pad_max_length, device, layer_wise=layer_wise + ) + fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization(model_path=model_path) + logger.info("GPTQ quantizing done.") + return fp32_modified_model, gptq_config + + def apply_gptq_quantize(model, configs_mapping, *args, **kwargs): """Apply gptq.""" # TODO: unify weight_config keys, add docstring, and support default config diff --git a/neural_compressor/torch/algorithms/weight_only/rtn.py b/neural_compressor/torch/algorithms/weight_only/rtn.py index c2b3929a220..9b4f0456007 100644 --- a/neural_compressor/torch/algorithms/weight_only/rtn.py +++ b/neural_compressor/torch/algorithms/weight_only/rtn.py @@ -20,383 +20,14 @@ import torch -from torch.nn import functional as F -from neural_compressor.common import DEBUG, Logger, level +from neural_compressor.torch.utils import logger from neural_compressor.torch.utils.utility import set_module -logger = Logger().get_logger() - - -NF4 = [ - -1.0, - -0.6961928009986877, - -0.5250730514526367, - -0.39491748809814453, - -0.28444138169288635, - -0.18477343022823334, - -0.09105003625154495, - 0.0, - 0.07958029955625534, - 0.16093020141124725, - 0.24611230194568634, - 0.33791524171829224, - 0.44070982933044434, - 0.5626170039176941, - 0.7229568362236023, - 1.0, -] -FP4_BNB = [-12.0, -8.0, -6.0, -4.0, -3.0, -2.0, -0.0625, 0, 0.0625, 2.0, 3.0, 4.0, 6.0, 8.0, 12.0] -FP4_E2M1 = [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.0625, 0, 0.0625, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] - -# the order is the same as float list, bit value range is [-7, 7] -# 1111 = -1, 1110 = -2, 1101= -3, ... - -NF4_BIT = [7, 1, 2, 3, 4, 5, 6, 0, -8, -7, -6, -5, -4, -3, -2, -1] -FP4_BNB_BIT = [-5, -6, -3, -4, -1, -2, -7, 0, 1, 6, 7, 4, 5, 2, 3] -FP4_E2M1_BIT = [-1, -2, -3, -4, -5, -6, -7, 0, 1, 2, 3, 4, 5, 6, 7] - -FLOAT_MAPPING = {"nf4": NF4, "fp4": FP4_BNB, "fp4_e2m1_bnb": FP4_BNB, "fp4_e2m1": FP4_E2M1} -INT_MAPPING = {"nf4": NF4_BIT, "fp4": FP4_BNB_BIT, "fp4_e2m1_bnb": FP4_BNB_BIT, "fp4_e2m1": FP4_E2M1_BIT} - - -def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False, **kwargs): - """Quantize tensor to NF4/FP4 data type. - - Args: - tensor: input tensor - quantile (float, optional): percentile of clip. Defaults to 1.0. - data_type (str, optional): data type. Defaults to 'nf4'. - return_int (bool, optional): whether return int data. Defaults to False. - - Returns: - q_tensor: fake quantized tensor - """ - assert data_type in FLOAT_MAPPING, "unexpected data type." - allow_data = FLOAT_MAPPING[data_type] - allow_data_bit = INT_MAPPING[data_type] - # get scale and update tensor - if "scale" in kwargs: - scale = kwargs["scale"] - else: - scale = tensor.abs().max(1)[0] * quantile / max(allow_data) - scale.unsqueeze_(dim=-1) - tensor = tensor / scale - mid_data = [(allow_data[i] + allow_data[i + 1]) / 2 for i in range(len(allow_data) - 1)] - q_tensor = torch.zeros_like(tensor) - for i in range(len(allow_data)): - data = allow_data_bit[i] if return_int else allow_data[i] - if i == 0: - q_tensor += torch.where(tensor <= mid_data[i], data, 0) - elif i == len(allow_data) - 1: - q_tensor += torch.where(tensor > mid_data[i - 1], data, 0) - else: - q_tensor += torch.where((mid_data[i - 1] < tensor) & (tensor <= mid_data[i]), data, 0) - double_quant = kwargs.get("double_quant", False) - if return_int or double_quant: - return q_tensor, scale, None - return q_tensor * scale - - -def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False, **kwargs): - """Quant and dequant tensor with asym schema. - - Args: - weight: input weight - num_bits (int, optional): num_bits. Defaults to 4. - quantile (float, optional): percentile of clip. Defaults to 1.0. - return_int (bool, optional): Choose return fp32 or int8/uint8 data. - Defaults to False. - - Returns: - output: qdq weight - """ - maxq = torch.tensor(2**num_bits - 1) - zeros = torch.zeros(weight.shape[0], device=weight.device) - wmin = torch.minimum(weight.min(1)[0], zeros) - wmax = torch.maximum(weight.max(1)[0], zeros) - wmin = wmin * quantile - wmax = wmax * quantile - tmp = (wmin == 0) & (wmax == 0) - wmin[tmp] = -1 - wmax[tmp] = +1 - scale = (wmax - wmin) / maxq - zp = torch.round(-wmin / scale) - scale.unsqueeze_(dim=-1) - zp.unsqueeze_(dim=-1) - q = torch.clamp(torch.round(weight / scale) + zp, 0, maxq) - double_quant = kwargs.get("double_quant", False) - if return_int or double_quant: - return q, scale, zp - return scale * (q - zp) - - -def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_range=False, **kwargs): - """Quant and dequant tensor with sym schema. - - Args: - weight : input weight - num_bits (int, optional): num_bits. Defaults to 4. - quantile (float, optional): percentile of clip. Defaults to 1.0. - return_int (bool, optional): Choose return fp32 or int8/uint8 data. - Defaults to False. - full_range (bool, optional): Choose sym range whether use -2**(bits-1). - For example: 4 bit - scale = amax / 8 if full_range else amax / 7 - If True, scale = -scale if abs(min)> abs(max) else scale - Defaults to False. - - Returns: - output: qdq weight - """ - # assert num_bits > 1, "symmetric scheme only supports num_bits > 1" - maxq = torch.tensor(2 ** (num_bits - 1) - 1).to(weight.device) - minq = torch.tensor(-(2 ** (num_bits - 1))).to(weight.device) - if num_bits == 1: # pragma: no cover - maxq = torch.tensor(2 ** (num_bits - 1)) - minq = torch.tensor(2 ** (num_bits - 1) - 1) - max_val = torch.max(weight, 1)[0] - min_val = torch.min(weight, 1)[0] - flip_flag = torch.abs(max_val) > torch.abs(min_val) - wmax = torch.max(torch.abs(max_val), torch.abs(min_val)) - wmax = wmax * quantile - tmp = wmax == 0 - wmax[tmp] = +1 - if full_range: - # use -8, 8 to make sure amax is not changed after fake quant - scale = wmax / (-minq) - tmp = scale * flip_flag.int() - scale -= 2 * tmp # set negetive scale with flip_flag - else: - scale = wmax / maxq - scale.unsqueeze_(dim=-1) - q = torch.clamp(torch.round(weight / scale), minq, maxq) - double_quant = kwargs.get("double_quant", False) - if return_int or double_quant: - return q, scale, None - return scale * q - - -def qdq_weight_actor( - weight, num_bits, scheme, quantile=1.0, data_type="int", return_int=False, full_range=False, **kwargs -): - """Quant and dequant tensor per channel. - - Args: - weight : input weight - num_bits (int, optional): num_bits. Defaults to 4. - quantile (float, optional): percentile of clip. Defaults to 1.0. - data_type (str, optional): select from int, nf4, fp4. Defaults to int. - return_int (bool, optional): Choose return fp32 or int8/uint8 data. - Defaults to False. - full_range (bool, optional): Choose sym range whether use -2**(bits-1). - - Returns: - output: qdq weight - """ - assert num_bits > 0, "num_bits should be larger than 0" - - if data_type in FLOAT_MAPPING.keys(): - return quantize_4bit(weight, quantile=quantile, data_type=data_type, return_int=return_int, **kwargs) - if scheme == "sym": - return qdq_weight_sym(weight, num_bits, quantile, return_int, full_range, **kwargs) - else: - return qdq_weight_asym(weight, num_bits, quantile, return_int, **kwargs) - - -def quant_weight( - weight, - num_bits=4, - group_size=-1, - scheme="asym", - quantile=1.0, - data_type="int", - return_int=False, - full_range=False, - **kwargs, -): - """Quant and dequant tensor with group size. - - Args: - weight: input weight - num_bits (int, optional): num_bits. Defaults to 4. - group_size (int, optional): how many elements share one scale/zp. Defaults to -1. - scheme (str, optional): sym or asym. Defaults to "asym". - quantile (float, optional): percentile of clip. Defaults to 1.0. - data_type (str, optional): select from int, nf4, fp4. Defaults to int. - return_int (bool, optional): Choose return fp32 or int8/uint8 data. - Defaults to False. - full_range (bool, optional): Choose sym range whether use -2**(bits-1). - - Returns: - output: qdq weight. - """ - double_quant = kwargs.get("double_quant", False) - if num_bits <= 0: # pragma: no cover - return weight - # case 1, group size = -1 - if group_size == -1 or weight.shape[1] < group_size: - group_size = weight.shape[1] - # case 2, reshape based on group size - orig_shape = weight.shape - if weight.shape[1] % group_size == 0: - weight = weight.reshape(-1, group_size) - weight = qdq_weight_actor( - weight, - num_bits, - scheme=scheme, - quantile=quantile, - return_int=return_int, - full_range=full_range, - data_type=data_type, - **kwargs, - ) - if return_int or double_quant: - weight, scale, zp = weight - weight = weight.reshape(orig_shape) - scale = scale.reshape(orig_shape[0], -1) - if zp is not None: - zp = zp.reshape(orig_shape[0], -1) - q_state = weight, scale, zp - else: - return weight.reshape(orig_shape) - else: - # case 3, process left part split by group size - split_index = weight.shape[1] // group_size * group_size - weight1 = weight[:, :split_index] - weight1 = weight1.reshape(-1, group_size) - weight1 = qdq_weight_actor( - weight1, - num_bits, - scheme=scheme, - quantile=quantile, - return_int=return_int, - full_range=full_range, - data_type=data_type, - **kwargs, - ) - if return_int or double_quant: - weight1, scale1, zp1 = weight1 - scale1 = scale1.reshape(orig_shape[0], -1) - if zp1 is not None: - zp1 = zp1.reshape(orig_shape[0], -1) - weight1 = weight1.reshape(orig_shape[0], split_index) - weight2 = weight[:, split_index:] - weight2 = qdq_weight_actor( - weight2, - num_bits, - scheme=scheme, - data_type=data_type, - quantile=quantile, - return_int=return_int, - full_range=full_range, - **kwargs, - ) - if return_int or double_quant: - weight2, scale2, zp2 = weight2 - weight = torch.cat([weight1, weight2], dim=1) - scale = torch.cat([scale1, scale2], dim=1) - zp = None if zp2 is None else torch.cat([zp1, zp2], dim=1) - q_state = (weight, scale, zp) - else: - weight = torch.cat([weight1, weight2], dim=1) - return weight - if double_quant: - weight, scale, zp = q_state - double_quant_dtype = kwargs.get("double_quant_dtype", "fp32") - double_quant_num_bits = kwargs.get("double_quant_num_bits", 8) - double_quant_scheme = kwargs.get("double_quant_scheme", "sym") - double_quant_group_size = kwargs.get("double_quant_group_size", 256) - double_quant_return_int = kwargs.get("double_quant_return_int", return_int) - # process scale - orig_scale_shape = scale.shape - scale = scale.reshape(1, -1) - scale = quant_weight( - scale, - double_quant_num_bits, - double_quant_group_size, - scheme=double_quant_scheme, - quantile=1.0, - data_type=double_quant_dtype, - return_int=double_quant_return_int, - full_range=False, - double_quant=False, - ) - if return_int: - if double_quant_return_int: - scale, hyper_scale, hyper_zp = scale - scale = scale.reshape(orig_scale_shape) - return weight, (scale, hyper_scale, hyper_zp), zp - else: - scale = scale.reshape(orig_scale_shape) - return weight, scale, zp - else: - scale = scale.reshape(orig_scale_shape) - if weight.shape[1] % group_size != 0: - if zp is not None: - weight1 = weight1.reshape(-1, group_size) - zp[:, :-1].reshape(-1, 1) - weight2 = weight2 - zp[:, -1].reshape(-1, 1) - else: - weight1 = weight1.reshape(-1, group_size) - weight1 = weight1 * scale[:, :-1].reshape(-1, 1) - weight1 = weight1.reshape(orig_shape[0], -1) - weight2 = weight2 * scale[:, -1].reshape(-1, 1) - weight = torch.cat([weight1, weight2], dim=1) - else: - if zp is not None: - weight = weight.reshape(-1, group_size) - zp.reshape(-1, 1) - else: - weight = weight.reshape(-1, group_size) - weight = weight * scale.reshape(-1, 1) - weight = weight.reshape(orig_shape[0], -1) - return weight - else: - return q_state - - -def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", enable_full_range=False): - """Search best clip range of each linears in current block. - - Args: - m (torch.nn.Module): torch module. - num_bits (int, optional): num bits. - group_size (int, optional): how many elements share one scale/zp. - scheme (str, optional): sym or asym. - data_type (str, optional): select from int, nf4, fp4. Defaults to int. - enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1). - - Returns: - best_clip_ratio (float): best percentile of clip - """ - org_weight = m.weight.data - logger.info("Searching the best clip range with RTN algorithm") - best_error = float("inf") - best_clip_ratio = None - n_grid = 200 - max_shrink = 0.2 - history = [] - for i_s in range(int(max_shrink * n_grid)): - ratio = 1 - i_s / n_grid # 1, 0.805-1.0 - cur_weight = quant_weight( - m.weight.data, - num_bits=num_bits, - group_size=group_size, - scheme=scheme, - data_type=data_type, - full_range=enable_full_range, - quantile=ratio, - ) - loss = (org_weight - cur_weight).float().pow(2).mean().item() - history.append(loss) - is_best = loss < best_error - if is_best: - best_error = loss - best_clip_ratio = ratio - logger.debug("The loss history of different clip range:{}".format(history)) - logger.debug("The best clip ratio is {}".format(best_clip_ratio)) - return best_clip_ratio +from .utility import quant_tensor, search_clip +@torch.no_grad() def rtn_quantize( model, num_bits=4, @@ -405,13 +36,13 @@ def rtn_quantize( quantile=1.0, weight_config={}, return_int=False, - data_type="int", + dtype="int", enable_full_range=False, enable_mse_search=False, group_dim=1, **kwargs, ): - """Quant the model with round to nearst method. + """Quant the model with round to nearest method. Args: model: torch module @@ -419,7 +50,7 @@ def rtn_quantize( group_size (int, optional): how many elements share one scale/zp. Defaults to 32. scheme (str, optional): sym or asym. Defaults to "asym". quantile (float, optional): percentile of clip. Defaults to 1.0. - data_type (str, optional): select from int, nf4, fp4. Defaults to int. + dtype (str, optional): select from int, nf4, fp4. Defaults to int. weight_config (dict, optional): specific layer wise configurations. Defaults to {}. For example, weight_config={ @@ -463,7 +94,7 @@ def rtn_quantize( if m.__class__.__name__ not in supported_layers: continue if name in weight_config: # pragma: no cover - data_type = weight_config[name].get("dtype", "int") + dtype = weight_config[name].get("dtype", "int") num_bits = weight_config[name]["bits"] group_size = weight_config[name]["group_size"] scheme = weight_config[name]["scheme"] @@ -472,25 +103,25 @@ def rtn_quantize( f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, " + f"scheme={scheme}, quantile={quantile}" ) - if data_type != "int": - log_msg += f", dtype={data_type}" + if dtype != "int": + log_msg += f", dtype={dtype}" elif scheme == "sym": # nf4/fp4 is always [-7,7] log_msg += f", enable_full_range={enable_full_range}" - if data_type == "fp32": + if dtype == "fp32": continue logger.debug(f"RTN quantized module:{name, m}") logger.debug(log_msg) weight = m.weight.T if group_dim == 0 else m.weight if enable_mse_search: - quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range) + quantile = search_clip(m, num_bits, group_size, scheme, dtype, enable_full_range) if return_int: - int_weight, scale, zp = quant_weight( + int_weight, scale, zp = quant_tensor( weight, num_bits, group_size, scheme, quantile, - data_type=data_type, + dtype=dtype, return_int=True, full_range=enable_full_range, **double_quant_config, @@ -505,7 +136,7 @@ def rtn_quantize( m.out_features, num_bits, group_size, - dtype=data_type, + dtype=dtype, zp=zp is not None, bias=m.bias is not None, compression_dtype=compression_dtype, @@ -519,13 +150,13 @@ def rtn_quantize( else: set_module(model, name, new_module) else: - q_weight = quant_weight( + q_weight = quant_tensor( weight, num_bits, group_size, scheme, quantile, - data_type=data_type, + dtype=dtype, full_range=enable_full_range, **double_quant_config, ) @@ -534,52 +165,6 @@ def rtn_quantize( return model -def quant_weight_w_scale(weight, scale, zp, group_size=-1, dtype="int"): - """Quant and dequant tensor with group size. - - Args: - weight: input weight - scale: scale - zp: zero point - group_size (int, optional): how many elements share one scale/zp. Defaults to -1. - dtype: data type, for NF4 FP4 - - Returns: - output: int weight. - """ - device = weight.device - scale = scale.to(device) - # NF4 FP4 - if dtype in FLOAT_MAPPING.keys(): - int_weight = quantize_4bit( - weight, - quantile=1.0, - data_type=dtype, - return_int=True, - scale=scale, - )[0] - return int_weight - # INT - if zp is not None: - zp = zp.to(device) - if group_size == -1: - return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp) - int_weight = torch.zeros(weight.shape).to(device) - leng = weight.shape[1] // group_size - tail_flag = False if weight.shape[1] % group_size == 0 else True - for i in range(leng): - int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1) - if zp is not None: - int_weight_tmp += zp[:, i].unsqueeze(1) - int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp) - if tail_flag: - int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1) - if zp is not None: - int_weight_tmp += zp[:, -1].unsqueeze(1) - int_weight[:, leng * group_size :] = torch.round(int_weight_tmp) - return int_weight - - from neural_compressor.torch.quantization.config import RTNConfig @@ -603,7 +188,7 @@ def apply_rtn_on_single_module(module: torch.nn.Module, quant_config: RTNConfig) group_size, scheme, return_int=return_int, - data_type=dtype, + dtype=dtype, enable_full_range=enable_full_range, enable_mse_search=enable_mse_search, group_dim=group_dim, diff --git a/neural_compressor/torch/algorithms/weight_only/utility.py b/neural_compressor/torch/algorithms/weight_only/utility.py new file mode 100644 index 00000000000..fa9633adb67 --- /dev/null +++ b/neural_compressor/torch/algorithms/weight_only/utility.py @@ -0,0 +1,435 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from neural_compressor.torch.utils import logger + +NF4 = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, +] +FP4_BNB = [-12.0, -8.0, -6.0, -4.0, -3.0, -2.0, -0.0625, 0, 0.0625, 2.0, 3.0, 4.0, 6.0, 8.0, 12.0] +FP4_E2M1 = [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.0625, 0, 0.0625, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + +# the order is the same as float list, bit value range is [-7, 7] +# 1111 = -1, 1110 = -2, 1101= -3, ... + +NF4_BIT = [7, 1, 2, 3, 4, 5, 6, 0, -8, -7, -6, -5, -4, -3, -2, -1] +FP4_BNB_BIT = [-5, -6, -3, -4, -1, -2, -7, 0, 1, 6, 7, 4, 5, 2, 3] +FP4_E2M1_BIT = [-1, -2, -3, -4, -5, -6, -7, 0, 1, 2, 3, 4, 5, 6, 7] + +FLOAT_MAPPING = {"nf4": NF4, "fp4": FP4_BNB, "fp4_e2m1_bnb": FP4_BNB, "fp4_e2m1": FP4_E2M1} +INT_MAPPING = {"nf4": NF4_BIT, "fp4": FP4_BNB_BIT, "fp4_e2m1_bnb": FP4_BNB_BIT, "fp4_e2m1": FP4_E2M1_BIT} + + +def quantize_4bit(tensor, quantile=1.0, dtype="nf4", return_int=False, **kwargs): + """Quantize tensor to NF4/FP4 data type. + + Args: + tensor: input tensor + quantile (float, optional): percentile of clip. Defaults to 1.0. + dtype (str, optional): data type. Defaults to 'nf4'. + return_int (bool, optional): whether return int data. Defaults to False. + + Returns: + q_tensor: fake quantized tensor + """ + assert dtype in FLOAT_MAPPING, "unexpected data type." + allow_data = FLOAT_MAPPING[dtype] + allow_data_bit = INT_MAPPING[dtype] + # get scale and update tensor + if "scale" in kwargs: + scale = kwargs["scale"] + else: + scale = tensor.abs().max(1)[0] * quantile / max(allow_data) + scale.unsqueeze_(dim=-1) + tensor.div_(scale) + mid_data = [(allow_data[i] + allow_data[i + 1]) / 2 for i in range(len(allow_data) - 1)] + q_tensor = torch.zeros_like(tensor) + for i in range(len(allow_data)): + data = allow_data_bit[i] if return_int else allow_data[i] + if i == 0: + q_tensor += torch.where(tensor <= mid_data[i], data, 0) + elif i == len(allow_data) - 1: + q_tensor += torch.where(tensor > mid_data[i - 1], data, 0) + else: + q_tensor += torch.where((mid_data[i - 1] < tensor) & (tensor <= mid_data[i]), data, 0) + tensor.copy_(q_tensor) + double_quant = kwargs.get("double_quant", False) + if return_int or double_quant: + return tensor, scale, None + return tensor.mul_(scale) + + +def qdq_weight_asym(weight, bits=4, quantile=1.0, return_int=False, **kwargs): + """Quant and dequant tensor with asym schema. + + Args: + weight: input weight + bits (int, optional): bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + + Returns: + output: qdq weight + """ + maxq = torch.tensor(2**bits - 1) + zeros = torch.zeros(weight.shape[0], device=weight.device) + wmin = torch.minimum(weight.min(1)[0], zeros) + wmax = torch.maximum(weight.max(1)[0], zeros) + wmin = wmin * quantile + wmax = wmax * quantile + tmp = (wmin == 0) & (wmax == 0) + wmin[tmp] = -1 + wmax[tmp] = +1 + scale = (wmax - wmin) / maxq + zp = torch.round(-wmin / scale) + scale.unsqueeze_(dim=-1) + zp.unsqueeze_(dim=-1) + weight.div_(scale) + weight.round_() + weight.clamp_(0, maxq) + double_quant = kwargs.get("double_quant", False) + if return_int or double_quant: + return weight, scale, zp + weight.sub_(zp) + return weight.mul_(scale) + + +def qdq_weight_sym(weight, bits=4, quantile=1.0, return_int=False, full_range=False, **kwargs): + """Quant and dequant tensor with sym schema. + + Args: + weight : input weight + bits (int, optional): bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). + For example: 4 bit + scale = amax / 8 if full_range else amax / 7 + If True, scale = -scale if abs(min)> abs(max) else scale + Defaults to False. + + Returns: + output: qdq weight + """ + # assert bits > 1, "symmetric scheme only supports bits > 1" + maxq = torch.tensor(2 ** (bits - 1) - 1).to(weight.device) + minq = torch.tensor(-(2 ** (bits - 1))).to(weight.device) + if bits == 1: # pragma: no cover + maxq = torch.tensor(2 ** (bits - 1)) + minq = torch.tensor(2 ** (bits - 1) - 1) + max_val = torch.max(weight, 1)[0] + min_val = torch.min(weight, 1)[0] + flip_flag = torch.abs(max_val) > torch.abs(min_val) + wmax = torch.max(torch.abs(max_val), torch.abs(min_val)) + wmax = wmax * quantile + tmp = wmax == 0 + wmax[tmp] = +1 + if full_range: + # use -8, 8 to make sure amax is not changed after fake quant + scale = wmax / (-minq) + tmp = scale * flip_flag.int() + scale -= 2 * tmp # set negetive scale with flip_flag + else: + scale = wmax / maxq + scale.unsqueeze_(dim=-1) + weight.div_(scale) + weight.round_() + weight.clamp_(minq, maxq) + double_quant = kwargs.get("double_quant", False) + if return_int or double_quant: + return weight, scale, None + return weight.mul_(scale) + + +def qdq_weight_actor(weight, bits, scheme, quantile=1.0, dtype="int", return_int=False, full_range=False, **kwargs): + """Quant and dequant tensor per channel. It is an in-place op. + + Args: + weight : input weight + bits (int, optional): bits. Defaults to 4. + quantile (float, optional): percentile of clip. Defaults to 1.0. + dtype (str, optional): select from int, nf4, fp4. Defaults to int. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). + + Returns: + output: qdq weight + """ + assert bits > 0, "bits should be larger than 0" + + if dtype in FLOAT_MAPPING.keys(): + return quantize_4bit(weight, quantile=quantile, dtype=dtype, return_int=return_int, **kwargs) + if scheme == "sym": + return qdq_weight_sym(weight, bits, quantile, return_int, full_range, **kwargs) + else: + return qdq_weight_asym(weight, bits, quantile, return_int, **kwargs) + + +def quant_tensor( + weight, + bits=4, + group_size=-1, + scheme="asym", + quantile=1.0, + dtype="int", + return_int=False, + full_range=False, + **kwargs, +): + """Quant and dequant tensor with group size. + + Args: + weight: input weight + bits (int, optional): bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to -1. + scheme (str, optional): sym or asym. Defaults to "asym". + quantile (float, optional): percentile of clip. Defaults to 1.0. + dtype (str, optional): select from int, nf4, fp4. Defaults to int. + return_int (bool, optional): Choose return fp32 or int8/uint8 data. + Defaults to False. + full_range (bool, optional): Choose sym range whether use -2**(bits-1). + + Returns: + output: qdq weight. + """ + double_quant = kwargs.get("double_quant", False) + if bits <= 0: # pragma: no cover + return weight + # case 1, group size = -1 + if group_size == -1 or weight.shape[1] < group_size: + group_size = weight.shape[1] + # case 2, reshape based on group size + orig_shape = weight.shape + if weight.shape[1] % group_size == 0: + weight = weight.reshape(-1, group_size) + weight = qdq_weight_actor( + weight, + bits, + scheme=scheme, + quantile=quantile, + return_int=return_int, + full_range=full_range, + dtype=dtype, + **kwargs, + ) + if return_int or double_quant: + weight, scale, zp = weight + weight = weight.reshape(orig_shape) + scale = scale.reshape(orig_shape[0], -1) + if zp is not None: + zp = zp.reshape(orig_shape[0], -1) + q_state = weight, scale, zp + else: + return weight.reshape(orig_shape) + else: + # case 3, process left part split by group size + split_index = weight.shape[1] // group_size * group_size + weight1 = weight[:, :split_index] + weight1 = weight1.reshape(-1, group_size) + weight1 = qdq_weight_actor( + weight1, + bits, + scheme=scheme, + quantile=quantile, + return_int=return_int, + full_range=full_range, + dtype=dtype, + **kwargs, + ) + if return_int or double_quant: + weight1, scale1, zp1 = weight1 + scale1 = scale1.reshape(orig_shape[0], -1) + if zp1 is not None: + zp1 = zp1.reshape(orig_shape[0], -1) + weight1 = weight1.reshape(orig_shape[0], split_index) + weight2 = weight[:, split_index:] + weight2 = qdq_weight_actor( + weight2, + bits, + scheme=scheme, + dtype=dtype, + quantile=quantile, + return_int=return_int, + full_range=full_range, + **kwargs, + ) + if return_int or double_quant: + weight2, scale2, zp2 = weight2 + weight.copy_(torch.cat([weight1, weight2], dim=1)) + scale = torch.cat([scale1, scale2], dim=1) + zp = None if zp2 is None else torch.cat([zp1, zp2], dim=1) + q_state = (weight, scale, zp) + else: + weight.copy_(torch.cat([weight1, weight2], dim=1)) + return weight + if double_quant: + weight, scale, zp = q_state + double_quant_dtype = kwargs.get("double_quant_dtype", "fp32") + double_quant_bits = kwargs.get("double_quant_bits", 8) + double_quant_scheme = kwargs.get("double_quant_scheme", "sym") + double_quant_group_size = kwargs.get("double_quant_group_size", 256) + double_quant_return_int = kwargs.get("double_quant_return_int", return_int) + # process scale + orig_scale_shape = scale.shape + scale = scale.reshape(1, -1) + scale = quant_tensor( + scale, + dtype=double_quant_dtype, + bits=double_quant_bits, + group_size=double_quant_group_size, + scheme=double_quant_scheme, + quantile=1.0, + return_int=double_quant_return_int, + full_range=False, + double_quant=False, + ) + if return_int: + if double_quant_return_int: + scale, hyper_scale, hyper_zp = scale + scale = scale.reshape(orig_scale_shape) + return weight, (scale, hyper_scale, hyper_zp), zp + else: + scale = scale.reshape(orig_scale_shape) + return weight, scale, zp + else: + scale = scale.reshape(orig_scale_shape) + if weight.shape[1] % group_size != 0: + if zp is not None: + weight1 = weight1.reshape(-1, group_size).sub_(zp[:, :-1].reshape(-1, 1)) + weight2 = weight2.sub_(zp[:, -1].reshape(-1, 1)) + else: + weight1 = weight1.reshape(-1, group_size) + weight1 = weight1.mul_(scale[:, :-1].reshape(-1, 1)) + weight1 = weight1.reshape(orig_shape[0], -1) + weight2 = weight2.mul_(scale[:, -1].reshape(-1, 1)) + weight = torch.cat([weight1, weight2], dim=1) + else: + if zp is not None: + weight = weight.reshape(-1, group_size) - zp.reshape(-1, 1) + else: + weight = weight.reshape(-1, group_size) + weight = weight.mul_(scale.reshape(-1, 1)) + weight = weight.reshape(orig_shape[0], -1) + return weight + else: + return q_state + + +def search_clip(m, bits=4, group_size=32, scheme="asym", dtype="int", enable_full_range=False): + """Search best clip range of each linear in current block. + + Args: + m (torch.nn.Module): torch module. + bits (int, optional): num bits. + group_size (int, optional): how many elements share one scale/zp. + scheme (str, optional): sym or asym. + dtype (str, optional): select from int, nf4, fp4. Defaults to int. + enable_full_range (bool, optional): Choose sym range whether use -2**(bits-1). + + Returns: + best_clip_ratio (float): best percentile of clip + """ + org_weight = m.weight.data.clone() + logger.info("Searching the best clip range with RTN algorithm") + best_error = float("inf") + best_clip_ratio = None + n_grid = 200 + max_shrink = 0.2 + history = [] + for i_s in range(int(max_shrink * n_grid)): + ratio = 1 - i_s / n_grid # 1, 0.805-1.0 + cur_weight = quant_tensor( + m.weight.data, + dtype=dtype, + bits=bits, + group_size=group_size, + scheme=scheme, + full_range=enable_full_range, + quantile=ratio, + ) + loss = (org_weight - cur_weight).float().pow(2).mean().item() + history.append(loss) + is_best = loss < best_error + if is_best: + best_error = loss + best_clip_ratio = ratio + logger.debug("The loss history of different clip range:{}".format(history)) + logger.debug("The best clip ratio is {}".format(best_clip_ratio)) + return best_clip_ratio + + +def quant_weight_w_scale(weight, scale, zp, group_size=-1, dtype="int"): + """Quant and dequant tensor with group size. + + Args: + weight: input weight + scale: scale + zp: zero point + group_size (int, optional): how many elements share one scale/zp. Defaults to -1. + dtype: data type, for NF4 FP4 + + Returns: + output: int weight. + """ + device = weight.device + scale = scale.to(device) + # NF4 FP4 + if dtype in FLOAT_MAPPING.keys(): + int_weight = quantize_4bit( + weight, + quantile=1.0, + dtype=dtype, + return_int=True, + scale=scale, + )[0] + return int_weight + # INT + if zp is not None: + zp = zp.to(device) + if group_size == -1: + return weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_() + int_weight = torch.zeros(weight.shape).to(device) + leng = weight.shape[1] // group_size + tail_flag = False if weight.shape[1] % group_size == 0 else True + for i in range(leng): + int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(scale[:, i].unsqueeze(1)) + if zp is not None: + int_weight_tmp.add_(zp[:, i].unsqueeze(1)) + int_weight[:, i * group_size : (i + 1) * group_size].copy_(int_weight_tmp.round_()) + if tail_flag: + int_weight_tmp = weight[:, leng * group_size :].div_(scale[:, -1].unsqueeze(1)) + if zp is not None: + int_weight_tmp.add_(zp[:, -1].unsqueeze(1)) + int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_()) + return int_weight diff --git a/neural_compressor/torch/quantization/layers.py b/neural_compressor/torch/quantization/layers.py index 66e1df14aaa..b1ce99d3a59 100644 --- a/neural_compressor/torch/quantization/layers.py +++ b/neural_compressor/torch/quantization/layers.py @@ -26,7 +26,7 @@ from torch.nn import functional as F from neural_compressor.common import DEBUG, level, logger -from neural_compressor.torch.algorithms.weight_only.rtn import quant_weight +from neural_compressor.torch.algorithms.weight_only import quant_tensor def get_torch_version(): @@ -164,7 +164,7 @@ def __init__( self.use_optimum_format = use_optimum_format self.dtype = dtype if "int" not in self.dtype: # for nf4, fp4 - from neural_compressor.torch.algorithms.weight_only.rtn import FLOAT_MAPPING, INT_MAPPING + from neural_compressor.torch.algorithms.weight_only import FLOAT_MAPPING, INT_MAPPING float_list = FLOAT_MAPPING[self.dtype] int_list = INT_MAPPING[self.dtype] @@ -200,7 +200,6 @@ def __init__( dtype=self.float_type, ).to(device), ) - self.scales = self.scales.T self.register_buffer( "qweight", torch.zeros( @@ -208,7 +207,6 @@ def __init__( dtype=self.compression_dtype, ).to(device), ) - self.qweight = self.qweight.T self.register_buffer( "qzeros", torch.zeros( @@ -216,7 +214,6 @@ def __init__( dtype=self.compression_dtype, ).to(device), ) - self.qzeros = self.qzeros.T self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) else: self.compression_dtype = compression_dtype @@ -270,6 +267,10 @@ def __init__( self.g_idx = None def pack(self, int_weight, scale, zp, bias, g_idx=None): + if self.use_optimum_format: + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() int_weight = int_weight.to(self.device) if self.use_optimum_format and zp is None: # to avoid overflow @@ -290,8 +291,8 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): assert scale.shape == self.scales.shape, "Scale shape is mismatched." self.scales = scale.type(self.float_type).to(self.device) if not self.use_optimum_format and self.compression_dim == 0: - int_weight = int_weight.T - self.qweight = self.qweight.T + int_weight = int_weight.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() origin_shape = int_weight.shape target_shape = self.qweight.shape assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." @@ -307,15 +308,15 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qweight[:, j] |= tmp[:, e] if not self.use_optimum_format and self.compression_dim == 0: - self.qweight = self.qweight.T + self.qweight = self.qweight.t_().contiguous() if zp is not None: zp = zp.to(self.device) if self.use_optimum_format: zp -= 1 if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - self.qzeros = self.qzeros.T + zp = zp.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() assert hasattr(self, "qzeros"), "zp is not set when initializing." target_shape = self.qzeros.shape for j in range(target_shape[1]): @@ -327,16 +328,16 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None): tmp[:, e] = tmp[:, e] << (self.bits * e) self.qzeros[:, j] |= tmp[:, e] if self.use_optimum_format or self.compression_dim == 0: - self.qzeros = self.qzeros.T + self.qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format: - self.scales = self.scales.T - self.qweight = self.qweight.T - self.qzeros = self.qzeros.T + self.scales = self.scales.t_().contiguous() + self.qweight = self.qweight.t_().contiguous() + self.qzeros = self.qzeros.t_().contiguous() def recover(self): logger.debug(f"Recovering {self} weight") - scales = self.scales.T if self.use_optimum_format else self.scales - qweight = self.qweight.T if self.use_optimum_format else self.qweight + scales = self.scales.t_().contiguous() if self.use_optimum_format else self.scales + qweight = self.qweight.t_().contiguous() if self.use_optimum_format else self.qweight device = scales.device fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) @@ -351,8 +352,8 @@ def recover(self): # unpack weight weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T - qweight = qweight.T + weight = weight.t_().contiguous() + qweight = qweight.t_().contiguous() origin_shape = weight.shape target_shape = qweight.shape for j in range(target_shape[1]): @@ -367,7 +368,7 @@ def recover(self): tmp &= mask # remove sign bit weight[:, index] = tmp.type(weight_dtype) if not self.use_optimum_format and self.compression_dim == 0: - weight = weight.T + weight = weight.t_().contiguous() if "int" not in self.dtype: new_weight = torch.zeros(self.out_features, self.in_features).to(device) for k, v in self.int2float_mapping.items(): @@ -377,10 +378,10 @@ def recover(self): if hasattr(self, "qzeros"): zp_dtype = self.compression_dtype # to avoid overflow when weight-zp zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) - qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros + qzeros = self.qzeros.t_().contiguous() if self.use_optimum_format else self.qzeros if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T - qzeros = qzeros.T + zp = zp.t_().contiguous() + qzeros = qzeros.t_().contiguous() origin_shape = zp.shape target_shape = qzeros.shape for j in range(target_shape[1]): @@ -394,7 +395,7 @@ def recover(self): tmp &= mask zp[:, index] = tmp.type(zp_dtype) if self.use_optimum_format or self.compression_dim == 0: - zp = zp.T + zp = zp.t_().contiguous() if self.use_optimum_format: # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 zp += 1 @@ -409,12 +410,13 @@ def recover(self): return fp32_weight def forward(self, input): - weight = self.recover() - device = self.scales.device - if weight.dtype == torch.float16 and device.type == "cpu": - weight = weight.float() - self.bias = self.bias.float() if self.bias is not None else None - if level == DEBUG: + if not hasattr(self, "weight"): + weight = self.recover() + device = self.scales.device + if weight.dtype == torch.float16 and device.type == "cpu": + weight = weight.float() + self.bias = self.bias.float() if self.bias is not None else None + if True: # keep reusing self.weight due to recover is too slow. if not hasattr(self, "weight"): self.weight = weight input = input.type(self.weight.dtype) @@ -454,7 +456,7 @@ def forward(ctx, inputs, num_bits=4, group_size=1024, scheme="asym"): Returns: outputs: A Tensor of type output_dtype """ - return quant_weight(inputs, num_bits, group_size, scheme) + return quant_tensor(inputs, num_bits, group_size, scheme) @staticmethod def backward(ctx, grad_outputs): diff --git a/neural_compressor/torch/utils/__init__.py b/neural_compressor/torch/utils/__init__.py index 8989ae9d722..a8e2bb95c8a 100644 --- a/neural_compressor/torch/utils/__init__.py +++ b/neural_compressor/torch/utils/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .utility import *