Skip to content

Commit

Permalink
Enhance 3.x torch algorithm entry (#1779)
Browse files Browse the repository at this point in the history
Enhance 3.x torch algorithm entry
---------

Signed-off-by: yuwenzho <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yuwenzho and pre-commit-ci[bot] authored May 9, 2024
1 parent 43c3580 commit ec49a29
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 90 deletions.
3 changes: 1 addition & 2 deletions neural_compressor/torch/algorithms/base_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,13 @@ def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):

return model

def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): # pragma: no cover
def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any):
"""Execute according to mode.
Args:
model (torch.nn.Module): The model to be executed.
mode (Mode): The mode of current phase, including 'prepare', 'convert' and 'quantize'.
"""
# TODO: remove '# pragma: no cover' once CI test can cover this function
if mode == Mode.PREPARE:
model = self.prepare(model, *args, **kwargs)
elif mode == Mode.CONVERT:
Expand Down
11 changes: 5 additions & 6 deletions neural_compressor/torch/algorithms/weight_only/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class AutoRoundQuantizer(Quantizer):
def __init__(
self,
weight_config: dict = {},
quant_config: dict = None,
enable_full_range: bool = False,
batch_size: int = 8,
amp: bool = True,
Expand All @@ -51,8 +51,8 @@ def __init__(
"""Init a AutQRoundQuantizer object.
Args:
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
weight_config={
quant_config (dict): Configuration for weight quantization (default is None).
quant_config={
'layer1':##layer_name
{
'data_type': 'int',
Expand Down Expand Up @@ -89,9 +89,8 @@ def __init__(
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
have different choices.
"""
super().__init__(weight_config)
super().__init__(quant_config)
self.tokenizer = None
self.weight_config = weight_config
self.enable_full_range = enable_full_range
self.batch_size = batch_size
self.amp = amp
Expand Down Expand Up @@ -125,7 +124,7 @@ def prepare(self, model: torch.nn.Module, *args, **kwargs):
self.rounder = AutoRoundProcessor(
model=model,
tokenizer=None,
weight_config=self.weight_config,
weight_config=self.quant_config or {},
enable_full_range=self.enable_full_range,
batch_size=self.batch_size,
amp=self.amp,
Expand Down
131 changes: 56 additions & 75 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@
StaticQuantConfig,
TEQConfig,
)
from neural_compressor.torch.utils import Mode, is_ipex_imported, logger, register_algo
from neural_compressor.torch.utils import (
Mode,
get_quantizer,
is_ipex_imported,
logger,
postprocess_model,
register_algo,
)
from neural_compressor.torch.utils.constants import PT2E_STATIC_QUANT


Expand Down Expand Up @@ -69,17 +76,9 @@ def rtn_entry(
"double_quant_group_size": quant_config.double_quant_group_size,
}

if getattr(model, "quantizer", False):
quantizer = model.quantizer
else:
quantizer = RTNQuantizer(quant_config=weight_config)

quantizer = get_quantizer(model, quantizer_cls=RTNQuantizer, quant_config=weight_config)
model = quantizer.execute(model, mode=mode)

if getattr(model, "quantizer", False):
del model.quantizer
else:
model.quantizer = quantizer
postprocess_model(model, mode, quantizer)
return model


Expand Down Expand Up @@ -126,15 +125,11 @@ def gptq_entry(
)
kwargs.pop("example_inputs")
logger.warning("lm_head in transformer model is skipped by GPTQ")
if getattr(model, "quantizer", False):
quantizer = model.quantizer
else:
quantizer = GPTQuantizer(quant_config=weight_config)

quantizer = get_quantizer(model, quantizer_cls=GPTQuantizer, quant_config=weight_config)
model = quantizer.execute(model, mode=mode, *args, **kwargs)
if getattr(model, "quantizer", False):
del model.quantizer
else:
model.quantizer = quantizer
postprocess_model(model, mode, quantizer)

return model


Expand Down Expand Up @@ -180,17 +175,10 @@ def static_quant_entry(
inplace = kwargs.get("inplace", True)
assert example_inputs is not None, "Please provide example_inputs for static quantization."

if getattr(model, "quantizer", False):
quantizer = model.quantizer
else:
quantizer = StaticQuantQuantizer(quant_config=quant_config_mapping)

quantizer = get_quantizer(model, quantizer_cls=StaticQuantQuantizer, quant_config=quant_config_mapping)
model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
postprocess_model(model, mode, quantizer)

if getattr(model, "quantizer", False):
del model.quantizer
else:
model.quantizer = quantizer
return model


Expand Down Expand Up @@ -323,11 +311,7 @@ def awq_quantize_entry(
example_inputs = kwargs.get("example_inputs", None)
assert example_inputs is not None, "Please provide example_inputs for AWQ quantization."

if getattr(model, "quantizer", False):
quantizer = model.quantizer
else:
quantizer = AWQQuantizer(quant_config=weight_config)

quantizer = get_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config)
model = quantizer.execute(
model,
mode=mode,
Expand All @@ -340,11 +324,8 @@ def awq_quantize_entry(
return_int=return_int,
use_full_range=use_full_range,
)
postprocess_model(model, mode, quantizer)

if getattr(model, "quantizer", False):
del model.quantizer
else:
model.quantizer = quantizer
return model


Expand Down Expand Up @@ -386,10 +367,18 @@ def teq_quantize_entry(
absorb_to_layer = quant_config.absorb_to_layer
folding = quant_config.folding
assert isinstance(model, torch.nn.Module), "only support torch module"
quantizer = TEQuantizer(
quant_config=weight_config, folding=folding, absorb_to_layer=absorb_to_layer, example_inputs=example_inputs

quantizer = get_quantizer(
model,
quantizer_cls=TEQuantizer,
quant_config=weight_config,
folding=folding,
absorb_to_layer=absorb_to_layer,
example_inputs=example_inputs,
)
model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
postprocess_model(model, mode, quantizer)

return model


Expand Down Expand Up @@ -436,35 +425,33 @@ def autoround_quantize_entry(
scale_dtype = quant_config.scale_dtype

kwargs.pop("example_inputs")
if getattr(model, "quantizer", False):
quantizer = model.quantizer
else:
quantizer = AutoRoundQuantizer(
weight_config=weight_config,
enable_full_range=enable_full_range,
batch_size=batch_size,
lr_scheduler=lr_scheduler,
use_quant_input=use_quant_input,
enable_minmax_tuning=enable_minmax_tuning,
lr=lr,
minmax_lr=minmax_lr,
low_gpu_mem_usage=low_gpu_mem_usage,
iters=iters,
seqlen=seqlen,
n_samples=n_samples,
sampler=sampler,
seed=seed,
n_blocks=n_blocks,
gradient_accumulate_steps=gradient_accumulate_steps,
not_use_best_mse=not_use_best_mse,
dynamic_max_gap=dynamic_max_gap,
scale_dtype=scale_dtype,
)

quantizer = get_quantizer(
model,
quantizer_cls=AutoRoundQuantizer,
quant_config=weight_config,
enable_full_range=enable_full_range,
batch_size=batch_size,
lr_scheduler=lr_scheduler,
use_quant_input=use_quant_input,
enable_minmax_tuning=enable_minmax_tuning,
lr=lr,
minmax_lr=minmax_lr,
low_gpu_mem_usage=low_gpu_mem_usage,
iters=iters,
seqlen=seqlen,
n_samples=n_samples,
sampler=sampler,
seed=seed,
n_blocks=n_blocks,
gradient_accumulate_steps=gradient_accumulate_steps,
not_use_best_mse=not_use_best_mse,
dynamic_max_gap=dynamic_max_gap,
scale_dtype=scale_dtype,
)
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
if getattr(model, "quantizer", False):
del model.quantizer
else:
model.quantizer = quantizer
postprocess_model(model, mode, quantizer)

logger.info("AutoRound quantization done.")
return model

Expand All @@ -482,17 +469,11 @@ def hqq_entry(
from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer

logger.info("Quantize model with the HQQ algorithm.")
if getattr(model, "quantizer", False):
quantizer = model.quantizer
else:
quantizer = HQQuantizer(quant_config=configs_mapping)

quantizer = get_quantizer(model, quantizer_cls=HQQuantizer, quant_config=configs_mapping)
model = quantizer.execute(model, mode=mode)
postprocess_model(model, mode, quantizer)

if getattr(model, "quantizer", False):
del model.quantizer
else:
model.quantizer = quantizer
return model


Expand Down
6 changes: 2 additions & 4 deletions neural_compressor/torch/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def prepare(
quant_config: BaseConfig,
inplace: bool = True,
example_inputs: Any = None,
): # pragma: no cover
):
"""Prepare the model for calibration.
Insert observers into the model so that it can monitor the input and output tensors during calibration.
Expand All @@ -105,7 +105,6 @@ def prepare(
Returns:
prepared and calibrated module.
"""
# TODO: remove '# pragma: no cover' once CI test can cover this function
prepared_model = model if inplace else copy.deepcopy(model)
registered_configs = config_registry.get_cls_configs()
if isinstance(quant_config, dict):
Expand Down Expand Up @@ -148,7 +147,7 @@ def convert(
model: torch.nn.Module,
quant_config: BaseConfig = None,
inplace: bool = True,
): # pragma: no cover
):
"""Convert the prepared model to a quantized model.
Args:
Expand All @@ -159,7 +158,6 @@ def convert(
Returns:
The quantized model.
"""
# TODO: remove '# pragma: no cover' once CI test can cover this function
q_model = model if inplace else copy.deepcopy(model)

# TODO: Optimize the check for prepared flag after adding HQT FP8 Quant
Expand Down
41 changes: 41 additions & 0 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,47 @@ class Mode(Enum):
QUANTIZE = "quantize"


def get_quantizer(model, quantizer_cls, quant_config=None, *args, **kwargs):
"""Get the quantizer.
Initialize a quantizer or get `quantizer` attribute from model.
Args:
model (torch.nn.Module): pytorch model.
quantizer_cls (Quantizer): quantizer class of a specific algorithm.
quant_config (dict, optional): Specifies how to apply the algorithm on the given model.
Defaults to None.
Returns:
quantizer object.
"""
if not hasattr(model, "quantizer"):
quantizer = quantizer_cls(quant_config=quant_config, *args, **kwargs)
return quantizer
else:
return model.quantizer


def postprocess_model(model, mode, quantizer):
"""Process `quantizer` attribute of model according to current phase.
In `prepare` phase, the `quantizer` is set as an attribute of the model
to avoid redundant initialization during `convert` phase.
In 'convert' or 'quantize' phase, the unused `quantizer` attribute is removed.
Args:
model (torch.nn.Module): pytorch model.
mode (Mode): The mode of current phase, including 'prepare', 'convert' and 'quantize'.
quantizer (Quantizer): quantizer object.
"""
if mode == Mode.PREPARE:
model.quantizer = quantizer
elif mode == Mode.CONVERT or mode == Mode.QUANTIZE:
if getattr(model, "quantizer", False):
del model.quantizer


def create_quant_spec_from_config(dtype, sym, granularity, algo) -> QuantizationSpec:
dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8}
qscheme_mapping = {
Expand Down
2 changes: 1 addition & 1 deletion test/3x/torch/quantization/weight_only/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_quantizer(self):
"sym": False,
}
}
quantizer = AutoRoundQuantizer(weight_config=weight_config)
quantizer = AutoRoundQuantizer(quant_config=weight_config)
fp32_model = gpt_j_model

# quantizer execute
Expand Down
7 changes: 5 additions & 2 deletions test/3x/torch/quantization/weight_only/test_mixed_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@


def run_fn(model):
model(torch.tensor([[10, 20, 30]], dtype=torch.long))
model(torch.tensor([[40, 50, 60]], dtype=torch.long))
# GPTQ uses ValueError to reduce computation when collecting input data of the first block
# It's special for UTs, no need to add this wrapper in examples.
with pytest.raises(ValueError):
model(torch.tensor([[10, 20, 30]], dtype=torch.long))
model(torch.tensor([[40, 50, 60]], dtype=torch.long))


class TestMixedTwoAlgo:
Expand Down

0 comments on commit ec49a29

Please sign in to comment.