Skip to content

Commit

Permalink
Quickfix: Accelerate YAML and LoRA Fused Ops (#92)
Browse files Browse the repository at this point in the history
* make accelerate.yaml 1.0 compatible

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* missed out biases in fused-lora

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* minor fix

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* make compare with ref logic more robust

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* again improve display bench logic

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* relax mixed precision settings for full, regular-peft

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* left out one bf16

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* another improvement on compare_results

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* update bench with full foak

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* improve comments

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* put back bf16 for baseline bnb

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim authored Oct 14, 2024
1 parent 4b2dfbc commit 97fc3c1
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 187 deletions.
2 changes: 1 addition & 1 deletion plugins/accelerated-peft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks


### Key Points
- fix upcasting (resulting in slowdown) issue for `bnb` plugin, originally discovered by inventors of [Unsloth](https://unsloth.ai/blog/mistral-benchmark).
- fix upcasting (resulting in slowdown) issue for `bnb` plugin, originally discovered by inventors of [Unsloth](https://unsloth.ai/blog/mistral-benchmark). **NOTE**: we recommend using *mixed precision* when using 4bit quant for better performance, as per our benchmarks.
- `bnb` properly configured to work with FSDP following [this guide](https://huggingface.co/docs/bitsandbytes/main/en/fsdp_qlora).
- `triton_v2` kernels are not yet properly integrated into huggingface optimum.
- `triton_v2` kernels are [the only 4bit kernels that work for training](https://github.com/AutoGPTQ/AutoGPTQ/issues/633).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ def augmentation(
train_args: TrainingArguments,
modifiable_args: Tuple[LoraConfig],
):
# - when using our prepare peft, we will enforce the mixed precision settings
assert (
train_args.bf16 is True or train_args.fp16 is True
), f"{self.__class__} requires mixed precision argument `--fp16` or `--bf16`"

(peft_config,) = modifiable_args # unpack modifiable args

# some assertions
Expand Down
8 changes: 5 additions & 3 deletions plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@ This library contains fused operations and custom kernels, to be expanded over t

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE | Contains extracted code | | ✅
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE (**Disabled**) | Contains extracted code | | ✅
[fast_kernels](./src/fms_accelerate_foak/framework_plugin_fast_kernels.py) | Enhanced version of `fast_quantized_peft`, also works for full-FT and non-quant peft | Contains extracted code | | ✅

### Supported DataType Settings
**Compatibility Matrix with Mixed Precision**
torch_dtype | Mixed Precision | Full-FT-FOAK | PEFT-FOAK | QPEFT-FOAK
-- | -- | -- | -- | --
FLOAT16 | - | ✗ Not Allowed | ✗| ✗
FLOAT16 | - | **Compatible** | **Compatible** | ✗
FLOAT16 | FP16 | ValueError: <br>Attempting to <br>unscale FP16 gradients. <br>[See here](https://github.com/huggingface/peft/blob/main/docs/source/developer_guides/troubleshooting.md) | **Compatible** | **Compatible**
BFLOAT16 | - | ✗ | ✗ | ✗
BFLOAT16 | - | **Compatible** | **Compatible** | ✗
BFLOAT16 | BF16 | **Compatible** | **Compatible** | [Less Performant](https://github.com/foundation-model-stack/fms-acceleration/issues/84)

NOTE: this chart is also a good reference for supported types, even for the non-FOAK case.

### Code Extracted from Unsloth


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,14 @@ def augmentation(
train_args: TrainingArguments,
modifiable_args: Tuple[LoraConfig],
):
# assert that plugin requires mixed precision to be set
assert (
train_args.bf16 is True or train_args.fp16 is True
), f"{self.__class__} requires mixed precision argument `--fp16` or `--bf16`"
has_quant = getattr(model, "quantization_method", None)

if has_quant:
# - only in the case where quant case, that we enforce the mixed precision settings
# - this is mostly for the fused-loras
assert (
train_args.bf16 is True or train_args.fp16 is True
), f"{self.__class__} requires mixed precision argument `--fp16` or `--bf16`"

# This is designed to be a passthrough if training scenario is
# full finetuning or standard peft, fused-lora rules (only meant for qpeft)
Expand All @@ -138,7 +142,7 @@ def augmentation(

# some logic to omit terms from the filter if logic precludes
omitted = set()
if getattr(model, "quantization_method", None) is None:
if has_quant is None:
# - fused_lora only required for quant-peft
omitted.add("fused_lora")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def forward(ctx, X : torch.Tensor,

e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS, dropout=dropout_gate)
g = matmul_lora(X, upW, upW_quant, upA, upB, upS, dropout=dropout_up)
e += gate_bias
g += up_bias
if gate_bias is not None: e += gate_bias
if up_bias is not None: g += up_bias
h = _forward_function(e, g)
i = matmul_lora(h, downW, downW_quant, downA, downB, downS, dropout=dropout_down)
i += down_bias
if down_bias is not None: i += down_bias

# Extract post-dropout X for use in backward computation
_dropped_X = []
Expand Down Expand Up @@ -261,9 +261,9 @@ def forward(ctx, X : torch.Tensor,
K = matmul_lora(X, KW, KW_quant, KA, KB, KS, dropout=dropout_K)
V = matmul_lora(X, VW, VW_quant, VA, VB, VS, dropout=dropout_V)

Q += Q_bias
K += K_bias
V += V_bias
if Q_bias is not None: Q += Q_bias
if K_bias is not None: K += K_bias
if V_bias is not None: V += V_bias

# Extract post-dropout X for use in backward computation
_dropped_X = []
Expand Down Expand Up @@ -406,7 +406,7 @@ def forward(ctx, X : torch.Tensor,
W, W_quant, bias, A, B, S, dropout_O):
dtype = X.dtype
XW = matmul_lora(X, W, W_quant, A, B, S, dropout=dropout_O)
XW += bias
if bias is not None: XW += bias

# Extract post-dropout X for use in backward computation
if dropout_O is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def forward(
e = matmul_lora(X, gateW, gateA, gateB, gateS, dropout=dropout_gate)
upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits)
g = matmul_lora(X, upW, upA, upB, upS, dropout=dropout_up)
e += gate_bias
g += up_bias
if gate_bias is not None: e += gate_bias
if up_bias is not None: g += up_bias
# f = torch.nn.functional.silu(e)
# h = f * g
h = swiglu_fg_kernel(e, g)
Expand All @@ -257,7 +257,7 @@ def forward(
down_qweight, down_scales, down_qzeros, down_g_idx, down_bits
)
i = matmul_lora(h, downW, downA, downB, downS, dropout=dropout_down)
i += down_bias
if down_bias is not None: i += down_bias

ctx.custom_saved_tensors = (
gate_qweight,
Expand Down Expand Up @@ -529,9 +529,9 @@ def forward(
K = matmul_lora(X, KW, KA, KB, KS, dropout=dropout_K)
V = matmul_lora(X, VW, VA, VB, VS, dropout=dropout_V)

Q += Q_bias
K += K_bias
V += V_bias
if Q_bias is not None: Q += Q_bias
if K_bias is not None: K += K_bias
if V_bias is not None: V += V_bias

ctx.custom_saved_tensors = (
Q_qweight,
Expand Down Expand Up @@ -774,7 +774,7 @@ def forward(
):
W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits)
XW = matmul_lora(X, W, A, B, S, dropout=dropout_O)
XW += O_bias
if O_bias is not None: XW += O_bias
del W
ctx.custom_saved_tensors = (
O_qweight,
Expand Down Expand Up @@ -843,6 +843,6 @@ def apply_lora_o(self, X):
# added by [email protected]
# this version can be directly patched on the output linear
def apply_lora_o_v2(self, X):
Oqstate, O_bias, OA, OB, OS, dropout = get_lora_parameters(self.o_proj)
Oqstate, O_bias, OA, OB, OS, dropout = get_lora_parameters(self)
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), O_bias, OA, OB, OS, dropout)
return O
3 changes: 2 additions & 1 deletion scripts/benchmarks/accelerate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP

# this controls the FSDP pipelining
fsdp_backward_prefetch_policy: BACKWARD_PRE # set to BACKWARD_PRE for the most time-efficient pipeline
fsdp_backward_prefetch: BACKWARD_PRE # set to BACKWARD_PRE for the most time-efficient pipeline
# but requires the most memory. BACKWARD_POST is the less
# memory intensive option
fsdp_backward_prefetch_policy: BACKWARD_PRE # for backward compatibility accelerate<1.0

# setting this to true will increase forward memory by prefetching the next FSDP all-gather, while performing
# the current forward pass.
Expand Down
28 changes: 17 additions & 11 deletions scripts/benchmarks/compare_with_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,30 @@ def compare_results(df, ref, plot_columns, threshold_ratio=0.1):
ref_series = ref[column].fillna(0)
df_series = df[column].fillna(0)
# Extract outliers base on some threshold % difference on referance
ds = abs(df_series - ref_series) / (ref_series + 1e-9)
outliers = ds.index[ds > threshold_ratio].to_list()
cmp = ref_series.to_frame()
cmp['metric'] = column
cmp = cmp.join(df_series.to_frame(), lsuffix='_ref')
cmp = cmp.rename(columns={f'{column}_ref': 'reference', column: 'new'})
cmp['ds'] = cmp.apply(
lambda x: (
abs(x.reference - x.new) / (x.reference + 1e-9)
), axis=1
)
outliers = cmp[cmp.ds > threshold_ratio]
outliers = outliers.drop('ds', axis=1)

plot_chart(
ax,
ref_series,
df_series,
cmp['reference'],
cmp['new'],
title=f"Metric: {column}",
xlabel="Reference",
ylabel="New",
)
charts.append((ax, f"compare-{column}.jpg"))
total_outliers += [
[column, *outlier, ref_series[outlier].item(), df_series[outlier].item()]
for outlier in outliers
]
outliers_df = pd.DataFrame(
total_outliers, columns=["scenario", *df.index.names, "reference", "new"]
)
total_outliers.append(outliers)

outliers_df = pd.concat(total_outliers)
return outliers_df, outliers, charts


Expand Down
7 changes: 6 additions & 1 deletion scripts/benchmarks/display_bench_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ def main(
df[c] = constant[c]
kept += 1

df = df.reset_index(drop=True).drop("output_dir", axis=1)
df = df.reset_index(drop=True)
try:
df = df.drop("output_dir", axis=1)
except KeyError:
pass # output_dir not found

df.reindex(sorted(df.columns), axis=1).to_csv(output_filename, index=False)
print("***************** Report Created ******************")
print(f"Total lines: '{len(df)}'")
Expand Down
Loading

0 comments on commit 97fc3c1

Please sign in to comment.