diff --git a/Readme.md b/Readme.md index 107d2d1..ff4ae90 100755 --- a/Readme.md +++ b/Readme.md @@ -37,6 +37,8 @@ The quantization parameters are set as follows: - ```quant_zero``` (bool): if True, it quantizes the zero-point to 8-bit without grouping. - ```quant_scale``` (bool): if True, it quantizes the scaling factor to 8-bit with a group_size of 128. +Additionally, you can set ```offload_meta=True``` to offload the meta-data to the CPU. This dramatically decreases the GPU memory requirements but makes processing slightly slower for smaller group-sizes. With ```offload_meta=True```, you can run Llama2-70B and Mixtral with HQQ 2-bit using only 18.8GB and 13GB VRAM respectively! + You can try to change the backend which could speed-up the runtime: ```Python HQQLinear.set_backend(HQQBackend.PYTORCH) #Pytorch backend (default) diff --git a/hqq/__init__.py b/hqq/__init__.py index 85c63a7..6664ce4 100755 --- a/hqq/__init__.py +++ b/hqq/__init__.py @@ -1,3 +1,3 @@ -__version__ = "0.1.3" +__version__ = "0.1.3.post1" __author__ = 'Dr. Hicham Badri' __credits__ = 'Mobius Labs GmbH' \ No newline at end of file diff --git a/hqq/core/quantize.py b/hqq/core/quantize.py index a68098b..be3dc46 100755 --- a/hqq/core/quantize.py +++ b/hqq/core/quantize.py @@ -1,6 +1,6 @@ #Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023 ##################################################### -import torch +import torch, copy import numpy as np from .utils import * @@ -98,7 +98,8 @@ def dequantize(cls, W_q, meta): @classmethod def to_inplace(cls, W_q, meta, device): compute_dtype = meta['compute_dtype'] if ('compute_dtype' in meta) else torch.float16 - W_q = W_q.to(device).contiguous() + if(W_q is not None): + W_q = W_q.to(device).contiguous() for key in meta: if(type(meta[key])==torch.Tensor): meta[key] = (meta[key].to(compute_dtype) if torch.is_floating_point(meta[key]) else meta[key]).to(device).contiguous() @@ -107,7 +108,10 @@ def to_inplace(cls, W_q, meta, device): @classmethod def to_ooplace(cls, W_q, meta, device): compute_dtype = meta['compute_dtype'] if ('compute_dtype' in meta) else torch.float16 - W_q_c = W_q.to(device).contiguous() + if(W_q is not None): + W_q_c = W_q.to(device).contiguous() + else: + W_q_c = None meta_c = {} for key in meta: if(type(meta[key])==torch.Tensor): @@ -250,12 +254,14 @@ def __init__(self, linear_layer, quant_config, del_orig=True, compute_dtype=torc self.bias = None self.device_n = device_n self.compute_dtype = compute_dtype - self.quant_config = quant_config + self.quant_config = copy.deepcopy(quant_config) + self.offload_meta = self.quant_config.pop('offload_meta') if (self.quant_config is not None) else None + self.set_backend(HQQLinear.backend) #Default backend if(linear_layer is not None): self.bias = None if (linear_layer.bias==None) else linear_layer.bias.to(self.compute_dtype).cuda() - self.quantize(linear_layer.weight.data, **quant_config) + self.quantize(linear_layer.weight.data, **self.quant_config) if(del_orig): del linear_layer torch.cuda.empty_cache() @@ -266,15 +272,29 @@ def set_backend(cls, backend: HQQBackend): HQQLinear.backend = backend cls.forward = getattr(cls, backend.value) + #TODO: rewrite this mess def cuda(self, device_n=0): if(self.in_gpu): return self.meta['compute_dtype'] = self.compute_dtype + self.W_q, self.meta = Quantizer.cuda(self.W_q, self.meta, device_n) - if(self.meta['quant_scale']): - self.meta['scale_q'] , self.meta['meta_scale'] = Quantizer.cuda(self.meta['scale_q'], self.meta['meta_scale'], device_n) + if(self.meta['quant_zero']): self.meta['zero_q'] , self.meta['meta_zero'] = Quantizer.cuda(self.meta['zero_q'], self.meta['meta_zero'], device_n) + if(self.meta['quant_scale']): + self.meta['scale_q'] , self.meta['meta_scale'] = Quantizer.cuda(self.meta['scale_q'], self.meta['meta_scale'], device_n) + + if(self.offload_meta): + if(self.meta['quant_scale'] and self.meta['quant_zero']): + self.meta['zero_scale'] = torch.stack((self.meta['zero_q'], self.meta['scale_q'])) + del self.meta['scale_q'], self.meta['zero_q'] + else: + self.meta['zero_scale'] = torch.stack((self.meta['zero'], self.meta['scale'])).to(self.compute_dtype) + del self.meta['scale'], self.meta['zero'] + + self.meta['zero_scale'] = self.meta['zero_scale'].contiguous().cpu() + if(self.bias is not None): self.bias = self.bias.to(self.compute_dtype).cuda(device_n) @@ -282,6 +302,8 @@ def cuda(self, device_n=0): self.device = self.W_q.device self.in_gpu = True + torch.cuda.empty_cache() + def to(self, *args, **kwargs): pass @@ -297,15 +319,18 @@ def load_state_dict(self, state_dict): self.bias = state_dict['bias'] if ('bias' in state_dict) else None self.in_gpu = self.W_q.device.type == 'cuda' if(self.in_gpu): - if('scale' in self.meta): - self.meta['scale'] = self.meta['scale'].to(self.compute_dtype) if('zero' in self.meta): self.meta['zero'] = self.meta['zero'].to(self.compute_dtype) + + if('scale' in self.meta): + self.meta['scale'] = self.meta['scale'].to(self.compute_dtype) + + if(('zero_scale' in self.meta) and (self.meta['quant_scale']==False) and (self.meta['zero_scale']==False)): + self.meta['zero_scale'] = self.meta['zero_scale'].to(self.compute_dtype) else: self.cuda(self.device_n) self.ready = True - #@torch.inference_mode() def quantize(self, W, weight_quant_params, scale_quant_params, zero_quant_params): quant_scale = scale_quant_params is not None quant_zero = zero_quant_params is not None @@ -315,13 +340,15 @@ def quantize(self, W, weight_quant_params, scale_quant_params, zero_quant_params #Quantize W_q , meta = Quantizer.quantize(W, **weight_quant_params) meta.update({'quant_scale':quant_scale, 'quant_zero':quant_zero}) - if(meta['quant_scale']): - meta['scale_q'] , meta['meta_scale'] = Quantizer.quantize(meta['scale'], **scale_quant_params); del meta['scale'] - meta['meta_scale']['compute_dtype'] = self.compute_dtype + if(meta['quant_zero']): meta['zero_q'], meta['meta_zero'] = Quantizer.quantize(meta['zero'], **zero_quant_params); del meta['zero'] meta['meta_zero']['compute_dtype'] = self.compute_dtype + if(meta['quant_scale']): + meta['scale_q'] , meta['meta_scale'] = Quantizer.quantize(meta['scale'], **scale_quant_params); del meta['scale'] + meta['meta_scale']['compute_dtype'] = self.compute_dtype + self.W_q = W_q self.meta = meta self.cuda(self.device_n) @@ -332,11 +359,23 @@ def dequantize(self): W_q, meta = self.W_q, self.meta del_keys = [] - if(meta['quant_scale']): - meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') + + if('zero_scale' in meta): + zero_scale = meta['zero_scale'].to(self.W_q.device) + + if(zero_scale.dtype==torch.uint8): + meta['zero_q'] = zero_scale[0]; del_keys.append('zero_q'); + meta['scale_q'] = zero_scale[1]; del_keys.append('scale_q'); + else: + meta['zero'] = zero_scale[0]; del_keys.append('zero'); + meta['scale'] = zero_scale[1]; del_keys.append('scale'); + if(meta['quant_zero']): meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero') + if(meta['quant_scale']): + meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') + W_est = Quantizer.dequantize(W_q, meta) #Cleanup @@ -381,11 +420,23 @@ def dequantize_aten(self): W_q, meta = self.W_q, self.meta del_keys = [] + + if('zero_scale' in meta): + zero_scale = meta['zero_scale'].to(self.W_q.device) + + if(zero_scale.dtype==torch.uint8): + meta['zero_q'] = zero_scale[0]; del_keys.append('zero_q'); + meta['scale_q'] = zero_scale[1]; del_keys.append('scale_q'); + else: + meta['zero'] = zero_scale[0]; del_keys.append('zero'); + meta['scale'] = zero_scale[1]; del_keys.append('scale'); + if(meta['quant_scale']): if(meta['meta_scale']['group_size']): meta['scale'] = self.dequantize_Wq_aten(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') else: meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale') + if(meta['quant_zero']): if(meta['meta_zero']['group_size']): meta['zero'] = self.dequantize_Wq_aten(meta['zero_q'], meta['meta_zero']); del_keys.append('zero') @@ -450,14 +501,26 @@ def forward_aten_backprop(self, x): # return hqq_aten.forward_with_quant(*args) -def hqq_base_quant_config(nbits=4, group_size=64, quant_zero=True, quant_scale=False): +def hqq_base_quant_config(nbits=4, group_size=64, quant_zero=True, quant_scale=False, offload_meta=False): assert nbits in Quantizer.SUPPORTED_BITS, "nbits value not supported. Check Quantizer.SUPPORTED_BITS." if(group_size is not None): assert is_divisible(group_size, 8), "Invalid group_size param: the value should be a multiple of 8." weight_quant_params = {'nbits':nbits,'channel_wise':True, 'group_size':group_size, 'optimize':True, 'round_zero':True if nbits==4 else False} - scale_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_scale) else None - zero_quant_params = {'nbits':8, 'channel_wise':False, 'group_size':None, 'optimize':False} if (quant_zero) else None - return {'weight_quant_params':weight_quant_params, 'scale_quant_params':scale_quant_params, 'zero_quant_params':zero_quant_params} + + if(offload_meta): + if((quant_scale!=quant_zero)): + print(colored("quant_zero and quant_scale must be the same when offload_meta is set to True. Setting quant_scale=quant_zero." , 'yellow')) + quant_scale = quant_zero + + scale_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_scale) else None + zero_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_zero) else None + + else: + scale_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_scale) else None + zero_quant_params = {'nbits':8, 'channel_wise':False, 'group_size':None, 'optimize':False} if (quant_zero) else None + + + return {'weight_quant_params':weight_quant_params, 'scale_quant_params':scale_quant_params, 'zero_quant_params':zero_quant_params, 'offload_meta':offload_meta} #Alias: follow similar Auto-GPTQ naming BaseQuantizeConfig = hqq_base_quant_config \ No newline at end of file diff --git a/hqq/engine/base.py b/hqq/engine/base.py index fa19f3a..079e757 100755 --- a/hqq/engine/base.py +++ b/hqq/engine/base.py @@ -1,6 +1,7 @@ from abc import abstractmethod from typing import Dict from ..models.base import BaseHQQModel +import torch #Wrapper that makes it easier to add quantization support to different engines (HF, VLLM, etc.) @@ -47,7 +48,7 @@ def _set_quantized(cls, model, quantized): ##################################################### @classmethod - def quantize_model_(cls, model, quant_config, compute_dtype): + def quantize_model_(cls, model, quant_config, compute_dtype=torch.float16): if(cls._is_quantizable(model)==False): cls._make_quantizable(model, quantized=False) cls._check_arch_support(model) @@ -61,7 +62,7 @@ def save_quantized_(cls, model, save_dir): cls._get_hqq_class(model).save_quantized(model, save_dir=save_dir) @classmethod - def from_quantized(cls, save_dir_or_hub, compute_dtype, cache_dir=''): + def from_quantized(cls, save_dir_or_hub, compute_dtype=torch.float16, cache_dir=''): #Both local and hub-support save_dir = BaseHQQModel.try_snapshot_download(save_dir_or_hub) arch_key = cls._get_arch_key_from_save_dir(save_dir) diff --git a/hqq/engine/hf.py b/hqq/engine/hf.py index 599f7af..ffc31cd 100755 --- a/hqq/engine/hf.py +++ b/hqq/engine/hf.py @@ -1,4 +1,4 @@ -import transformers, json +import transformers, json, torch from typing import Dict _HQQ_REGISTRY = {} @@ -34,7 +34,7 @@ def __init__(self, *args, **kwargs): def _make_quantizable(cls, model, quantized): model.hqq_quantized = quantized model.arch_key = model.config.architectures[0] - model.quantize_model = lambda quant_config, compute_dtype: cls.quantize_model_(model=model, quant_config=quant_config, compute_dtype=compute_dtype) + model.quantize_model = lambda quant_config, compute_dtype=torch.float16: cls.quantize_model_(model=model, quant_config=quant_config, compute_dtype=compute_dtype) model.save_quantized = lambda save_dir: cls.save_quantized_(model=model, save_dir=save_dir) model.cuda = lambda *args, **kwargs: model if(quantized) else model.cuda model.to = lambda *args, **kwargs: model if(quantized) else model.to diff --git a/hqq/engine/timm.py b/hqq/engine/timm.py index 82edf33..8c5db17 100755 --- a/hqq/engine/timm.py +++ b/hqq/engine/timm.py @@ -1,4 +1,4 @@ -import timm, json +import timm, json, torch from typing import Dict from ..models.base import BaseHQQModel from ..models.timm.vit_clip import ViTCLIPHQQ @@ -26,7 +26,7 @@ def __init__(self, *args, **kwargs): def _make_quantizable(cls, model, quantized): model.hqq_quantized = quantized model.arch_key = model.default_cfg['architecture'] - model.quantize_model = lambda quant_config, compute_dtype: cls.quantize_model_(model=model, quant_config=quant_config, compute_dtype=compute_dtype) + model.quantize_model = lambda quant_config, compute_dtype=torch.float16: cls.quantize_model_(model=model, quant_config=quant_config, compute_dtype=compute_dtype) model.save_quantized = lambda save_dir: cls.save_quantized_(model=model, save_dir=save_dir) model.cuda = lambda *args, **kwargs: model if(quantized) else model.cuda model.to = lambda *args, **kwargs: model if(quantized) else model.to diff --git a/setup.py b/setup.py index 632a49a..1789697 100755 --- a/setup.py +++ b/setup.py @@ -2,12 +2,12 @@ setup( name='hqq', - version='0.1.3', + version='0.1.3.post1', description='Half-Quadratic Quantization (HQQ)', url='https://github.com/mobiusml/hqq/', author='Dr. Hicham Badri', author_email='hicham@mobiuslabs.com', license='Apache 2', packages=['hqq', 'hqq/core', 'hqq/engine', 'hqq/models', 'hqq/models/hf', 'hqq/models/timm', 'hqq/models/vllm'], - install_requires=['numpy>=1.24.4','tqdm>=4.64.1', 'huggingface_hub', 'accelerate', 'timm', 'transformers>=4.36.1', 'termcolor'], #'torch>=2.1.1', add vllm/langchain? + install_requires=['numpy>=1.24.4','tqdm>=4.64.1', 'huggingface_hub', 'accelerate', 'timm', 'transformers>=4.36.1', 'termcolor'], )