-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Quickfix: Accelerate YAML and LoRA Fused Ops (#92)
* make accelerate.yaml 1.0 compatible Signed-off-by: Yu Chin Fabian Lim <[email protected]> * missed out biases in fused-lora Signed-off-by: Yu Chin Fabian Lim <[email protected]> * minor fix Signed-off-by: Yu Chin Fabian Lim <[email protected]> * make compare with ref logic more robust Signed-off-by: Yu Chin Fabian Lim <[email protected]> * again improve display bench logic Signed-off-by: Yu Chin Fabian Lim <[email protected]> * relax mixed precision settings for full, regular-peft Signed-off-by: Yu Chin Fabian Lim <[email protected]> * left out one bf16 Signed-off-by: Yu Chin Fabian Lim <[email protected]> * another improvement on compare_results Signed-off-by: Yu Chin Fabian Lim <[email protected]> * update bench with full foak Signed-off-by: Yu Chin Fabian Lim <[email protected]> * improve comments Signed-off-by: Yu Chin Fabian Lim <[email protected]> * put back bf16 for baseline bnb Signed-off-by: Yu Chin Fabian Lim <[email protected]> --------- Signed-off-by: Yu Chin Fabian Lim <[email protected]>
- Loading branch information
Showing
12 changed files
with
209 additions
and
187 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -247,8 +247,8 @@ def forward( | |
e = matmul_lora(X, gateW, gateA, gateB, gateS, dropout=dropout_gate) | ||
upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits) | ||
g = matmul_lora(X, upW, upA, upB, upS, dropout=dropout_up) | ||
e += gate_bias | ||
g += up_bias | ||
if gate_bias is not None: e += gate_bias | ||
if up_bias is not None: g += up_bias | ||
# f = torch.nn.functional.silu(e) | ||
# h = f * g | ||
h = swiglu_fg_kernel(e, g) | ||
|
@@ -257,7 +257,7 @@ def forward( | |
down_qweight, down_scales, down_qzeros, down_g_idx, down_bits | ||
) | ||
i = matmul_lora(h, downW, downA, downB, downS, dropout=dropout_down) | ||
i += down_bias | ||
if down_bias is not None: i += down_bias | ||
|
||
ctx.custom_saved_tensors = ( | ||
gate_qweight, | ||
|
@@ -529,9 +529,9 @@ def forward( | |
K = matmul_lora(X, KW, KA, KB, KS, dropout=dropout_K) | ||
V = matmul_lora(X, VW, VA, VB, VS, dropout=dropout_V) | ||
|
||
Q += Q_bias | ||
K += K_bias | ||
V += V_bias | ||
if Q_bias is not None: Q += Q_bias | ||
if K_bias is not None: K += K_bias | ||
if V_bias is not None: V += V_bias | ||
|
||
ctx.custom_saved_tensors = ( | ||
Q_qweight, | ||
|
@@ -774,7 +774,7 @@ def forward( | |
): | ||
W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits) | ||
XW = matmul_lora(X, W, A, B, S, dropout=dropout_O) | ||
XW += O_bias | ||
if O_bias is not None: XW += O_bias | ||
del W | ||
ctx.custom_saved_tensors = ( | ||
O_qweight, | ||
|
@@ -843,6 +843,6 @@ def apply_lora_o(self, X): | |
# added by [email protected] | ||
# this version can be directly patched on the output linear | ||
def apply_lora_o_v2(self, X): | ||
Oqstate, O_bias, OA, OB, OS, dropout = get_lora_parameters(self.o_proj) | ||
Oqstate, O_bias, OA, OB, OS, dropout = get_lora_parameters(self) | ||
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), O_bias, OA, OB, OS, dropout) | ||
return O |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.