diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 4123035ec87..b3ffb9c91a5 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1308,9 +1308,6 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward - convert_forward(model, - module.Qwen2Model, - qwen2_model_forward) convert_forward(model, module.Qwen2ForCausalLM, qwen2_causal_lm_forward) @@ -1326,6 +1323,12 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.Qwen2SdpaAttention, qwen2_attention_forward) + if version.parse(trans_version) >= version.parse("4.42"): + from ipex_llm.transformers.models.qwen2 import qwen2_model_forward_4_42 + convert_forward(model, module.Qwen2Model, qwen2_model_forward_4_42) + else: + from ipex_llm.transformers.models.qwen2 import qwen2_model_forward + convert_forward(model, module.Qwen2Model, qwen2_model_forward) elif model.config.model_type == "qwen2_moe": # for Qwen1.5-MOE-A2.7B modeling_module_name = model.__class__.__module__ @@ -1356,6 +1359,8 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.Qwen2MoeSdpaAttention, qwen2_attention_forward) + elif model.config.model_type == "qwen2_audio": + _optimize_post(model.language_model, lightweight_bmm=lightweight_bmm) elif model.config.model_type == "cohere": # for CohereForAI/c4ai-command-r-v01 invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"), diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index b80f4cff86f..32d838cb128 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -55,8 +55,6 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv -from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa -from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.cache_utils import Cache from transformers import logging @@ -76,12 +74,15 @@ def qwen2_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, # for transformers >= 4.42 ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else \ - self.config.output_attentions + output_attentions = ( + output_attentions if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else - self.config.output_hidden_states + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -90,8 +91,7 @@ def qwen2_model_forward( # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: invalidInputError(False, - "You cannot specify both decoder_input_ids and " - "decoder_inputs_embeds at the same time") + "You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: @@ -159,6 +159,9 @@ def qwen2_model_forward( "the input. " ) + from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa + from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask + # ipex-llm changes start: don't generate `attention_mask` in specific cases if seq_length == 1 or batch_size == 1 and use_sdp_causal( seq_length, seq_length + past_key_values_length, @@ -259,6 +262,138 @@ def qwen2_model_forward( ) +def qwen2_model_forward_4_42( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + invalidInputError( + (input_ids is None) ^ (inputs_embeds is None), + "You cannot specify both input_ids and inputs_embeds at the same time, " + "and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. " + "Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # ipex-llm changes start + # IPEX-LLM OPT: kv cache and quantize kv cache + use_quantize_kv = ( + self.config.hidden_size != 3584 # disable quantize kv in specific model + and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds, + self.config.num_attention_heads//self.config.num_key_value_heads) + ) + if use_cache: + if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + # ipex-llm changes end + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # ipex-llm changes start: remove `to_legacy_cache` + next_cache = None + if use_cache: + next_cache = next_decoder_cache + # ipex-llm changes end + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, + all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def qwen2_causal_lm_forward( self, input_ids: torch.LongTensor = None, @@ -271,6 +406,7 @@ def qwen2_causal_lm_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, # for transformers >= 4.42 ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None @@ -293,6 +429,7 @@ def qwen2_causal_lm_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0]