Skip to content

Commit

Permalink
updates for quanto fp8 support, clarifications on error messages and …
Browse files Browse the repository at this point in the history
…unblocking multigpu PEFT for non-fp8 training
  • Loading branch information
bghira committed Oct 1, 2024
1 parent 83c13e0 commit 9f46450
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
13 changes: 8 additions & 5 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ def safety_check(args, accelerator):
# mulit-gpu safety checks & warnings
if args.model_type == "lora" and args.lora_type == "standard":
# multi-gpu PEFT checks & warnings
if "quanto" in args.base_model_precision:
if args.base_model_precision in ["fp8-quanto"]:
logger.error(
"Quanto is incompatible with multi-GPU training on PEFT adapters. Use LORA_TYPE (--lora_type) lycoris for quantised multi-GPU training of LoKr models."
f"{args.base_model_precision} is incompatible with multi-GPU training on PEFT LoRA."
" Use LORA_TYPE (--lora_type) lycoris for quantised multi-GPU training of LoKr models in FP8."
)
sys.exit(1)
args.base_model_precision = "int8-quanto"
# sys.exit(1)
if (
args.base_model_precision in ["fp8-quanto", "int4-quanto"]
and accelerator.state.dynamo_plugin.backend.lower() == "inductor"
):
logger.warning(
f"{args.base_model_precision} is not supported with Dynamo backend. Switching to int8-quanto instead."
f"{args.base_model_precision} is not supported with Dynamo backend. Disabling Dynamo."
)
args.base_model_precision = "int8-quanto"
from accelerate.utils import DynamoBackend
accelerator.state.dynamo_plugin.backend = DynamoBackend.NO
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError(
Expand Down
22 changes: 13 additions & 9 deletions helpers/training/quantisation/quanto_workarounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
import optimum
from optimum.quanto.library.extensions.cuda import ext as quanto_ext

@torch.library.impl("quanto::gemm_f16f8_marlin", ["CUDA"])
def fp8_marlin_gemm(
# torch tells us to do this because
torch._dynamo.config.optimize_ddp=False
# Save the original operator
original_gemm_f16f8_marlin = torch.ops.quanto.gemm_f16f8_marlin

def fp8_marlin_gemm_wrapper(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
Expand All @@ -17,11 +21,11 @@ def fp8_marlin_gemm(
size_n: int,
size_k: int,
) -> torch.Tensor:
assert b_scales.dtype == torch.float16 or b_scales.dtype == torch.bfloat16
assert b_q_weight.dim() == 2
assert b_q_weight.dtype == torch.int32
return quanto_ext.lib.fp8_marlin_gemm(
a.to(b_scales.dtype),
# Ensure 'a' has the correct dtype
a = a.to(b_scales.dtype)
# Call the original operator
return original_gemm_f16f8_marlin(
a,
b_q_weight,
b_scales,
workspace,
Expand All @@ -31,8 +35,8 @@ def fp8_marlin_gemm(
size_k,
)

optimum.quanto.library.extensions.cuda.fp8_marlin_gemm = fp8_marlin_gemm

# Monkey-patch the operator
torch.ops.quanto.gemm_f16f8_marlin = fp8_marlin_gemm_wrapper
class TinyGemmQBitsLinearFunction(
optimum.quanto.tensor.function.QuantizedLinearFunction
):
Expand Down
4 changes: 0 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9f46450

Please sign in to comment.