diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 5c5ca0ce195..8507c13c0c7 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -48,70 +48,11 @@ import torch.multiprocessing as mp from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast +from ipex_llm.transformers.npu_models.mp_models_base import run_model +from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory -@torch.no_grad() -def run_model( - x: Union[torch.Tensor, List[torch.Tensor]], - weights: List[torch.Tensor], - backend_cls: Any, - op_id: str, - replica: int = 1, -) -> torch.Tensor: - global _model_cache - import time - - t0 = time.perf_counter() - - # Use or not op_id depending on the class used - op_kwargs = {"op_id": op_id} if op_id else {} - - if not isinstance(x, (list, tuple)): - x = [x] - - # Reshape input - input_dtype = x[0].dtype - x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] - op_args = [] - op_args_flatten = [] - for w in weights: - if isinstance(w, tuple): # from QuantizedLinear - op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) - op_args_flatten.append(op_args[-1][0]) - op_args_flatten.append(op_args[-1][1]) - else: - op_args.append(set_contiguous(w).to(torch.float16).numpy()) - op_args_flatten.append(op_args[-1]) - - shape_dtype_signature = "_".join( - ["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten] - ) - key = f"{backend_cls.func.__name__}_{shape_dtype_signature}" - models = _model_cache.get(key, None) - - input_shapes = [elem.shape for elem in x_np] - if models is None: - _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(replica)]) - elif len(models) < 1: - _model_cache[key].append(backend_cls(*input_shapes)) - else: - _model_cache[key].rotate(1) - - # Get the model - model = _model_cache[key][0] - - with record_function(f"npu_factory_mul_{key}"): - ret = model.run(x_np, *op_args, **op_kwargs) - - if isinstance(ret, list): - results = [adapt_output_tensor(r, r.shape, input_dtype) for r in ret] - else: - results = adapt_output_tensor(ret, ret.shape, input_dtype) - - return results - - -class LowBitLlamaMultiDecoderlayer(NNFactory): +class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): def __init__( self, # batch_size: int, @@ -135,7 +76,11 @@ def __init__( rms_norm_eps, intermediate_size, ): - super().__init__(profile, device) + super().__init__(max_seq_len=max_seq_len, + transpose_value=transpose_value, + dtype=dtype, + profile=profile, + device=device) self.max_seq_len = max_seq_len self.intermediate_size = intermediate_size self.dtype = dtype @@ -145,6 +90,7 @@ def __init__( self.mode = mode self.rms_norm_eps = rms_norm_eps self.transpose_value = transpose_value + self.num_layers = num_layers cos = self.constant(self.cached_cos) self.cos = self.unsqueeze(cos, axis=0) @@ -164,28 +110,28 @@ def __init__( self.num_key_value_groups = self.num_heads // self.num_key_value_heads # define input, the order self.parameter matters - input = self.parameter((self.batch_size, self.seq_len, self.hidden_size)) + input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) # Self Attention if mode == "decode": - attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1)) + attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1)) else: - attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len)) + attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len)) - position_ids = self.parameter((self.batch_size, self.seq_len)) + position_ids = self.create_input_op((self.batch_size, self.seq_len)) past_keys = [] past_values = [] if mode == "decode": for i in range(num_layers): - past_key = self.parameter( + past_key = self.create_cache_op( (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim) ) if transpose_value: - past_value = self.parameter( + past_value = self.create_cache_op( (self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len) ) else: - past_value = self.parameter( + past_value = self.create_cache_op( (self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim) ) past_keys.append(past_key) @@ -199,7 +145,7 @@ def __init__( post_attn_layernorm_weights = [] for i in range(num_layers): input_layernorm_weights.append( - self.parameter( + self.create_input_op( ( 1, self.hidden_size, @@ -207,7 +153,7 @@ def __init__( ) ) post_attn_layernorm_weights.append( - self.parameter( + self.create_input_op( ( 1, self.hidden_size, @@ -243,37 +189,6 @@ def __init__( print("start compiling") self.compile() - def repeat_kv(self, hidden_states, n_rep, transpose=False): - if n_rep == 1: - return hidden_states - if not transpose: - hidden_states = self.reshape( - hidden_states, - [self.batch_size, self.num_key_value_heads, 1, self.kv_seq_len, self.head_dim], - ) - hidden_states = self.broadcast( - hidden_states, - [self.batch_size, self.num_key_value_heads, n_rep, self.kv_seq_len, self.head_dim], - ) - hidden_states = self.reshape( - hidden_states, - [self.batch_size, n_rep * self.num_key_value_heads, self.kv_seq_len, self.head_dim], - ) - else: - hidden_states = self.reshape( - hidden_states, - [self.batch_size, self.num_key_value_heads, 1, self.head_dim, self.kv_seq_len], - ) - hidden_states = self.broadcast( - hidden_states, - [self.batch_size, self.num_key_value_heads, n_rep, self.head_dim, self.kv_seq_len], - ) - hidden_states = self.reshape( - hidden_states, - [self.batch_size, n_rep * self.num_key_value_heads, self.head_dim, self.kv_seq_len], - ) - return hidden_states - def build_decoder( self, hidden_states, @@ -286,157 +201,31 @@ def build_decoder( ): residual = hidden_states - input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size)) - - # input layernorm - input_2d = self.convert_to_fp32(input_2d) - variance = self.reduce_mean( - self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), - -1, - keep_dims=True, - ) - eps = self.constant(self.rms_norm_eps) - input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps))) - input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight) - input_2d = self.eltwise_mul(input_layernorm_weight, input_2d) - input_2d = self.convert_to_fp16(input_2d) - - # attention - query_states = self.linear( - input_2d, - self.num_heads * self.head_dim, - self.hidden_size, - bias=False, - wt_dtype=self.dtype, - ) - key_states = self.linear( - input_2d, - self.num_key_value_heads * self.head_dim, - self.hidden_size, - bias=False, - wt_dtype=self.dtype, - ) - value_states = self.linear( - input_2d, - self.num_key_value_heads * self.head_dim, - self.hidden_size, - bias=False, - wt_dtype=self.dtype, - ) - - query_states = self.reshape( - query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim] - ) - key_states = self.reshape( - key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim] - ) - value_states = self.reshape( - value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim] - ) - - query_states = self.transpose(query_states, [0, 2, 1, 3]) - key_states = self.transpose(key_states, [0, 2, 1, 3]) - if self.transpose_value: - value_states = self.transpose(value_states, [0, 2, 3, 1]) - else: - value_states = self.transpose(value_states, [0, 2, 1, 3]) - - query_states, key_states = self.apply_rotary_pos_emb( - query_states, key_states, self.cos, self.sin, position_ids - ) - new_key_states = key_states - new_value_states = value_states - - if self.mode == "decode": - key_states = self.concat(past_key, key_states, axis=-2) - if self.transpose_value: - value_states = self.concat(past_value, value_states, axis=-1) - else: - value_states = self.concat(past_value, value_states, axis=-2) - - key_states = self.repeat_kv(key_states, self.num_key_value_groups) - value_states = self.repeat_kv(value_states, self.num_key_value_groups, self.transpose_value) - - attn_weight = self.matmul(query_states, key_states, False, True) / ( - math.sqrt(self.head_dim) - ) - attn_weight = self.eltwise_add(attn_weight, attention_mask) - attn_weight = self.convert_to_fp32(attn_weight) - attn_weight = self.softmax(attn_weight, -1) - attn_weight = self.convert_to_fp16(attn_weight) - attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) - - attn_output = self.transpose(attn_output, [0, 2, 1, 3]) - attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size]) - - attn_output = self.linear( - attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype + input_2d = self.layer_norm(input_2d, input_layernorm_weight) + attn_output, new_key_states, new_value_states = self.attention( + hidden_states=input_2d, + position_ids=position_ids, + attention_mask=attention_mask, + past_key=past_key, + past_value=past_value, + cos=self.cos, + sin=self.sin, + mode=self.mode, + num_heads=self.num_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + seq_len=self.seq_len, ) - hidden_states = self.eltwise_add(residual, attn_output) - - # Fully Connected residual = hidden_states - # post_attention_layernorm forward - - hidden_states = self.convert_to_fp32(hidden_states) - variance = self.reduce_mean( - self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), - -1, - keep_dims=True, - ) - hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) - post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight) - hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states) - hidden_states = self.convert_to_fp16(hidden_states) - - # mlp - mm1 = self.linear( - hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype - ) - mm2 = self.linear( - hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype - ) # type: ignore[attr-defined] - mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] - hidden_states = self.linear( - mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype - ) - + hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) + hidden_states = self.mlp(hidden_states) hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) return hidden_states, new_key_states, new_value_states - def rotate_half(self, x): - x1 = self.slice( - x, - [0, 0, 0, 0], - [self.batch_size, self.num_heads, self.seq_len, self.head_dim // 2], - ) - x2 = self.slice( - x, - [0, 0, 0, self.head_dim // 2], - [self.batch_size, self.num_heads, self.seq_len, self.head_dim], - ) - return self.concat(self.negative(x2), x1, axis=-1) - - def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids): - position_ids = self.squeeze(position_ids) - cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) - sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) - cos = self.unsqueeze(cos, [1]) - sin = self.unsqueeze(sin, [1]) - - q_embed = self.eltwise_add( - self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin) - ) - k_embed = self.eltwise_add( - self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin) - ) - - return q_embed, k_embed - class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): @@ -479,8 +268,6 @@ def __init__( self.intra_stages = intra_stages self.layer_indexes = layer_indexes - self.num_layers_1 = len(self.layer_indexes) // 2 - self.num_layers_0 = len(self.layer_indexes) - self.num_layers_1 num_layers = len(self.layer_indexes) // intra_stages self.layer_ranges = [] for i in range(intra_stages): @@ -515,16 +302,7 @@ def __init__( for i in range(intra_stages): start, end = self.layer_ranges[i] - num_intra_layers = end - start - self.backend_decoders[i].setWeights( - 3 + (num_intra_layers) * 2, self.op_id, *op_parameters[start * 7:end * 7] - ) - with FileLock(f"decoder_run.lock"): - backend_lib.run(self.backend_decoders[i]._mm) - - self.kv_cache_c_parameter_handel = [] - self.kv_cache_parameters = [] - self.kv_cache_prefetched = False + self.backend_decoders[i].set_weights(self.op_id, op_parameters[start * 7:end * 7]) def forward( self, @@ -544,76 +322,22 @@ def forward( position_ids, ) - if len(self.kv_cache_parameters) > 0: - # the case kv cache changed - cached_prt = self.kv_cache_parameters[0].storage().data_ptr() - current_ptr = past_key_value.key_cache[self.layer_indexes[0]].storage().data_ptr() - if cached_prt != current_ptr: - self.kv_cache_parameters = [] - self.kv_cache_c_parameter_handel = [] - self.kv_cache_prefetched = False - - if len(self.kv_cache_parameters) == 0: - for idx in self.layer_indexes: - past_key = past_key_value.key_cache[idx] - past_value = past_key_value.value_cache[idx] - - invalidInputError( - past_key.dtype == torch.float16, f"past_key dtype is {past_key.dtype}" - ) - - new_size = (past_key.size(0), past_key.size(1), self.max_seq_len, past_key.size(3)) - past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0) - invalidInputError(past_key.is_contiguous(), "past_key is not contiguous") - past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0) - if self.transpose_value: - past_value = past_value.transpose(-1, -2) - invalidInputError(past_value.is_contiguous(), "past_value is not contiguous") - - self.kv_cache_parameters.append(past_key) - self.kv_cache_parameters.append(past_value) - - for i in range(self.intra_stages): - start, end = self.layer_ranges[i] - layer_kv_cache = self.kv_cache_parameters[start * 2:end * 2] - layer_kv_cache = [p.numpy() for p in layer_kv_cache] - handle = self.backend_decoders[i].create_parameters(layer_kv_cache) - self.kv_cache_c_parameter_handel.append(handle) - - x_np = [elem.to(torch.float16).numpy() for elem in inputs] - - with record_function(f"npu_factory"): - if not self.kv_cache_prefetched: - for i in range(self.intra_stages): - self.backend_decoders[i].load_wt_fn( - len(inputs), - self.backend_decoders[i]._mm, - self.kv_cache_c_parameter_handel[i], - ) - - array_type = ctypes.POINTER(ctypes.c_char) * self.intra_stages - models_ptr = array_type( - *[self.backend_decoders[i]._mm for i in range(self.intra_stages)] - ) - inputs_ptr = (ctypes.c_void_p * 3)( - x_np[0].ctypes.data_as(ctypes.c_void_p), - x_np[1].ctypes.data_as(ctypes.c_void_p), - x_np[2].ctypes.data_as(ctypes.c_void_p), - ) - t0 = time.perf_counter() - backend_lib.run_decoders(models_ptr, inputs_ptr, self.intra_stages, 3) - t1 = time.perf_counter() + for i in range(self.intra_stages): + start, end = self.layer_ranges[i] + self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end]) - hidden_states = self.backend_decoders[-1].torch_out[0] + hidden_states, new_keys, new_values = LowBitLlamaMultiDecoderlayer.run_decoders( + inputs, + decoders=self.backend_decoders) if self.do_print: print("outputs:", hidden_states) outputs = (hidden_states,) - outputs += (past_key_value,) - return outputs, t1 - t0 + outputs += (past_key_value, new_keys, new_values) + return outputs - def post_forward(self, past_key_value, cache_position): + def post_forward(self, past_key_value, new_keys, new_values, cache_position): key_value_states = [] for i in range(self.intra_stages): for j in range(1, len(self.backend_decoders[i].torch_out)): @@ -626,17 +350,14 @@ def post_forward(self, past_key_value, cache_position): } for i in range(len(self.layer_indexes)): key_states, value_states = past_key_value.update( - key_value_states[2 * i], - key_value_states[2 * i + 1], + new_keys[i], + new_values[i], self.layer_indexes[i], cache_kwargs, ) for i in range(self.intra_stages): - self.backend_decoders[i].load_wt_fn( - 3, self.backend_decoders[i]._mm, self.kv_cache_c_parameter_handel[i] - ) - self.kv_cache_prefetched = True + self.backend_decoders[i].load_cache_async() class FusedLlamaLowBitDecoderlayer(torch.nn.Module): @@ -843,7 +564,7 @@ def run_decode( padded_causal_mask[:, :, :, -1] = 0.0 dist.recv(hidden_states, src=rank - 1) t1 = time.perf_counter() - layer_outputs, elapse = multi_decoder( + layer_outputs = multi_decoder( hidden_states, attention_mask=padded_causal_mask, position_ids=position_ids, @@ -857,7 +578,10 @@ def run_decode( t3 = time.perf_counter() dist.send(hidden_states, dst=(rank + 1) % world_size) t4 = time.perf_counter() - multi_decoder.post_forward(past_key_values, cache_position) + past_key_values = layer_outputs[1] + new_keys = layer_outputs[2] + new_values = layer_outputs[3] + multi_decoder.post_forward(past_key_values, new_keys, new_values, cache_position) class DecodeRunner: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py new file mode 100644 index 00000000000..a3abcca83ce --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -0,0 +1,418 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from intel_npu_acceleration_library.backend.factory import NNFactory +from typing import List, Union, Any +from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function +from intel_npu_acceleration_library.backend.runtime import adapt_output_tensor, _model_cache +from collections import deque +from intel_npu_acceleration_library.backend.bindings import lib as backend_lib +from ipex_llm.utils.common import invalidInputError +from transformers.utils import logging +from filelock import FileLock +import ctypes +import math +import numpy as np + +logger = logging.get_logger(__name__) + + +@torch.no_grad() +def run_model( + x: Union[torch.Tensor, List[torch.Tensor]], + weights: List[torch.Tensor], + backend_cls: Any, + op_id: str, + replica: int = 1, +) -> torch.Tensor: + global _model_cache + import time + + t0 = time.perf_counter() + + # Use or not op_id depending on the class used + op_kwargs = {"op_id": op_id} if op_id else {} + + if not isinstance(x, (list, tuple)): + x = [x] + + # Reshape input + input_dtype = x[0].dtype + x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] + op_args = [] + op_args_flatten = [] + for w in weights: + if isinstance(w, tuple): # from QuantizedLinear + op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) + op_args_flatten.append(op_args[-1][0]) + op_args_flatten.append(op_args[-1][1]) + else: + op_args.append(set_contiguous(w).to(torch.float16).numpy()) + op_args_flatten.append(op_args[-1]) + + shape_dtype_signature = "_".join( + ["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten] + ) + key = f"{backend_cls.func.__name__}_{shape_dtype_signature}" + models = _model_cache.get(key, None) + + input_shapes = [elem.shape for elem in x_np] + if models is None: + _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(replica)]) + elif len(models) < 1: + _model_cache[key].append(backend_cls(*input_shapes)) + else: + _model_cache[key].rotate(1) + + # Get the model + model = _model_cache[key][0] + + with record_function(f"npu_factory_mul_{key}"): + ret = model.run(x_np, *op_args, **op_kwargs) + + if isinstance(ret, list): + results = [adapt_output_tensor(r, r.shape, input_dtype) for r in ret] + else: + results = adapt_output_tensor(ret, ret.shape, input_dtype) + + return results + + +class LLMBaseNNFactory(NNFactory): + + def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU"): + super().__init__(profile, device) + self.cache_parameter_ops = [] + self.input_ops = [] + self.linear_ops = [] + self.kv_cache_c_handle = None + self.kv_cache_torch = [] + self.max_seq_len = max_seq_len + self.transpose_value = transpose_value + self.dtype = dtype + + def attention(self, + *, + hidden_states, + position_ids, + attention_mask, + past_key, + past_value, + cos, + sin, + mode, + num_heads, + num_key_value_heads, + head_dim, + seq_len): + hidden_size = num_heads * head_dim + num_key_value_groups = num_heads // num_key_value_heads + query_states = self.linear( + hidden_states, + num_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + ) + key_states = self.linear( + hidden_states, + num_key_value_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + ) + value_states = self.linear( + hidden_states, + num_key_value_heads * head_dim, + hidden_size, + bias=False, + wt_dtype=self.dtype, + ) + + query_states = self.reshape( + query_states, [1, seq_len, num_heads, head_dim] + ) + key_states = self.reshape( + key_states, [1, seq_len, num_key_value_heads, head_dim] + ) + value_states = self.reshape( + value_states, [1, seq_len, num_key_value_heads, head_dim] + ) + + query_states = self.transpose(query_states, [0, 2, 1, 3]) + key_states = self.transpose(key_states, [0, 2, 1, 3]) + if self.transpose_value: + value_states = self.transpose(value_states, [0, 2, 3, 1]) + else: + value_states = self.transpose(value_states, [0, 2, 1, 3]) + + query_states, key_states = self.apply_rotary_pos_emb( + q=query_states, + k=key_states, + cos=cos, + sin=sin, + position_ids=position_ids, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, + ) + new_key_states = key_states + new_value_states = value_states + + if mode == "decode": + key_states = self.concat(past_key, key_states, axis=-2) + if self.transpose_value: + value_states = self.concat(past_value, value_states, axis=-1) + else: + value_states = self.concat(past_value, value_states, axis=-2) + kv_seq_len = self.max_seq_len + 1 + else: + kv_seq_len = seq_len + + key_states = self.repeat_kv(hidden_states=key_states, + n_rep=num_key_value_groups, + num_key_value_heads=num_key_value_heads, + kv_seq_len=kv_seq_len, + head_dim=head_dim,) + value_states = self.repeat_kv(hidden_states=value_states, + n_rep=num_key_value_groups, + num_key_value_heads=num_key_value_heads, + kv_seq_len=kv_seq_len, + head_dim=head_dim,) + attn_weight = self.matmul(query_states, key_states, False, True) / ( + math.sqrt(head_dim) + ) + attn_weight = self.eltwise_add(attn_weight, attention_mask) + attn_weight = self.convert_to_fp32(attn_weight) + attn_weight = self.softmax(attn_weight, -1) + attn_weight = self.convert_to_fp16(attn_weight) + attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) + + attn_output = self.transpose(attn_output, [0, 2, 1, 3]) + attn_output = self.reshape(attn_output, [1, seq_len, hidden_size]) + + attn_output = self.linear( + attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype + ) + + return attn_output, new_key_states, new_value_states + + def mlp(self, hidden_states): + mm1 = self.linear( + hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype + ) + mm2 = self.linear( + hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype + ) # type: ignore[attr-defined] + mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] + hidden_states = self.linear( + mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype + ) + return hidden_states + + def layer_norm(self, hidden_states, layernorm_weight): + hidden_states = self.convert_to_fp32(hidden_states) + variance = self.reduce_mean( + self.power(hidden_states, self.constant(np.array([[2]], dtype=np.float32))), + -1, + keep_dims=True, + ) + eps = self.constant(self.rms_norm_eps) + hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) + layernorm_weight = self.convert_to_fp32(layernorm_weight) + hidden_states = self.eltwise_mul(layernorm_weight, hidden_states) + hidden_states = self.convert_to_fp16(hidden_states) + return hidden_states + + def rotate_half(self, x, *, num_heads, seq_len, head_dim): + x1 = self.slice( + x, + [0, 0, 0, 0], + [1, num_heads, seq_len, head_dim // 2], + ) + x2 = self.slice( + x, + [0, 0, 0, head_dim // 2], + [1, num_heads, seq_len, head_dim], + ) + return self.concat(self.negative(x2), x1, axis=-1) + + def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids, + num_heads, seq_len, head_dim): + position_ids = self.squeeze(position_ids) + cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) + sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) + cos = self.unsqueeze(cos, [1]) + sin = self.unsqueeze(sin, [1]) + + rotate_half_q = self.rotate_half(q, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim) + rotate_half_k = self.rotate_half(k, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim) + + q_embed = self.eltwise_add( + self.eltwise_mul(q, cos), self.eltwise_mul(rotate_half_q, sin) + ) + k_embed = self.eltwise_add( + self.eltwise_mul(k, cos), self.eltwise_mul(rotate_half_k, sin) + ) + + return q_embed, k_embed + + def repeat_kv(self, *, hidden_states, n_rep, num_key_value_heads, + kv_seq_len, head_dim, transpose=False): + if n_rep == 1: + return hidden_states + if not transpose: + hidden_states = self.reshape( + hidden_states, + [1, num_key_value_heads, 1, kv_seq_len, head_dim], + ) + hidden_states = self.broadcast( + hidden_states, + [1, num_key_value_heads, n_rep, kv_seq_len, head_dim], + ) + hidden_states = self.reshape( + hidden_states, + [1, n_rep * num_key_value_heads, kv_seq_len, head_dim], + ) + else: + hidden_states = self.reshape( + hidden_states, + [1, num_key_value_heads, 1, head_dim, kv_seq_len], + ) + hidden_states = self.broadcast( + hidden_states, + [1, num_key_value_heads, n_rep, head_dim, kv_seq_len], + ) + hidden_states = self.reshape( + hidden_states, + [1, n_rep * num_key_value_heads, head_dim, kv_seq_len], + ) + return hidden_states + + def create_cache_op(self, shape): + invalidInputError(len(self.linear_ops) == 0, + "create_cache_op should be called before any linear op") + op = super().parameter(shape) + self.cache_parameter_ops.append(op) + return op + + def create_input_op(self, shape): + invalidInputError(len(self.cache_parameter_ops) == 0, + "create_input_op should be called before any create_cache_op") + invalidInputError(len(self.linear_ops) == 0, + "create_input_op should be called before any linear op") + + op = super().parameter(shape) + self.input_ops.append(op) + return op + + def linear(self, *args, **kwargs): + op = super().linear(*args, **kwargs) + self.linear_ops.append(op) + return op + + def parameter(self, shape): + invalidInputError(False, + ("parameter should not be called directly, " + "use create_cache_op or create_input_op instead")) + + def update_cache(self, past_key_value, indexes): + + if self.kv_cache_c_handle is not None: + curr_ptr = self.kv_cache_torch[0].storage().data_ptr() + new_ptr = past_key_value.key_cache[indexes[0]].storage().data_ptr() + if curr_ptr != new_ptr: + backend_lib.destroyParameters(self.kv_cache_c_handle) + self.kv_cache_c_handle = None + self.kv_cache_torch = [] + if self.kv_cache_c_handle is None: + for idx in indexes: + past_key = past_key_value.key_cache[idx] + past_value = past_key_value.value_cache[idx] + invalidInputError( + past_key.dtype == torch.float16, f"past_key dtype is {past_key.dtype}" + ) + new_size = (past_key.size(0), past_key.size(1), self.max_seq_len, past_key.size(3)) + past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0) + invalidInputError(past_key.is_contiguous(), "past_key is not contiguous") + past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0) + if self.transpose_value: + past_value = past_value.transpose(-1, -2) + invalidInputError(past_value.is_contiguous(), "past_value is not contiguous") + + self.kv_cache_torch.append(past_key) + self.kv_cache_torch.append(past_value) + + layer_kv_cache_np = [p.numpy() for p in self.kv_cache_torch] + invalidInputError(len(self.cache_parameter_ops) == len(layer_kv_cache_np), + (f"kv_cache size does not match graph, " + f"with kv_cache size: {len(layer_kv_cache_np)} and" + f" graph size: {len(self.cache_parameter_ops)}") + ) + self.kv_cache_c_handle = self.create_parameters(layer_kv_cache_np) + self.load_cache_async() + + def load_cache_async(self): + self.load_wt_fn(len(self.input_ops), self._mm, self.kv_cache_c_handle) + + def set_weights(self, op_id, weights): + self.set_weights_async(op_id, weights) + with FileLock(f"decoder_run.lock"): + backend_lib.run(self._mm) + + def set_weights_async(self, op_id, weights): + offset = len(self.input_ops) + len(self.cache_parameter_ops) + invalidInputError(len(weights) == len(self.linear_ops), + (f"weights size does not match graph, " + f"with weights size: {len(weights)} and " + f" graph linear size: {len(self.linear_ops)}")) + self.setWeights(offset, op_id, *weights) + + @staticmethod + def run_decoders(inputs, decoders): + x_np = [elem.to(torch.float16).numpy() for elem in inputs] + + num_decoders = len(decoders) + num_inputs = len(x_np) + + with record_function(f"npu_factory"): + + array_type = ctypes.POINTER(ctypes.c_char) * num_decoders + models_ptr = array_type( + *[decoders[i]._mm for i in range(num_decoders)] + ) + inputs_ptr = (ctypes.c_void_p * num_inputs)( + *[x.ctypes.data_as(ctypes.c_void_p) for x in x_np] + ) + backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs) + + hidden_states = decoders[-1].torch_out[0] + new_key_states = [] + new_value_states = [] + for i in range(num_decoders): + for j in range(1, len(decoders[i].torch_out)): + if j % 2 == 1: + new_key_states.append(decoders[i].torch_out[j]) + else: + new_value_states.append(decoders[i].torch_out[j]) + return hidden_states, new_key_states, new_value_states