Skip to content

Commit

Permalink
Fix dynamic shape when creating zeros or ones
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570135092
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Oct 2, 2023
1 parent 452ab61 commit b1f0503
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 24 deletions.
17 changes: 10 additions & 7 deletions aqt/common/aqt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,15 @@ class StatsConfig(_BaseConfig):

safe_divide: bool = False

def validate(self, data_shape: List[Optional[int]]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
def validate(self, data_shape: List[Optional[int]], # pytype: disable=signature-mismatch # overriding-parameter-count-checks
dynamic: bool = False):
"""Validates this StatsConfig for the provided data shape.
Args:
data_shape: the shape of the input tensor which will be quantized with
self as the statistics configuration. If an entry is None, this
indicates a dimension whose size is unknown at graph compilation time.
dynamic: whether the quantization is dynamic.
Raises:
ConfigError: if any of the specified share_stats_axes are not between
Expand All @@ -169,12 +171,13 @@ def validate(self, data_shape: List[Optional[int]]): # pytype: disable=signatur
raise ConfigError(
f'share_stats_axes ({self.share_stats_axes}) must be strictly sorted')

unknown_axes = {i for i, dim in enumerate(data_shape) if dim is None}
shared_axes = set(self.share_stats_axes)
if not unknown_axes.issubset(shared_axes):
raise ConfigError(f'expected share_stats_axes ({self.share_stats_axes}) '
'to contain unknown axes for given data shape '
f'({data_shape})')
if not dynamic:
unknown_axes = {i for i, dim in enumerate(data_shape) if dim is None}
shared_axes = set(self.share_stats_axes)
if not unknown_axes.issubset(shared_axes):
raise ConfigError(f'expected share_stats_axes ({self.share_stats_axes})'
' to contain unknown axes for given data shape '
f'({data_shape})')

if self.ema_update_count < 1:
raise ConfigError(
Expand Down
58 changes: 41 additions & 17 deletions aqt/tensorflow/aqt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,37 @@ def _get_stats_shape(
return stats_shape


def _init_dynamic_stats(
stats_config: aqt_config.StatsConfig,
x: tf.Tensor,
init_value: float = 0.0,
) -> tf.Tensor:
"""Initializes a dynamic statistical tensor."""
# assume x is of dynamic shape and we want to have a constant tensor with the
# shape of x except for shared statistics axes where dimensions are ones.
rank = len(x.shape.as_list())
indicies = []
for i in range(rank):
if i in stats_config.share_stats_axes:
indicies.append(slice(0, 1))
else:
indicies.append(slice(None))
ones = tf.ones_like(x[indicies], dtype=x.dtype)
return ones * init_value


def _bound(
calibration_config: aqt_config.CalibrationConfig,
lp_order: int,
stats_shape: Iterable[int],
init_bound: tf.Tensor,
sum_of_ones: Optional[tf.Tensor],
max_of_abs_vals: Optional[tf.Tensor],
sum_of_l1_vals: Optional[tf.Tensor],
sum_of_lp_vals: Optional[tf.Tensor],
divide_fn: _DivideFn,
) -> tf.Tensor:
"""Computes the upper bound."""
bound = tf.ones(stats_shape) * calibration_config.const_bound_coeff
bound = init_bound + calibration_config.const_bound_coeff
if calibration_config.l1_dev_coeff:
l1_dev = divide_fn(sum_of_l1_vals, sum_of_ones)
bound += calibration_config.l1_dev_coeff * l1_dev
Expand All @@ -223,8 +242,8 @@ def _dynamic_bound(
weight: Optional[tf.Tensor],
) -> tf.Tensor:
"""Compute the upper bound on input tensor values dynamically."""
config.validate(x.shape.as_list())
stats_shape = _get_stats_shape(config, x.shape.as_list())
config.validate(x.shape.as_list(), dynamic=True)
init_bound = _init_dynamic_stats(config, x, init_value=0.0)
divide_fn = tf.math.divide_no_nan if config.safe_divide else tf.divide
sum_of_ones = max_of_abs_vals = sum_of_l1_vals = sum_of_lp_vals = None
if any([
Expand All @@ -242,7 +261,7 @@ def _dynamic_bound(
return _bound(
calibration_config,
config.lp_order,
stats_shape,
init_bound,
sum_of_ones,
max_of_abs_vals,
sum_of_l1_vals,
Expand Down Expand Up @@ -336,7 +355,7 @@ def bound( #
return _bound(
calibration_config,
self._config.lp_order,
self.stats_shape,
tf.zeros(self.stats_shape, dtype=tf.float32),
self._sum_of_ones.read_value(),
self._max_of_abs_vals,
self._sum_of_l1_vals.read_value(),
Expand Down Expand Up @@ -917,8 +936,7 @@ def dynamic_clip_range(

def case_fn(config: aqt_config.AqtTensorConfig) -> tf.Tensor:
if isinstance(config.quant_config, aqt_config.FloatConfig):
stats_shape = _get_stats_shape(self.config.stats_config, sample.shape)
return tf.zeros(stats_shape, dtype=sample.dtype)
return tf.zeros_like(inv_scale, dtype=inv_scale.dtype)

# We return the range derived from the inverse scale, rather than
# from the stats themselves, to respect freezing settings and
Expand All @@ -943,12 +961,8 @@ def _fresh_dynamic_scale(
# We shouldn't return the scale if the given config contains FloatConfig
# and no emulation;
# fill with a poison value if we get into this situation.
stats_shape = _get_stats_shape(self.config.stats_config, sample.shape)
nan = tf.constant(
float('nan'),
sample.dtype,
stats_shape,
)
nan = _init_dynamic_stats(
self.config.stats_config, sample, init_value=float('nan'))
return nan, nan

x_bound = _dynamic_bound(
Expand All @@ -975,9 +989,7 @@ def _get_dynamic_quant_scale(

# We intentionally initialize scale to zero to fail loudly if someone uses
# a parameter such as scale without properly update()-ing it.
stats_shape = _get_stats_shape(self.config.stats_config, sample.shape)
zeros = tf.zeros(stats_shape, dtype=sample.dtype)

zeros = _init_dynamic_stats(self.config.stats_config, sample, init_value=0)
def case_fn(config):
# only need to update the event_count for dynamic quantizer
updates = []
Expand Down Expand Up @@ -1005,3 +1017,15 @@ def case_fn(config):
should_scale = self._should_scale(train)

return self._maybe_fallback_to_ones(should_scale, scale, inv_scale)


def get_tensor_quantizer(
data_shape: Iterable[Optional[int]],
config: aqt_config.AqtScheduleConfig,
name: str,
) -> TensorQuantizer | DynamicTensorQuantizer:
stats_shape = _get_stats_shape(config.stats_config, data_shape)
# try to use dynamic tensor quantizer if statistics should be dynamic
if None in stats_shape:
return DynamicTensorQuantizer(data_shape, config, name=name)
return TensorQuantizer(data_shape, config, name=name)

0 comments on commit b1f0503

Please sign in to comment.