Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
jenniew committed Aug 26, 2024
1 parent 9334431 commit 22042ea
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 107 deletions.
105 changes: 0 additions & 105 deletions python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,37 +243,6 @@ def __init__(
print("start compiling")
self.compile()

# def repeat_kv(self, hidden_states, n_rep, transpose=False):
# if n_rep == 1:
# return hidden_states
# if not transpose:
# hidden_states = self.reshape(
# hidden_states,
# [self.batch_size, self.num_key_value_heads, 1, self.kv_seq_len, self.head_dim],
# )
# hidden_states = self.broadcast(
# hidden_states,
# [self.batch_size, self.num_key_value_heads, n_rep, self.kv_seq_len, self.head_dim],
# )
# hidden_states = self.reshape(
# hidden_states,
# [self.batch_size, n_rep * self.num_key_value_heads, self.kv_seq_len, self.head_dim],
# )
# else:
# hidden_states = self.reshape(
# hidden_states,
# [self.batch_size, self.num_key_value_heads, 1, self.head_dim, self.kv_seq_len],
# )
# hidden_states = self.broadcast(
# hidden_states,
# [self.batch_size, self.num_key_value_heads, n_rep, self.head_dim, self.kv_seq_len],
# )
# hidden_states = self.reshape(
# hidden_states,
# [self.batch_size, n_rep * self.num_key_value_heads, self.head_dim, self.kv_seq_len],
# )
# return hidden_states

def build_decoder(
self,
hidden_states,
Expand Down Expand Up @@ -325,44 +294,6 @@ def build_decoder(
else:
value_states = self.transpose(value_states, [0, 2, 1, 3])

# query_states = self.linear(
# input_2d,
# self.num_heads * self.head_dim,
# self.hidden_size,
# bias=False,
# wt_dtype=self.dtype,
# )
# key_states = self.linear(
# input_2d,
# self.num_key_value_heads * self.head_dim,
# self.hidden_size,
# bias=False,
# wt_dtype=self.dtype,
# )
# value_states = self.linear(
# input_2d,
# self.num_key_value_heads * self.head_dim,
# self.hidden_size,
# bias=False,
# wt_dtype=self.dtype,
# )

# query_states = self.reshape(
# query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim]
# )
# key_states = self.reshape(
# key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim]
# )
# value_states = self.reshape(
# value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim]
# )
#
# query_states = self.transpose(query_states, [0, 2, 1, 3])
# key_states = self.transpose(key_states, [0, 2, 1, 3])
# if self.transpose_value:
# value_states = self.transpose(value_states, [0, 2, 3, 1])
# else:
# value_states = self.transpose(value_states, [0, 2, 1, 3])
cos = self.unsqueeze(self.squeeze(self.cos), [0])
sin = self.unsqueeze(self.squeeze(self.sin), [0])
query_states, key_states = self.apply_rotary_pos_emb(
Expand All @@ -378,36 +309,13 @@ def build_decoder(
else:
value_states = self.concat(past_value, value_states, axis=-2)

# key_states = self.repeat_kv(key_states, self.num_key_value_groups)
# value_states = self.repeat_kv(value_states, self.num_key_value_groups, self.transpose_value)

# if query_states.size(2) == key_states.size(2):
# # first token
# from intel_npu_acceleration_library.functional import scaled_dot_product_attention
# attn_output = scaled_dot_product_attention(
# query_states,
# key_states,
# value_states,
# attn_mask=attention_mask
# )
# attn_weights = None
# else:
attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim))
attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1)
attn_weight = self.convert_to_fp16(attn_weight)
attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value)

# attn_weight = self.matmul(query_states, key_states, False, True) / (
# math.sqrt(self.head_dim)
# )
# attn_weight = self.eltwise_add(attn_weight, attention_mask)
# attn_weight = self.convert_to_fp32(attn_weight)
# attn_weight = self.softmax(attn_weight, -1)
# attn_weight = self.convert_to_fp16(attn_weight)
# attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value)

attn_output = self.transpose(attn_output, [0, 2, 1, 3])
attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size])

Expand Down Expand Up @@ -878,16 +786,6 @@ def run_decode(
past_key_values = input_queue.get()
else:
t0 = time.perf_counter()
# past_seen_tokens = past_key_values.get_seq_length()
# attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
# cache_position = torch.arange(
# past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device
# )
#
# position_ids = position_ids = cache_position.unsqueeze(0)
# causal_mask = model.model._update_causal_mask(
# attention_mask, hidden_states, cache_position, past_seen_tokens
# )
past_key_values_length = past_key_values.get_seq_length()
seq_length_with_past = 1 + past_key_values_length
position_ids = torch.arange(
Expand Down Expand Up @@ -1210,9 +1108,6 @@ def baichuan_fused_model_forward(

seq_length_with_past = seq_length
past_key_values_length = 0
# if past_key_values is not None:
# past_key_values_length = past_key_values.get_seq_length()
# seq_length_with_past = seq_length_with_past + past_key_values_length

# ipex-llm changes start
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
Expand Down
2 changes: 0 additions & 2 deletions python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,10 @@ def optimize_llm(
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
)

qwen2_model_forward = gen_qwen2_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
convert_forward(model, Qwen2Model, qwen2_model_forward)

from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
Expand Down

0 comments on commit 22042ea

Please sign in to comment.