From fc6f1803dea8d8c83df4717805aedbd118282192 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Mon, 25 Mar 2024 15:03:11 +0800 Subject: [PATCH 1/3] enable fp4 fused mlp and qkv --- .../src/ipex_llm/transformers/models/gemma.py | 7 +++--- .../src/ipex_llm/transformers/models/llama.py | 23 ++++++++++--------- .../ipex_llm/transformers/models/mistral.py | 13 ++++++----- .../ipex_llm/transformers/models/mixtral.py | 2 +- .../src/ipex_llm/transformers/models/qwen.py | 4 +++- .../src/ipex_llm/transformers/models/qwen2.py | 4 ++-- .../src/ipex_llm/transformers/models/utils.py | 13 ++++++----- .../src/ipex_llm/transformers/models/yuan.py | 5 ++-- 8 files changed, 39 insertions(+), 32 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/gemma.py b/python/llm/src/ipex_llm/transformers/models/gemma.py index 26934f03120..f6bf2db50bd 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma.py @@ -41,6 +41,7 @@ from ipex_llm.transformers.models.utils import mlp_fusion_check, GELU from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5 +from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -74,8 +75,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids): return use_fuse_rope -def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs): - return q_type in [SYM_INT4, FP8E5] and \ +def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): + return decoding_fast_path_qtype_check(proj) and \ use_fuse_rope and enough_kv_room and bs == 1 @@ -137,7 +138,7 @@ def gemma_attention_forward( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype, + decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, enough_kv_room, bsz * q_len) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 45d944c5e4d..19e51e32da9 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -50,7 +50,7 @@ from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaModel -from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS +from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4 from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.utils.common import invalidInputError @@ -64,6 +64,12 @@ logger = logging.get_logger(__name__) +def llama_decoding_fast_path_qtype_check(proj): + # IQ2_XXS only can be used in Llama-like model + qtype = getattr(proj, "qtype", None) + return qtype in [SYM_INT4, FP8E5, IQ2_XXS, FP4] + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states @@ -329,8 +335,7 @@ def llama_attention_forward_4_31_quantized( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) - qtype = getattr(self.q_proj, "qtype", None) - qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS] + qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj) no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) @@ -463,8 +468,7 @@ def llama_attention_forward_4_31_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) - qtype = getattr(self.q_proj, "qtype", None) - qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS] + qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj) no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) @@ -692,8 +696,7 @@ def llama_attention_selective_batching_forward_4_31( # TODO: decoding fast path use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0]) - qtype = getattr(self.q_proj, "qtype", None) - qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS] + qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj) no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and bsz * q_len == 1) @@ -911,8 +914,7 @@ def llama_attention_forward_4_36_quantized( device = hidden_states.device use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) - qtype = getattr(self.q_proj, "qtype", None) - qtype_check = qtype in [SYM_INT4, FP8E5] + qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj) no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) @@ -1093,8 +1095,7 @@ def llama_attention_forward_4_36_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) - qtype = getattr(self.q_proj, "qtype", None) - qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS] + qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj) no_tp = not self.config.pretraining_tp > 1 decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 5c7a63438e1..5cba7f0e2a6 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -53,6 +53,7 @@ is_enough_kv_cache_room_4_36 from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check try: from transformers.cache_utils import Cache except ImportError: @@ -81,8 +82,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids): return use_fuse_rope -def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs): - return q_type in [SYM_INT4, FP8E5, IQ2_XXS] and \ +def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): + return llama_decoding_fast_path_qtype_check(proj) and \ use_fuse_rope and enough_kv_room and bs == 1 @@ -200,7 +201,7 @@ def mistral_attention_forward_quantized( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) - decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype, + decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, enough_kv_room, bsz * q_len) @@ -375,7 +376,7 @@ def mistral_attention_forward_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) - decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype, + decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, enough_kv_room, bsz * q_len) @@ -551,7 +552,7 @@ def mistral_attention_forward_4_36_quantized( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) - decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype, + decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, enough_kv_room, bsz * q_len) @@ -731,7 +732,7 @@ def mistral_attention_forward_4_36_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype, + decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, enough_kv_room, bsz * q_len) diff --git a/python/llm/src/ipex_llm/transformers/models/mixtral.py b/python/llm/src/ipex_llm/transformers/models/mixtral.py index f5c836acfe7..967625127db 100644 --- a/python/llm/src/ipex_llm/transformers/models/mixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/mixtral.py @@ -155,7 +155,7 @@ def mixtral_attention_forward( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype, + decoding_fast_path = use_decoding_fast_path(self.q_proj, use_fuse_rope, enough_kv_room, bsz * q_len) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index 833ff866512..08401767369 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -43,6 +43,7 @@ from ipex_llm.transformers.models.utils import mlp_fusion_check from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check from ipex_llm.utils.common import invalidInputError, invalidOperationError from ipex_llm.ggml.quantize import ggml_tensor_qtype from transformers.modeling_outputs import BaseModelOutputWithPast @@ -137,7 +138,8 @@ def qwen_attention_forward_original( original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states) - decoding_fast_path = (use_fuse_rope and bsz * q_len == 1) + qtype_check = decoding_fast_path_qtype_check(self.q_proj) + decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1) if decoding_fast_path: hidden_states = hidden_states.view(1, -1) cache_k, cache_v = layer_past[0], layer_past[1] diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index faf14a8750c..81e564a93fe 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -57,6 +57,7 @@ from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast +from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check try: from transformers.cache_utils import Cache, DynamicCache @@ -431,8 +432,7 @@ def qwen2_attention_forward_origin( device = hidden_states.device enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - qtype = getattr(self.q_proj, "qtype", None) - qtype_check = qtype in [SYM_INT4, FP8E5] + qtype_check = decoding_fast_path_qtype_check(self.q_proj) decoding_fast_path = (qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) if decoding_fast_path: diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 1a4e1f0b94d..2bebc600d94 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -19,18 +19,20 @@ from ipex_llm.utils.common import invalidInputError from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type +from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4 FP8_KV_ALLOC_LENGTH = 512 -SYM_INT4 = ggml_tensor_qtype["sym_int4"] -SYM_INT8 = ggml_tensor_qtype["sym_int8"] -FP8E4 = ggml_tensor_qtype["fp8_e4m3"] -FP8E5 = ggml_tensor_qtype["fp8_e5m2"] # used in fused mlp forward SILU = 0 GELU = 1 +def decoding_fast_path_qtype_check(proj): + qtype = getattr(proj, "qtype", None) + return qtype in [SYM_INT4, FP8E5, FP4] + + def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): key_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, @@ -335,8 +337,7 @@ def mlp_fusion_check(x, qtype, training): return False if x.device.type != 'xpu': return False - if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"], - ggml_tensor_qtype["gguf_iq2_xxs"]]: + if qtype not in [SYM_INT4, FP8E5, FP4, IQ2_XXS]: return False if training or x.requires_grad: return False diff --git a/python/llm/src/ipex_llm/transformers/models/yuan.py b/python/llm/src/ipex_llm/transformers/models/yuan.py index f17b0ec7b04..71f4d81765b 100644 --- a/python/llm/src/ipex_llm/transformers/models/yuan.py +++ b/python/llm/src/ipex_llm/transformers/models/yuan.py @@ -36,12 +36,13 @@ restore_fp8_kv_cache, use_quantize_kv_cache from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, SILU from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5 +from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check KV_CACHE_ALLOC_BLOCK_LENGTH = 256 -def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs): - return q_type in [SYM_INT4, FP8E5] and \ +def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): + return decoding_fast_path_qtype_check(proj) and \ use_fuse_rope and enough_kv_room and bs == 1 From d629c8b121bba90aafe28beb06b0a728ca226200 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Mon, 25 Mar 2024 17:05:20 +0800 Subject: [PATCH 2/3] update qwen --- python/llm/src/ipex_llm/transformers/models/qwen.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index 08401767369..3a0eb0cc3c3 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -334,7 +334,8 @@ def qwen_attention_forward_quantized( device = hidden_states.device use_fuse_rope = should_use_fuse_rope(self, hidden_states) - # TODO: use when decoding_fast_path = (use_fuse_rope and bsz * q_len == 1) + # qtype_check = decoding_fast_path_qtype_check(self.q_proj) + # TODO: use when decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1) decoding_fast_path = False if decoding_fast_path: hidden_states = hidden_states.view(1, -1) From 15587d0e5000986479a6f83e2a2dfc1b6fd1eeb0 Mon Sep 17 00:00:00 2001 From: qiuxin2012 Date: Mon, 25 Mar 2024 17:07:10 +0800 Subject: [PATCH 3/3] update qwen2 --- python/llm/src/ipex_llm/transformers/models/qwen2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 81e564a93fe..8e557b235e5 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -601,8 +601,7 @@ def qwen2_sdpa_attention_forward( device = hidden_states.device enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - qtype = getattr(self.q_proj, "qtype", None) - qtype_check = qtype in [SYM_INT4, FP8E5] + qtype_check = decoding_fast_path_qtype_check(self.q_proj) decoding_fast_path = (qtype_check and use_fuse_rope and enough_kv_room and bsz * q_len == 1) if decoding_fast_path: