Skip to content

Commit

Permalink
support hardware scale for gaudi2 (#1637)
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he authored Feb 28, 2024
1 parent e6664b0 commit 2b86e50
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False"
os.environ["USE_GAUDI2_SCALE"] = "True"
os.environ.pop("USE_GAUDI2_SCALE") # gaudi scale work
# USE_GAUDI2_SCALE requires PT_USE_FP8_AMAX for torch.mm/bmm, or got failure
os.environ["PT_USE_FP8_AMAX"] = "True"
# os.environ["GRAPH_VISUALIZATION"] = "True"
# import shutil
# shutil.rmtree(".graph_dumps", ignore_errors=True)
Expand Down Expand Up @@ -173,7 +174,7 @@
args.model,
trust_remote_code=args.trust_remote_code
)

tokenizer.pad_token = tokenizer.eos_token

user_model.eval()

Expand Down Expand Up @@ -219,6 +220,7 @@ def calib_func(model):

user_model = quantize(user_model, qconfig, calib_func, inplace=True)
# saving
print(user_model)
if args.save and local_rank in [-1, 0]:
user_model.save("saved_results")

Expand Down
5 changes: 4 additions & 1 deletion neural_compressor/torch/algorithms/habana_fp8/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def forward(self, x):

##################### FP8 modules #######################
def _map_guadi2_scale(scale):
USE_GAUDI2_SCALE = os.environ.get("USE_GAUDI2_SCALE")
USE_GAUDI2_SCALE = bool(os.getenv("USE_GAUDI2_SCALE", False))
if USE_GAUDI2_SCALE:
scale_list = torch.tensor([16, 1, 1 / 16, 1 / 256])
for i in scale_list:
Expand Down Expand Up @@ -135,6 +135,7 @@ def forward(self, inp):
if inp.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
if self.use_amax:
input_scale = self.dtype_amax / inp.abs().max()
input_scale = _map_guadi2_scale(input_scale)
input_scale_inv = torch.reciprocal(input_scale)
else:
input_scale, input_scale_inv = None, None
Expand Down Expand Up @@ -183,6 +184,7 @@ def forward(self, input1, input2):
self.out_dtype = input1.dtype
if self.use_amax:
input1_scale = self.dtype_amax / input1.data.abs().max()
input1_scale = _map_guadi2_scale(input1_scale)
input1_scale_inv = torch.reciprocal(input1_scale)
else:
input1_scale, input1_scale_inv = None, None
Expand All @@ -195,6 +197,7 @@ def forward(self, input1, input2):
self.out_dtype = input2.dtype
if self.use_amax:
input2_scale = self.dtype_amax / input2.data.abs().max()
input2_scale = _map_guadi2_scale(input2_scale)
input2_scale_inv = torch.reciprocal(input2_scale)
else:
input2_scale, input2_scale_inv = None, None
Expand Down
9 changes: 7 additions & 2 deletions neural_compressor/torch/amp/fp8/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import torch
from torch.nn import functional as F

from neural_compressor.common import logger
from neural_compressor.torch.algorithms.habana_fp8.modules import _map_guadi2_scale
from neural_compressor.torch.utils import logger

_F_linear = F.linear
_torch_matmul = torch.matmul
Expand All @@ -32,7 +33,7 @@
E5M2_AMAX = torch.tensor(57344, dtype=torch.float).to("hpu")

DTYPE_AMAX = E4M3_AMAX if DATA_TYPE == torch.float8_e4m3fn else E5M2_AMAX
USE_AMAX = False if os.getenv("PT_USE_FP8_AMAX") is None else True
USE_AMAX = bool(os.getenv("PT_USE_FP8_AMAX", False))


def fp8_linear_forward(input, weight, bias=None):
Expand All @@ -44,6 +45,7 @@ def fp8_linear_forward(input, weight, bias=None):
out_dtype = input.dtype
if USE_AMAX:
input_scale = DTYPE_AMAX / input.data.abs().max()
input_scale = _map_guadi2_scale(input_scale)
input_scale_inv = torch.reciprocal(input_scale)
else:
input_scale, input_scale_inv = None, None
Expand All @@ -56,6 +58,7 @@ def fp8_linear_forward(input, weight, bias=None):
out_dtype = weight.dtype
if USE_AMAX:
weight_scale = DTYPE_AMAX / weight.data.abs().max()
weight_scale = _map_guadi2_scale(weight_scale)
weight_scale_inv = torch.reciprocal(weight_scale)
else:
weight_scale, weight_scale_inv = None, None
Expand Down Expand Up @@ -86,6 +89,7 @@ def fp8_matmul(input1, input2):
out_dtype = input1.dtype
if USE_AMAX:
input1_scale = DTYPE_AMAX / input1.data.abs().max()
input1_scale = _map_guadi2_scale(input1_scale)
input1_scale_inv = torch.reciprocal(input1_scale)
else:
input1_scale, input1_scale_inv = None, None
Expand All @@ -98,6 +102,7 @@ def fp8_matmul(input1, input2):
out_dtype = input2.dtype
if USE_AMAX:
input2_scale = DTYPE_AMAX / input2.data.abs().max()
input2_scale = _map_guadi2_scale(input2_scale)
input2_scale_inv = torch.reciprocal(input2_scale)
else:
input2_scale, input2_scale_inv = None, None
Expand Down

0 comments on commit 2b86e50

Please sign in to comment.