Skip to content

Commit

Permalink
Add option use_one_hot_case to use one hot matmul instead of case
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560107892
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Aug 25, 2023
1 parent cf54757 commit 2726aad
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions aqt/tensorflow/aqt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ def default_get_variable(name: str, shape: Iterable[int],
use_resource=True)


def one_hot_case(cases, exclusive=True):
predictors, fns = zip(*cases)
cases = tf.stack([f() for f in fns]) # shape of [num_cases, ...]
one_hot_vec = tf.cast(tf.stack(predictors), # shape of [num_cases]
dtype=cases.dtype)
dependencies = []
if exclusive:
dependencies.append(tf.debugging.assert_equal(
tf.reduce_sum(one_hot_vec), tf.constant(1, cases.dtype),
message='predictors must be exclusive.'))
with tf.control_dependencies(dependencies):
return tf.einsum('i...,i...->...', cases, one_hot_vec)


class Stats:
"""Manages efficient gathering of running statistics."""

Expand Down Expand Up @@ -327,6 +341,9 @@ def __init__(
self._last_update = get_variable('last_update', [], tf.int64,
tf.int64.min)

def last_update(self) -> tf.Tensor:
return self._last_update.read_value()

def tracked_variables(self) -> Dict[str, tf.Variable]:
"""Returns variables used to track updates and calibration variables."""
variables = {
Expand Down Expand Up @@ -363,7 +380,7 @@ def _fresh_scale(
inv_scale = x_bound / clip_bound
return new_scale, inv_scale

def clip_range(self) -> tf.Tensor:
def clip_range(self, use_one_hot_case: bool = False) -> tf.Tensor:
"""Returns the tensor clip range or zeros if no int config is active."""

def case_fn(config: aqt_config.AqtTensorConfig) -> tf.Tensor:
Expand All @@ -380,7 +397,8 @@ def case_fn(config: aqt_config.AqtTensorConfig) -> tf.Tensor:
clip_bound = aqt_common.get_clip_bound(config.quant_config)
return self._inv_scale.read_value() * clip_bound

return self._config_case(case_fn, self._last_update.read_value())
return self._config_case(case_fn, self._last_update.read_value(),
use_one_hot_case=use_one_hot_case)

def update(self, #
sample: Optional[tf.Tensor],
Expand Down Expand Up @@ -415,6 +433,7 @@ def _config_case(
self, #
case_fn: Callable[[aqt_config.AqtTensorConfig], tf.Tensor],
event_count: tf.Tensor,
use_one_hot_case: bool = False,
) -> tf.Tensor:
"""Switches over configs, applying case_fn to active one at event_count."""
assert self.config.tensor_configs, 'There must be at least one config.'
Expand All @@ -424,6 +443,8 @@ def make_case(config):
return pred, lambda: case_fn(config)

cases = [make_case(c) for c in self.config.tensor_configs]
if use_one_hot_case:
return one_hot_case(cases, exclusive=True)
return tf.case(cases, exclusive=True)

def _update_state_config(
Expand Down

0 comments on commit 2726aad

Please sign in to comment.