Skip to content

Commit

Permalink
Refactor: move int4 code to deepspeed/inference (#528)
Browse files Browse the repository at this point in the history
* Move int 4 code to deepspeed/inference

* fix

* fix

* fix
  • Loading branch information
donglinz authored Jun 5, 2023
1 parent 2461449 commit 8751edf
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 76 deletions.
29 changes: 0 additions & 29 deletions deepspeed/compression/inference/config.py

This file was deleted.

2 changes: 0 additions & 2 deletions deepspeed/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,3 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .engine import InferenceEngine
2 changes: 2 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class BaseQuantConfig(DeepSpeedConfigModel):

class WeightQuantConfig(BaseQuantConfig):
enabled = True
quantized_initialization: Dict = {}
post_init_quant: Dict = {}


class ActivationQuantConfig(BaseQuantConfig):
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from torch import nn
from typing import Dict
import gc
from deepspeed.compression.inference import layers
from deepspeed.inference.quantization import layers
from .layers import QUANTIZATION_LAYER_MAPPINGS
from .utils import get_AsyncPartitionedParameterSwapper
from ..helper import recursive_setattr
from .utils import get_AsyncPartitionedParameterSwapper, recursive_setattr
from deepspeed.utils.logging import logger
from collections import deque
from transformers.utils.generic import ContextManagers
Expand All @@ -35,7 +34,7 @@ def _init_group_wise_weight_quantization(model: nn.Module, ds_config: Dict) -> n
matched_module_count = 0

assert 'weight_quantization' in ds_config, 'Please provide quantization config in ds_config'
quantization_config = ds_config['weight_quantization']
quantization_config = ds_config['weight_quantization']['post_init_quant']

# Return nvme swapper if exists, else return None.
# For nvme offloading we must use the same swapper here as model initialized.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,24 @@ def get_AsyncPartitionedParameterSwapper(model: nn.Module):
return None


def recursive_setattr(model, module_name, module):
"""
Recursively set the attribute of a module.
Args:
model (`torch.nn.Module`)
The model to set the attribute in.
module_name (`str`)
The name of the module to set the attribute in.
module (`torch.nn.Module`)
The module to set the attribute to.
"""
split_list = module_name.split('.')
output = model
for name in split_list[:-1]:
output = getattr(output, name)
output.__setattr__(split_list[-1], module)


def concat_to_compat_param(quantized_weight: Tensor,
quant_scale: Tensor,
quant_min: Tensor,
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from ..compression.inference.config import WeightQuantizationConfig
from ..inference.config import WeightQuantConfig

from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
Expand Down Expand Up @@ -871,7 +871,8 @@ def _initialize_params(self, param_dict):

self.nebula_config = DeepSpeedNebulaConfig(param_dict)

self.weight_quantization_config = WeightQuantizationConfig(param_dict)
self.weight_quantization_config = WeightQuantConfig(
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None

def _batch_assertion(self):

Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
debug_param2name_id, debug_param2name_id_shape_status)
from deepspeed.accelerator import get_accelerator
from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
from deepspeed.compression.inference.utils import _quantize_param, WEIGHT_QUANTIZATION_LAYERS, wrap_quantized_functional, wrap_load_from_state_dict
from deepspeed.inference.quantization.utils import _quantize_param, WEIGHT_QUANTIZATION_LAYERS, wrap_quantized_functional, wrap_load_from_state_dict

param_count = 0
partitioned_param_data_shape = [0]
Expand Down Expand Up @@ -295,7 +295,7 @@ def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtyp
self.wrapped_cls = set()

self.quantized_initialization = None
if ds_config is not None and ds_config.weight_quantization_config.quantized_initialization:
if ds_config is not None and ds_config.weight_quantization_config and ds_config.weight_quantization_config.quantized_initialization:
self.quantized_initialization = ds_config.weight_quantization_config.quantized_initialization

def __enter__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch
import torch.nn as nn
from deepspeed.accelerator import get_accelerator
from deepspeed.compression.inference.quantization import _init_group_wise_weight_quantization
from deepspeed.compression.inference.utils import Quantizer, DeQuantizer
from deepspeed.compression.inference.layers import QuantizedLinear
from deepspeed.inference.quantization.quantization import _init_group_wise_weight_quantization
from deepspeed.inference.quantization.utils import Quantizer, DeQuantizer
from deepspeed.inference.quantization.layers import QuantizedLinear
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers import AutoConfig, OPTConfig, AutoModel
import pytest
Expand Down Expand Up @@ -258,35 +258,37 @@ def test_model_quantization():

ds_config = {
'weight_quantization': {
'fc': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
},
'self_attn.q_proj': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
},
'self_attn.k_proj': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
},
'self_attn.v_proj': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
},
'self_attn.out_proj': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
'post_init_quant': {
'fc': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
},
'self_attn.q_proj': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
},
'self_attn.k_proj': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
},
'self_attn.v_proj': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
},
'self_attn.out_proj': {
'num_bits': bits,
'group_size': 64,
'group_dim': 0,
'symmetric': False
}
}
}
}
Expand Down Expand Up @@ -321,11 +323,13 @@ def test_quantized_linear():

ds_config = {
'weight_quantization': {
'layer': {
'num_bits': 4,
'group_size': 64,
'group_dim': 0,
'symmetric': False
'post_init_quant': {
'layer': {
'num_bits': 4,
'group_size': 64,
'group_dim': 0,
'symmetric': False
}
}
}
}
Expand Down

0 comments on commit 8751edf

Please sign in to comment.