diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index e51905c2201..05a575187cf 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -102,6 +102,7 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.transpose_value = transpose_value self.num_layers = num_layers + self.asym = asym cos = self.constant(self.cached_cos) self.cos = self.unsqueeze(cos, axis=0) @@ -234,7 +235,8 @@ def attention(self, wt_dtype=self.dtype, n_splits=self.n_splits_linear, scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill") + is_prefill=(mode == "prefill"), + asym = self.asym ) proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h @@ -302,7 +304,8 @@ def attention(self, attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype, n_splits=self.n_splits_linear, scale_factor=(self.group_size == 0), - is_prefill=(mode == "prefill") + is_prefill=(mode == "prefill"), + asym = self.asym ) return attn_output, new_key_states, new_value_states