From eab1a4accf56bd01b3050b091fb74cfb42bf3c6e Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Thu, 12 Sep 2024 18:54:01 +0800 Subject: [PATCH] Replace itrex qbits to ipex woq linear (#549) Co-authored-by: Casper --- README.md | 35 ++++--- awq/models/auto.py | 4 +- awq/models/base.py | 51 +++++----- awq/modules/fused/attn.py | 48 ++++++---- awq/modules/fused/block.py | 1 + awq/modules/fused/norm.py | 23 +++-- awq/modules/linear/__init__.py | 2 +- awq/modules/linear/gemm_ipex.py | 110 ++++++++++++++++++++++ awq/modules/linear/gemm_qbits.py | 155 ------------------------------- awq/utils/fused_utils.py | 10 +- awq/utils/utils.py | 2 +- docs/examples.md | 4 +- docs/index.md | 5 +- examples/benchmark.py | 21 ++--- setup.py | 2 +- tests/test_ipex_cpu.py | 59 ++++++++++++ tests/test_qbits_cpu.py | 73 --------------- 17 files changed, 279 insertions(+), 326 deletions(-) create mode 100644 awq/modules/linear/gemm_ipex.py delete mode 100644 awq/modules/linear/gemm_qbits.py create mode 100644 tests/test_ipex_cpu.py delete mode 100644 tests/test_qbits_cpu.py diff --git a/README.md b/README.md index c3b2757c..7e16de7d 100644 --- a/README.md +++ b/README.md @@ -228,26 +228,25 @@ GPU: 2x NVIDIA GeForce RTX 4090 ### CPU -- CPU: INTEL(R) XEON(R) PLATINUM 8592+ with 8-channel 4800MT/s memory. +- CPU: 48 cores SPR (Intel 4th Gen Xeon CPU) - Command: `python examples/benchmark.py --model_path --batch_size 1` -| Model | Size | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (RAM) | -|--------:|------:|-----------:|-------------:|-----------------:|----------------:|---------------:|:------------------| -| Mixtral | 7B | 1 | 64 | 64 | 389.24 | 16.01 | 5.59 GB (0.02%) | -| Mixtral | 7B | 1 | 2048 | 2048 | 1412 | 17.76 | 6.29 GB (0.03%) | -| Vicuna | 7B | 1 | 64 | 64 | 346 | 18.13 | 8.18 GB (0.03%) | -| Vicuna | 7B | 1 | 2048 | 2048 | 1023.4 | 18.18 | 8.80 GB (0.04%) | -| LLaMA2 | 13B | 1 | 64 | 64 | 160.24 | 9.87 | 14.65 GB (0.06%) | -| LLaMA2 | 13B | 1 | 2048 | 2048 | 592.35 | 9.93 | 16.87 GB (0.07%) | -| Mosaicml | 7B | 1 | 64 | 64 | 433.17 | 18.79 | 4.60 GB (0.02%) | -| Mosaicml | 7B | 1 | 2048 | 2048 | 404.25 | 19.91 | 4.75 GB (0.02%) | -| Falcon | 7B | 1 | 64 | 64 | 303.16 | 14.41 | 5.18 GB (0.02%) | -| Falcon | 7B | 1 | 2048 | 2048 | 634.57 | 15.55 | 5.80 GB (0.02%) | -| CodeLlama | 34B | 1 | 64 | 64 | 153.73 | 4.23 | 29.00 GB (0.12%) | -| CodeLlama | 34B | 1 | 2048 | 2048 | 274.25 | 4.38 | 35.21 GB (0.15%) | -| Deepseek-coder | 33B | 1 | 64 | 64 | 83.08 | 4.07 | 22.16 GB (0.09%) | -| Deepseek-coder | 33B | 1 | 2048 | 2048 | 296.04 | 4.33 | 37.05 GB | - +| Model | Version | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory | +|-------|---------|------------|----------------|---------------|-------------------|------------------|---------------| +| Llama 2 7B | gemm | 1 | 32 | 32 | 817.86 | 70.93 | 1.94 GB (0.00%) | +| Llama 2 7B | gemm | 1 | 2048 | 2048 | 5279.15 | 36.83 | 2.31 GB (0.00%) | +| Falcon | gemm | 1 | 32 | 32 | 337.51 | 26.41 | 9.57 GB (0.01%) | +| Falcon | gemm | 1 | 2048 | 2048 | 546.71 | 18.8 | 13.46 GB (0.01%) | +| Mistral | gemm | 1 | 32 | 32 | 343.08 | 28.46 | 9.74 GB (0.01%) | +| Mistral | gemm | 1 | 2048 | 2048 | 1135.23 | 13.23 | 10.35 GB (0.01%) | +| Vicuna | gemm | 1 | 32 | 32 | 340.73 | 28.86 | 9.59 GB (0.01%) | +| Vicuna | gemm | 1 | 2048 | 2048 | 1143.19 | 11.14 | 10.98 GB (0.01%) | +| Llama 2 13B | gemm | 1 | 32 | 32 | 220.79 | 18.14 | 17.46 GB (0.02%) | +| Llama 2 13B | gemm | 1 | 2048 | 2048 | 650.94 | 6.54 | 19.84 GB (0.02%) | +| DeepSeek Coder 33B | gemm | 1 | 32 | 32 | 101.61 | 8.58 | 40.80 GB (0.04%) | +| DeepSeek Coder 33B | gemm | 1 | 2048 | 2048 | 245.02 | 3.48 | 41.72 GB (0.04%) | +| Phind CodeLlama 34B | gemm | 1 | 32 | 32 | 102.47 | 9.04 | 41.70 GB (0.04%) | +| Phind CodeLlama 34B | gemm | 1 | 2048 | 2048 | 237.57 | 3.48 | 42.47 GB (0.04%) | ## Reference diff --git a/awq/models/auto.py b/awq/models/auto.py index af2580d5..1ce1b21d 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -88,7 +88,7 @@ def from_quantized( fuse_layers=True, use_exllama=False, use_exllama_v2=False, - use_qbits=False, + use_ipex=False, batch_size=1, safetensors=True, device_map="balanced", @@ -116,7 +116,7 @@ def from_quantized( fuse_layers=fuse_layers, use_exllama=use_exllama, use_exllama_v2=use_exllama_v2, - use_qbits=use_qbits, + use_ipex=use_ipex, safetensors=safetensors, device_map=device_map, max_memory=max_memory, diff --git a/awq/models/base.py b/awq/models/base.py index a76bb293..f2396bb5 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -17,7 +17,7 @@ from awq.modules.linear import ( WQLinear_GEMM, WQLinear_GEMV, - WQLinear_QBits, + WQLinear_IPEX, WQLinear_Marlin, WQLinear_Exllama, WQLinear_ExllamaV2, @@ -25,7 +25,7 @@ marlin_post_init, exllama_post_init, exllamav2_post_init, - qbits_post_init, + ipex_post_init, ) from awq.utils.module import ( get_named_linears, @@ -33,7 +33,7 @@ exclude_layers_to_not_quantize, try_import, ) -from awq.utils.utils import get_best_device, qbits_available +from awq.utils.utils import get_best_device, ipex_available from transformers import ( AutoConfig, PreTrainedModel, @@ -52,9 +52,6 @@ from awq.quantize.quantizer import AwqQuantizer from awq.utils.module import get_named_linears, set_op_by_name -if qbits_available: - from intel_extension_for_transformers.qbits import check_isa_supported - # Since we support different `AutoModelForxxx` from transformers # we need to define a custom mapping dict as below: @@ -440,8 +437,8 @@ def from_quantized( use_exllama_v2: Annotated[ bool, Doc("Whether to map the weights to ExLlamaV2 kernels.") ] = False, - use_qbits: Annotated[ - bool, Doc("Whether to map the weights to qbits kernels for CPU device.") + use_ipex: Annotated[ + bool, Doc("Whether to map the weights to ipex kernels for CPU device.") ] = False, device_map: Annotated[ Union[str, Dict], @@ -494,17 +491,11 @@ def from_quantized( trust_remote_code=trust_remote_code, ) - use_cpu_qbits = use_qbits or get_best_device() == "cpu" - if use_cpu_qbits: - if not qbits_available: - raise ImportError( - "Please install intel-extension-for-transformers with " - "`pip install intel-extension-for-transformers` for 'qbits' kernel!" - ) - - fuse_layers = False - logging.warn( - "Unsupport fuse_layers featrue for CPU device with QBits backend!" + use_cpu_ipex = use_ipex or get_best_device() == "cpu" + if use_cpu_ipex and not ipex_available: + raise ImportError( + "Please install intel_extension_for_pytorch with " + "`pip install intel_extension_for_pytorch` for 'ipex' kernel!" ) # Prepare WQLinear layers, replace nn.Linear self._load_quantized_modules( @@ -514,7 +505,7 @@ def from_quantized( quant_config.version, use_exllama=use_exllama, use_exllama_v2=use_exllama_v2, - use_qbits=use_cpu_qbits, + use_ipex=use_cpu_ipex, ) model.tie_weights() @@ -539,11 +530,11 @@ def from_quantized( else: self.fuse_layers(model) - if use_cpu_qbits: - dtype = torch.bfloat16 if check_isa_supported("AMX") else torch.float32 + if use_cpu_ipex: + dtype = torch.bfloat16 model.to(dtype=dtype, device="cpu") - # repack qweight to match the QBits kernel. - model = qbits_post_init(model) + # repack qweight to match the ipex kernel. + model = ipex_post_init(model) elif quant_config.version == "marlin": model = marlin_post_init(model) elif use_exllama: @@ -631,11 +622,11 @@ def _load_config( return model_weights_path, config, quant_config def _load_quantized_modules( - self, model, quant_config, version, use_exllama, use_exllama_v2, use_qbits=False + self, model, quant_config, version, use_exllama, use_exllama_v2, use_ipex=False ): # Real quantization of weights assert not ( - version == "gemv" and (use_exllama or use_exllama_v2 or use_qbits) + version == "gemv" and (use_exllama or use_exllama_v2 or use_ipex) ), "Exllama kernels only support GEMM version." # Get blocks of model @@ -657,8 +648,8 @@ def _load_quantized_modules( # Replace nn.Linear with WQLinear for name, module in named_linears.items(): - if use_qbits: - q_linear_module = WQLinear_QBits + if use_ipex: + q_linear_module = WQLinear_IPEX elif version == "marlin": q_linear_module = WQLinear_Marlin elif use_exllama: @@ -672,7 +663,7 @@ def _load_quantized_modules( elif version == "gemv_fast": q_linear_module = WQLinear_GEMVFast - if use_qbits: + if use_ipex: q_linear = q_linear_module.from_linear( module, quant_config.w_bit, @@ -687,7 +678,7 @@ def _load_quantized_modules( q_linear.to(next(layer.parameters()).device) set_op_by_name(layer, name, q_linear) - if not use_qbits: + if not use_ipex: torch.cuda.empty_cache() gc.collect() diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index 3a8bfc6f..d4ce6c4e 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -176,6 +176,7 @@ def __init__( self.is_neox = kwargs["is_neox"] self.attn_logit_softcapping = attn_logit_softcapping + self.use_sdpa = kwargs.get("use_sdpa", False) def forward( self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs @@ -266,7 +267,6 @@ def forward( xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) # Used in Gemma2 if self.attn_logit_softcapping is not None: @@ -274,21 +274,37 @@ def forward( scores = torch.tanh(scores) scores = scores * self.attn_logit_softcapping - if self.use_alibi: - scores = self.alibi.forward(scores, seqlen) - - # When seqlen is 1, there is nothing else to attend to - if attention_mask is not None and seqlen > 1: - # For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we - # need to slice it - if attention_mask.shape[-1] != seqlen: - attention_mask = attention_mask[:, :, :seqlen, :seqlen] - - scores = ( - scores + attention_mask - ) # (bs, n_local_heads, slen, cache_len + slen) - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) + if self.use_sdpa: + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : keys.shape[-2]] + is_causal = True if causal_mask is None and seqlen > 1 else False + output = torch.nn.functional.scaled_dot_product_attention( + xq, + keys, + values, + attn_mask=causal_mask, + dropout_p=0.0, + is_causal=is_causal, + ) + else: + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if self.use_alibi: + scores = self.alibi.forward(scores, seqlen) + + # When seqlen is 1, there is nothing else to attend to + if attention_mask is not None and seqlen > 1: + # For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we + # need to slice it + if attention_mask.shape[-1] != seqlen: + attention_mask = attention_mask[:, :, :seqlen, :seqlen] + + scores = ( + scores + attention_mask + ) # (bs, n_local_heads, slen, cache_len + slen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) + attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) else: xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"]) diff --git a/awq/modules/fused/block.py b/awq/modules/fused/block.py index 9d7889b7..a4a02f2b 100644 --- a/awq/modules/fused/block.py +++ b/awq/modules/fused/block.py @@ -106,6 +106,7 @@ def __init__( rope_theta=rope_theta, partial_rotary_factor=partial_rotary_factor, head_dim=head_dim, + use_sdpa=True, ).to(dev) self.norm_2 = norm_2.to(dev) self.mlp = mlp.to(dev) diff --git a/awq/modules/fused/norm.py b/awq/modules/fused/norm.py index 7a552afd..17710b08 100644 --- a/awq/modules/fused/norm.py +++ b/awq/modules/fused/norm.py @@ -8,6 +8,13 @@ except: AWQ_INSTALLED = False +try: + import intel_extension_for_pytorch as ipex # with IPEX kernels + + IPEX_INSTALLED = True +except: + IPEX_INSTALLED = False + class FasterTransformerRMSNorm(nn.Module): def __init__(self, weight, eps=1e-6): @@ -16,12 +23,14 @@ def __init__(self, weight, eps=1e-6): self.variance_epsilon = eps def forward(self, x): - assert AWQ_INSTALLED, ( - "AWQ kernels could not be loaded. " - "Please install them from https://github.com/casper-hansen/AutoAWQ_kernels" - ) - - output = torch.empty_like(x) - awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon) + if IPEX_INSTALLED: + output = ipex.llm.functional.rms_norm(x, self.weight, self.variance_epsilon) + else: + assert AWQ_INSTALLED, ( + "AWQ kernels could not be loaded. " + "Please install them from https://github.com/casper-hansen/AutoAWQ_kernels" + ) + output = torch.empty_like(x) + awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon) return output diff --git a/awq/modules/linear/__init__.py b/awq/modules/linear/__init__.py index f0439eae..f2845d34 100644 --- a/awq/modules/linear/__init__.py +++ b/awq/modules/linear/__init__.py @@ -1,7 +1,7 @@ from .exllama import WQLinear_Exllama, exllama_post_init from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init from .gemm import WQLinear_GEMM -from .gemm_qbits import WQLinear_QBits, qbits_post_init +from .gemm_ipex import WQLinear_IPEX, ipex_post_init from .gemv import WQLinear_GEMV from .marlin import WQLinear_Marlin, marlin_post_init from .gemv_fast import WQLinear_GEMVFast diff --git a/awq/modules/linear/gemm_ipex.py b/awq/modules/linear/gemm_ipex.py new file mode 100644 index 00000000..399b98d3 --- /dev/null +++ b/awq/modules/linear/gemm_ipex.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn + +try: + from intel_extension_for_pytorch.nn.modules.weight_only_quantization import WeightOnlyQuantizedLinear + assert hasattr(WeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4" + IPEX_INSTALLED = True +except: + IPEX_INSTALLED = False + + +class WQLinear_IPEX(nn.Module): + + def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev): + super().__init__() + assert IPEX_INSTALLED, \ + "Please install IPEX package with `pip install intel_extension_for_pytorch`." + assert w_bit == 4, "Only 4 bit are supported for now." + + self.use_bf16 = True # Intel platform support bf16 even without amx. + + self.in_features = in_features + self.out_features = out_features + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else in_features + self.zero_point = zero_point + self.scale_dtype = torch.float32 + + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert out_features % (32 // self.w_bit) == 0 + self.pack_num = 32 // self.w_bit + + self.register_buffer( + "qzeros", + torch.zeros( + (in_features // self.group_size, out_features // self.pack_num), + dtype=torch.int8, + device=dev, + ) if self.zero_point else None, + ) + self.register_buffer( + "scales", + torch.zeros( + (in_features // self.group_size, out_features), + dtype=torch.bfloat16 if self.use_bf16 else torch.float32, + device=dev, + )) + if bias: + self.register_buffer( + "bias", + torch.zeros((out_features), dtype=torch.bfloat16 if self.use_bf16 else torch.float32, device=dev), + ) + else: + self.register_buffer( + "bias", + None, + ) + qweight = torch.zeros((in_features, out_features // self.pack_num), dtype=torch.int32, device=dev) + self.register_buffer("qweight", qweight) + + def post_init(self): + assert self.qweight.device.type == "cpu" + self.ipex_linear = WeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \ + self.in_features, self.out_features, None, self.bias, \ + self.group_size, None, 0, 1) + + @classmethod + def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None, has_zero_points=False): + awq_linear = cls( + w_bit, + group_size, + linear.in_features, + linear.out_features, + linear.bias is not None, + has_zero_points, + linear.weight.device, + ) + if init_only: # just prepare for loading sd + return awq_linear + + raise NotImplementedError("Only inference is supported for IPEX kernels") + + @torch.no_grad() + def forward(self, x): + assert IPEX_INSTALLED, ( + "IPEX kernels could not be loaded. " + "Please install with `pip install intel_extension_for_pytorch` and " + "refer to the detial https://github.com/intel/intel-extension-for-pytorch/tree/main") + + outputs = self.ipex_linear(x) + + return outputs + + def extra_repr(self) -> str: + return ("in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( + self.in_features, + self.out_features, + self.bias is not None, + self.w_bit, + self.group_size, + )) + + +def ipex_post_init(model): + for _, submodule in model.named_modules(): + if isinstance(submodule, WQLinear_IPEX): + submodule.post_init() + + return model diff --git a/awq/modules/linear/gemm_qbits.py b/awq/modules/linear/gemm_qbits.py deleted file mode 100644 index 126aad29..00000000 --- a/awq/modules/linear/gemm_qbits.py +++ /dev/null @@ -1,155 +0,0 @@ -import torch -import torch.nn as nn -from awq.utils.module import try_import -from ...utils.packing_utils import reverse_awq_order, unpack_awq - -intel_extension_for_transformers, msg = try_import("intel_extension_for_transformers") -if intel_extension_for_transformers is not None: - qbits = getattr(intel_extension_for_transformers, 'qbits') - -BITS_DTYPE_MAPPING = { - 4: "int4_clip", - 8: "int8", -} - - -def convert_dtype_torch2str(dtype): - if dtype == torch.int8: - return "int8" - elif dtype == torch.float: - return "fp32" - elif dtype == torch.float16: - return "fp16" - elif dtype == torch.bfloat16: - return "bf16" - elif isinstance(dtype, str) and dtype in ["int8", "fp32", "fp16", "bf16"]: - return dtype - else: - assert False, "Unsupported pytorch dtype {} to str dtype".format(dtype) - - -class WQLinear_QBits(nn.Module): - - def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev): - super().__init__() - if intel_extension_for_transformers is None: - raise ModuleNotFoundError("Please install ITREX qbits package with `pip install intel-extension-for-transformers`." + msg) - - self.use_bf16 = qbits.check_isa_supported("AMX") - - if w_bit not in [2, 3, 4, 8]: - raise NotImplementedError("Only 2, 3, 4, 8 bits are supported for now.") - - self.in_features = in_features - self.out_features = out_features - self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else in_features - self.zero_point = zero_point - self.scale_dtype = torch.float32 - - # quick sanity check (make sure aligment) - assert self.in_features % self.group_size == 0 - assert out_features % (32 // self.w_bit) == 0 - self.pack_num = 32 // self.w_bit - - self.register_buffer( - "qzeros", - torch.zeros( - (in_features // self.group_size, out_features // self.pack_num), - dtype=torch.int8, - device=dev, - ) if self.zero_point else None, - ) - self.register_buffer( - "scales", - torch.zeros( - (in_features // self.group_size, out_features), - dtype=torch.bfloat16 if self.use_bf16 else torch.float32, - device=dev, - )) - if bias: - self.register_buffer( - "bias", - torch.zeros((out_features), dtype=torch.bfloat16 if self.use_bf16 else torch.float32, device=dev), - ) - else: - self.register_buffer( - "bias", - None, - ) - qweight = torch.zeros((in_features, out_features // self.pack_num), dtype=torch.int32, device=dev) - self.register_buffer("qweight", qweight) - - def post_init(self): - assert self.qweight.device.type == "cpu" - - intweight, zeros = unpack_awq(self.qweight, self.qzeros, self.w_bit) # weight: k x n zeros: k / group_size x n - intweight, zeros = reverse_awq_order(intweight, zeros, self.w_bit) # weight: k x n zeros: k / group_size x n - if self.zero_point: - intweight = torch.bitwise_and(intweight, (2**self.w_bit) - 1) - (2**(self.w_bit - 1)) - zeros = torch.bitwise_and(zeros, (2**self.w_bit) - 1) - (2**(self.w_bit - 1)) - else: - intweight = torch.bitwise_and(intweight, (2**self.w_bit) - 1) - g_idx = torch.empty(0, dtype=torch.int32) - - self.qweight = qbits.repack_quantized_weight(intweight, self.scales.float(), zeros, g_idx, - BITS_DTYPE_MAPPING[self.w_bit], - convert_dtype_torch2str(self.scale_dtype), - convert_dtype_torch2str(self.scales.dtype), self.zero_point, - self.group_size) - - @classmethod - def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None, has_zero_points=False): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - has_zero_points, - linear.weight.device, - ) - if init_only: # just prepare for loading sd - return awq_linear - - raise NotImplementedError("Only inference is supported for Exllama kernels") - - @torch.no_grad() - def forward(self, x): - if intel_extension_for_transformers is None: - raise ModuleNotFoundError( - "QBits kernels could not be loaded. " - "Please install with `pip install intel-extension-for-transformers` and " - "refer to the detial https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md" - ) - - input_dtype = x.dtype - out_shape = x.shape[:-1] + (self.out_features,) - x = x.view(-1, x.shape[-1]) # convert xd to 2d - out_2d_shape = x.shape[:-1] + (self.out_features,) - - outputs = torch.zeros(out_2d_shape, dtype=input_dtype) - bias = self.bias if self.bias is not None else torch.empty( - 0, dtype=torch.bfloat16 if self.use_bf16 else torch.float32) - - qbits.woq_linear(x, self.qweight, bias, outputs, convert_dtype_torch2str(input_dtype), - BITS_DTYPE_MAPPING[self.w_bit], convert_dtype_torch2str(self.scale_dtype), self.zero_point) - - return outputs.view(out_shape) - - def extra_repr(self) -> str: - return ("in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.w_bit, - self.group_size, - )) - - -def qbits_post_init(model): - for _, submodule in model.named_modules(): - if isinstance(submodule, WQLinear_QBits): - submodule.post_init() - - return model diff --git a/awq/utils/fused_utils.py b/awq/utils/fused_utils.py index 946c6904..a78fead1 100644 --- a/awq/utils/fused_utils.py +++ b/awq/utils/fused_utils.py @@ -7,7 +7,7 @@ WQLinear_Exllama, WQLinear_ExllamaV2, WQLinear_GEMVFast, - WQLinear_QBits, + WQLinear_IPEX, ) @@ -79,10 +79,10 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): q_linear = WQLinear_Marlin elif isinstance(q_proj, WQLinear_GEMVFast): q_linear = WQLinear_GEMVFast - elif isinstance(q_proj, WQLinear_QBits): - q_linear = WQLinear_QBits + elif isinstance(q_proj, WQLinear_IPEX): + q_linear = WQLinear_IPEX - if isinstance(q_proj, WQLinear_QBits): + if isinstance(q_proj, WQLinear_IPEX): qkv_layer = q_linear( q_proj.w_bit, q_proj.group_size, @@ -113,7 +113,7 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): [q_proj.scales, k_proj.scales, v_proj.scales], dim=0 ) qkv_layer.split_k_iters = q_proj.split_k_iters - elif isinstance(q_proj, WQLinear_GEMM) or isinstance(q_proj, WQLinear_QBits): + elif isinstance(q_proj, WQLinear_GEMM) or isinstance(q_proj, WQLinear_IPEX): qkv_layer.qweight = torch.cat( [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1 ) diff --git a/awq/utils/utils.py b/awq/utils/utils.py index f15ade08..7553c5df 100644 --- a/awq/utils/utils.py +++ b/awq/utils/utils.py @@ -4,7 +4,7 @@ import accelerate -qbits_available = importlib.util.find_spec("intel_extension_for_transformers") is not None +ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None def get_module_by_name_suffix(model, module_name: str): diff --git a/docs/examples.md b/docs/examples.md index 86983dd3..2fc0259c 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -316,14 +316,14 @@ generation_output = model.generate( ``` ### Inference With CPU -To run inference with CPU , you should specify `use_qbits=True`. QBits is the backend for CPU including kernel for operators. QBits is a module of the intel-extension-for-transformers package. Up to now, the feature of fusing layers hasn't been ready, you should run model with `fuse_layers=False`. +To run inference with CPU , you should specify `use_ipex=True`. ipex is the backend for CPU including kernel for operators. ipex is intel_extension_for_pytorch package. ```python from awq import AutoAWQForCausalLM quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ" # Load model -model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False, use_qbits=True) +model = AutoAWQForCausalLM.from_quantized(quant_path, use_ipex=True) ``` ### Transformers diff --git a/docs/index.md b/docs/index.md index dedeb807..2bcf36a1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -22,12 +22,11 @@ Example inference speed (RTX 4090, Ryzen 9 7950X, 64 tokens): use_exllama_v2=True ) ``` -- For CPU device, you should install intel-extension-for-transformers with `pip install intel-extension-for-transformers`. And the latest version of torch is required since "intel-extension-for-transformers(ITREX)" was built with the latest version of torch(now ITREX 1.4 was build with torch 2.2). If you build ITREX from source code, then you need to ensure the consistency of the torch version. And you should use "use_qbits=True" for CPU device. Up to now, the feature of fuse_layers hasn't been ready for CPU device. +- For CPU device, you should install intel_extension_for_pytorch with `pip install intel_extension_for_pytorch`. And the latest version of torch is required since "intel_extension_for_pytorch(IPEX)" was built with the latest version of torch(now IPEX 2.4 was build with torch 2.4). If you build IPEX from source code, then you need to ensure the consistency of the torch version. And you should use "use_ipex=True" for CPU device. ```python model = AutoAWQForCausalLM.from_quantized( ..., - fuse_layers=False, - use_qbits=True + use_ipex=True ) ``` diff --git a/examples/benchmark.py b/examples/benchmark.py index 0be30d7f..f3ff44ec 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -7,17 +7,16 @@ import psutil from awq import AutoAWQForCausalLM from awq.models.base import BaseAWQForCausalLM -from awq.utils.utils import get_best_device, qbits_available +from awq.utils.utils import get_best_device, ipex_available from transformers import AutoTokenizer, GenerationConfig, LogitsProcessor, LogitsProcessorList DEVICE = get_best_device() if DEVICE == "cpu": - if qbits_available: - from intel_extension_for_transformers.qbits import check_isa_supported - torch_dtype = torch.bfloat16 if check_isa_supported("AMX") else torch.float32 + if ipex_available: + torch_dtype = torch.bfloat16 else: - raise ImportError("Please import intel-extension-for-transformers " - "by `pip install intel-extension-for-transformers`") + raise ImportError("Please import intel_extension_for_pytorch " + "by `pip install intel_extension_for_pytorch`") else: torch_dtype = torch.float16 @@ -86,8 +85,8 @@ def generate_hf(model: BaseAWQForCausalLM, input_ids, n_generate): min_new_tokens=n_generate, max_new_tokens=n_generate, use_cache=True, - forced_eos_token_id=-100, - eos_token_id=-100, + forced_eos_token_id=1, + eos_token_id=1, ) time_processor = TimeMeasuringLogitsProcessor() @@ -115,9 +114,7 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si ) else: model = AutoAWQForCausalLM.from_quantized( - model_path, quant_file, fuse_layers=False if DEVICE == "cpu" else True, - max_seq_len=n_generate, batch_size=batch_size, - safetensors=not no_safetensors + model_path, quant_file, max_seq_len=n_generate, batch_size=batch_size, safetensors=not no_safetensors ) print(f" -- Warming up...") @@ -163,7 +160,7 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si decode_tokens_per_second = 'OOM' if pretrained: - version = "FP16" if DEVICE != "cpu" else "BF16" if check_isa_supported("AMX") else "FP32" + version = "FP16" if DEVICE != "cpu" else "BF16" else: version = model.quant_config.version diff --git a/setup.py b/setup.py index 95319129..3b631933 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ requirements.append("autoawq-kernels") elif IS_CPU_ONLY: - requirements.append("intel-extension-for-transformers>=1.4.2") + requirements.append("intel-extension-for-pytorch>=2.4.0") force_extension = os.getenv("PYPI_FORCE_TAGS", "0") if force_extension == "1": diff --git a/tests/test_ipex_cpu.py b/tests/test_ipex_cpu.py new file mode 100644 index 00000000..f0397e31 --- /dev/null +++ b/tests/test_ipex_cpu.py @@ -0,0 +1,59 @@ +import torch +from awq.utils.packing_utils import dequantize_gemm +from intel_extension_for_pytorch.nn.modules.weight_only_quantization import WeightOnlyQuantizedLinear + +assert hasattr(WeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4" +torch.manual_seed(0) + +in_features = 256 +out_features = 128 +w_bit = 4 +group_size = 32 +torch_dtype = torch.bfloat16 + +MAX_INT32 = 0x7fffffff +MIN_INT32 = -MAX_INT32 - 1 + +qweight = torch.randint( + MIN_INT32, + MAX_INT32, + (in_features, out_features // (32 // w_bit)), + dtype=torch.int32, + device="cpu", +) + +qzeros = torch.randint( + MIN_INT32, + MAX_INT32, + (in_features // group_size, out_features // (32 // w_bit)), + dtype=torch.int32, + device="cpu", +) + +scales = torch.randn( + (in_features // group_size, out_features), + dtype=torch_dtype, + device="cpu", +) + +with torch.no_grad(): + fp_weight = dequantize_gemm( + qweight, + qzeros, + scales, + w_bit, + group_size + ) + + ipex_linear = WeightOnlyQuantizedLinear.from_weight(qweight, scales, qzeros, \ + in_features, out_features, None, None, \ + group_size, None, 0, 1) + + + input = torch.rand(1, in_features, dtype=torch_dtype) + torch_out = torch.matmul(input, fp_weight) + + ipex_dst = ipex_linear(input) + results = torch.amax(ipex_dst - torch_out) + + assert(torch.allclose(ipex_dst, torch_out, rtol=0.06)) \ No newline at end of file diff --git a/tests/test_qbits_cpu.py b/tests/test_qbits_cpu.py deleted file mode 100644 index f07e63de..00000000 --- a/tests/test_qbits_cpu.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -from awq.utils.packing_utils import unpack_awq, reverse_awq_order -from awq.modules.linear.gemm_qbits import BITS_DTYPE_MAPPING, convert_dtype_torch2str -from awq.utils.packing_utils import dequantize_gemm -from intel_extension_for_transformers import qbits -torch.manual_seed(0) - -in_features = 256 -out_features = 128 -w_bit = 4 -group_size = 32 -torch_dtype = torch.bfloat16 if qbits.check_isa_supported("AMX") else torch.float32 - -MAX_INT32 = 0x7fffffff -MIN_INT32 = -MAX_INT32 - 1 - -qweight = torch.randint( - MIN_INT32, - MAX_INT32, - (in_features, out_features // (32 // w_bit)), - dtype=torch.int32, - device="cpu", -) - -qzeros = torch.randint( - MIN_INT32, - MAX_INT32, - (in_features // group_size, out_features // (32 // w_bit)), - dtype=torch.int32, - device="cpu", -) - -scales = torch.randn( - (in_features // group_size, out_features), - dtype=torch_dtype, - device="cpu", -) - -with torch.no_grad(): - fp_weight = dequantize_gemm( - qweight, - qzeros, - scales, - w_bit, - group_size - ) - intweight, zeros = unpack_awq(qweight, qzeros, w_bit) # weight: k x n zeros: k / group_size x n - intweight, zeros = reverse_awq_order(intweight, zeros, w_bit) # weight: k x n zeros: k / group_size x n - # overflow checks - intweight = torch.bitwise_and(intweight, (2**w_bit) - 1) - (2**(w_bit - 1)) - zeros = torch.bitwise_and(zeros, (2**w_bit) - 1) - (2**(w_bit - 1)) - g_idx = torch.empty(0, dtype=torch.int32) - qbits_qweight = qbits.repack_quantized_weight(intweight, scales.float().contiguous(), zeros, g_idx, - BITS_DTYPE_MAPPING[w_bit], - "fp32", - convert_dtype_torch2str(torch_dtype), - True, - group_size) - qbits_out = torch.zeros(in_features, out_features, dtype=torch.float32) - qbits.dequantize_packed_weight( - qbits_qweight, qbits_out, False, convert_dtype_torch2str(torch_dtype), BITS_DTYPE_MAPPING[w_bit], "fp32") - qbits_out = qbits_out.to(torch_dtype) - assert(torch.allclose(qbits_out, fp_weight, rtol=0.0001)) - - input = torch.rand(1, in_features, dtype=torch_dtype) - torch_out = torch.matmul(input, fp_weight) - - qbits_dst = torch.zeros(1, out_features, dtype=torch.bfloat16) - qbits.woq_linear( - input, qbits_qweight, torch.empty(0), qbits_dst, convert_dtype_torch2str(torch_dtype), "int4_clip", "fp32", True) - results = torch.amax(qbits_dst - torch_out) - - assert(torch.allclose(qbits_dst, torch_out, rtol=0.03)) \ No newline at end of file