From c410d9cf73d080fe08dc9f06c2dc692a4c257892 Mon Sep 17 00:00:00 2001 From: Zijie Li Date: Mon, 23 Dec 2024 20:17:50 -0500 Subject: [PATCH] [NPU] support asym_int4 for baichuan (#12576) * add npu support for baichuan * Update baichuan_mp.py * Update baichuan_mp.py --- .../transformers/npu_models/baichuan_mp.py | 55 ++++++++++++++----- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index e2dd913c9a4..dc2fcbe1df0 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -80,7 +80,8 @@ def __init__( intermediate_size, n_splits_linear: int = 1, n_splits_down_proj: int = 1, - group_size: int = 0 + group_size: int = 0, + asym: bool = False, ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, @@ -89,7 +90,8 @@ def __init__( device=device, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size) + group_size=group_size, + asym=asym) self.max_seq_len = max_seq_len self.intermediate_size = intermediate_size self.dtype = dtype @@ -100,6 +102,7 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.transpose_value = transpose_value self.num_layers = num_layers + self.asym = asym cos = self.constant(self.cached_cos) self.cos = self.unsqueeze(cos, axis=0) @@ -232,7 +235,8 @@ def attention(self, wt_dtype=self.dtype, n_splits=self.n_splits_linear, scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill") + is_prefill=(mode == "prefill"), + asym=self.asym ) proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h @@ -300,7 +304,8 @@ def attention(self, attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype, n_splits=self.n_splits_linear, scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill") + is_prefill=(mode == "prefill"), + asym=self.asym ) return attn_output, new_key_states, new_value_states @@ -368,7 +373,8 @@ def __init__( do_print: bool = False, n_splits_linear: int = 1, n_splits_down_proj: int = 1, - group_size: int = 0 + group_size: int = 0, + asym: bool = False, ): super().__init__() @@ -376,8 +382,10 @@ def __init__( op_parameters = [] for w in parameters: - if isinstance(w, tuple): # from QuantizedLinear + if isinstance(w, tuple) and not asym: # from QuantizedLinear op_parameters.append((w[0].numpy(), w[1].numpy())) + elif isinstance(w, tuple) and asym: # from QuantizedLinear + op_parameters.append((w[0].numpy(), w[1].numpy(), w[2].numpy())) elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight op_parameters.append(w.numpy()) elif isinstance(w, np.ndarray): # scale @@ -430,7 +438,8 @@ def __init__( dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym, ) self.backend_decoders.append(decoder) @@ -506,7 +515,8 @@ def __init__( transpose_value: bool = False, n_splits_linear: int = 1, n_splits_down_proj: int = 1, - group_size: int = 0 + group_size: int = 0, + asym: bool = False, ): super().__init__() self.op_parameters = parameters @@ -537,7 +547,8 @@ def __init__( dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 @@ -620,6 +631,7 @@ def run_decode( layer_indexs = range(layer_start, layer_end) n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + asym = getattr(model.config, "asym", False) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn @@ -631,10 +643,17 @@ def run_decode( mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] + zeros = [] for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -663,7 +682,8 @@ def run_decode( do_print=False, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym, ) dist.barrier() @@ -827,6 +847,7 @@ def run_prefill( layer_indexs = range(layer_start, layer_end) n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + asym = getattr(model.config, "asym", False) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn @@ -838,10 +859,17 @@ def run_prefill( mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] + zeros = [] for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -864,7 +892,8 @@ def run_prefill( transpose_value=transpose_value_cache, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym ) layer_weights.extend(weights)