Skip to content

Commit

Permalink
Merge pull request #1013 from bghira/dependency/quanto-fp8-overhead-fix
Browse files Browse the repository at this point in the history
fp8-quanto fixes, unblocking of PEFT multigpu LoRA training for other precision levels
  • Loading branch information
bghira authored Oct 1, 2024
2 parents e615252 + 9f46450 commit d05f048
Show file tree
Hide file tree
Showing 7 changed files with 4,098 additions and 36 deletions.
13 changes: 8 additions & 5 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ def safety_check(args, accelerator):
# mulit-gpu safety checks & warnings
if args.model_type == "lora" and args.lora_type == "standard":
# multi-gpu PEFT checks & warnings
if "quanto" in args.base_model_precision:
if args.base_model_precision in ["fp8-quanto"]:
logger.error(
"Quanto is incompatible with multi-GPU training on PEFT adapters. Use LORA_TYPE (--lora_type) lycoris for quantised multi-GPU training of LoKr models."
f"{args.base_model_precision} is incompatible with multi-GPU training on PEFT LoRA."
" Use LORA_TYPE (--lora_type) lycoris for quantised multi-GPU training of LoKr models in FP8."
)
sys.exit(1)
args.base_model_precision = "int8-quanto"
# sys.exit(1)
if (
args.base_model_precision in ["fp8-quanto", "int4-quanto"]
and accelerator.state.dynamo_plugin.backend.lower() == "inductor"
):
logger.warning(
f"{args.base_model_precision} is not supported with Dynamo backend. Switching to int8-quanto instead."
f"{args.base_model_precision} is not supported with Dynamo backend. Disabling Dynamo."
)
args.base_model_precision = "int8-quanto"
from accelerate.utils import DynamoBackend
accelerator.state.dynamo_plugin.backend = DynamoBackend.NO
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError(
Expand Down
38 changes: 23 additions & 15 deletions helpers/training/quantisation/quanto_workarounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import optimum
from optimum.quanto.library.extensions.cuda import ext as quanto_ext

@torch.library.custom_op(
"quanto::fp8_marlin_gemm", mutates_args=(), device_types=["cuda"]
)
def fp8_marlin_gemm(
# torch tells us to do this because
torch._dynamo.config.optimize_ddp=False
# Save the original operator
original_gemm_f16f8_marlin = torch.ops.quanto.gemm_f16f8_marlin

def fp8_marlin_gemm_wrapper(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
Expand All @@ -19,11 +21,11 @@ def fp8_marlin_gemm(
size_n: int,
size_k: int,
) -> torch.Tensor:
assert b_scales.dtype == torch.float16 or b_scales.dtype == torch.bfloat16
assert b_q_weight.dim() == 2
assert b_q_weight.dtype == torch.int32
return quanto_ext.lib.fp8_marlin_gemm(
a.to(b_scales.dtype),
# Ensure 'a' has the correct dtype
a = a.to(b_scales.dtype)
# Call the original operator
return original_gemm_f16f8_marlin(
a,
b_q_weight,
b_scales,
workspace,
Expand All @@ -33,9 +35,11 @@ def fp8_marlin_gemm(
size_k,
)

optimum.quanto.library.extensions.cuda.fp8_marlin_gemm = fp8_marlin_gemm

class TinyGemmQBitsLinearFunction(optimum.quanto.tensor.function.QuantizedLinearFunction):
# Monkey-patch the operator
torch.ops.quanto.gemm_f16f8_marlin = fp8_marlin_gemm_wrapper
class TinyGemmQBitsLinearFunction(
optimum.quanto.tensor.function.QuantizedLinearFunction
):
@staticmethod
def forward(ctx, input, other, bias):
ctx.save_for_backward(input, other)
Expand All @@ -45,12 +49,16 @@ def forward(ctx, input, other, bias):
out_features = other.shape[0]
output_shape = input.shape[:-1] + (out_features,)
output = torch._weight_int4pack_mm(
input.view(-1, in_features).to(dtype=other.dtype), other._data._data, other._group_size, other._scale_shift
input.view(-1, in_features).to(dtype=other.dtype),
other._data._data,
other._group_size,
other._scale_shift,
)
output = output.view(output_shape)
if bias is not None:
output = output + bias
return output

from optimum.quanto.tensor.weights import tinygemm
tinygemm.qbits.TinyGemmQBitsLinearFunction = TinyGemmQBitsLinearFunction

tinygemm.qbits.TinyGemmQBitsLinearFunction = TinyGemmQBitsLinearFunction
8 changes: 4 additions & 4 deletions install/apple/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d05f048

Please sign in to comment.