diff --git a/aqt/tensorflow/aqt_tensor.py b/aqt/tensorflow/aqt_tensor.py index dfbfbd48..2e1c9fa8 100644 --- a/aqt/tensorflow/aqt_tensor.py +++ b/aqt/tensorflow/aqt_tensor.py @@ -89,6 +89,20 @@ def default_get_variable(name: str, shape: Iterable[int], use_resource=True) +def one_hot_case(cases, exclusive=True): + predictors, fns = zip(*cases) + cases = tf.stack([f() for f in fns]) # shape of [num_cases, ...] + one_hot_vec = tf.cast(tf.stack(predictors), # shape of [num_cases] + dtype=cases.dtype) + dependencies = [] + if exclusive: + dependencies.append(tf.debugging.assert_equal( + tf.reduce_sum(one_hot_vec), tf.constant(1, cases.dtype), + message='predictors must be exclusive.')) + with tf.control_dependencies(dependencies): + return tf.einsum('i...,i...->...', cases, one_hot_vec) + + class Stats: """Manages efficient gathering of running statistics.""" @@ -327,6 +341,9 @@ def __init__( self._last_update = get_variable('last_update', [], tf.int64, tf.int64.min) + def last_update(self) -> tf.Tensor: + return self._last_update.read_value() + def tracked_variables(self) -> Dict[str, tf.Variable]: """Returns variables used to track updates and calibration variables.""" variables = { @@ -363,7 +380,7 @@ def _fresh_scale( inv_scale = x_bound / clip_bound return new_scale, inv_scale - def clip_range(self) -> tf.Tensor: + def clip_range(self, use_one_hot_case: bool = False) -> tf.Tensor: """Returns the tensor clip range or zeros if no int config is active.""" def case_fn(config: aqt_config.AqtTensorConfig) -> tf.Tensor: @@ -380,7 +397,8 @@ def case_fn(config: aqt_config.AqtTensorConfig) -> tf.Tensor: clip_bound = aqt_common.get_clip_bound(config.quant_config) return self._inv_scale.read_value() * clip_bound - return self._config_case(case_fn, self._last_update.read_value()) + return self._config_case(case_fn, self._last_update.read_value(), + use_one_hot_case=use_one_hot_case) def update(self, # sample: Optional[tf.Tensor], @@ -415,6 +433,7 @@ def _config_case( self, # case_fn: Callable[[aqt_config.AqtTensorConfig], tf.Tensor], event_count: tf.Tensor, + use_one_hot_case: bool = False, ) -> tf.Tensor: """Switches over configs, applying case_fn to active one at event_count.""" assert self.config.tensor_configs, 'There must be at least one config.' @@ -424,6 +443,8 @@ def make_case(config): return pred, lambda: case_fn(config) cases = [make_case(c) for c in self.config.tensor_configs] + if use_one_hot_case: + return one_hot_case(cases, exclusive=True) return tf.case(cases, exclusive=True) def _update_state_config( diff --git a/aqt/tensorflow/aqt_tensor_test.py b/aqt/tensorflow/aqt_tensor_test.py index 8f695c04..88f64d98 100644 --- a/aqt/tensorflow/aqt_tensor_test.py +++ b/aqt/tensorflow/aqt_tensor_test.py @@ -229,6 +229,41 @@ def alt_get_var(name, shape, dtype, init): tf.global_variables_initializer().run() self.assertAllEqual(var_qx, alt_qx) + def test_one_hot_case_methods(self): + """Validates clip_range is the same whether use_one_hot_case.""" + sc = aqt_config.StatsConfig( + ema_update_count=2, + share_stats_axes=[1], + tpu_cross_replica_sum=False, + ) + config = aqt_config.AqtTensorConfig( + begin_at_event=4, + quant_config=aqt_config.IntQuantConfig(bits=8), + calibration_config=aqt_config.CalibrationConfig(const_bound_coeff=1), + freeze_scale_at_begin=False) + config = aqt_config.AqtScheduleConfig(sc, [config]) + config.fill_gaps_with_float_config() + + batches = 5 + + rng = np.random.default_rng(1234) + data_shape = [4, 4] + x = rng.integers(-10, 10, + size=(batches, *data_shape) + ).astype(np.float32) + + with tf.Graph().as_default(): + quant = aqt_tensor.TensorQuantizer(data_shape=data_shape, config=config) + event_count = tf.Variable(0, trainable=False, dtype=tf.int64) + with self.cached_session(): + tf.global_variables_initializer().run() + for i in range(batches): + event_count.assign_add(1) + quant.update(f32(x[i]), None, 1).run() + cr1 = quant.clip_range(use_one_hot_case=True).eval() + cr2 = quant.clip_range(use_one_hot_case=False).eval() + self.assertAllEqual(cr1, cr2) + def extract_referenced_variables(t: tf.Tensor) -> Set[str]: """Returns a set of all referenced variable names in a tensor's graph."""