Skip to content

Commit

Permalink
implement TorchBaseConfig (#1911)
Browse files Browse the repository at this point in the history
Signed-off-by: xin3he <[email protected]>
  • Loading branch information
xin3he authored Jul 16, 2024
1 parent 7a4715c commit be42d03
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
41 changes: 29 additions & 12 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,26 @@ class OperatorConfig(NamedTuple):
valid_func_list: List[Callable] = []


class TorchBaseConfig(BaseConfig):
# re-write func _get_op_name_op_type_config to fallback op_type with string
# because there are some special op_types for IPEX backend: `Linear&Relu`, `Linear&add`, ...
def _get_op_name_op_type_config(self):
op_type_config_dict = dict()
op_name_config_dict = dict()
for name, config in self.local_config.items():
if self._is_op_type(name):
# Convert the Callable to String.
new_name = self._op_type_to_str(name)
op_type_config_dict[new_name] = config
else:
op_name_config_dict[name] = config
op_type_config_dict[name] = config
return op_type_config_dict, op_name_config_dict


######################## RNT Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN, priority=PRIORITY_RTN)
class RTNConfig(BaseConfig):
class RTNConfig(TorchBaseConfig):
"""Config class for round-to-nearest weight-only quantization."""

name = RTN
Expand Down Expand Up @@ -242,7 +259,7 @@ def get_default_double_quant_config(type="BNB_NF4"):

######################## GPTQ Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=GPTQ, priority=PRIORITY_GPTQ)
class GPTQConfig(BaseConfig):
class GPTQConfig(TorchBaseConfig):
"""Config class for GPTQ.
GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.
Expand Down Expand Up @@ -397,7 +414,7 @@ def get_default_gptq_config(processor_type: Optional[Union[str, torch_utils.Proc

######################## AWQ Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=AWQ, priority=PRIORITY_AWQ)
class AWQConfig(BaseConfig):
class AWQConfig(TorchBaseConfig):
"""Config class for AWQ.
AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
Expand Down Expand Up @@ -539,7 +556,7 @@ def get_default_awq_config() -> AWQConfig:

######################## TEQ Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=TEQ, priority=PRIORITY_TEQ)
class TEQConfig(BaseConfig):
class TEQConfig(TorchBaseConfig):
"""Config class for TEQ.
TEQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
Expand Down Expand Up @@ -677,7 +694,7 @@ def get_default_teq_config() -> TEQConfig:

######################## AUTOROUND Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=AUTOROUND, priority=PRIORITY_AUTOROUND)
class AutoRoundConfig(BaseConfig):
class AutoRoundConfig(TorchBaseConfig):
"""Config class for AUTOROUND.
AUTOROUND: Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs.
Expand Down Expand Up @@ -827,7 +844,7 @@ def get_default_AutoRound_config(processor_type: Optional[Union[str, torch_utils

######################## MX Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=MX_QUANT)
class MXQuantConfig(BaseConfig):
class MXQuantConfig(TorchBaseConfig):
"""Config class for MX quantization."""

supported_configs: List[OperatorConfig] = []
Expand Down Expand Up @@ -940,7 +957,7 @@ def get_default_mx_config() -> MXQuantConfig:

######################## Dynamic Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=PT2E_DYNAMIC_QUANT)
class DynamicQuantConfig(BaseConfig):
class DynamicQuantConfig(TorchBaseConfig):
"""Config class for dynamic quantization."""

name = PT2E_DYNAMIC_QUANT
Expand Down Expand Up @@ -1014,7 +1031,7 @@ def get_default_dynamic_config() -> DynamicQuantConfig:

######################## Static Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT)
class StaticQuantConfig(BaseConfig):
class StaticQuantConfig(TorchBaseConfig):
"""Config class for static quantization."""

name = STATIC_QUANT
Expand Down Expand Up @@ -1103,7 +1120,7 @@ def get_default_static_config() -> StaticQuantConfig:

######################## Smooth Quant Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=SMOOTH_QUANT)
class SmoothQuantConfig(BaseConfig):
class SmoothQuantConfig(TorchBaseConfig):
"""Config class for smooth quantization."""

name = SMOOTH_QUANT
Expand Down Expand Up @@ -1217,7 +1234,7 @@ def get_default_sq_config() -> SmoothQuantConfig:

######################## HQQ Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=HQQ, priority=PRIORITY_HQQ)
class HQQConfig(BaseConfig):
class HQQConfig(TorchBaseConfig):
# Half-Quadratic Quantization (HQQ), more details:
# Blog: https://mobiusml.github.io/hqq_blog/
# Code: https://github.com/mobiusml/hqq
Expand Down Expand Up @@ -1298,7 +1315,7 @@ def get_default_hqq_config() -> HQQConfig:

######################## FP8 Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=FP8_QUANT)
class FP8Config(BaseConfig):
class FP8Config(TorchBaseConfig):
"""Config class for FP8 quantization."""

name = FP8_QUANT
Expand Down Expand Up @@ -1393,7 +1410,7 @@ def get_default_fp8_config_set() -> FP8Config:

######################## MixPrecision Config ###############################
@register_config(framework_name=FRAMEWORK_NAME, algo_name=MIX_PRECISION)
class MixPrecisionConfig(BaseConfig):
class MixPrecisionConfig(TorchBaseConfig):
"""Config class for mix-precision."""

name = MIX_PRECISION
Expand Down
2 changes: 1 addition & 1 deletion test/3x/torch/quantization/test_static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_static_quant_fallback(self):
quant_config = get_default_static_config()
example_inputs = self.input
# fallback by op_type
quant_config.set_local(torch.nn.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
quant_config.set_local([torch.nn.Linear, "Linear&add"], StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
run_fn(prepared_model)
q_model = convert(prepared_model)
Expand Down

0 comments on commit be42d03

Please sign in to comment.