Skip to content

Commit

Permalink
fix a quantization bug (#34647)
Browse files Browse the repository at this point in the history
  • Loading branch information
XGZhang11 authored Aug 10, 2021
1 parent 4f4662b commit cfd49ac
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ def _sample_mse(self):
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
s = 0.3
if var_name not in self._best_mse_loss:
self._best_mse_loss[var_name] = float('inf')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,7 @@ def _insert_post_dequant_op(self, graph, op_node):
assert self._is_float(
scale_v), 'The scale of parameter %s is not a float.' % (
original_var_name)
scale_v = 1e-8 if scale_v == 0.0 else scale_v
max_range *= param_range / scale_v
else:
max_range *= act_range
Expand Down Expand Up @@ -1413,6 +1414,7 @@ def _clip(x, scale):
x[:, i] = _clip(x[:, i], s)
x[:, i] = np.round(x[:, i] / s * bnt)
else:
scale = 1e-8 if scale == 0.0 else scale
x = _clip(x, scale)
x = np.round(x / scale * bnt)
return x
Expand Down

0 comments on commit cfd49ac

Please sign in to comment.