diff --git a/aqt/tensorflow/aqt_tensor.py b/aqt/tensorflow/aqt_tensor.py index dfbfbd48..2e1c9fa8 100644 --- a/aqt/tensorflow/aqt_tensor.py +++ b/aqt/tensorflow/aqt_tensor.py @@ -89,6 +89,20 @@ def default_get_variable(name: str, shape: Iterable[int], use_resource=True) +def one_hot_case(cases, exclusive=True): + predictors, fns = zip(*cases) + cases = tf.stack([f() for f in fns]) # shape of [num_cases, ...] + one_hot_vec = tf.cast(tf.stack(predictors), # shape of [num_cases] + dtype=cases.dtype) + dependencies = [] + if exclusive: + dependencies.append(tf.debugging.assert_equal( + tf.reduce_sum(one_hot_vec), tf.constant(1, cases.dtype), + message='predictors must be exclusive.')) + with tf.control_dependencies(dependencies): + return tf.einsum('i...,i...->...', cases, one_hot_vec) + + class Stats: """Manages efficient gathering of running statistics.""" @@ -327,6 +341,9 @@ def __init__( self._last_update = get_variable('last_update', [], tf.int64, tf.int64.min) + def last_update(self) -> tf.Tensor: + return self._last_update.read_value() + def tracked_variables(self) -> Dict[str, tf.Variable]: """Returns variables used to track updates and calibration variables.""" variables = { @@ -363,7 +380,7 @@ def _fresh_scale( inv_scale = x_bound / clip_bound return new_scale, inv_scale - def clip_range(self) -> tf.Tensor: + def clip_range(self, use_one_hot_case: bool = False) -> tf.Tensor: """Returns the tensor clip range or zeros if no int config is active.""" def case_fn(config: aqt_config.AqtTensorConfig) -> tf.Tensor: @@ -380,7 +397,8 @@ def case_fn(config: aqt_config.AqtTensorConfig) -> tf.Tensor: clip_bound = aqt_common.get_clip_bound(config.quant_config) return self._inv_scale.read_value() * clip_bound - return self._config_case(case_fn, self._last_update.read_value()) + return self._config_case(case_fn, self._last_update.read_value(), + use_one_hot_case=use_one_hot_case) def update(self, # sample: Optional[tf.Tensor], @@ -415,6 +433,7 @@ def _config_case( self, # case_fn: Callable[[aqt_config.AqtTensorConfig], tf.Tensor], event_count: tf.Tensor, + use_one_hot_case: bool = False, ) -> tf.Tensor: """Switches over configs, applying case_fn to active one at event_count.""" assert self.config.tensor_configs, 'There must be at least one config.' @@ -424,6 +443,8 @@ def make_case(config): return pred, lambda: case_fn(config) cases = [make_case(c) for c in self.config.tensor_configs] + if use_one_hot_case: + return one_hot_case(cases, exclusive=True) return tf.case(cases, exclusive=True) def _update_state_config(