Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable fp4 fused mlp and qkv #10531

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ipex_llm.transformers.models.utils import mlp_fusion_check, GELU
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check

KV_CACHE_ALLOC_BLOCK_LENGTH = 256

Expand Down Expand Up @@ -74,8 +75,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
return use_fuse_rope


def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
return q_type in [SYM_INT4, FP8E5] and \
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
return decoding_fast_path_qtype_check(proj) and \
use_fuse_rope and enough_kv_room and bs == 1


Expand Down Expand Up @@ -137,7 +138,7 @@ def gemma_attention_forward(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down
23 changes: 12 additions & 11 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.utils.common import invalidInputError

Expand All @@ -64,6 +64,12 @@
logger = logging.get_logger(__name__)


def llama_decoding_fast_path_qtype_check(proj):
# IQ2_XXS only can be used in Llama-like model
qtype = getattr(proj, "qtype", None)
return qtype in [SYM_INT4, FP8E5, IQ2_XXS, FP4]


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states
Expand Down Expand Up @@ -329,8 +335,7 @@ def llama_attention_forward_4_31_quantized(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1)
Expand Down Expand Up @@ -463,8 +468,7 @@ def llama_attention_forward_4_31_original(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)
Expand Down Expand Up @@ -692,8 +696,7 @@ def llama_attention_selective_batching_forward_4_31(
# TODO: decoding fast path
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0])
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
bsz * q_len == 1)
Expand Down Expand Up @@ -911,8 +914,7 @@ def llama_attention_forward_4_36_quantized(
device = hidden_states.device
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5]
qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1)
Expand Down Expand Up @@ -1093,8 +1095,7 @@ def llama_attention_forward_4_36_original(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5, IQ2_XXS]
qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and
enough_kv_room and bsz * q_len == 1)
Expand Down
13 changes: 7 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
is_enough_kv_cache_room_4_36
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
try:
from transformers.cache_utils import Cache
except ImportError:
Expand Down Expand Up @@ -81,8 +82,8 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
return use_fuse_rope


def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
return q_type in [SYM_INT4, FP8E5, IQ2_XXS] and \
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
return llama_decoding_fast_path_qtype_check(proj) and \
use_fuse_rope and enough_kv_room and bs == 1


Expand Down Expand Up @@ -200,7 +201,7 @@ def mistral_attention_forward_quantized(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down Expand Up @@ -375,7 +376,7 @@ def mistral_attention_forward_original(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down Expand Up @@ -551,7 +552,7 @@ def mistral_attention_forward_4_36_quantized(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down Expand Up @@ -731,7 +732,7 @@ def mistral_attention_forward_4_36_original(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def mixtral_attention_forward(

use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj.qtype,
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
Expand Down
7 changes: 5 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ipex_llm.transformers.models.utils import mlp_fusion_check
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check
from ipex_llm.utils.common import invalidInputError, invalidOperationError
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from transformers.modeling_outputs import BaseModelOutputWithPast
Expand Down Expand Up @@ -137,7 +138,8 @@ def qwen_attention_forward_original(
original_dtype = hidden_states.dtype

use_fuse_rope = should_use_fuse_rope(self, hidden_states)
decoding_fast_path = (use_fuse_rope and bsz * q_len == 1)
qtype_check = decoding_fast_path_qtype_check(self.q_proj)
decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
cache_k, cache_v = layer_past[0], layer_past[1]
Expand Down Expand Up @@ -332,7 +334,8 @@ def qwen_attention_forward_quantized(
device = hidden_states.device

use_fuse_rope = should_use_fuse_rope(self, hidden_states)
# TODO: use when decoding_fast_path = (use_fuse_rope and bsz * q_len == 1)
# qtype_check = decoding_fast_path_qtype_check(self.q_proj)
# TODO: use when decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1)
decoding_fast_path = False
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
Expand Down
7 changes: 3 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check

try:
from transformers.cache_utils import Cache, DynamicCache
Expand Down Expand Up @@ -431,8 +432,7 @@ def qwen2_attention_forward_origin(
device = hidden_states.device

enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5]
qtype_check = decoding_fast_path_qtype_check(self.q_proj)
decoding_fast_path = (qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1)
if decoding_fast_path:
Expand Down Expand Up @@ -601,8 +601,7 @@ def qwen2_sdpa_attention_forward(
device = hidden_states.device

enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
qtype = getattr(self.q_proj, "qtype", None)
qtype_check = qtype in [SYM_INT4, FP8E5]
qtype_check = decoding_fast_path_qtype_check(self.q_proj)
decoding_fast_path = (qtype_check and use_fuse_rope
and enough_kv_room and bsz * q_len == 1)
if decoding_fast_path:
Expand Down
13 changes: 7 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
from ipex_llm.utils.common import invalidInputError
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.transformers.utils import get_ipex_version, get_xpu_device_type
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8, FP8E5, IQ2_XXS, FP4, FP8E4

FP8_KV_ALLOC_LENGTH = 512
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
FP8E4 = ggml_tensor_qtype["fp8_e4m3"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]

# used in fused mlp forward
SILU = 0
GELU = 1


def decoding_fast_path_qtype_check(proj):
qtype = getattr(proj, "qtype", None)
return qtype in [SYM_INT4, FP8E5, FP4]


def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
key_cache_storage = torch.empty(batch_size, num_heads,
max_length, head_dim,
Expand Down Expand Up @@ -335,8 +337,7 @@ def mlp_fusion_check(x, qtype, training):
return False
if x.device.type != 'xpu':
return False
if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"],
ggml_tensor_qtype["gguf_iq2_xxs"]]:
if qtype not in [SYM_INT4, FP8E5, FP4, IQ2_XXS]:
return False
if training or x.requires_grad:
return False
Expand Down
5 changes: 3 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/yuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@
restore_fp8_kv_cache, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, SILU
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check

KV_CACHE_ALLOC_BLOCK_LENGTH = 256


def use_decoding_fast_path(q_type, use_fuse_rope, enough_kv_room, bs):
return q_type in [SYM_INT4, FP8E5] and \
def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs):
return decoding_fast_path_qtype_check(proj) and \
use_fuse_rope and enough_kv_room and bs == 1


Expand Down
Loading