diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index d9f6b60e9cf1..9dab60782b42 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -162,18 +162,16 @@ def from_pretrained(cls, ggml_tensor_qtype, FP4Params if isinstance(model.lm_head, torch.nn.Linear): - new_linear = LowBitLinear( - model.lm_head.in_features, - model.lm_head.out_features, - ggml_tensor_qtype["sym_int4"], - False - ) + new_linear = LowBitLinear(model.lm_head.in_features, + model.lm_head.out_features, + ggml_tensor_qtype["sym_int4"], + False) paramsLowBit = FP4Params(data=model.lm_head.weight.data, - requires_grad=False, - quantized=False, - _shape=None, - qtype=ggml_tensor_qtype["sym_int4"], - in_features=model.lm_head.in_features).to("cpu") + requires_grad=False, + quantized=False, + _shape=None, + qtype=ggml_tensor_qtype["sym_int4"], + in_features=model.lm_head.in_features).to("cpu") new_linear._parameters['weight'] = paramsLowBit model.lm_head = new_linear diff --git a/python/llm/src/ipex_llm/transformers/npu_models/kv.py b/python/llm/src/ipex_llm/transformers/npu_models/kv.py index 0f2affd401ce..ce5b29ee2945 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/kv.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/kv.py @@ -25,8 +25,8 @@ def init_fused_kv_cache(batch_size, num_heads, head_dim, current_length, max_len max_length, head_dim, dtype=dtype, device=device) value_cache_storage = torch.zeros(batch_size, num_heads, - max_length, head_dim, - dtype=dtype, device=device) + max_length, head_dim, + dtype=dtype, device=device) key_cache = key_cache_storage.as_strided((batch_size, num_heads, current_length, head_dim), @@ -57,9 +57,9 @@ class DynamicFusedNormalCache(DynamicCache): KV_ALLOC_BLOCK_LENGTH = 256 def __init__(self) -> None: - self.key_cache: Dict[int, torch.Tensor] = {} + self.key_cache: Dict[int, torch.Tensor] = {} self.value_cache: Dict[int, torch.Tensor] = {} - self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 # Used in `generate` to keep how many tokens the cache has seen def update( self, @@ -85,7 +85,8 @@ def update( # Update the cache # if len(self.key_cache) <= layer_idx: if layer_idx not in self.key_cache: - max_len = max_seq_length if max_seq_length is not None else key_states.size(2) + self.KV_ALLOC_BLOCK_LENGTH + max_len = max_seq_length if max_seq_length is not None else key_states.size(2) + \ + self.KV_ALLOC_BLOCK_LENGTH k_cache, v_cache = init_fused_kv_cache( batch_size, num_heads, head_dim, 0, max_len, @@ -107,7 +108,8 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + """Returns the sequence length of the cached states. + A layer index can be optionally passed.""" for idx, layer in self.key_cache.items(): return layer.shape[-2] diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama.py b/python/llm/src/ipex_llm/transformers/npu_models/llama.py index 7323b9e2915f..a322d731e51c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama.py @@ -232,7 +232,7 @@ def llama_fused_model_forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) @@ -247,21 +247,17 @@ def llama_fused_model_forward( seq_len = hidden_states.size(1) if seq_len == 1: - # assert hasattr(self, "multi_decoder") # multi_decoder = self.layers[(self.layer_end + 1) % num_layers] layer_outputs = self.multi_decoder(hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position,) + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position,) hidden_states = layer_outputs[0] - assert use_cache next_decoder_cache = layer_outputs[1] - - assert not output_attentions else: for decoder_layer in self.layers: if output_hidden_states: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py index 91d5c34313ba..69614439338c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py @@ -276,7 +276,7 @@ def pipeline_parallel_generate(self, bs = inputs_tensor.shape[0] if model_kwargs.get("attention_mask", None) is None: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id) + inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id) if self.config.is_encoder_decoder: input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( batch_size=bs, @@ -289,7 +289,7 @@ def pipeline_parallel_generate(self, else: input_ids = inputs_tensor if model_input_name == "input_ids" \ else model_kwargs.pop("input_ids") - + local_rank = dist.get_rank() pre_rank = (local_rank - 1) % self.pipeline_parallel_stages next_rank = (local_rank + 1) % self.pipeline_parallel_stages @@ -325,7 +325,7 @@ def pipeline_parallel_generate(self, if _input_ids is None: _input_ids = input_ids - + model_inputs = self.prepare_inputs_for_generation(output_ids, **model_kwargs) tic = time.time() @@ -360,8 +360,8 @@ def pipeline_parallel_generate(self, output_ids = torch.cat([output_ids, next_ids], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) # finished sentences should have their next token be a padding token next_ids = next_ids.squeeze() @@ -602,7 +602,7 @@ def glm4_conditional_generation_forward_lowmem( hidden_states = transformer_outputs[0] if return_last_logit: hidden_states = hidden_states[:, -1:] - + device = hidden_states.device # ipex-llm change starts if device.type == "xpu":