diff --git a/deepspeed/compression/inference/config.py b/deepspeed/compression/inference/config.py deleted file mode 100644 index ab2dd4f3e4e2..000000000000 --- a/deepspeed/compression/inference/config.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -from typing import Dict - -_WEIGNT_QUANTIZATION_ = 'weight_quantization' -_QUANTIZED_INITIALIZATION_ = 'quantized_initialization' -_POST_INIT_QUIANT_ = 'post_init_quantization' - - -class WeightQuantizationConfig: - - def __init__(self, param_dict: Dict) -> None: - super(WeightQuantizationConfig, self).__init__() - self.quantized_initialization = None - self.post_init_quant = None - - if _WEIGNT_QUANTIZATION_ in param_dict: - weight_quantization_config = param_dict[_WEIGNT_QUANTIZATION_] - - assert not (_QUANTIZED_INITIALIZATION_ in weight_quantization_config and _POST_INIT_QUIANT_ - in weight_quantization_config), 'Must choose only one quantization flavor.' - - if _QUANTIZED_INITIALIZATION_ in weight_quantization_config: - self.quantized_initialization = weight_quantization_config[_QUANTIZED_INITIALIZATION_] - if _POST_INIT_QUIANT_ in weight_quantization_config: - self.post_init_quant = weight_quantization_config[_POST_INIT_QUIANT_] diff --git a/deepspeed/inference/__init__.py b/deepspeed/inference/__init__.py index 0fc748f4e167..208299fb8c50 100644 --- a/deepspeed/inference/__init__.py +++ b/deepspeed/inference/__init__.py @@ -2,5 +2,3 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team - -from .engine import InferenceEngine diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index 70a67c062ad2..2c3b0f4ebb62 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -101,6 +101,8 @@ class BaseQuantConfig(DeepSpeedConfigModel): class WeightQuantConfig(BaseQuantConfig): enabled = True + quantized_initialization: Dict = {} + post_init_quant: Dict = {} class ActivationQuantConfig(BaseQuantConfig): diff --git a/deepspeed/compression/inference/__init__.py b/deepspeed/inference/quantization/__init__.py similarity index 100% rename from deepspeed/compression/inference/__init__.py rename to deepspeed/inference/quantization/__init__.py diff --git a/deepspeed/compression/inference/layers.py b/deepspeed/inference/quantization/layers.py similarity index 100% rename from deepspeed/compression/inference/layers.py rename to deepspeed/inference/quantization/layers.py diff --git a/deepspeed/compression/inference/quantization.py b/deepspeed/inference/quantization/quantization.py similarity index 95% rename from deepspeed/compression/inference/quantization.py rename to deepspeed/inference/quantization/quantization.py index 07a0e4292559..9ae39e8d5688 100644 --- a/deepspeed/compression/inference/quantization.py +++ b/deepspeed/inference/quantization/quantization.py @@ -7,10 +7,9 @@ from torch import nn from typing import Dict import gc -from deepspeed.compression.inference import layers +from deepspeed.inference.quantization import layers from .layers import QUANTIZATION_LAYER_MAPPINGS -from .utils import get_AsyncPartitionedParameterSwapper -from ..helper import recursive_setattr +from .utils import get_AsyncPartitionedParameterSwapper, recursive_setattr from deepspeed.utils.logging import logger from collections import deque from transformers.utils.generic import ContextManagers @@ -35,7 +34,7 @@ def _init_group_wise_weight_quantization(model: nn.Module, ds_config: Dict) -> n matched_module_count = 0 assert 'weight_quantization' in ds_config, 'Please provide quantization config in ds_config' - quantization_config = ds_config['weight_quantization'] + quantization_config = ds_config['weight_quantization']['post_init_quant'] # Return nvme swapper if exists, else return None. # For nvme offloading we must use the same swapper here as model initialized. diff --git a/deepspeed/compression/inference/quantization_context.py b/deepspeed/inference/quantization/quantization_context.py similarity index 100% rename from deepspeed/compression/inference/quantization_context.py rename to deepspeed/inference/quantization/quantization_context.py diff --git a/deepspeed/compression/inference/utils.py b/deepspeed/inference/quantization/utils.py similarity index 95% rename from deepspeed/compression/inference/utils.py rename to deepspeed/inference/quantization/utils.py index 989d22e3afd4..5d8a73d1b8cc 100644 --- a/deepspeed/compression/inference/utils.py +++ b/deepspeed/inference/quantization/utils.py @@ -192,6 +192,24 @@ def get_AsyncPartitionedParameterSwapper(model: nn.Module): return None +def recursive_setattr(model, module_name, module): + """ + Recursively set the attribute of a module. + Args: + model (`torch.nn.Module`) + The model to set the attribute in. + module_name (`str`) + The name of the module to set the attribute in. + module (`torch.nn.Module`) + The module to set the attribute to. + """ + split_list = module_name.split('.') + output = model + for name in split_list[:-1]: + output = getattr(output, name) + output.__setattr__(split_list[-1], module) + + def concat_to_compat_param(quantized_weight: Tensor, quant_scale: Tensor, quant_min: Tensor, diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index ee589427b665..cef63a76c8ca 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -29,7 +29,7 @@ from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig from ..comm.config import DeepSpeedCommsConfig from ..monitor.config import get_monitor_config -from ..compression.inference.config import WeightQuantizationConfig +from ..inference.config import WeightQuantConfig from deepspeed import comm as dist from deepspeed.runtime.config_utils import DeepSpeedConfigModel @@ -871,7 +871,8 @@ def _initialize_params(self, param_dict): self.nebula_config = DeepSpeedNebulaConfig(param_dict) - self.weight_quantization_config = WeightQuantizationConfig(param_dict) + self.weight_quantization_config = WeightQuantConfig( + **param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None def _batch_assertion(self): diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index a1daea25d9c8..be90b0ba775e 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -31,7 +31,7 @@ debug_param2name_id, debug_param2name_id_shape_status) from deepspeed.accelerator import get_accelerator from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus -from deepspeed.compression.inference.utils import _quantize_param, WEIGHT_QUANTIZATION_LAYERS, wrap_quantized_functional, wrap_load_from_state_dict +from deepspeed.inference.quantization.utils import _quantize_param, WEIGHT_QUANTIZATION_LAYERS, wrap_quantized_functional, wrap_load_from_state_dict param_count = 0 partitioned_param_data_shape = [0] @@ -295,7 +295,7 @@ def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtyp self.wrapped_cls = set() self.quantized_initialization = None - if ds_config is not None and ds_config.weight_quantization_config.quantized_initialization: + if ds_config is not None and ds_config.weight_quantization_config and ds_config.weight_quantization_config.quantized_initialization: self.quantized_initialization = ds_config.weight_quantization_config.quantized_initialization def __enter__(self): diff --git a/tests/unit/compression/inference/test_int4_quantization.py b/tests/unit/inference/quantization/test_int4_quantization.py similarity index 90% rename from tests/unit/compression/inference/test_int4_quantization.py rename to tests/unit/inference/quantization/test_int4_quantization.py index 598d4632b21d..769f44928603 100644 --- a/tests/unit/compression/inference/test_int4_quantization.py +++ b/tests/unit/inference/quantization/test_int4_quantization.py @@ -7,9 +7,9 @@ import torch import torch.nn as nn from deepspeed.accelerator import get_accelerator -from deepspeed.compression.inference.quantization import _init_group_wise_weight_quantization -from deepspeed.compression.inference.utils import Quantizer, DeQuantizer -from deepspeed.compression.inference.layers import QuantizedLinear +from deepspeed.inference.quantization.quantization import _init_group_wise_weight_quantization +from deepspeed.inference.quantization.utils import Quantizer, DeQuantizer +from deepspeed.inference.quantization.layers import QuantizedLinear from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers import AutoConfig, OPTConfig, AutoModel import pytest @@ -258,35 +258,37 @@ def test_model_quantization(): ds_config = { 'weight_quantization': { - 'fc': { - 'num_bits': bits, - 'group_size': 64, - 'group_dim': 0, - 'symmetric': False - }, - 'self_attn.q_proj': { - 'num_bits': bits, - 'group_size': 64, - 'group_dim': 0, - 'symmetric': False - }, - 'self_attn.k_proj': { - 'num_bits': bits, - 'group_size': 64, - 'group_dim': 0, - 'symmetric': False - }, - 'self_attn.v_proj': { - 'num_bits': bits, - 'group_size': 64, - 'group_dim': 0, - 'symmetric': False - }, - 'self_attn.out_proj': { - 'num_bits': bits, - 'group_size': 64, - 'group_dim': 0, - 'symmetric': False + 'post_init_quant': { + 'fc': { + 'num_bits': bits, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + }, + 'self_attn.q_proj': { + 'num_bits': bits, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + }, + 'self_attn.k_proj': { + 'num_bits': bits, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + }, + 'self_attn.v_proj': { + 'num_bits': bits, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + }, + 'self_attn.out_proj': { + 'num_bits': bits, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + } } } } @@ -321,11 +323,13 @@ def test_quantized_linear(): ds_config = { 'weight_quantization': { - 'layer': { - 'num_bits': 4, - 'group_size': 64, - 'group_dim': 0, - 'symmetric': False + 'post_init_quant': { + 'layer': { + 'num_bits': 4, + 'group_size': 64, + 'group_dim': 0, + 'symmetric': False + } } } }