Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support and optimize qwen2-audio #11809

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,9 +1308,6 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
convert_forward(model,
module.Qwen2Model,
qwen2_model_forward)
convert_forward(model,
module.Qwen2ForCausalLM,
qwen2_causal_lm_forward)
Expand All @@ -1326,6 +1323,12 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.Qwen2SdpaAttention,
qwen2_attention_forward)
if version.parse(trans_version) >= version.parse("4.42"):
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward_4_42
convert_forward(model, module.Qwen2Model, qwen2_model_forward_4_42)
else:
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
convert_forward(model, module.Qwen2Model, qwen2_model_forward)
elif model.config.model_type == "qwen2_moe":
# for Qwen1.5-MOE-A2.7B
modeling_module_name = model.__class__.__module__
Expand Down Expand Up @@ -1356,6 +1359,8 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.Qwen2MoeSdpaAttention,
qwen2_attention_forward)
elif model.config.model_type == "qwen2_audio":
_optimize_post(model.language_model, lightweight_bmm=lightweight_bmm)
elif model.config.model_type == "cohere":
# for CohereForAI/c4ai-command-r-v01
invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"),
Expand Down
153 changes: 145 additions & 8 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@

from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.cache_utils import Cache
from transformers import logging
Expand All @@ -76,12 +74,15 @@ def qwen2_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, # for transformers >= 4.42
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else \
self.config.output_attentions
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

Expand All @@ -90,8 +91,7 @@ def qwen2_model_forward(
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
invalidInputError(False,
"You cannot specify both decoder_input_ids and "
"decoder_inputs_embeds at the same time")
"You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
Expand Down Expand Up @@ -159,6 +159,9 @@ def qwen2_model_forward(
"the input. "
)

from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask

# ipex-llm changes start: don't generate `attention_mask` in specific cases
if seq_length == 1 or batch_size == 1 and use_sdp_causal(
seq_length, seq_length + past_key_values_length,
Expand Down Expand Up @@ -259,6 +262,138 @@ def qwen2_model_forward(
)


def qwen2_model_forward_4_42(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

invalidInputError(
(input_ids is None) ^ (inputs_embeds is None),
"You cannot specify both input_ids and inputs_embeds at the same time, "
"and must specify either one"
)

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. "
"Setting `use_cache=False`..."
)
use_cache = False

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# ipex-llm changes start
# IPEX-LLM OPT: kv cache and quantize kv cache
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
self.config.num_attention_heads//self.config.num_key_value_heads)
)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# ipex-llm changes end

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)

hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

# ipex-llm changes start: remove `to_legacy_cache`
next_cache = None
if use_cache:
next_cache = next_decoder_cache
# ipex-llm changes end

if not return_dict:
return tuple(v for v in [hidden_states, next_cache,
all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)


def qwen2_causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
Expand All @@ -271,6 +406,7 @@ def qwen2_causal_lm_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, # for transformers >= 4.42
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None
Expand All @@ -293,6 +429,7 @@ def qwen2_causal_lm_forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

hidden_states = outputs[0]
Expand Down
Loading