Skip to content

Commit

Permalink
enable fp4 fused mlp and qkv
Browse files Browse the repository at this point in the history
  • Loading branch information
qiuxin2012 committed Mar 25, 2024
1 parent 5b76f88 commit fc6f180
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 32 deletions.
7 changes: 4 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ipex_llm.transformers.models.utils import mlp_fusion_check, GELU
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check

KV_CACHE_ALLOC_BLOCK_LENGTH = 256

Expand Down Expand Up @@ -74,8 +75,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
return use_fuse_rope


def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
return q_type in [SYM_INT4, FP8E5] and \
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
return decoding_fast_path_qtype_check(proj) and \
use_fuse_rope and enough_kv_room and bs == 1


Expand Down Expand Up @@ -137,7 +138,7 @@ def gemma_attention_forward(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down
23 changes: 12 additions & 11 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.utils.common import invalidInputError

Expand All @@ -64,6 +64,12 @@
logger = logging.get_logger(__name__)


def llama_decoding_fast_path_qtype_check(proj):
# IQ2_XXS only can be used in Llama-like model
qtype = getattr(proj, "qtype", None)
return qtype in [SYM_INT4, FP8E5, IQ2_XXS, FP4]


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states
Expand Down Expand Up @@ -329,8 +335,7 @@ def llama_attention_forward_4_31_quantized(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1)
Expand Down Expand Up @@ -463,8 +468,7 @@ def llama_attention_forward_4_31_original(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)
Expand Down Expand Up @@ -692,8 +696,7 @@ def llama_attention_selective_batching_forward_4_31(
# TODO: decoding fast path
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0])
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
bsz * q_len == 1)
Expand Down Expand Up @@ -911,8 +914,7 @@ def llama_attention_forward_4_36_quantized(
device = hidden_states.device
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5]
qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1)
Expand Down Expand Up @@ -1093,8 +1095,7 @@ def llama_attention_forward_4_36_original(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)
Expand Down
13 changes: 7 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
is_enough_kv_cache_room_4_36
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
try:
from transformers.cache_utils import Cache
except ImportError:
Expand Down Expand Up @@ -81,8 +82,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
return use_fuse_rope


def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
return q_type in [SYM_INT4, FP8E5, IQ2_XXS] and \
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
return llama_decoding_fast_path_qtype_check(proj) and \
use_fuse_rope and enough_kv_room and bs == 1


Expand Down Expand Up @@ -200,7 +201,7 @@ def mistral_attention_forward_quantized(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down Expand Up @@ -375,7 +376,7 @@ def mistral_attention_forward_original(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down Expand Up @@ -551,7 +552,7 @@ def mistral_attention_forward_4_36_quantized(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down Expand Up @@ -731,7 +732,7 @@ def mistral_attention_forward_4_36_original(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def mixtral_attention_forward(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
from ipex_llm.utils.common import invalidInputError, invalidOperationError
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from transformers.modeling_outputs import BaseModelOutputWithPast
Expand Down Expand Up @@ -137,7 +138,8 @@ def qwen_attention_forward_original(
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states)
decoding_fast_path = (use_fuse_rope and bsz * q_len == 1)
qtype_check = decoding_fast_path_qtype_check(self.q_proj)
decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
cache_k, cache_v = layer_past[0], layer_past[1]
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check

try:
from transformers.cache_utils import Cache, DynamicCache
Expand Down Expand Up @@ -431,8 +432,7 @@ def qwen2_attention_forward_origin(
device = hidden_states.device

enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5]
qtype_check = decoding_fast_path_qtype_check(self.q_proj)
decoding_fast_path = (qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1)
if decoding_fast_path:
Expand Down
13 changes: 7 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
from ipex_llm.utils.common import invalidInputError
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4

FP8_KV_ALLOC_LENGTH = 512
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
FP8E4 = ggml_tensor_qtype["fp8_e4m3"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]

# used in fused mlp forward
SILU = 0
GELU = 1


def decoding_fast_path_qtype_check(proj):
qtype = getattr(proj, "qtype", None)
return qtype in [SYM_INT4, FP8E5, FP4]


def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
key_cache_storage = torch.empty(batch_size, num_heads,
max_length, head_dim,
Expand Down Expand Up @@ -335,8 +337,7 @@ def mlp_fusion_check(x, qtype, training):
return False
if x.device.type != 'xpu':
return False
if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"],
ggml_tensor_qtype["gguf_iq2_xxs"]]:
if qtype not in [SYM_INT4, FP8E5, FP4, IQ2_XXS]:
return False
if training or x.requires_grad:
return False
Expand Down
5 changes: 3 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/yuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@
restore_fp8_kv_cache, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, SILU
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check

KV_CACHE_ALLOC_BLOCK_LENGTH = 256


def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
return q_type in [SYM_INT4, FP8E5] and \
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
return decoding_fast_path_qtype_check(proj) and \
use_fuse_rope and enough_kv_room and bs == 1


Expand Down

0 comments on commit fc6f180

Please sign in to comment.