Skip to content

Commit

Permalink
Add DynamicTensorQuantizer class and use it for gradient quantization…
Browse files Browse the repository at this point in the history
… in aqt_einsum

PiperOrigin-RevId: 569188954
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Sep 29, 2023
1 parent db2bdaa commit f4c88da
Show file tree
Hide file tree
Showing 5 changed files with 776 additions and 298 deletions.
64 changes: 22 additions & 42 deletions aqt/tensorflow/aqt_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,29 +324,6 @@ def get_einsum_transpose(eq: str, swap_ans: bool = False) -> str:
return '{},{}->{}'.format(out_dims, y_dims, x_dims)


def _maybe_random(
random_gen: Optional[tf.random.Generator],
shape: Iterable[int],
dtype: tf.dtypes.DType,
) -> Optional[tf.Tensor]:
"""Maybe generate random floats in [-0.5, 0.5] to perturb gradients."""
if random_gen is None:
return None
return random_gen.uniform(shape, -0.5, 0.5, dtype=dtype)


def _round(
x_quantizer: aqt_tensor.TensorQuantizer,
x: tf.Tensor,
random: Optional[tf.Tensor],
train: bool,
) -> tf.Tensor:
if random is None:
return x_quantizer._to_quant(x, train=train)
assert x.shape == random.shape, (x.shape, random.shape)
return x_quantizer._to_quant(x + random, train=train)


def einsum(
eq: str, #
lhs_quantizer: aqt_tensor.TensorQuantizer,
Expand All @@ -355,9 +332,8 @@ def einsum(
rhs: tf.Tensor,
train: bool = True,
quantize_bwd: bool = False,
lhs_grad_quantizer: Optional[aqt_tensor.TensorQuantizer] = None,
rhs_grad_quantizer: Optional[aqt_tensor.TensorQuantizer] = None,
random_gen: Optional[tf.random.Generator] = None,
lhs_grad_quantizer: Optional[aqt_tensor.DynamicTensorQuantizer] = None,
rhs_grad_quantizer: Optional[aqt_tensor.DynamicTensorQuantizer] = None,
**tf_einsum_kwargs,
) -> tf.Tensor:
"""Performs a quantized two-argument :py:func:`tf.einsum`.
Expand All @@ -382,9 +358,6 @@ def einsum(
the einsum equation, `grad,rhs->lhs_grad`, in the backward pass.
rhs_grad_quantizer: A `TensorQuantizer` for grad, which is used to quantize
the einsum equation, `grad,lhs->rhs_grad`, in the backward pass.
random_gen: A `tf.random.Generator` used to generate random numbers between
[0.5, 0.5] added to the gradients before quantization in the backward
pass.
**tf_einsum_kwargs: Keyword arguments to pass onto `einsum`.
Returns:
Expand Down Expand Up @@ -526,9 +499,17 @@ def bwd(grad: tf.Tensor) -> tf.Tensor:
lhs_scaled = lhs_scale * lhs
rhs_scaled = rhs_scale * rhs

if quantize_bwd:
# Stochastic rounding is necessary for gradient quantization. We
# call uniform() once and share it across both scaled gradients to
# avoid potential bottlenecks with random number generation.
random = tf.random.uniform(
tf.shape(grad), -0.5, 0.5, dtype=grad.dtype
)

def _bwd(
eq: str,
grad_quantizer: Optional[aqt_tensor.TensorQuantizer],
grad_quantizer: Optional[aqt_tensor.DynamicTensorQuantizer],
y_quantizer: aqt_tensor.TensorQuantizer,
grad: tf.Tensor,
qy: tf.Tensor,
Expand All @@ -552,19 +533,18 @@ def _bwd(
# We assume the backward-pass quantization is dynamic so no need
# to pass weight when updating stats but still need _last_update
# to switch tensor configs.
update = grad_quantizer.update(
grad,
weight=None,
event_count=lhs_quantizer._last_update,
grad_scale, grad_inv_scale = (
grad_quantizer._get_dynamic_quant_scale(
grad,
weight=None,
event_count=lhs_quantizer._last_update,
train=train,
)
)
with tf.control_dependencies([update]):
grad_scale, grad_inv_scale = grad_quantizer._get_quant_scale(
train
)
grad_scaled = grad_scale * grad
random = _maybe_random(random_gen, grad.shape, grad.dtype)
qgrad = _round(grad_quantizer, grad_scaled, random, train)
assert len(grad_inv_scale.shape) == len(qgrad.shape)
grad_scaled = grad_scale * grad
grad_scaled = grad_scaled + random
qgrad = grad_quantizer._to_quant(grad_scaled, train=train)
assert len(grad_inv_scale.shape) == len(qgrad.shape)
else:
qgrad = grad
grad_inv_scale = None
Expand Down
50 changes: 16 additions & 34 deletions aqt/tensorflow/aqt_einsum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def _einsum_op(
quantize_bwd: bool = False,
lhs_bwd_config: Optional[aqt_config.AqtScheduleConfig] = None,
rhs_bwd_config: Optional[aqt_config.AqtScheduleConfig] = None,
random_noise_seed: Optional[int] = 1234,
**einsum_kwargs,
) -> tf.Tensor:
"""Updates quantizers at event_count=0 and computes einsum."""
Expand All @@ -105,17 +104,13 @@ def _einsum_op(
lhs_bwd_tq, rhs_bwd_tq = None, None
grad_shape = aqt_einsum.get_out_shape(eq, lhs.shape, rhs.shape)
if lhs_bwd_config:
lhs_bwd_tq = aqt_tensor.TensorQuantizer(
lhs_bwd_tq = aqt_tensor.DynamicTensorQuantizer(
grad_shape, lhs_bwd_config, name="lhs_bwd"
)
if rhs_bwd_config:
rhs_bwd_tq = aqt_tensor.TensorQuantizer(
rhs_bwd_tq = aqt_tensor.DynamicTensorQuantizer(
grad_shape, rhs_bwd_config, name="rhs_bwd"
)
if quantize_bwd and random_noise_seed is not None:
random_gen = tf.random.Generator.from_seed(random_noise_seed)
else:
random_gen = None

event_count = tf.constant(0, tf.int64)
updates = [
Expand All @@ -133,7 +128,6 @@ def _einsum_op(
quantize_bwd,
lhs_bwd_tq,
rhs_bwd_tq,
random_gen=random_gen,
**einsum_kwargs,
)

Expand Down Expand Up @@ -623,6 +617,9 @@ def _get_grad_config(eq: str,
bwd_eq = aqt_einsum.get_einsum_transpose(eq, swap_ans=swap_ans)
# 16 bits to preserve gradients
grad_config, _ = _exact_schedule_config(16, bwd_eq, 1.0)
grad_config.use_quantized_variable = False
for tc in grad_config.tensor_configs:
tc.freeze_scale_at_begin = False
return grad_config

lhs_bwd_config = _get_grad_config(eq, False)
Expand Down Expand Up @@ -721,16 +718,14 @@ def test_vars_over_inputs_at_inference(self, eq, quantize_bwd):
rhs_tq = aqt_tensor.TensorQuantizer(rhs.shape, rhs_config, name="rhs")
if quantize_bwd:
grad_shape = aqt_einsum.get_out_shape(eq, lhs.shape, rhs.shape)
lhs_bwd_tq = aqt_tensor.TensorQuantizer(
lhs_bwd_tq = aqt_tensor.DynamicTensorQuantizer(
grad_shape, lhs_bwd_config, name="lhs_bwd"
)
rhs_bwd_tq = aqt_tensor.TensorQuantizer(
rhs_bwd_tq = aqt_tensor.DynamicTensorQuantizer(
grad_shape, rhs_bwd_config, name="rhs_bwd"
)
random_gen = tf.random.Generator.from_seed(1234)
else:
lhs_bwd_tq = rhs_bwd_tq = None
random_gen = None

# Update at least once to initialize scale, then grab the expected
# value while in training mode.
Expand All @@ -750,7 +745,6 @@ def test_vars_over_inputs_at_inference(self, eq, quantize_bwd):
quantize_bwd=quantize_bwd,
lhs_grad_quantizer=lhs_bwd_tq,
rhs_grad_quantizer=rhs_bwd_tq,
random_gen=random_gen,
)

with self.cached_session() as sess, sess.as_default():
Expand All @@ -767,7 +761,6 @@ def test_vars_over_inputs_at_inference(self, eq, quantize_bwd):
quantize_bwd=quantize_bwd,
lhs_grad_quantizer=lhs_bwd_tq,
rhs_grad_quantizer=rhs_bwd_tq,
random_gen=random_gen,
)

self.assertAllEqual(actual, expected)
Expand All @@ -792,16 +785,14 @@ def test_float_config_not_save_quantized_var(self, eq, quantize_bwd):
rhs_tq = aqt_tensor.TensorQuantizer(rhs.shape, rhs_config, name="rhs")
if quantize_bwd:
grad_shape = aqt_einsum.get_out_shape(eq, lhs.shape, rhs.shape)
lhs_bwd_tq = aqt_tensor.TensorQuantizer(
lhs_bwd_tq = aqt_tensor.DynamicTensorQuantizer(
grad_shape, lhs_bwd_config, name="lhs_bwd"
)
rhs_bwd_tq = aqt_tensor.TensorQuantizer(
rhs_bwd_tq = aqt_tensor.DynamicTensorQuantizer(
grad_shape, rhs_bwd_config, name="rhs_bwd"
)
random_gen = tf.random.Generator.from_seed(1234)
else:
lhs_bwd_tq = rhs_bwd_tq = None
random_gen = None

event_count = tf.constant(0, tf.int64)

Expand All @@ -820,7 +811,6 @@ def test_float_config_not_save_quantized_var(self, eq, quantize_bwd):
quantize_bwd=quantize_bwd,
lhs_grad_quantizer=lhs_bwd_tq,
rhs_grad_quantizer=rhs_bwd_tq,
random_gen=random_gen,
)
# Although the input tensors are non-zeros, the result of einsum with
# inference mode should be zeros because lhs uses zero-initialized
Expand Down Expand Up @@ -875,7 +865,6 @@ def test_exact_grads(self, eq, quantize_bwd):
)
)

random_noise_seed = 1234 if quantize_bwd else None
actual_fwd = _einsum_op(
eq,
lhs,
Expand All @@ -885,7 +874,6 @@ def test_exact_grads(self, eq, quantize_bwd):
quantize_bwd=quantize_bwd,
lhs_bwd_config=lhs_bwd_config,
rhs_bwd_config=rhs_bwd_config,
random_noise_seed=random_noise_seed,
)
expected_fwd = tf.einsum(eq, lhs, rhs)

Expand Down Expand Up @@ -980,7 +968,7 @@ def test_consistent_bwd_improves_grads(self, eq):
eq, quantize_bwd=True, dynamic_bwd_quant=True,
)
)
def get_perturbed_gradients(random_noise_seed):
def get_perturbed_gradients(step_i):
actual_fwd = _einsum_op(
eq,
lhs,
Expand All @@ -990,17 +978,13 @@ def get_perturbed_gradients(random_noise_seed):
quantize_bwd=True,
lhs_bwd_config=lhs_bwd_config,
rhs_bwd_config=rhs_bwd_config,
random_noise_seed=random_noise_seed,
varscope_name=f"einsum_seed_{random_noise_seed}",
varscope_name=f"einsum_seed_{step_i}",
)
return tf.gradients([actual_fwd], [lhs, rhs])

exact_fwd = tf.einsum(eq, lhs, rhs)
exact = tf.gradients([exact_fwd], [lhs, rhs])

biased = get_perturbed_gradients(None)
biased_errors = [tf.linalg.norm(i - j) for i, j in zip(biased, exact)]

num_samples = 8
qgrad_samples = [get_perturbed_gradients(i) for i in range(num_samples)]
estimate1 = qgrad_samples[0]
Expand All @@ -1015,13 +999,11 @@ def get_error(estimate):
with self.cached_session() as sess, sess.as_default():
tf.global_variables_initializer().run()

for biased_g, exact_g, sample_error, ensemble_err, biased_err in zip(
biased, exact, sample_errors, ensemble_errors, biased_errors
):
# Check dynamic backward quant is inexact
self.assertNotAllEqual(biased_g, exact_g)
# unbiased estimate should have smaller errors than the biased one
self.assertAllLess(ensemble_err, biased_err)
for estimate1_g, exact_g, sample_error, ensemble_err in zip(
estimate1, exact, sample_errors, ensemble_errors
):
# Check dynamic backward quant should be close
self.assertAllClose(estimate1_g, exact_g, rtol=1e-2)
# the unbiased estimate should eventually converge or make improvement
self.assertAllLess(ensemble_err, sample_error)

Expand Down
Loading

0 comments on commit f4c88da

Please sign in to comment.