Skip to content

Commit

Permalink
add fp8 autotune ut and fix bug in autotune (#1638)
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he authored Feb 28, 2024
1 parent b4e37b7 commit e6664b0
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 6 deletions.
1 change: 1 addition & 0 deletions neural_compressor/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_default_hqq_config,
FP8Config,
get_default_fp8_config,
get_default_fp8_config_set,
)

from neural_compressor.torch.quantization.autotune import (
Expand Down
7 changes: 5 additions & 2 deletions neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,13 @@ def autotune(
tuning_logger.trial_end(trial_index)
if tuning_monitor.need_stop():
logger.info("Stopped tuning.")
del q_model # maybe gc.collect() is needed for memory release
best_quant_config: BaseConfig = tuning_monitor.get_best_quant_config()
# !!! Make sure to use deepcopy only when inplace is set to `True`.
quantize(deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True)
best_quant_model = model # quantize model inplace
q_model = quantize(
deepcopy(model), quant_config=best_quant_config, run_fn=run_fn, run_args=run_args, inplace=True
)
best_quant_model = q_model # quantize model inplace
break
tuning_logger.tuning_end()
return best_quant_model
13 changes: 11 additions & 2 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,14 +941,23 @@ def get_config_set_for_tuning(cls) -> Union[None, "FP8Config", List["FP8Config"]


def get_default_fp8_config() -> FP8Config:
"""Generate the default gptq config.
"""Generate the default fp8 config.
Returns:
the default gptq config.
the default fp8 config.
"""
return FP8Config()


def get_default_fp8_config_set() -> FP8Config:
"""Generate the default fp8 config set.
Returns:
the default fp8 config.
"""
return FP8Config.get_config_set_for_tuning()


##################### Algo Configs End ###################################


Expand Down
38 changes: 36 additions & 2 deletions test/3x/torch/quantization/habana_fp8/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@
FP8Matmul,
Matmul,
)
from neural_compressor.torch.quantization import quantize
from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config
from neural_compressor.torch.quantization import (
FP8Config,
TuningConfig,
autotune,
get_default_fp8_config,
get_default_fp8_config_set,
quantize,
)

torch.set_grad_enabled(False)

Expand Down Expand Up @@ -164,3 +170,31 @@ def calib_func(model):
assert isinstance(m.fc1, FP8Linear), "Unexpected result. Please double check."
assert isinstance(m.mm, FP8Matmul), "Unexpected result. Please double check."
assert isinstance(m.bmm, FP8BatchMatmul), "Unexpected result. Please double check."

def test_autotune(self):
m = copy.deepcopy(self.model)
inp = self.inp
fp32_out = m(inp)

def calib_func(model):
model(inp)

accu_list = [1.0, 0.9, 0.99]

def eval_func(model):
nonlocal accu_list
return accu_list.pop()

tune_config = TuningConfig(
config_set=get_default_fp8_config_set(),
tolerable_loss=0.01,
)
best_model = autotune(
model=m,
tune_config=tune_config,
run_fn=calib_func,
eval_fns=eval_func,
)
assert isinstance(best_model.fc1, FP8Linear), "Unexpected result. Please double check."
assert isinstance(best_model.mm, FP8Matmul), "Unexpected result. Please double check."
assert isinstance(best_model.bmm, FP8BatchMatmul), "Unexpected result. Please double check."

0 comments on commit e6664b0

Please sign in to comment.