Skip to content

Commit

Permalink
Support for "middle" scale in "matmul" like SmoothQuant, AWQ, etc.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676603741
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Oct 1, 2024
1 parent 9acd003 commit 9637f5f
Show file tree
Hide file tree
Showing 7 changed files with 506 additions and 34 deletions.
131 changes: 119 additions & 12 deletions aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,17 @@ def dot_general_raw_make(
rhs = aqt_quantizer.quantizer_make(
rhs_bits, initialize_calibration=initialize_calibration
)
dg_quantizer = DefaultDotGeneralQuantizer(lhs=lhs, rhs=rhs)
# TODO(lew): This code (and surrounding code)is duplicated.
# We should dedup.
lhs_mid = aqt_quantizer.quantizer_make(
None, initialize_calibration=initialize_calibration
)
rhs_mid = aqt_quantizer.quantizer_make(
None, initialize_calibration=initialize_calibration
)
dg_quantizer = DefaultDotGeneralQuantizer(
lhs=lhs, rhs=rhs, lhs_mid=lhs_mid, rhs_mid=rhs_mid
)

return DotGeneralRaw(
lhs=lhs_cfg,
Expand Down Expand Up @@ -418,9 +428,36 @@ class DefaultDotGeneralQuantizer(DotGeneralQuantizer):
lhs: aqt_quantizer.Quantizer
rhs: aqt_quantizer.Quantizer

# Quantizers for "middle" scale in "matmul" like SmoothQuant, AWQ, etc.
lhs_mid: aqt_quantizer.Quantizer
rhs_mid: aqt_quantizer.Quantizer

# The amount (exponent) of the scales that should be transferred to the
# other side. 0.0 = nothing, 1.0 = all.
lhs_mid_alpha: float | None = None
rhs_mid_alpha: float | None = None

# This is a hack to make QTensors compatible with the current
# _qtensor_dot_general.
# The QTensors that are returned do not include the mid-scales.
# But this is ok, because the skipped mid-scales are reciprocal of each other
# and they would cancel out in _qtensor_dot_general anyway.
# A good design would be to hardcode skip_mid_scales=False because it would
# maintain semantics of QTensor (QTensor.dequant).
# This also is needed to send a correct QTensor to backprop
# in use_fwd_quant=True mode.
# We don't do it now mostly because it would require a separate mechanism
# in _qtensor_dot_general to skip the mid-scales
# (which do cancel each other mathematically).
skip_mid_scales: bool = True

def init_calibration(self):
self.lhs.init_calibration()
self.rhs.init_calibration()
if self.lhs_mid_alpha is not None:
self.lhs_mid.init_calibration()
if self.rhs_mid_alpha is not None:
self.rhs_mid.init_calibration()

def calibrate(
self,
Expand All @@ -432,10 +469,10 @@ def calibrate(
) -> tuple[
tuple[jax.Array, aqt_tensor.QTensor], tuple[jax.Array, aqt_tensor.QTensor]
]:
if dimension_numbers is None:
lhs_calib, rhs_calib = None, None
else:
if dimension_numbers is not None:
(lhs_ca, rhs_ca), (lhs_ba, rhs_ba) = dimension_numbers
lhs_ra = utils.get_remaining_axes(lhs.ndim, lhs_ca, lhs_ba)
rhs_ra = utils.get_remaining_axes(rhs.ndim, rhs_ca, rhs_ba)

def _get_calibration_axes(
mode: CalibrationMode,
Expand All @@ -447,17 +484,76 @@ def _get_calibration_axes(
match mode:
case CalibrationMode.REMAINING_AXIS:
calibration_axes = utils.get_remaining_axes(ndim, ca, ba)
m = 'mid-quantization is not supported for REMAINING_AXIS mode'
assert self.lhs_mid_alpha is None and self.rhs_mid_alpha is None, m
case CalibrationMode.CONTRACTING_AXIS:
calibration_axes = ca
case _:
raise ValueError(f'Unknown calibration mode: {mode}')
return calibration_axes

lhs_calib = _get_calibration_axes(lhs_mode, lhs.ndim, lhs_ca, lhs_ba)
rhs_calib = _get_calibration_axes(rhs_mode, rhs.ndim, rhs_ca, rhs_ba)
lhs_calib_axes = _get_calibration_axes(lhs_mode, lhs.ndim, lhs_ca, lhs_ba)
rhs_calib_axes = _get_calibration_axes(rhs_mode, rhs.ndim, rhs_ca, rhs_ba)
else:
(lhs_ra, rhs_ra) = (None, None)
lhs_calib_axes = None
rhs_calib_axes = None

def dezero(x):
return jnp.where(x == 0.0, jnp.ones_like(x), x)

if self.lhs_mid_alpha is not None:
assert self.lhs_mid is not None
lhs_mid_qt = self.lhs_mid.calibrate(lhs, calibration_axes=lhs_ra)
assert len(lhs_mid_qt.scale) == 1, 'you must set some numerics'
lhs_mid_scale = dezero(lhs_mid_qt.scale[0])
lhs_mid_scale = lhs_mid_scale**self.lhs_mid_alpha
lhs_mid_scale_t = transpose.lhs_scale_transpose_for_rhs_input(
lhs_mid_scale, dimension_numbers, rhs.shape
)
else:
lhs_mid_scale = 1.0
lhs_mid_scale_t = 1.0

if self.rhs_mid_alpha is not None:
assert self.rhs_mid is not None
rhs_mid_qt = self.rhs_mid.calibrate(rhs, calibration_axes=rhs_ra)
assert len(rhs_mid_qt.scale) == 1, 'you must set some numerics'
rhs_mid_scale = dezero(rhs_mid_qt.scale[0])
rhs_mid_scale = rhs_mid_scale**self.rhs_mid_alpha
rhs_mid_scale_t = transpose.rhs_scale_transpose_for_lhs_input(
rhs_mid_scale, dimension_numbers, lhs.shape
)
else:
rhs_mid_scale = 1.0
rhs_mid_scale_t = 1.0

# This condition can be considered an optimization.
# ATM it also allows us to not deal with 1.0 being a scalar.
if self.lhs_mid_alpha is not None or self.rhs_mid_alpha is not None:
# Combined SmoothQuant scales
lhs_mid_scale_combined = lhs_mid_scale / rhs_mid_scale_t
rhs_mid_scale_combined = rhs_mid_scale / lhs_mid_scale_t

# Apply the combined scales before per-tensor calibration
lhs_mid = lhs / lhs_mid_scale_combined
rhs_mid = rhs / rhs_mid_scale_combined

# Per-tensor calibration, same as "else" branch.
lhs_qt = self.lhs.calibrate(lhs_mid, calibration_axes=lhs_calib_axes)
rhs_qt = self.rhs.calibrate(rhs_mid, calibration_axes=rhs_calib_axes)

# To maintain QTensor.dequant semantics, we need to append the combined
# scales.
assert lhs_qt.scale is not None
assert rhs_qt.scale is not None
if not self.skip_mid_scales:
lhs_qt.scale.append(lhs_mid_scale_combined)
rhs_qt.scale.append(rhs_mid_scale_combined)
else:
lhs_qt = self.lhs.calibrate(lhs, calibration_axes=lhs_calib_axes)
rhs_qt = self.rhs.calibrate(rhs, calibration_axes=rhs_calib_axes)

lhs_qt = self.lhs.calibrate(lhs, calibration_axes=lhs_calib)
rhs_qt = self.rhs.calibrate(rhs, calibration_axes=rhs_calib)
return ((lhs, lhs_qt), (rhs, rhs_qt))

def calculate_qvalue(
Expand Down Expand Up @@ -595,18 +691,24 @@ def _maybe_use_fwd_quant(
A tuple of updated (lhs, rhs). If use_fwd_quant is True, lhs is multiplied
with rhs scale, while rhs is set to the original rhs's qvalue.
"""
fwd_quantized = rhs.qx.scale is not None and len(rhs.qx.scale) == 1
scale_count = -1
if rhs.qx.scale is not None:
scale_count = len(rhs.qx.scale)
msg = (
f'Found use_fwd_quant is {use_fwd_quant} in bwd, but fwd is not'
' quantized.'
f'Found use_fwd_quant is {use_fwd_quant} in bwd. '
'It is supported only if there is exactly one scale in a good shape.\n'
f'{scale_count=}'
)
if use_fwd_quant:
assert fwd_quantized, msg
assert scale_count == 1, msg
if rhs.qx.bias:
raise NotImplementedError(
'Quantization biases are not supported in forward quantization.'
)

# It this transpose fails or the multpilication below fails,
# we have some misconfiguration. One way to deal with it is
# set use_fwd_quant to False.
scale_t = transpose.rhs_scale_transpose_for_lhs_input(
rhs.qx.scale[0], dimension_numbers, lhs.shape
)
Expand Down Expand Up @@ -769,6 +871,11 @@ def _maybe_dequant(

dtypes_can_be_scaled = [jnp.bfloat16, jnp.float32, jnp.float64]

# If the transposes below fail it might be because of a misconfiguration.
# For instance "mid" quantization in DefaultDotGeneralQuantizer is not
# compatible with DequantMode.OTHER_INPUT
# TODO(lew): One way to deal with it is to have a per-scale DequantMode.

if cfg.lhs.dequant_mode == DequantMode.OTHER_INPUT:
assert rhs_qin.dtype in dtypes_can_be_scaled
for scale in lhs_qt.scale:
Expand Down
94 changes: 91 additions & 3 deletions aqt/jax/v2/aqt_dot_general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from aqt.jax.v2 import stochastic_rounding
from aqt.jax.v2 import tiled_dot_general
from aqt.jax.v2 import utils

import aqt.jax.v2.aqt_dot_general as aqt
from aqt.jax.v2.numerics import int_numerics
from aqt.jax.v2.numerics import no_numerics
Expand Down Expand Up @@ -112,14 +111,27 @@ def rand_unif(shape, maxval, seed, dtype=jnp.float32):


def test_eq(name, a, b):
assert a.shape == b.shape, (a.shape, b.shape)
# TODO(lew): use library function.
mean_err = jnp.mean(jnp.abs(a - b))
if mean_err != 0.0:
print("mean_err =", mean_err)
print(a.shape)
print(a[:3, :3])
match a.ndim:
case 1:
print(a[:3])
case 2:
print(a[:3, :3])
print("sum =", jnp.sum(a))

print(b.shape)
print(b[:3, :3])
match b.ndim:
case 1:
print(b[:3])
case 2:
print(b[:3, :3])
print("sum =", jnp.sum(b))

print(f"FAIL: {name}")
assert False

Expand Down Expand Up @@ -206,6 +218,7 @@ def _modify_dg(
disable_rounding: bool = False,
fwd_lhs_tricky_clip_and_round: bool = False,
local_aqt: aqt.LocalAqt | None = None,
use_mid_quant: bool = False,
clip_gradient: bool = False,
) -> aqt.DotGeneral:
dg = copy.deepcopy(readonly_dg)
Expand Down Expand Up @@ -267,6 +280,14 @@ def disable_quant(c):
if not isinstance(dg.fwd.dg_quantizer.rhs.numerics, no_numerics.NoNumerics):
dg.dlhs.rhs.use_fwd_quant = use_fwd_quant

if use_mid_quant:
config.set_use_mid_quant(
dg,
fwd_mid_alpha_both=1.0,
dlhs_mid_alpha_both=1.0,
drhs_mid_alpha_both=1.0,
)

if local_aqt is not None:
# Currently we are not supporting local_aqt in fwd pass
# dg.fwd.local_aqt = local_aqt
Expand All @@ -291,6 +312,7 @@ def _aqt_dg_full_lr_diff(
disable_rounding: bool = False,
fwd_lhs_tricky_clip_and_round: bool = False,
local_aqt: aqt.LocalAqt | None = None,
use_mid_quant: bool = False,
*,
readonly_dg: aqt.DotGeneral,
dims: jax.lax.DotDimensionNumbers,
Expand All @@ -306,6 +328,7 @@ def _aqt_dg_full_lr_diff(
disable_rounding=disable_rounding,
fwd_lhs_tricky_clip_and_round=fwd_lhs_tricky_clip_and_round,
local_aqt=local_aqt,
use_mid_quant=use_mid_quant,
clip_gradient=clip_gradient,
)
dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None)
Expand All @@ -319,6 +342,7 @@ def _aqt_dg_full(
disable_rounding: bool = False,
fwd_lhs_tricky_clip_and_round: bool = False,
local_aqt: aqt.LocalAqt | None = None,
use_mid_quant: bool = False,
*,
readonly_dg: aqt.DotGeneral,
dims: jax.lax.DotDimensionNumbers,
Expand All @@ -333,6 +357,7 @@ def _aqt_dg_full(
disable_rounding=disable_rounding,
fwd_lhs_tricky_clip_and_round=fwd_lhs_tricky_clip_and_round,
local_aqt=local_aqt,
use_mid_quant=use_mid_quant,
readonly_dg=readonly_dg,
dims=dims,
clip_gradient=clip_gradient,
Expand Down Expand Up @@ -598,6 +623,27 @@ def test_dot_general_calibration_with_contracting_axis(
),
])

check([
(
"midQ FQ ",
aqt_dg_full(
aqt.DequantMode.THIS_INPUT,
use_mid_quant=True,
use_fwd_quant=False,
),
dict(),
),
(
"midQ ",
aqt_dg_full(
aqt.DequantMode.OUTPUT,
use_mid_quant=True,
use_fwd_quant=False,
),
dict(),
),
])

check([
(
"fwd_quant=F",
Expand Down Expand Up @@ -1129,6 +1175,48 @@ def test_per_subchannel(self):
x = qx.dequant()
self.assertEqual(x.shape, (4, 4, 4))

def test_mid_quantization(self):
def make_binary_dg(use_mid):
mid_alpha: str | float = 0.5 if use_mid else config.SKIP
bits = 1
dg = config.config_v4(
fwd_bits=bits,
dlhs_bits=bits,
drhs_bits=bits,
fwd_mid_alpha_both=mid_alpha,
dlhs_mid_alpha_both=mid_alpha,
drhs_mid_alpha_both=mid_alpha,
)
# for exact equality
dg.fwd.dg_quantizer.lhs.numerics.preserve_max_val = True
dg.fwd.dg_quantizer.rhs.numerics.preserve_max_val = True
# PO2 scales for exact equality
dg.fwd.dg_quantizer.lhs.calibration = functools.partial(
dg.fwd.dg_quantizer.lhs.calibration, po2_scale=True
)
dg.fwd.dg_quantizer.rhs.calibration = functools.partial(
dg.fwd.dg_quantizer.rhs.calibration, po2_scale=True
)
return dg

# Note that we are testing with mid_alpha = 0.5, and po2 scales.
a = jnp.array([[1.0, 2.0, 4.0], [1.0, 4.0, 16.0]])
b = jnp.array([[4.0, 2.0, 1.0], [16.0, 4.0, 1.0]])
ret = jnp.array([4.0, 16.0]) * 3.0
dimension_numbers = (((1,), (1,)), ((0,), (0,)))

# Sanity check.
test_eq("", jax.lax.dot_general(a, b, dimension_numbers), ret)

# Without mid quantization all values in a, b will be rounded up
# to 4.0 or 8.0 because of binary quantization.
ret_no_mid = jnp.array([3 * 4.0**2, 3 * 16.0**2])
test_eq("", make_binary_dg(False)(a, b, dimension_numbers), ret_no_mid)

# With mid scales all values in a, b will be equal to 2.0 and
# binary quantization will be lossless.
test_eq("", make_binary_dg(True)(a, b, dimension_numbers), ret)


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit 9637f5f

Please sign in to comment.