Skip to content

Commit

Permalink
Add basic sparsity support.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693889997
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Nov 7, 2024
1 parent 6134c4b commit 97e8f29
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 25 deletions.
2 changes: 2 additions & 0 deletions aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,8 @@ def _maybe_dequant(
output = input_qtensor.dequant()
else:
output = input_qtensor.qvalue
if input_qtensor.sparsity_mask is not None:
output = output * input_qtensor.sparsity_mask
return output

# Dequantize before the lax dg call if in fake quant mode
Expand Down
8 changes: 6 additions & 2 deletions aqt/jax/v2/aqt_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,11 @@ def calibrate(
)
assert self.calibration is not None, msg
assert self._calibrator is not None, "forgot self.init_calibration()?"
scale, bias = self._calibrator.get_scale_and_bias(
x, shared_axes, self.numerics, self.context

scale, bias, sparsity_mask = (
self._calibrator.get_scale_and_bias_and_sparsity(
x, shared_axes, self.numerics, self.context
)
)
if self.scale_stop_grad:
# TODO(lew): Does not matter in DG, because we are using custom gradient.
Expand All @@ -144,6 +147,7 @@ def calibrate(

qt = aqt_tensor.QTensor(
qvalue=None,
sparsity_mask=sparsity_mask,
scale=scale,
scale_t=None,
bias=bias,
Expand Down
14 changes: 14 additions & 0 deletions aqt/jax/v2/aqt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class QTensor:
# Use dequant() method to "decompress" to the original tensor.
qvalue: None | ArrayT

sparsity_mask: None | ArrayT = flax.struct.field(
pytree_node=True, default=None
)

# (scale == None) means that scale is unknown/invalid;
# Otherwise, check dequant(self) for semantics.
scale: None | list[ArrayT]
Expand Down Expand Up @@ -137,6 +141,13 @@ def quant(self, x) -> Self:
for b in self.bias:
qvalue += b

# TODO(lew): We could apply sparsity AFTER biases and maybe we could still
# efficiently compute the post-matmul bias correction (how big is it?)
# But the math is more complex, and it is not a superset, it is additional
# option, so we can do it later.
if self.sparsity_mask is not None:
qvalue = qvalue * self.sparsity_mask

for s in self.scale:
# TODO(lew): We could store s_inv for faster activation quantization.
s_inv = jax.lax.reciprocal(s)
Expand Down Expand Up @@ -171,6 +182,9 @@ def dequant(self) -> jnp.ndarray:
for s in self.scale:
ret *= s

if self.sparsity_mask is not None:
ret = ret * self.sparsity_mask

# Apply bias after all rescaling is done. There may be more biases than
# scales, e.g. in native asymmetric matmul output dequantization.
for b in self.bias:
Expand Down
30 changes: 16 additions & 14 deletions aqt/jax/v2/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ class Calibration(abc.ABC):
po2_scale: bool = utils.static_field(default=False)

@abc.abstractmethod
def get_scale_and_bias(
def get_scale_and_bias_and_sparsity(
self,
x: jnp.ndarray,
shared_axes: None | Sequence[utils.AxisIdx],
numerics_: numerics.AqtNumerics,
context: None | utils.Context = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], None | jnp.ndarray]:
"""Returns the quantizaiton scale and bias for the given input tensor."""
# NOTE: The sparsity, has to be compatible with scale and bias.
# The equation is defind in QTensor.quant() and QTensor.dequant() functions.
# NOTE: The scale and bias calculation are handled by the Calibration
# class because there is not a single order in which they should be
# calculated. In the case of symmetric quantization, the scale depends on
Expand All @@ -69,13 +71,13 @@ class ConstantCalibration(Calibration):
bound: jnp.ndarray | float
bias: None | jnp.ndarray | float = None

def get_scale_and_bias(
def get_scale_and_bias_and_sparsity(
self,
x: jnp.ndarray,
shared_axes: None | Sequence[utils.AxisIdx],
numerics_: numerics.AqtNumerics,
context: None | utils.Context = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], None | jnp.ndarray]:
del context
if isinstance(self.bound, float) and self.bound <= 0.0:
raise ValueError(f'{self.bound=} should be positive.')
Expand All @@ -98,7 +100,7 @@ def get_scale_and_bias(
bias = [jnp.full(x.shape, self.bias, x.dtype)]
else:
bias = [self.bias.astype(dtype)]
return [scale.astype(dtype)], bias
return [scale.astype(dtype)], bias, None


@utils.flax_slots_kw_only_dataclass
Expand All @@ -112,13 +114,13 @@ class AbsMaxCalibration(Calibration):

clipping_scale: None | float = None

def get_scale_and_bias(
def get_scale_and_bias_and_sparsity(
self,
x: jnp.ndarray,
shared_axes: None | Sequence[utils.AxisIdx],
numerics_: numerics.AqtNumerics,
context: None | utils.Context = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], None | jnp.ndarray]:
"""Calibration.
Args:
Expand Down Expand Up @@ -152,7 +154,7 @@ def get_scale_and_bias(

scale = bound / numerics_.get_quant_bound()
scale = ceil_to_po2(scale) if self.po2_scale else scale
return [scale.astype(dtype)], []
return [scale.astype(dtype)], [], None


@utils.flax_slots_kw_only_dataclass
Expand All @@ -168,13 +170,13 @@ class AbsMeanCalibration(Calibration):
clipping_scale: float
p: float

def get_scale_and_bias(
def get_scale_and_bias_and_sparsity(
self,
x: jnp.ndarray,
shared_axes: None | Sequence[utils.AxisIdx],
numerics_: numerics.AqtNumerics,
context: None | utils.Context = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], None | jnp.ndarray]:
"""Calibration."""
del context
assert shared_axes is not None
Expand All @@ -189,7 +191,7 @@ def get_scale_and_bias(

scale = abs_mean / numerics_.get_quant_bound()
scale = ceil_to_po2(scale) if self.po2_scale else scale
return [scale.astype(dtype)], []
return [scale.astype(dtype)], [], None


@utils.flax_slots_kw_only_dataclass
Expand All @@ -212,13 +214,13 @@ class SnrBasedAutoCalibration(Calibration):

auto_clip_search_config: utils.AutoScaleSearchConfig

def get_scale_and_bias(
def get_scale_and_bias_and_sparsity(
self,
x: jnp.ndarray,
shared_axes: None | Sequence[utils.AxisIdx],
numerics_: numerics.AqtNumerics,
context: None | utils.Context = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], None | jnp.ndarray]:
"""Produces the scale for quantization based on SNR values.
Args:
Expand Down Expand Up @@ -278,7 +280,7 @@ def get_scale_and_bias(
bound = abs_max * best_subchannel_clip_scales
scale = bound / numerics_.get_quant_bound()
scale = ceil_to_po2(scale) if self.po2_scale else scale
return [scale.astype(dtype)], []
return [scale.astype(dtype)], [], None

def _update_best_clip_scales_and_max_snr(
self,
Expand Down
12 changes: 6 additions & 6 deletions aqt/jax/v2/flax/aqt_flax_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ def get_bound(
# Maybe wait for the JAX language upgrade to have a better support for this?
return sum_of_max.value / count.value

def get_scale_and_bias(
def get_scale_and_bias_and_sparsity(
self,
x: jnp.ndarray,
shared_axes: None | Sequence[utils.AxisIdx],
numerics_: numerics.AqtNumerics,
context: None | utils.Context = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], None | jnp.ndarray]:
dtype = self.dtype if self.dtype is not None else x.dtype
bound = self.get_bound(x, shared_axes, context)
scale = bound / numerics_.get_quant_bound()
scale = calibration.ceil_to_po2(scale) if self.po2_scale else scale
return [scale.astype(dtype)], []
return [scale.astype(dtype)], [], None


# TODO: b/335764538 - Check the math correctness of the module.
Expand Down Expand Up @@ -237,15 +237,15 @@ def init_var_fn(init_val: float) -> jnp.ndarray:
+ self.const_bound_coeff
)

def get_scale_and_bias(
def get_scale_and_bias_and_sparsity(
self,
x: jnp.ndarray,
shared_axes: None | Sequence[utils.AxisIdx],
numerics_: numerics.AqtNumerics,
context: None | utils.Context = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], None | jnp.ndarray]:
dtype = self.dtype if self.dtype is not None else x.dtype
bound = self.get_bound(x, shared_axes, context)
scale = bound / numerics_.get_quant_bound()
scale = calibration.ceil_to_po2(scale) if self.po2_scale else scale
return [scale.astype(dtype)], []
return [scale.astype(dtype)], [], None
6 changes: 3 additions & 3 deletions aqt/jax/v2/flax/delayed_scaling_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,18 @@ def get_bound(
amax_history_mutable_arr[:] = new_history[:]
return new_bound.reshape((1,) * len(x.shape))

def get_scale_and_bias(
def get_scale_and_bias_and_sparsity(
self,
x: jnp.ndarray,
shared_axes: None | Sequence[utils.AxisIdx],
numerics_: numerics.AqtNumerics,
context: None | utils.Context = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
) -> tuple[list[jnp.ndarray], list[jnp.ndarray], None | jnp.ndarray]:
dtype = self.dtype if self.dtype is not None else x.dtype
bound = self.get_bound(x, shared_axes, context)
scale = bound / numerics_.get_quant_bound()
scale = calibration.ceil_to_po2(scale) if self.po2_scale else scale
return [scale.astype(dtype)], []
return [scale.astype(dtype)], [], None

def compute_bound(self, amax, prev_bound):
new_bound = jnp.copy(amax)
Expand Down

0 comments on commit 97e8f29

Please sign in to comment.