Skip to content

Commit

Permalink
Allow to configure dtype for quantization scales and statistics
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570395082
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Oct 3, 2023
1 parent abd8c47 commit de05a73
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions aqt/tensorflow/aqt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def _sum_of_ones(
# unbiased estimation of non-sparse mean l1 and lp.
# This clips away less of the distribution of inputs.
if stats_config.filter_zeros:
ones = tf.cast(tf.math.not_equal(x, 0), dtype=tf.float32)
ones = tf.cast(tf.math.not_equal(x, 0), dtype=x.dtype)
else:
ones = tf.ones_like(x)
ones = tf.ones_like(x, dtype=x.dtype)
return _reduce_fn(stats_config, ones, weight)


Expand Down Expand Up @@ -183,7 +183,7 @@ def _sum_of_lp_vals(
return _reduce_fn(stats_config, px**stats_config.lp_order, weight)


def _get_stats_shape(
def get_stats_shape(
stats_config: aqt_config.StatsConfig, data_shape: Iterable[int]
) -> List[int]:
stats_shape = list(data_shape)
Expand Down Expand Up @@ -278,23 +278,23 @@ def __init__(
*,
data_shape: Iterable[Optional[int]],
config: aqt_config.StatsConfig,
get_variable: GetVariable):
get_variable: GetVariable,
dtype=tf.float32):
self._data_shape = list(data_shape)
config.validate(self._data_shape)
if config.lp_order > 30:
raise NotImplementedError('For higher norms we should add stabilization.')
self._config = config
self._ema_update_count = self._config.ema_update_count

self.stats_shape = self._data_shape[:]
for axis in self._config.share_stats_axes:
self.stats_shape[axis] = 1
self.stats_shape = get_stats_shape(self._config, self._data_shape)

self.divide = (tf.math.divide_no_nan if self._config.safe_divide
else tf.math.divide)
self.dtype = dtype

def mk_var(name, init_val):
return get_variable(name, self.stats_shape, tf.float32, init_val)
return get_variable(name, self.stats_shape, self.dtype, init_val)

self._sum_of_ones = mk_var('sum_of_ones', self._config.update_count_prior)
self._sum_of_vals = mk_var(
Expand All @@ -320,7 +320,8 @@ def update(self, x: tf.Tensor, weight: Optional[tf.Tensor]) -> tf.Operation:
def update_var(var, update_fn):
s = update_fn(self._config, x, weight)
rate = 1.0 / self._ema_update_count
return var.assign((1.0 - rate) * var.read_value() + rate * s)
ema = (1.0 - rate) * var.read_value() + rate * s
return var.assign(tf.cast(ema, var.dtype))

return tf.group([
update_var(self._sum_of_ones, _sum_of_ones),
Expand Down Expand Up @@ -355,7 +356,7 @@ def bound( #
return _bound(
calibration_config,
self._config.lp_order,
tf.zeros(self.stats_shape, dtype=tf.float32),
tf.zeros(self.stats_shape, dtype=self.dtype),
self._sum_of_ones.read_value(),
self._max_of_abs_vals,
self._sum_of_l1_vals.read_value(),
Expand Down Expand Up @@ -441,11 +442,13 @@ def __init__(
config: aqt_config.AqtScheduleConfig,
get_variable: GetVariable = default_get_variable,
name: str = 'tensor_quantizer_base',
dtype: tf.dtypes.DType = tf.float32,
):
self.data_shape = list(data_shape)
config.fill_gaps_with_float_config()
config.validate(self.data_shape)
self.config = config
self._dtype = dtype

with tf.variable_scope(name):
# This variable maintains the most recent event count at which this
Expand Down Expand Up @@ -514,7 +517,7 @@ def qparams(config: aqt_config.AqtTensorConfig,
if isinstance(config.quant_config, aqt_config.FloatConfig):
params.clip_bound = tf.where_v2(config_active, float('inf'), 0.0)
elif isinstance(config.quant_config, aqt_config.IntQuantConfig):
config_active = tf.cast(config_active, tf.float32)
config_active = tf.cast(config_active, self._dtype)
params.should_quantize += config_active

# TODO(vladf): some serving environments, such as adbrain,
Expand Down Expand Up @@ -687,12 +690,14 @@ def __init__(
config: aqt_config.AqtScheduleConfig,
get_variable: GetVariable = default_get_variable,
name: str = 'tensor_quantizer',
dtype: tf.DType = tf.float32,
):
super().__init__(
data_shape=data_shape,
config=config,
get_variable=get_variable,
name=name,
dtype=dtype,
)

with tf.variable_scope(name):
Expand All @@ -705,11 +710,11 @@ def __init__(
# We intentionally initialize scale to zero to fail loudly if someone uses
# a parameter such as scale without properly update()-ing it.
self._scale = get_variable(
'scale', self._stats.stats_shape, tf.float32, 0
'scale', self._stats.stats_shape, self._dtype, 0
)
# Save the inverse scale so that we don't recompute it at inference time.
self._inv_scale = get_variable(
'inv_scale', self._stats.stats_shape, tf.float32, 0
'inv_scale', self._stats.stats_shape, self._dtype, 0
)

# Variable to save or read quantized tensors to, if the config says so.
Expand Down Expand Up @@ -742,7 +747,7 @@ def _fresh_scale(
# We shouldn't update the scale if the given config contains FloatConfig
# and no emulation;
# fill with a poison value if we get into this situation.
nan = tf.constant(float('nan'), tf.float32, self._stats.stats_shape)
nan = tf.constant(float('nan'), self._dtype, self._stats.stats_shape)
return nan, nan

x_bound = self._stats.bound(config.calibration_config)
Expand Down Expand Up @@ -909,13 +914,15 @@ def __init__(
config: aqt_config.AqtScheduleConfig,
get_variable: GetVariable = default_get_variable,
name: str = 'dynamic_tensor_quantizer',
dtype: tf.DType = tf.float32,
):
validate_dynamic(config)
super().__init__(
data_shape=data_shape,
config=config,
get_variable=get_variable,
name=name,
dtype=dtype,
)

def calibration_variables(self) -> Dict[str, tf.Variable]:
Expand Down

0 comments on commit de05a73

Please sign in to comment.