diff --git a/aqt/common/aqt_config.py b/aqt/common/aqt_config.py index fb62d0b5..1de53035 100644 --- a/aqt/common/aqt_config.py +++ b/aqt/common/aqt_config.py @@ -146,13 +146,15 @@ class StatsConfig(_BaseConfig): safe_divide: bool = False - def validate(self, data_shape: List[Optional[int]]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + def validate(self, data_shape: List[Optional[int]], # pytype: disable=signature-mismatch # overriding-parameter-count-checks + dynamic: bool = False): """Validates this StatsConfig for the provided data shape. Args: data_shape: the shape of the input tensor which will be quantized with self as the statistics configuration. If an entry is None, this indicates a dimension whose size is unknown at graph compilation time. + dynamic: whether the quantization is dynamic. Raises: ConfigError: if any of the specified share_stats_axes are not between @@ -169,12 +171,13 @@ def validate(self, data_shape: List[Optional[int]]): # pytype: disable=signatur raise ConfigError( f'share_stats_axes ({self.share_stats_axes}) must be strictly sorted') - unknown_axes = {i for i, dim in enumerate(data_shape) if dim is None} - shared_axes = set(self.share_stats_axes) - if not unknown_axes.issubset(shared_axes): - raise ConfigError(f'expected share_stats_axes ({self.share_stats_axes}) ' - 'to contain unknown axes for given data shape ' - f'({data_shape})') + if not dynamic: + unknown_axes = {i for i, dim in enumerate(data_shape) if dim is None} + shared_axes = set(self.share_stats_axes) + if not unknown_axes.issubset(shared_axes): + raise ConfigError(f'expected share_stats_axes ({self.share_stats_axes})' + ' to contain unknown axes for given data shape ' + f'({data_shape})') if self.ema_update_count < 1: raise ConfigError( diff --git a/aqt/tensorflow/aqt_tensor.py b/aqt/tensorflow/aqt_tensor.py index 01edd84c..ef4365f4 100644 --- a/aqt/tensorflow/aqt_tensor.py +++ b/aqt/tensorflow/aqt_tensor.py @@ -192,10 +192,29 @@ def _get_stats_shape( return stats_shape +def _init_dynamic_stats( + stats_config: aqt_config.StatsConfig, + x: tf.Tensor, + init_value: float = 0.0, +) -> tf.Tensor: + """Initializes a dynamic statistical tensor.""" + # assume x is of dynamic shape and we want to have a constant tensor with the + # shape of x except for shared statistics axes where dimensions are ones. + rank = len(x.shape.as_list()) + indicies = [] + for i in range(rank): + if i in stats_config.share_stats_axes: + indicies.append(slice(0, 1)) + else: + indicies.append(slice(None)) + ones = tf.ones_like(x[indicies], dtype=x.dtype) + return ones * init_value + + def _bound( calibration_config: aqt_config.CalibrationConfig, lp_order: int, - stats_shape: Iterable[int], + init_bound: tf.Tensor, sum_of_ones: Optional[tf.Tensor], max_of_abs_vals: Optional[tf.Tensor], sum_of_l1_vals: Optional[tf.Tensor], @@ -203,7 +222,7 @@ def _bound( divide_fn: _DivideFn, ) -> tf.Tensor: """Computes the upper bound.""" - bound = tf.ones(stats_shape) * calibration_config.const_bound_coeff + bound = init_bound + calibration_config.const_bound_coeff if calibration_config.l1_dev_coeff: l1_dev = divide_fn(sum_of_l1_vals, sum_of_ones) bound += calibration_config.l1_dev_coeff * l1_dev @@ -223,8 +242,8 @@ def _dynamic_bound( weight: Optional[tf.Tensor], ) -> tf.Tensor: """Compute the upper bound on input tensor values dynamically.""" - config.validate(x.shape.as_list()) - stats_shape = _get_stats_shape(config, x.shape.as_list()) + config.validate(x.shape.as_list(), dynamic=True) + init_bound = _init_dynamic_stats(config, x, init_value=0.0) divide_fn = tf.math.divide_no_nan if config.safe_divide else tf.divide sum_of_ones = max_of_abs_vals = sum_of_l1_vals = sum_of_lp_vals = None if any([ @@ -242,7 +261,7 @@ def _dynamic_bound( return _bound( calibration_config, config.lp_order, - stats_shape, + init_bound, sum_of_ones, max_of_abs_vals, sum_of_l1_vals, @@ -336,7 +355,7 @@ def bound( # return _bound( calibration_config, self._config.lp_order, - self.stats_shape, + tf.zeros(self.stats_shape, dtype=tf.float32), self._sum_of_ones.read_value(), self._max_of_abs_vals, self._sum_of_l1_vals.read_value(), @@ -917,8 +936,7 @@ def dynamic_clip_range( def case_fn(config: aqt_config.AqtTensorConfig) -> tf.Tensor: if isinstance(config.quant_config, aqt_config.FloatConfig): - stats_shape = _get_stats_shape(self.config.stats_config, sample.shape) - return tf.zeros(stats_shape, dtype=sample.dtype) + return tf.zeros_like(inv_scale, dtype=inv_scale.dtype) # We return the range derived from the inverse scale, rather than # from the stats themselves, to respect freezing settings and @@ -943,12 +961,8 @@ def _fresh_dynamic_scale( # We shouldn't return the scale if the given config contains FloatConfig # and no emulation; # fill with a poison value if we get into this situation. - stats_shape = _get_stats_shape(self.config.stats_config, sample.shape) - nan = tf.constant( - float('nan'), - sample.dtype, - stats_shape, - ) + nan = _init_dynamic_stats( + self.config.stats_config, sample, init_value=float('nan')) return nan, nan x_bound = _dynamic_bound( @@ -975,9 +989,7 @@ def _get_dynamic_quant_scale( # We intentionally initialize scale to zero to fail loudly if someone uses # a parameter such as scale without properly update()-ing it. - stats_shape = _get_stats_shape(self.config.stats_config, sample.shape) - zeros = tf.zeros(stats_shape, dtype=sample.dtype) - + zeros = _init_dynamic_stats(self.config.stats_config, sample, init_value=0) def case_fn(config): # only need to update the event_count for dynamic quantizer updates = [] @@ -1005,3 +1017,15 @@ def case_fn(config): should_scale = self._should_scale(train) return self._maybe_fallback_to_ones(should_scale, scale, inv_scale) + + +def get_tensor_quantizer( + data_shape: Iterable[Optional[int]], + config: aqt_config.AqtScheduleConfig, + name: str, + ) -> TensorQuantizer | DynamicTensorQuantizer: + stats_shape = _get_stats_shape(config.stats_config, data_shape) + # try to use dynamic tensor quantizer if statistics should be dynamic + if None in stats_shape: + return DynamicTensorQuantizer(data_shape, config, name=name) + return TensorQuantizer(data_shape, config, name=name)