Skip to content

Commit

Permalink
[Bug] fix alpha-space generation (#1465)
Browse files Browse the repository at this point in the history
Signed-off-by: Lu, Yintong <[email protected]>
(cherry picked from commit 33ece90)
  • Loading branch information
yintong-lu authored and chensuyue committed Dec 23, 2023
1 parent 111b3ce commit 2d8ea6a
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
logger = logging.getLogger()
from collections import UserDict, defaultdict

import numpy
from tqdm import tqdm


Expand Down Expand Up @@ -976,15 +977,10 @@ def _auto_tune_alpha(
:return:
"""
logger.info("start sq auto tuning")
alpha_scale = 100
alpha_space = list(
range(
round(alpha_min * alpha_scale),
round((alpha_max + alpha_step) * alpha_scale),
round(alpha_step * alpha_scale),
)
round_num = max(
len(str(alpha_min).split(".")[1]), len(str(alpha_max).split(".")[1]), len(str(alpha_step).split(".")[1])
)
alpha_space = [alpha / alpha_scale for alpha in alpha_space]
alpha_space = numpy.round(numpy.arange(alpha_min, alpha_max + alpha_step, alpha_step), round_num).tolist()
##wrapper new module
self._qdq_model_wrapper_for_auto(save_q_input=True)
##set alpha to 0.5 as default
Expand Down Expand Up @@ -1189,7 +1185,6 @@ def transform(
self.insert_mul, self.allow_absorb = True, False
if isinstance(alpha, float) and (alpha < 0 or alpha > 1):
logger.warning("reset alpha to in range [0.0, 1.0]")
import numpy

alpha = numpy.clip(alpha, 0.0, 1.0)

Expand Down

0 comments on commit 2d8ea6a

Please sign in to comment.