Skip to content

Commit

Permalink
Add option use_one_hot_case to use one hot matmul instead of case
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560107892
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Aug 25, 2023
1 parent cf54757 commit 7406135
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
25 changes: 23 additions & 2 deletions aqt/tensorflow/aqt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ def default_get_variable(name: str, shape: Iterable[int],
use_resource=True)


def one_hot_case(cases, exclusive=True):
predictors, fns = zip(*cases)
cases = tf.stack([f() for f in fns]) # shape of [num_cases, ...]
one_hot_vec = tf.cast(tf.stack(predictors), # shape of [num_cases]
dtype=cases.dtype)
dependencies = []
if exclusive:
dependencies.append(tf.debugging.assert_equal(
tf.reduce_sum(one_hot_vec), tf.constant(1, cases.dtype),
message='predictors must be exclusive.'))
with tf.control_dependencies(dependencies):
return tf.einsum('i...,i...->...', cases, one_hot_vec)


class Stats:
"""Manages efficient gathering of running statistics."""

Expand Down Expand Up @@ -327,6 +341,9 @@ def __init__(
self._last_update = get_variable('last_update', [], tf.int64,
tf.int64.min)

def last_update(self) -> tf.Tensor:
return self._last_update.read_value()

def tracked_variables(self) -> Dict[str, tf.Variable]:
"""Returns variables used to track updates and calibration variables."""
variables = {
Expand Down Expand Up @@ -363,7 +380,7 @@ def _fresh_scale(
inv_scale = x_bound / clip_bound
return new_scale, inv_scale

def clip_range(self) -> tf.Tensor:
def clip_range(self, use_one_hot_case: bool = False) -> tf.Tensor:
"""Returns the tensor clip range or zeros if no int config is active."""

def case_fn(config: aqt_config.AqtTensorConfig) -> tf.Tensor:
Expand All @@ -380,7 +397,8 @@ def case_fn(config: aqt_config.AqtTensorConfig) -> tf.Tensor:
clip_bound = aqt_common.get_clip_bound(config.quant_config)
return self._inv_scale.read_value() * clip_bound

return self._config_case(case_fn, self._last_update.read_value())
return self._config_case(case_fn, self._last_update.read_value(),
use_one_hot_case=use_one_hot_case)

def update(self, #
sample: Optional[tf.Tensor],
Expand Down Expand Up @@ -415,6 +433,7 @@ def _config_case(
self, #
case_fn: Callable[[aqt_config.AqtTensorConfig], tf.Tensor],
event_count: tf.Tensor,
use_one_hot_case: bool = False,
) -> tf.Tensor:
"""Switches over configs, applying case_fn to active one at event_count."""
assert self.config.tensor_configs, 'There must be at least one config.'
Expand All @@ -424,6 +443,8 @@ def make_case(config):
return pred, lambda: case_fn(config)

cases = [make_case(c) for c in self.config.tensor_configs]
if use_one_hot_case:
return one_hot_case(cases, exclusive=True)
return tf.case(cases, exclusive=True)

def _update_state_config(
Expand Down
35 changes: 35 additions & 0 deletions aqt/tensorflow/aqt_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,41 @@ def alt_get_var(name, shape, dtype, init):
tf.global_variables_initializer().run()
self.assertAllEqual(var_qx, alt_qx)

def test_one_hot_case_methods(self):
"""Validates clip_range is the same whether use_one_hot_case."""
sc = aqt_config.StatsConfig(
ema_update_count=2,
share_stats_axes=[1],
tpu_cross_replica_sum=False,
)
config = aqt_config.AqtTensorConfig(
begin_at_event=4,
quant_config=aqt_config.IntQuantConfig(bits=8),
calibration_config=aqt_config.CalibrationConfig(const_bound_coeff=1),
freeze_scale_at_begin=False)
config = aqt_config.AqtScheduleConfig(sc, [config])
config.fill_gaps_with_float_config()

batches = 5

rng = np.random.default_rng(1234)
data_shape = [4, 4]
x = rng.integers(-10, 10,
size=(batches, *data_shape)
).astype(np.float32)

with tf.Graph().as_default():
quant = aqt_tensor.TensorQuantizer(data_shape=data_shape, config=config)
event_count = tf.Variable(0, trainable=False, dtype=tf.int64)
with self.cached_session():
tf.global_variables_initializer().run()
for i in range(batches):
event_count.assign_add(1)
quant.update(f32(x[i]), None, 1).run()
cr1 = quant.clip_range(use_one_hot_case=True).eval()
cr2 = quant.clip_range(use_one_hot_case=False).eval()
self.assertAllEqual(cr1, cr2)


def extract_referenced_variables(t: tf.Tensor) -> Set[str]:
"""Returns a set of all referenced variable names in a tensor's graph."""
Expand Down

0 comments on commit 7406135

Please sign in to comment.