From 2b86e5061fcefca615369d40ecd921d5e8780205 Mon Sep 17 00:00:00 2001 From: xinhe Date: Wed, 28 Feb 2024 16:36:04 +0800 Subject: [PATCH] support hardware scale for gaudi2 (#1637) Signed-off-by: xin3he --- .../language-modeling/quantization/habana_fp8/run_llm.py | 6 ++++-- neural_compressor/torch/algorithms/habana_fp8/modules.py | 5 ++++- neural_compressor/torch/amp/fp8/functions.py | 9 +++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py index 3c0f91b9f58..d7fb14c89d2 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py @@ -1,7 +1,8 @@ import os os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False" os.environ["USE_GAUDI2_SCALE"] = "True" -os.environ.pop("USE_GAUDI2_SCALE") # gaudi scale work +# USE_GAUDI2_SCALE requires PT_USE_FP8_AMAX for torch.mm/bmm, or got failure +os.environ["PT_USE_FP8_AMAX"] = "True" # os.environ["GRAPH_VISUALIZATION"] = "True" # import shutil # shutil.rmtree(".graph_dumps", ignore_errors=True) @@ -173,7 +174,7 @@ args.model, trust_remote_code=args.trust_remote_code ) - +tokenizer.pad_token = tokenizer.eos_token user_model.eval() @@ -219,6 +220,7 @@ def calib_func(model): user_model = quantize(user_model, qconfig, calib_func, inplace=True) # saving + print(user_model) if args.save and local_rank in [-1, 0]: user_model.save("saved_results") diff --git a/neural_compressor/torch/algorithms/habana_fp8/modules.py b/neural_compressor/torch/algorithms/habana_fp8/modules.py index 6e74c46870e..0d710bb5143 100644 --- a/neural_compressor/torch/algorithms/habana_fp8/modules.py +++ b/neural_compressor/torch/algorithms/habana_fp8/modules.py @@ -55,7 +55,7 @@ def forward(self, x): ##################### FP8 modules ####################### def _map_guadi2_scale(scale): - USE_GAUDI2_SCALE = os.environ.get("USE_GAUDI2_SCALE") + USE_GAUDI2_SCALE = bool(os.getenv("USE_GAUDI2_SCALE", False)) if USE_GAUDI2_SCALE: scale_list = torch.tensor([16, 1, 1 / 16, 1 / 256]) for i in scale_list: @@ -135,6 +135,7 @@ def forward(self, inp): if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: if self.use_amax: input_scale = self.dtype_amax / inp.abs().max() + input_scale = _map_guadi2_scale(input_scale) input_scale_inv = torch.reciprocal(input_scale) else: input_scale, input_scale_inv = None, None @@ -183,6 +184,7 @@ def forward(self, input1, input2): self.out_dtype = input1.dtype if self.use_amax: input1_scale = self.dtype_amax / input1.data.abs().max() + input1_scale = _map_guadi2_scale(input1_scale) input1_scale_inv = torch.reciprocal(input1_scale) else: input1_scale, input1_scale_inv = None, None @@ -195,6 +197,7 @@ def forward(self, input1, input2): self.out_dtype = input2.dtype if self.use_amax: input2_scale = self.dtype_amax / input2.data.abs().max() + input2_scale = _map_guadi2_scale(input2_scale) input2_scale_inv = torch.reciprocal(input2_scale) else: input2_scale, input2_scale_inv = None, None diff --git a/neural_compressor/torch/amp/fp8/functions.py b/neural_compressor/torch/amp/fp8/functions.py index 49427f921f1..6411f627c4e 100644 --- a/neural_compressor/torch/amp/fp8/functions.py +++ b/neural_compressor/torch/amp/fp8/functions.py @@ -20,7 +20,8 @@ import torch from torch.nn import functional as F -from neural_compressor.common import logger +from neural_compressor.torch.algorithms.habana_fp8.modules import _map_guadi2_scale +from neural_compressor.torch.utils import logger _F_linear = F.linear _torch_matmul = torch.matmul @@ -32,7 +33,7 @@ E5M2_AMAX = torch.tensor(57344, dtype=torch.float).to("hpu") DTYPE_AMAX = E4M3_AMAX if DATA_TYPE == torch.float8_e4m3fn else E5M2_AMAX -USE_AMAX = False if os.getenv("PT_USE_FP8_AMAX") is None else True +USE_AMAX = bool(os.getenv("PT_USE_FP8_AMAX", False)) def fp8_linear_forward(input, weight, bias=None): @@ -44,6 +45,7 @@ def fp8_linear_forward(input, weight, bias=None): out_dtype = input.dtype if USE_AMAX: input_scale = DTYPE_AMAX / input.data.abs().max() + input_scale = _map_guadi2_scale(input_scale) input_scale_inv = torch.reciprocal(input_scale) else: input_scale, input_scale_inv = None, None @@ -56,6 +58,7 @@ def fp8_linear_forward(input, weight, bias=None): out_dtype = weight.dtype if USE_AMAX: weight_scale = DTYPE_AMAX / weight.data.abs().max() + weight_scale = _map_guadi2_scale(weight_scale) weight_scale_inv = torch.reciprocal(weight_scale) else: weight_scale, weight_scale_inv = None, None @@ -86,6 +89,7 @@ def fp8_matmul(input1, input2): out_dtype = input1.dtype if USE_AMAX: input1_scale = DTYPE_AMAX / input1.data.abs().max() + input1_scale = _map_guadi2_scale(input1_scale) input1_scale_inv = torch.reciprocal(input1_scale) else: input1_scale, input1_scale_inv = None, None @@ -98,6 +102,7 @@ def fp8_matmul(input1, input2): out_dtype = input2.dtype if USE_AMAX: input2_scale = DTYPE_AMAX / input2.data.abs().max() + input2_scale = _map_guadi2_scale(input2_scale) input2_scale_inv = torch.reciprocal(input2_scale) else: input2_scale, input2_scale_inv = None, None