diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index b5a09dcbf40..25cb790db99 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -48,70 +48,11 @@ import torch.multiprocessing as mp from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast +from ipex_llm.transformers.npu_models.mp_models_base import run_model +from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory -@torch.no_grad() -def run_model( - x: Union[torch.Tensor, List[torch.Tensor]], - weights: List[torch.Tensor], - backend_cls: Any, - op_id: str, - replica: int = 1, -) -> torch.Tensor: - global _model_cache - import time - - t0 = time.perf_counter() - - # Use or not op_id depending on the class used - op_kwargs = {"op_id": op_id} if op_id else {} - - if not isinstance(x, (list, tuple)): - x = [x] - - # Reshape input - input_dtype = x[0].dtype - x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] - op_args = [] - op_args_flatten = [] - for w in weights: - if isinstance(w, tuple): # from QuantizedLinear - op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) - op_args_flatten.append(op_args[-1][0]) - op_args_flatten.append(op_args[-1][1]) - else: - op_args.append(set_contiguous(w).to(torch.float16).numpy()) - op_args_flatten.append(op_args[-1]) - - shape_dtype_signature = "_".join( - ["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten] - ) - key = f"{backend_cls.func.__name__}_{shape_dtype_signature}" - models = _model_cache.get(key, None) - - input_shapes = [elem.shape for elem in x_np] - if models is None: - _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(replica)]) - elif len(models) < 1: - _model_cache[key].append(backend_cls(*input_shapes)) - else: - _model_cache[key].rotate(1) - - # Get the model - model = _model_cache[key][0] - - with record_function(f"npu_factory_mul_{key}"): - ret = model.run(x_np, *op_args, **op_kwargs) - - if isinstance(ret, list): - results = [adapt_output_tensor(r, r.shape, input_dtype) for r in ret] - else: - results = adapt_output_tensor(ret, ret.shape, input_dtype) - - return results - - -class LowBitLlamaMultiDecoderlayer(NNFactory): +class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory): def __init__( self, # batch_size: int, @@ -135,7 +76,11 @@ def __init__( rms_norm_eps, intermediate_size, ): - super().__init__(profile, device) + super().__init__(max_seq_len=max_seq_len, + transpose_value=transpose_value, + dtype=dtype, + profile=profile, + device=device) self.max_seq_len = max_seq_len self.intermediate_size = intermediate_size self.dtype = dtype @@ -145,6 +90,7 @@ def __init__( self.mode = mode self.rms_norm_eps = rms_norm_eps self.transpose_value = transpose_value + self.num_layers = num_layers cos = self.constant(self.cached_cos) self.cos = self.unsqueeze(cos, axis=0) @@ -158,34 +104,33 @@ def __init__( self.kv_seq_len = self.seq_len self.num_heads = num_heads - # self.num_key_value_heads = num_key_value_heads self.head_dim = self.hidden_size // self.num_heads - # self.num_key_value_groups = self.num_heads // self.num_key_value_heads # define input, the order self.parameter matters - input = self.parameter((self.batch_size, self.seq_len, self.hidden_size)) + input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) # Self Attention if mode == "decode": - attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1)) + attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1)) else: - attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len)) + attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len)) - position_ids = self.parameter((self.batch_size, self.seq_len)) + position_ids = self.create_input_op((self.batch_size, self.seq_len)) + # self.num_key_value_heads = num_key_value_heads past_keys = [] past_values = [] if mode == "decode": for i in range(num_layers): - past_key = self.parameter( + past_key = self.create_cache_op( (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) ) if transpose_value: - past_value = self.parameter( + past_value = self.create_cache_op( (self.batch_size, self.num_heads, self.head_dim, self.max_seq_len) ) else: - past_value = self.parameter( + past_value = self.create_cache_op( (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) ) past_keys.append(past_key) @@ -199,7 +144,7 @@ def __init__( post_attn_layernorm_weights = [] for i in range(num_layers): input_layernorm_weights.append( - self.parameter( + self.create_input_op( ( 1, self.hidden_size, @@ -207,7 +152,7 @@ def __init__( ) ) post_attn_layernorm_weights.append( - self.parameter( + self.create_input_op( ( 1, self.hidden_size, @@ -243,63 +188,56 @@ def __init__( print("start compiling") self.compile() - def build_decoder( - self, - hidden_states, - attention_mask, - position_ids, - input_layernorm_weight, - post_attention_layernorm_weight, - past_key=None, - past_value=None, - ): - - residual = hidden_states - - input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size)) - - # input layernorm - input_2d = self.convert_to_fp32(input_2d) - variance = self.reduce_mean( - self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), - -1, - keep_dims=True, + def attention(self, + *, + hidden_states, + position_ids, + attention_mask, + past_key, + past_value, + cos, + sin, + mode, + num_heads, + head_dim, + seq_len, + q_bias=None, + k_bias=None, + v_bias=None): + hidden_size = num_heads * head_dim + proj = self.linear( + hidden_states, + 3 * hidden_size, + hidden_size, + bias=False, + wt_dtype=self.dtype ) - eps = self.constant(self.rms_norm_eps) - input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps))) - input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight) - input_2d = self.eltwise_mul(input_layernorm_weight, input_2d) - input_2d = self.convert_to_fp16(input_2d) - - # attention - proj = self.linear(input_2d, 3 * self.hidden_size, - self.hidden_size, bias=False, wt_dtype=self.dtype) - # proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) - proj = self.reshape(proj, [-1, 3, self.hidden_size]) # b*s, 3, h + proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h proj = self.unsqueeze(proj, [0]) # b, s, 3, h proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h proj = self.squeeze(proj) # 3, b*s, h - proj = self.unsqueeze(proj, [1]) - # query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = self.reshape(proj[0, ...], [self.batch_size, - self.seq_len, self.num_heads, self.head_dim]) + query_states = self.reshape(proj[0, ...], [1, self.seq_len, num_heads, head_dim]) query_states = self.transpose(query_states, [0, 2, 1, 3]) - # key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.reshape(proj[1, ...], [self.batch_size, - self.seq_len, self.num_heads, self.head_dim]) + key_states = self.reshape(proj[1, ...], [1, self.seq_len, num_heads, head_dim]) key_states = self.transpose(key_states, [0, 2, 1, 3]) - # value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.reshape(proj[2, ...], [self.batch_size, - self.seq_len, self.num_heads, self.head_dim]) + value_states = self.reshape(proj[2, ...], [1, self.seq_len, num_heads, head_dim]) if self.transpose_value: value_states = self.transpose(value_states, [0, 2, 3, 1]) else: value_states = self.transpose(value_states, [0, 2, 1, 3]) - cos = self.unsqueeze(self.squeeze(self.cos), [0]) - sin = self.unsqueeze(self.squeeze(self.sin), [0]) + cos = self.unsqueeze(self.squeeze(cos), [0]) + sin = self.unsqueeze(self.squeeze(sin), [0]) + query_states, key_states = self.apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids + q=query_states, + k=key_states, + cos=cos, + sin=sin, + position_ids=position_ids, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, ) new_key_states = key_states new_value_states = value_states @@ -320,95 +258,55 @@ def build_decoder( attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) attn_output = self.transpose(attn_output, [0, 2, 1, 3]) - attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size]) + attn_output = self.reshape(attn_output, [1, seq_len, hidden_size]) attn_output = self.linear( - attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype + attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype ) + return attn_output, new_key_states, new_value_states - hidden_states = self.eltwise_add(residual, attn_output) + def build_decoder( + self, + hidden_states, + attention_mask, + position_ids, + input_layernorm_weight, + post_attention_layernorm_weight, + past_key=None, + past_value=None, + ): - # Fully Connected residual = hidden_states - # post_attention_layernorm forward - hidden_states = self.convert_to_fp32(hidden_states) - variance = self.reduce_mean( - self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), - -1, - keep_dims=True, - ) - hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) - post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight) - hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states) - hidden_states = self.convert_to_fp16(hidden_states) + input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size)) + input_2d = self.layer_norm(input_2d, input_layernorm_weight) - # mlp - # gate proj - mm1 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, - bias=False, wt_dtype=self.dtype) - # up proj - mm2 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, - bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined] - mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] - # down proj - hidden_states = self.linear(mm1, self.hidden_size, - self.intermediate_size, bias=False, wt_dtype=self.dtype) + # attention + attn_output, new_key_states, new_value_states = self.attention( + hidden_states=input_2d, + position_ids=position_ids, + attention_mask=attention_mask, + past_key=past_key, + past_value=past_value, + cos=self.cos, + sin=self.sin, + mode=self.mode, + num_heads=self.num_heads, + head_dim=self.head_dim, + seq_len=self.seq_len, + ) + hidden_states = self.eltwise_add(residual, attn_output) + residual = hidden_states + hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) + hidden_states = self.mlp(hidden_states) hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) return hidden_states, new_key_states, new_value_states - def rotate_half(self, x): - x1 = self.slice( - x, - [0, 0, 0, 0], - [self.batch_size, self.num_heads, self.seq_len, self.head_dim // 2], - ) - x2 = self.slice( - x, - [0, 0, 0, self.head_dim // 2], - [self.batch_size, self.num_heads, self.seq_len, self.head_dim], - ) - return self.concat(self.negative(x2), x1, axis=-1) - - def apply_rotary_pos_emb2(self, q, k, cos, sin, position_ids): - - cos = self.squeeze(cos) # [seq_len, dim] - sin = self.squeeze(sin) # [seq_len, dim] - # cos = cos[position_ids] - cos = self.unsqueeze(cos, [0, 1]) # [bs, 1, seq_len, dim] - # sin = sin[position_ids] - sin = self.unsqueeze(sin, [0, 1]) # [bs, 1, seq_len, dim] - - q_embed = self.eltwise_add( - self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin) - ) - k_embed = self.eltwise_add( - self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin) - ) - - return q_embed, k_embed - def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids): - position_ids = self.squeeze(position_ids) - cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) - sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) - cos = self.unsqueeze(cos, [1]) - sin = self.unsqueeze(sin, [1]) - - q_embed = self.eltwise_add( - self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin) - ) - k_embed = self.eltwise_add( - self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin) - ) - - return q_embed, k_embed - - -class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): +class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module): def __init__( self, @@ -449,8 +347,6 @@ def __init__( self.intra_stages = intra_stages self.layer_indexes = layer_indexes - self.num_layers_1 = len(self.layer_indexes) // 2 - self.num_layers_0 = len(self.layer_indexes) - self.num_layers_1 num_layers = len(self.layer_indexes) // intra_stages self.layer_ranges = [] for i in range(intra_stages): @@ -465,7 +361,7 @@ def __init__( start, end = self.layer_ranges[i] lm_0 = input_laynorm_weights[start:end] lm_1 = post_attn_layernorm_weights[start:end] - decoder = LowBitLlamaMultiDecoderlayer( + decoder = LowBitBaichuanMultiDecoderlayer( [1, 1, num_heads * head_dim], input_layernorm_weights=lm_0, post_attn_layernorm_weights=lm_1, @@ -485,16 +381,7 @@ def __init__( for i in range(intra_stages): start, end = self.layer_ranges[i] - num_intra_layers = end - start - self.backend_decoders[i].setWeights( - 3 + (num_intra_layers) * 2, self.op_id, *op_parameters[start * 5:end * 5] - ) - with FileLock(f"decoder_run.lock"): - backend_lib.run(self.backend_decoders[i]._mm) - - self.kv_cache_c_parameter_handel = [] - self.kv_cache_parameters = [] - self.kv_cache_prefetched = False + self.backend_decoders[i].set_weights(self.op_id, op_parameters[start * 5:end * 5]) def forward( self, @@ -512,102 +399,45 @@ def forward( position_ids, ) - if len(self.kv_cache_parameters) > 0: - # the case kv cache changed - cached_prt = self.kv_cache_parameters[0].storage().data_ptr() - current_ptr = past_key_value.key_cache[self.layer_indexes[0]].storage().data_ptr() - if cached_prt != current_ptr: - self.kv_cache_parameters = [] - self.kv_cache_c_parameter_handel = [] - self.kv_cache_prefetched = False - - if len(self.kv_cache_parameters) == 0: - for idx in self.layer_indexes: - past_key = past_key_value.key_cache[idx] - past_value = past_key_value.value_cache[idx] - - invalidInputError( - past_key.dtype == torch.float16, f"past_key dtype is {past_key.dtype}" - ) - - new_size = (past_key.size(0), past_key.size(1), self.max_seq_len, past_key.size(3)) - past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0) - invalidInputError(past_key.is_contiguous(), "past_key is not contiguous") - past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0) - if self.transpose_value: - past_value = past_value.transpose(-1, -2) - invalidInputError(past_value.is_contiguous(), "past_value is not contiguous") - - self.kv_cache_parameters.append(past_key) - self.kv_cache_parameters.append(past_value) - - for i in range(self.intra_stages): - start, end = self.layer_ranges[i] - layer_kv_cache = self.kv_cache_parameters[start * 2:end * 2] - layer_kv_cache = [p.numpy() for p in layer_kv_cache] - handle = self.backend_decoders[i].create_parameters(layer_kv_cache) - self.kv_cache_c_parameter_handel.append(handle) - - x_np = [elem.to(torch.float16).numpy() for elem in inputs] - - with record_function(f"npu_factory"): - if not self.kv_cache_prefetched: - for i in range(self.intra_stages): - self.backend_decoders[i].load_wt_fn( - len(inputs), - self.backend_decoders[i]._mm, - self.kv_cache_c_parameter_handel[i], - ) - - array_type = ctypes.POINTER(ctypes.c_char) * self.intra_stages - models_ptr = array_type( - *[self.backend_decoders[i]._mm for i in range(self.intra_stages)] - ) - inputs_ptr = (ctypes.c_void_p * 3)( - x_np[0].ctypes.data_as(ctypes.c_void_p), - x_np[1].ctypes.data_as(ctypes.c_void_p), - x_np[2].ctypes.data_as(ctypes.c_void_p), - ) - t0 = time.perf_counter() - backend_lib.run_decoders(models_ptr, inputs_ptr, self.intra_stages, 3) - t1 = time.perf_counter() + for i in range(self.intra_stages): + start, end = self.layer_ranges[i] + self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end]) - hidden_states = self.backend_decoders[-1].torch_out[0] + hidden_states, new_keys, new_values = LowBitBaichuanMultiDecoderlayer.run_decoders( + inputs, + decoders=self.backend_decoders) if self.do_print: print("outputs:", hidden_states) outputs = (hidden_states,) - outputs += (past_key_value,) - return outputs, t1 - t0 + outputs += (past_key_value, new_keys, new_values) + return outputs - def post_forward(self, past_key_value): + def post_forward(self, past_key_value, new_keys, new_values): key_value_states = [] for i in range(self.intra_stages): for j in range(1, len(self.backend_decoders[i].torch_out)): key_value_states.append(self.backend_decoders[i].torch_out[j]) cache_kwargs = { - # "cache_position": cache_position, "max_seq_len": self.max_seq_len, "transpose": self.transpose_value, } + for i in range(len(self.layer_indexes)): key_states, value_states = past_key_value.update( - key_value_states[2 * i], - key_value_states[2 * i + 1], + new_keys[i], + new_values[i], self.layer_indexes[i], cache_kwargs, ) for i in range(self.intra_stages): - self.backend_decoders[i].load_wt_fn( - 3, self.backend_decoders[i]._mm, self.kv_cache_c_parameter_handel[i] - ) - self.kv_cache_prefetched = True + self.backend_decoders[i].load_cache_async() -class FusedLlamaLowBitDecoderlayer(torch.nn.Module): +class FusedBaichuanLowBitDecoderlayer(torch.nn.Module): """LLAMA MLP operation NPU backend.""" def __init__( @@ -638,7 +468,7 @@ def __init__( np_dtype = np.float16 self.backend_cls_prefill = partial( - LowBitLlamaMultiDecoderlayer, + LowBitBaichuanMultiDecoderlayer, num_heads=num_heads, # num_key_value_heads=num_key_value_heads, num_layers=1, @@ -664,8 +494,6 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - # cache_position: Optional[torch.LongTensor] = None, - # **kwargs, ) -> torch.Tensor: """Torch module forward method. @@ -685,7 +513,6 @@ def forward( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 ) cache_kwargs = { - # "cache_position": cache_position, "max_seq_len": self.max_seq_len, "transpose": self.transpose_value, } @@ -756,7 +583,7 @@ def run_decode( input_layer_norm_weights.append(layer_norm_0) post_attn_layernorm_weights.append(layer_norm_1) - multi_decoder = FusedLlamaLowBitMultiDecoderlayer( + multi_decoder = FusedBaichuanLowBitMultiDecoderlayer( parameters=layer_weights, input_laynorm_weights=input_layer_norm_weights, post_attn_layernorm_weights=post_attn_layernorm_weights, @@ -810,7 +637,7 @@ def run_decode( padded_causal_mask[:, :, :, -1] = 0.0 dist.recv(hidden_states, src=rank - 1) t1 = time.perf_counter() - layer_outputs, elapse = multi_decoder( + layer_outputs = multi_decoder( hidden_states, attention_mask=padded_causal_mask, position_ids=position_ids, @@ -823,7 +650,10 @@ def run_decode( t3 = time.perf_counter() dist.send(hidden_states, dst=(rank + 1) % world_size) t4 = time.perf_counter() - multi_decoder.post_forward(past_key_values) + past_key_values = layer_outputs[1] + new_keys = layer_outputs[2] + new_values = layer_outputs[3] + multi_decoder.post_forward(past_key_values, new_keys, new_values) class DecodeRunner: @@ -956,7 +786,7 @@ def run_prefill( layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) - new_decoderlayer = FusedLlamaLowBitDecoderlayer( + new_decoderlayer = FusedBaichuanLowBitDecoderlayer( weights, num_heads=num_heads, # num_key_value_heads=num_key_value_heads,