From 9637f5fbccf465b4a56588bb25d1531a60bfdd48 Mon Sep 17 00:00:00 2001 From: Cerebra Catalyst Team Date: Thu, 19 Sep 2024 16:19:33 -0700 Subject: [PATCH] Support for "middle" scale in "matmul" like SmoothQuant, AWQ, etc. PiperOrigin-RevId: 676603741 --- aqt/jax/v2/aqt_dot_general.py | 131 ++++++++++- aqt/jax/v2/aqt_dot_general_test.py | 94 +++++++- aqt/jax/v2/config.py | 84 ++++++- aqt/jax/v2/config_test.py | 209 +++++++++++++++++- .../gptq/examples/gptq_flax_e2e_model.py | 9 +- aqt/jax/v2/transpose.py | 2 +- aqt/jax/v2/utils.py | 11 +- 7 files changed, 506 insertions(+), 34 deletions(-) diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index 0885558a..7146e787 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -131,7 +131,17 @@ def dot_general_raw_make( rhs = aqt_quantizer.quantizer_make( rhs_bits, initialize_calibration=initialize_calibration ) - dg_quantizer = DefaultDotGeneralQuantizer(lhs=lhs, rhs=rhs) + # TODO(lew): This code (and surrounding code)is duplicated. + # We should dedup. + lhs_mid = aqt_quantizer.quantizer_make( + None, initialize_calibration=initialize_calibration + ) + rhs_mid = aqt_quantizer.quantizer_make( + None, initialize_calibration=initialize_calibration + ) + dg_quantizer = DefaultDotGeneralQuantizer( + lhs=lhs, rhs=rhs, lhs_mid=lhs_mid, rhs_mid=rhs_mid + ) return DotGeneralRaw( lhs=lhs_cfg, @@ -418,9 +428,36 @@ class DefaultDotGeneralQuantizer(DotGeneralQuantizer): lhs: aqt_quantizer.Quantizer rhs: aqt_quantizer.Quantizer + # Quantizers for "middle" scale in "matmul" like SmoothQuant, AWQ, etc. + lhs_mid: aqt_quantizer.Quantizer + rhs_mid: aqt_quantizer.Quantizer + + # The amount (exponent) of the scales that should be transferred to the + # other side. 0.0 = nothing, 1.0 = all. + lhs_mid_alpha: float | None = None + rhs_mid_alpha: float | None = None + + # This is a hack to make QTensors compatible with the current + # _qtensor_dot_general. + # The QTensors that are returned do not include the mid-scales. + # But this is ok, because the skipped mid-scales are reciprocal of each other + # and they would cancel out in _qtensor_dot_general anyway. + # A good design would be to hardcode skip_mid_scales=False because it would + # maintain semantics of QTensor (QTensor.dequant). + # This also is needed to send a correct QTensor to backprop + # in use_fwd_quant=True mode. + # We don't do it now mostly because it would require a separate mechanism + # in _qtensor_dot_general to skip the mid-scales + # (which do cancel each other mathematically). + skip_mid_scales: bool = True + def init_calibration(self): self.lhs.init_calibration() self.rhs.init_calibration() + if self.lhs_mid_alpha is not None: + self.lhs_mid.init_calibration() + if self.rhs_mid_alpha is not None: + self.rhs_mid.init_calibration() def calibrate( self, @@ -432,10 +469,10 @@ def calibrate( ) -> tuple[ tuple[jax.Array, aqt_tensor.QTensor], tuple[jax.Array, aqt_tensor.QTensor] ]: - if dimension_numbers is None: - lhs_calib, rhs_calib = None, None - else: + if dimension_numbers is not None: (lhs_ca, rhs_ca), (lhs_ba, rhs_ba) = dimension_numbers + lhs_ra = utils.get_remaining_axes(lhs.ndim, lhs_ca, lhs_ba) + rhs_ra = utils.get_remaining_axes(rhs.ndim, rhs_ca, rhs_ba) def _get_calibration_axes( mode: CalibrationMode, @@ -447,17 +484,76 @@ def _get_calibration_axes( match mode: case CalibrationMode.REMAINING_AXIS: calibration_axes = utils.get_remaining_axes(ndim, ca, ba) + m = 'mid-quantization is not supported for REMAINING_AXIS mode' + assert self.lhs_mid_alpha is None and self.rhs_mid_alpha is None, m case CalibrationMode.CONTRACTING_AXIS: calibration_axes = ca case _: raise ValueError(f'Unknown calibration mode: {mode}') return calibration_axes - lhs_calib = _get_calibration_axes(lhs_mode, lhs.ndim, lhs_ca, lhs_ba) - rhs_calib = _get_calibration_axes(rhs_mode, rhs.ndim, rhs_ca, rhs_ba) + lhs_calib_axes = _get_calibration_axes(lhs_mode, lhs.ndim, lhs_ca, lhs_ba) + rhs_calib_axes = _get_calibration_axes(rhs_mode, rhs.ndim, rhs_ca, rhs_ba) + else: + (lhs_ra, rhs_ra) = (None, None) + lhs_calib_axes = None + rhs_calib_axes = None + + def dezero(x): + return jnp.where(x == 0.0, jnp.ones_like(x), x) + + if self.lhs_mid_alpha is not None: + assert self.lhs_mid is not None + lhs_mid_qt = self.lhs_mid.calibrate(lhs, calibration_axes=lhs_ra) + assert len(lhs_mid_qt.scale) == 1, 'you must set some numerics' + lhs_mid_scale = dezero(lhs_mid_qt.scale[0]) + lhs_mid_scale = lhs_mid_scale**self.lhs_mid_alpha + lhs_mid_scale_t = transpose.lhs_scale_transpose_for_rhs_input( + lhs_mid_scale, dimension_numbers, rhs.shape + ) + else: + lhs_mid_scale = 1.0 + lhs_mid_scale_t = 1.0 + + if self.rhs_mid_alpha is not None: + assert self.rhs_mid is not None + rhs_mid_qt = self.rhs_mid.calibrate(rhs, calibration_axes=rhs_ra) + assert len(rhs_mid_qt.scale) == 1, 'you must set some numerics' + rhs_mid_scale = dezero(rhs_mid_qt.scale[0]) + rhs_mid_scale = rhs_mid_scale**self.rhs_mid_alpha + rhs_mid_scale_t = transpose.rhs_scale_transpose_for_lhs_input( + rhs_mid_scale, dimension_numbers, lhs.shape + ) + else: + rhs_mid_scale = 1.0 + rhs_mid_scale_t = 1.0 + + # This condition can be considered an optimization. + # ATM it also allows us to not deal with 1.0 being a scalar. + if self.lhs_mid_alpha is not None or self.rhs_mid_alpha is not None: + # Combined SmoothQuant scales + lhs_mid_scale_combined = lhs_mid_scale / rhs_mid_scale_t + rhs_mid_scale_combined = rhs_mid_scale / lhs_mid_scale_t + + # Apply the combined scales before per-tensor calibration + lhs_mid = lhs / lhs_mid_scale_combined + rhs_mid = rhs / rhs_mid_scale_combined + + # Per-tensor calibration, same as "else" branch. + lhs_qt = self.lhs.calibrate(lhs_mid, calibration_axes=lhs_calib_axes) + rhs_qt = self.rhs.calibrate(rhs_mid, calibration_axes=rhs_calib_axes) + + # To maintain QTensor.dequant semantics, we need to append the combined + # scales. + assert lhs_qt.scale is not None + assert rhs_qt.scale is not None + if not self.skip_mid_scales: + lhs_qt.scale.append(lhs_mid_scale_combined) + rhs_qt.scale.append(rhs_mid_scale_combined) + else: + lhs_qt = self.lhs.calibrate(lhs, calibration_axes=lhs_calib_axes) + rhs_qt = self.rhs.calibrate(rhs, calibration_axes=rhs_calib_axes) - lhs_qt = self.lhs.calibrate(lhs, calibration_axes=lhs_calib) - rhs_qt = self.rhs.calibrate(rhs, calibration_axes=rhs_calib) return ((lhs, lhs_qt), (rhs, rhs_qt)) def calculate_qvalue( @@ -595,18 +691,24 @@ def _maybe_use_fwd_quant( A tuple of updated (lhs, rhs). If use_fwd_quant is True, lhs is multiplied with rhs scale, while rhs is set to the original rhs's qvalue. """ - fwd_quantized = rhs.qx.scale is not None and len(rhs.qx.scale) == 1 + scale_count = -1 + if rhs.qx.scale is not None: + scale_count = len(rhs.qx.scale) msg = ( - f'Found use_fwd_quant is {use_fwd_quant} in bwd, but fwd is not' - ' quantized.' + f'Found use_fwd_quant is {use_fwd_quant} in bwd. ' + 'It is supported only if there is exactly one scale in a good shape.\n' + f'{scale_count=}' ) if use_fwd_quant: - assert fwd_quantized, msg + assert scale_count == 1, msg if rhs.qx.bias: raise NotImplementedError( 'Quantization biases are not supported in forward quantization.' ) + # It this transpose fails or the multpilication below fails, + # we have some misconfiguration. One way to deal with it is + # set use_fwd_quant to False. scale_t = transpose.rhs_scale_transpose_for_lhs_input( rhs.qx.scale[0], dimension_numbers, lhs.shape ) @@ -769,6 +871,11 @@ def _maybe_dequant( dtypes_can_be_scaled = [jnp.bfloat16, jnp.float32, jnp.float64] + # If the transposes below fail it might be because of a misconfiguration. + # For instance "mid" quantization in DefaultDotGeneralQuantizer is not + # compatible with DequantMode.OTHER_INPUT + # TODO(lew): One way to deal with it is to have a per-scale DequantMode. + if cfg.lhs.dequant_mode == DequantMode.OTHER_INPUT: assert rhs_qin.dtype in dtypes_can_be_scaled for scale in lhs_qt.scale: diff --git a/aqt/jax/v2/aqt_dot_general_test.py b/aqt/jax/v2/aqt_dot_general_test.py index c97e13b4..0c53e62f 100644 --- a/aqt/jax/v2/aqt_dot_general_test.py +++ b/aqt/jax/v2/aqt_dot_general_test.py @@ -24,7 +24,6 @@ from aqt.jax.v2 import stochastic_rounding from aqt.jax.v2 import tiled_dot_general from aqt.jax.v2 import utils - import aqt.jax.v2.aqt_dot_general as aqt from aqt.jax.v2.numerics import int_numerics from aqt.jax.v2.numerics import no_numerics @@ -112,14 +111,27 @@ def rand_unif(shape, maxval, seed, dtype=jnp.float32): def test_eq(name, a, b): + assert a.shape == b.shape, (a.shape, b.shape) # TODO(lew): use library function. mean_err = jnp.mean(jnp.abs(a - b)) if mean_err != 0.0: print("mean_err =", mean_err) print(a.shape) - print(a[:3, :3]) + match a.ndim: + case 1: + print(a[:3]) + case 2: + print(a[:3, :3]) + print("sum =", jnp.sum(a)) + print(b.shape) - print(b[:3, :3]) + match b.ndim: + case 1: + print(b[:3]) + case 2: + print(b[:3, :3]) + print("sum =", jnp.sum(b)) + print(f"FAIL: {name}") assert False @@ -206,6 +218,7 @@ def _modify_dg( disable_rounding: bool = False, fwd_lhs_tricky_clip_and_round: bool = False, local_aqt: aqt.LocalAqt | None = None, + use_mid_quant: bool = False, clip_gradient: bool = False, ) -> aqt.DotGeneral: dg = copy.deepcopy(readonly_dg) @@ -267,6 +280,14 @@ def disable_quant(c): if not isinstance(dg.fwd.dg_quantizer.rhs.numerics, no_numerics.NoNumerics): dg.dlhs.rhs.use_fwd_quant = use_fwd_quant + if use_mid_quant: + config.set_use_mid_quant( + dg, + fwd_mid_alpha_both=1.0, + dlhs_mid_alpha_both=1.0, + drhs_mid_alpha_both=1.0, + ) + if local_aqt is not None: # Currently we are not supporting local_aqt in fwd pass # dg.fwd.local_aqt = local_aqt @@ -291,6 +312,7 @@ def _aqt_dg_full_lr_diff( disable_rounding: bool = False, fwd_lhs_tricky_clip_and_round: bool = False, local_aqt: aqt.LocalAqt | None = None, + use_mid_quant: bool = False, *, readonly_dg: aqt.DotGeneral, dims: jax.lax.DotDimensionNumbers, @@ -306,6 +328,7 @@ def _aqt_dg_full_lr_diff( disable_rounding=disable_rounding, fwd_lhs_tricky_clip_and_round=fwd_lhs_tricky_clip_and_round, local_aqt=local_aqt, + use_mid_quant=use_mid_quant, clip_gradient=clip_gradient, ) dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None) @@ -319,6 +342,7 @@ def _aqt_dg_full( disable_rounding: bool = False, fwd_lhs_tricky_clip_and_round: bool = False, local_aqt: aqt.LocalAqt | None = None, + use_mid_quant: bool = False, *, readonly_dg: aqt.DotGeneral, dims: jax.lax.DotDimensionNumbers, @@ -333,6 +357,7 @@ def _aqt_dg_full( disable_rounding=disable_rounding, fwd_lhs_tricky_clip_and_round=fwd_lhs_tricky_clip_and_round, local_aqt=local_aqt, + use_mid_quant=use_mid_quant, readonly_dg=readonly_dg, dims=dims, clip_gradient=clip_gradient, @@ -598,6 +623,27 @@ def test_dot_general_calibration_with_contracting_axis( ), ]) + check([ + ( + "midQ FQ ", + aqt_dg_full( + aqt.DequantMode.THIS_INPUT, + use_mid_quant=True, + use_fwd_quant=False, + ), + dict(), + ), + ( + "midQ ", + aqt_dg_full( + aqt.DequantMode.OUTPUT, + use_mid_quant=True, + use_fwd_quant=False, + ), + dict(), + ), + ]) + check([ ( "fwd_quant=F", @@ -1129,6 +1175,48 @@ def test_per_subchannel(self): x = qx.dequant() self.assertEqual(x.shape, (4, 4, 4)) + def test_mid_quantization(self): + def make_binary_dg(use_mid): + mid_alpha: str | float = 0.5 if use_mid else config.SKIP + bits = 1 + dg = config.config_v4( + fwd_bits=bits, + dlhs_bits=bits, + drhs_bits=bits, + fwd_mid_alpha_both=mid_alpha, + dlhs_mid_alpha_both=mid_alpha, + drhs_mid_alpha_both=mid_alpha, + ) + # for exact equality + dg.fwd.dg_quantizer.lhs.numerics.preserve_max_val = True + dg.fwd.dg_quantizer.rhs.numerics.preserve_max_val = True + # PO2 scales for exact equality + dg.fwd.dg_quantizer.lhs.calibration = functools.partial( + dg.fwd.dg_quantizer.lhs.calibration, po2_scale=True + ) + dg.fwd.dg_quantizer.rhs.calibration = functools.partial( + dg.fwd.dg_quantizer.rhs.calibration, po2_scale=True + ) + return dg + + # Note that we are testing with mid_alpha = 0.5, and po2 scales. + a = jnp.array([[1.0, 2.0, 4.0], [1.0, 4.0, 16.0]]) + b = jnp.array([[4.0, 2.0, 1.0], [16.0, 4.0, 1.0]]) + ret = jnp.array([4.0, 16.0]) * 3.0 + dimension_numbers = (((1,), (1,)), ((0,), (0,))) + + # Sanity check. + test_eq("", jax.lax.dot_general(a, b, dimension_numbers), ret) + + # Without mid quantization all values in a, b will be rounded up + # to 4.0 or 8.0 because of binary quantization. + ret_no_mid = jnp.array([3 * 4.0**2, 3 * 16.0**2]) + test_eq("", make_binary_dg(False)(a, b, dimension_numbers), ret_no_mid) + + # With mid scales all values in a, b will be equal to 2.0 and + # binary quantization will be lossless. + test_eq("", make_binary_dg(True)(a, b, dimension_numbers), ret) + if __name__ == "__main__": absltest.main() diff --git a/aqt/jax/v2/config.py b/aqt/jax/v2/config.py index f1358f61..c36459cc 100644 --- a/aqt/jax/v2/config.py +++ b/aqt/jax/v2/config.py @@ -13,6 +13,10 @@ # limitations under the License. """Configuration dataclasses.""" +# pylint: disable=g-importing-member +# pylint: disable=unused-import +# pylint: disable=g-explicit-bool-comparison + import copy import functools from typing import Literal, Optional, TypeAlias, Union @@ -24,8 +28,6 @@ from aqt.jax.v2 import utils # Temporary re-export from aqt.jax.v2.aqt_dot_general # TODO(lew): Remove these imports, use setters instead -# pylint: disable=g-importing-member -# pylint: disable=unused-import from aqt.jax.v2.aqt_conv_general import conv_general_dilated_make from aqt.jax.v2.aqt_dot_general import CalibrationMode from aqt.jax.v2.aqt_dot_general import DequantMode @@ -260,12 +262,76 @@ def set_use_fwd_quant( dlhs_use_fwd_quant: Union[bool, None, SkipT], drhs_use_fwd_quant: Union[bool, None, SkipT], ): + """Enable resusing of fwd pass quantization for backprop.""" + msg = 'use_fwd_quant is incompatible with use_mid_quant right now.' + assert cfg.fwd.dg_quantizer.lhs_mid_alpha is None, msg + assert cfg.fwd.dg_quantizer.rhs_mid_alpha is None, msg + assert cfg.dlhs.dg_quantizer.lhs_mid_alpha is None, msg + assert cfg.dlhs.dg_quantizer.rhs_mid_alpha is None, msg + assert cfg.drhs.dg_quantizer.lhs_mid_alpha is None, msg + assert cfg.drhs.dg_quantizer.rhs_mid_alpha is None, msg if dlhs_use_fwd_quant != SKIP: cfg.dlhs.rhs.use_fwd_quant = dlhs_use_fwd_quant if drhs_use_fwd_quant != SKIP: cfg.drhs.rhs.use_fwd_quant = drhs_use_fwd_quant +def set_use_mid_quant( + cfg: DotGeneral, + fwd_mid_alpha_both: Union[SkipT, float], + dlhs_mid_alpha_both: Union[SkipT, float], + drhs_mid_alpha_both: Union[SkipT, float], +): + """Enable middle quantization. Variant of SmoothQuant / AWQ.""" + assert isinstance( + cfg.fwd.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer + ) + assert isinstance( + cfg.dlhs.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer + ) + assert isinstance( + cfg.drhs.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer + ) + + msg = 'use_fwd_quant is incompatible with use_mid_quant right now.' + assert not cfg.dlhs.rhs.use_fwd_quant, msg + assert not cfg.drhs.rhs.use_fwd_quant, msg + + @utils.flax_slots_kw_only_dataclass + class DummyNumerics(numerics.AqtNumerics): + """DummyNumerics for mid-quantization.""" + + def get_quant_bound(self): + return 1.0 + + def get_dtype(self): + assert False, 'Should not request dtype for mid-quantization.' + + def vjp_fwd(self, x, context): + res = () + return x, res + + def vjp_bwd(self, res, grad): + assert res == () + return grad + + if fwd_mid_alpha_both != SKIP: + cfg.fwd.dg_quantizer.lhs_mid.numerics = DummyNumerics() + cfg.fwd.dg_quantizer.rhs_mid.numerics = DummyNumerics() + cfg.fwd.dg_quantizer.lhs_mid_alpha = fwd_mid_alpha_both + cfg.fwd.dg_quantizer.rhs_mid_alpha = fwd_mid_alpha_both + if dlhs_mid_alpha_both != SKIP: + cfg.dlhs.dg_quantizer.lhs_mid.numerics = DummyNumerics() + cfg.dlhs.dg_quantizer.rhs_mid.numerics = DummyNumerics() + cfg.dlhs.dg_quantizer.lhs_mid_alpha = dlhs_mid_alpha_both + cfg.dlhs.dg_quantizer.rhs_mid_alpha = dlhs_mid_alpha_both + if drhs_mid_alpha_both != SKIP: + cfg.drhs.dg_quantizer.lhs_mid.numerics = DummyNumerics() + cfg.drhs.dg_quantizer.rhs_mid.numerics = DummyNumerics() + cfg.drhs.dg_quantizer.lhs_mid_alpha = drhs_mid_alpha_both + cfg.drhs.dg_quantizer.rhs_mid_alpha = drhs_mid_alpha_both + + def set_int_numerics_preserve_zero(cfg: DotGeneral, preserve_zero: bool): """Set preserve_zero for int_numerics.""" assert isinstance( @@ -466,7 +532,10 @@ def dg_raw_cfg(jax_scope_name: str) -> DotGeneralRaw: lhs=tensor_cfg(), rhs=tensor_cfg(), dg_quantizer=aqt_dot_general.DefaultDotGeneralQuantizer( - lhs=quantizer(), rhs=quantizer() + lhs=quantizer(), + rhs=quantizer(), + lhs_mid=quantizer(), + rhs_mid=quantizer(), ), dg_accumulator_dtype=None, local_aqt=None, @@ -624,6 +693,9 @@ def config_v4( drhs_accumulator_dtype: Union[jnp.dtype, None, SkipT] = SKIP, dlhs_use_fwd_quant: Union[bool, None, SkipT] = SKIP, drhs_use_fwd_quant: Union[bool, None, SkipT] = SKIP, + fwd_mid_alpha_both: Union[SkipT, float] = SKIP, + dlhs_mid_alpha_both: Union[SkipT, float] = SKIP, + drhs_mid_alpha_both: Union[SkipT, float] = SKIP, ) -> DotGeneral: """Version 4 of user-visible AQT config.""" cfg = default_unquantized_config() @@ -661,6 +733,12 @@ def config_v4( dlhs_use_fwd_quant=dlhs_use_fwd_quant, drhs_use_fwd_quant=drhs_use_fwd_quant, ) + set_use_mid_quant( + cfg, + fwd_mid_alpha_both=fwd_mid_alpha_both, + dlhs_mid_alpha_both=dlhs_mid_alpha_both, + drhs_mid_alpha_both=drhs_mid_alpha_both, + ) assert cfg.fwd.local_aqt is None, 'local_aqt is not yet supported in fwd.' return cfg diff --git a/aqt/jax/v2/config_test.py b/aqt/jax/v2/config_test.py index 26b6ed18..b8a6ce4b 100644 --- a/aqt/jax/v2/config_test.py +++ b/aqt/jax/v2/config_test.py @@ -28,6 +28,8 @@ def _dot_general_full_init_calibration(cfg): cfg.drhs.dg_quantizer.lhs.init_calibration() cfg.drhs.dg_quantizer.rhs.init_calibration() +# TODO(lew): We should use go/minion, go/blessing or other go/golden-diff-test + class AqtConfigTest(parameterized.TestCase): @@ -90,7 +92,28 @@ def test_config_v4(self): clipping_scale=None), context=Context(key=None, train_step=None, - quant_mode=))), + quant_mode=)), + lhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + rhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + lhs_mid_alpha=None, + rhs_mid_alpha=None, + skip_mid_scales=True), dg_accumulator_dtype=, local_aqt=None, jax_scope_name='aqt_fwd', @@ -135,7 +158,28 @@ def test_config_v4(self): clipping_scale=None), context=Context(key=None, train_step=None, - quant_mode=))), + quant_mode=)), + lhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + rhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + lhs_mid_alpha=None, + rhs_mid_alpha=None, + skip_mid_scales=True), dg_accumulator_dtype=, local_aqt=LocalAqt(contraction_axis_shard_count=2, contraction_axis_shard_size=None, @@ -182,7 +226,28 @@ def test_config_v4(self): clipping_scale=None), context=Context(key=None, train_step=None, - quant_mode=))), + quant_mode=)), + lhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + rhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + lhs_mid_alpha=None, + rhs_mid_alpha=None, + skip_mid_scales=True), dg_accumulator_dtype=, local_aqt=LocalAqt(contraction_axis_shard_count=3, contraction_axis_shard_size=None, @@ -234,7 +299,28 @@ def test_configv4_original(self): clipping_scale=None), context=Context(key=None, train_step=None, - quant_mode=))), + quant_mode=)), + lhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + rhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + lhs_mid_alpha=None, + rhs_mid_alpha=None, + skip_mid_scales=True), dg_accumulator_dtype=, local_aqt=None, jax_scope_name='aqt_fwd', @@ -279,7 +365,28 @@ def test_configv4_original(self): clipping_scale=None), context=Context(key=None, train_step=None, - quant_mode=))), + quant_mode=)), + lhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + rhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + lhs_mid_alpha=None, + rhs_mid_alpha=None, + skip_mid_scales=True), dg_accumulator_dtype=, local_aqt=None, jax_scope_name='aqt_dlhs', @@ -312,7 +419,28 @@ def test_configv4_original(self): clipping_scale=None), context=Context(key=None, train_step=None, - quant_mode=))), + quant_mode=)), + lhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + rhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + lhs_mid_alpha=None, + rhs_mid_alpha=None, + skip_mid_scales=True), dg_accumulator_dtype=None, local_aqt=None, jax_scope_name='aqt_drhs', @@ -355,7 +483,28 @@ def test_config_fwd_fp8(self): clipping_scale=None), context=Context(key=None, train_step=None, - quant_mode=))), + quant_mode=)), + lhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + rhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + lhs_mid_alpha=None, + rhs_mid_alpha=None, + skip_mid_scales=True), dg_accumulator_dtype=, local_aqt=None, jax_scope_name='aqt_fwd', @@ -388,7 +537,28 @@ def test_config_fwd_fp8(self): clipping_scale=None), context=Context(key=None, train_step=None, - quant_mode=))), + quant_mode=)), + lhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + rhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + lhs_mid_alpha=None, + rhs_mid_alpha=None, + skip_mid_scales=True), dg_accumulator_dtype=None, local_aqt=None, jax_scope_name='aqt_dlhs', @@ -421,7 +591,28 @@ def test_config_fwd_fp8(self): clipping_scale=None), context=Context(key=None, train_step=None, - quant_mode=))), + quant_mode=)), + lhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + rhs_mid=Quantizer(numerics=NoNumerics(noise_fn=None, + dtype=None), + calib_shared_axes=None, + scale_stop_grad=True, + calibration=, + _calibrator=None, + context=Context(key=None, + train_step=None, + quant_mode=)), + lhs_mid_alpha=None, + rhs_mid_alpha=None, + skip_mid_scales=True), dg_accumulator_dtype=None, local_aqt=None, jax_scope_name='aqt_drhs', diff --git a/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py b/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py index 3c5f14ad..5f10c10e 100644 --- a/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py +++ b/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py @@ -51,8 +51,15 @@ def update_cfg_with_gptq(aqt_cfg: aqt_dot_general.DotGeneral) -> None: rhs_bits = aqt_cfg.fwd.dg_quantizer.rhs.numerics.bits lhs = aqt_quantizer.quantizer_make(lhs_bits, initialize_calibration=False) rhs = aqt_quantizer.quantizer_make(rhs_bits, initialize_calibration=False) + lhs_mid = aqt_quantizer.quantizer_make(lhs_bits, initialize_calibration=False) + rhs_mid = aqt_quantizer.quantizer_make(rhs_bits, initialize_calibration=False) gptq_dg_quantizer = gptq_dot_general_quantizer.GptqDotGeneralQuantizer( - lhs=lhs, rhs=rhs, sharding_axes=None, quant_collection='gptq' + lhs=lhs, + rhs=rhs, + lhs_mid=lhs_mid, + rhs_mid=rhs_mid, + sharding_axes=None, + quant_collection='gptq', ) aqt_cfg.fwd.dg_quantizer = gptq_dg_quantizer diff --git a/aqt/jax/v2/transpose.py b/aqt/jax/v2/transpose.py index f6d32197..f29fe232 100644 --- a/aqt/jax/v2/transpose.py +++ b/aqt/jax/v2/transpose.py @@ -100,7 +100,7 @@ def _scale_trans(x, ca, ba): ca = list(ca) ba = list(ba) for i in ca: - assert x.shape[i] == 1 + assert x.shape[i] == 1, (x.shape, ca, ba) ra = utils.get_remaining_axes(x.ndim, ca, ba) x = transpose(x, ba + ra + ca) # TODO(lew): x = jnp.squeeze(x, axis=range(len(ba+ra): len(x.shape)) diff --git a/aqt/jax/v2/utils.py b/aqt/jax/v2/utils.py index ada7fe64..4a93f752 100644 --- a/aqt/jax/v2/utils.py +++ b/aqt/jax/v2/utils.py @@ -97,11 +97,12 @@ def dynamic_field(**kwargs): return flax.struct.field(pytree_node=True, **kwargs) -def print_diff(str_a: str, str_b: str): - diff_generator = difflib.context_diff(str_a.split(' '), str_b.split(' ')) - print('Diff:') - for diff in diff_generator: - print(diff) +def print_diff(str_a: str, str_b: str, do_print_diff=False): + if do_print_diff: + diff_generator = difflib.context_diff(str_a.split(' '), str_b.split(' ')) + print('Diff:') + for diff in diff_generator: + print(diff) print(f'first string (actual):\n{str_a}')