Skip to content

Commit

Permalink
Update baichuan_mp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lzivan committed Dec 23, 2024
1 parent 9ab5879 commit fb064a2
Showing 1 changed file with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
self.rms_norm_eps = rms_norm_eps
self.transpose_value = transpose_value
self.num_layers = num_layers
self.asym = asym

cos = self.constant(self.cached_cos)
self.cos = self.unsqueeze(cos, axis=0)
Expand Down Expand Up @@ -234,7 +235,8 @@ def attention(self,
wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
is_prefill=(mode == "prefill"),
asym = self.asym
)

proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h
Expand Down Expand Up @@ -302,7 +304,8 @@ def attention(self,
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
is_prefill=(mode == "prefill"),
asym = self.asym
)
return attn_output, new_key_states, new_value_states

Expand Down

0 comments on commit fb064a2

Please sign in to comment.