From 22042ea23cd4c7fd73af1fe1e2138fa7187a05a0 Mon Sep 17 00:00:00 2001 From: jenniew Date: Mon, 26 Aug 2024 14:01:35 -0700 Subject: [PATCH] clean --- .../transformers/npu_models/baichuan_mp.py | 105 ------------------ .../transformers/npu_models/convert_mp.py | 2 - 2 files changed, 107 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index f94f8715e7a..b436c317e5f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -243,37 +243,6 @@ def __init__( print("start compiling") self.compile() - # def repeat_kv(self, hidden_states, n_rep, transpose=False): - # if n_rep == 1: - # return hidden_states - # if not transpose: - # hidden_states = self.reshape( - # hidden_states, - # [self.batch_size, self.num_key_value_heads, 1, self.kv_seq_len, self.head_dim], - # ) - # hidden_states = self.broadcast( - # hidden_states, - # [self.batch_size, self.num_key_value_heads, n_rep, self.kv_seq_len, self.head_dim], - # ) - # hidden_states = self.reshape( - # hidden_states, - # [self.batch_size, n_rep * self.num_key_value_heads, self.kv_seq_len, self.head_dim], - # ) - # else: - # hidden_states = self.reshape( - # hidden_states, - # [self.batch_size, self.num_key_value_heads, 1, self.head_dim, self.kv_seq_len], - # ) - # hidden_states = self.broadcast( - # hidden_states, - # [self.batch_size, self.num_key_value_heads, n_rep, self.head_dim, self.kv_seq_len], - # ) - # hidden_states = self.reshape( - # hidden_states, - # [self.batch_size, n_rep * self.num_key_value_heads, self.head_dim, self.kv_seq_len], - # ) - # return hidden_states - def build_decoder( self, hidden_states, @@ -325,44 +294,6 @@ def build_decoder( else: value_states = self.transpose(value_states, [0, 2, 1, 3]) - # query_states = self.linear( - # input_2d, - # self.num_heads * self.head_dim, - # self.hidden_size, - # bias=False, - # wt_dtype=self.dtype, - # ) - # key_states = self.linear( - # input_2d, - # self.num_key_value_heads * self.head_dim, - # self.hidden_size, - # bias=False, - # wt_dtype=self.dtype, - # ) - # value_states = self.linear( - # input_2d, - # self.num_key_value_heads * self.head_dim, - # self.hidden_size, - # bias=False, - # wt_dtype=self.dtype, - # ) - - # query_states = self.reshape( - # query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim] - # ) - # key_states = self.reshape( - # key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim] - # ) - # value_states = self.reshape( - # value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim] - # ) - # - # query_states = self.transpose(query_states, [0, 2, 1, 3]) - # key_states = self.transpose(key_states, [0, 2, 1, 3]) - # if self.transpose_value: - # value_states = self.transpose(value_states, [0, 2, 3, 1]) - # else: - # value_states = self.transpose(value_states, [0, 2, 1, 3]) cos = self.unsqueeze(self.squeeze(self.cos), [0]) sin = self.unsqueeze(self.squeeze(self.sin), [0]) query_states, key_states = self.apply_rotary_pos_emb( @@ -378,20 +309,6 @@ def build_decoder( else: value_states = self.concat(past_value, value_states, axis=-2) - # key_states = self.repeat_kv(key_states, self.num_key_value_groups) - # value_states = self.repeat_kv(value_states, self.num_key_value_groups, self.transpose_value) - - # if query_states.size(2) == key_states.size(2): - # # first token - # from intel_npu_acceleration_library.functional import scaled_dot_product_attention - # attn_output = scaled_dot_product_attention( - # query_states, - # key_states, - # value_states, - # attn_mask=attention_mask - # ) - # attn_weights = None - # else: attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim)) attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.convert_to_fp32(attn_weight) @@ -399,15 +316,6 @@ def build_decoder( attn_weight = self.convert_to_fp16(attn_weight) attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) - # attn_weight = self.matmul(query_states, key_states, False, True) / ( - # math.sqrt(self.head_dim) - # ) - # attn_weight = self.eltwise_add(attn_weight, attention_mask) - # attn_weight = self.convert_to_fp32(attn_weight) - # attn_weight = self.softmax(attn_weight, -1) - # attn_weight = self.convert_to_fp16(attn_weight) - # attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) - attn_output = self.transpose(attn_output, [0, 2, 1, 3]) attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size]) @@ -878,16 +786,6 @@ def run_decode( past_key_values = input_queue.get() else: t0 = time.perf_counter() - # past_seen_tokens = past_key_values.get_seq_length() - # attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) - # cache_position = torch.arange( - # past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device - # ) - # - # position_ids = position_ids = cache_position.unsqueeze(0) - # causal_mask = model.model._update_causal_mask( - # attention_mask, hidden_states, cache_position, past_seen_tokens - # ) past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = 1 + past_key_values_length position_ids = torch.arange( @@ -1210,9 +1108,6 @@ def baichuan_fused_model_forward( seq_length_with_past = seq_length past_key_values_length = 0 - # if past_key_values is not None: - # past_key_values_length = past_key_values.get_seq_length() - # seq_length_with_past = seq_length_with_past + past_key_values_length # ipex-llm changes start from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index d04a5c9c0be..1b7a9f4505c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -77,12 +77,10 @@ def optimize_llm( max_prompt_len=max_prompt_len, transpose_value_cache=transpose_value_cache, ) - qwen2_model_forward = gen_qwen2_fused_model_forward( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, Qwen2Model, qwen2_model_forward) - from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)