Skip to content

Commit

Permalink
add offload_meta
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Feb 20, 2024
1 parent 96ce17d commit 5bc48cb
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 28 deletions.
2 changes: 2 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ The quantization parameters are set as follows:
- ```quant_zero``` (bool): if True, it quantizes the zero-point to 8-bit without grouping.
- ```quant_scale``` (bool): if True, it quantizes the scaling factor to 8-bit with a group_size of 128.

Additionally, you can set ```offload_meta=True``` to offload the meta-data to the CPU. This dramatically decreases the GPU memory requirements but makes processing slightly slower for smaller group-sizes. With ```offload_meta=True```, you can run Llama2-70B and Mixtral with HQQ 2-bit using only 18.8GB and 13GB VRAM respectively!

You can try to change the backend which could speed-up the runtime:
```Python
HQQLinear.set_backend(HQQBackend.PYTORCH) #Pytorch backend (default)
Expand Down
2 changes: 1 addition & 1 deletion hqq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.1.3"
__version__ = "0.1.3.post1"
__author__ = 'Dr. Hicham Badri'
__credits__ = 'Mobius Labs GmbH'
101 changes: 82 additions & 19 deletions hqq/core/quantize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023
#####################################################
import torch
import torch, copy
import numpy as np

from .utils import *
Expand Down Expand Up @@ -98,7 +98,8 @@ def dequantize(cls, W_q, meta):
@classmethod
def to_inplace(cls, W_q, meta, device):
compute_dtype = meta['compute_dtype'] if ('compute_dtype' in meta) else torch.float16
W_q = W_q.to(device).contiguous()
if(W_q is not None):
W_q = W_q.to(device).contiguous()
for key in meta:
if(type(meta[key])==torch.Tensor):
meta[key] = (meta[key].to(compute_dtype) if torch.is_floating_point(meta[key]) else meta[key]).to(device).contiguous()
Expand All @@ -107,7 +108,10 @@ def to_inplace(cls, W_q, meta, device):
@classmethod
def to_ooplace(cls, W_q, meta, device):
compute_dtype = meta['compute_dtype'] if ('compute_dtype' in meta) else torch.float16
W_q_c = W_q.to(device).contiguous()
if(W_q is not None):
W_q_c = W_q.to(device).contiguous()
else:
W_q_c = None
meta_c = {}
for key in meta:
if(type(meta[key])==torch.Tensor):
Expand Down Expand Up @@ -250,12 +254,14 @@ def __init__(self, linear_layer, quant_config, del_orig=True, compute_dtype=torc
self.bias = None
self.device_n = device_n
self.compute_dtype = compute_dtype
self.quant_config = quant_config
self.quant_config = copy.deepcopy(quant_config)
self.offload_meta = self.quant_config.pop('offload_meta') if (self.quant_config is not None) else None

self.set_backend(HQQLinear.backend) #Default backend

if(linear_layer is not None):
self.bias = None if (linear_layer.bias==None) else linear_layer.bias.to(self.compute_dtype).cuda()
self.quantize(linear_layer.weight.data, **quant_config)
self.quantize(linear_layer.weight.data, **self.quant_config)

if(del_orig): del linear_layer
torch.cuda.empty_cache()
Expand All @@ -266,22 +272,38 @@ def set_backend(cls, backend: HQQBackend):
HQQLinear.backend = backend
cls.forward = getattr(cls, backend.value)

#TODO: rewrite this mess
def cuda(self, device_n=0):
if(self.in_gpu): return
self.meta['compute_dtype'] = self.compute_dtype

self.W_q, self.meta = Quantizer.cuda(self.W_q, self.meta, device_n)
if(self.meta['quant_scale']):
self.meta['scale_q'] , self.meta['meta_scale'] = Quantizer.cuda(self.meta['scale_q'], self.meta['meta_scale'], device_n)

if(self.meta['quant_zero']):
self.meta['zero_q'] , self.meta['meta_zero'] = Quantizer.cuda(self.meta['zero_q'], self.meta['meta_zero'], device_n)

if(self.meta['quant_scale']):
self.meta['scale_q'] , self.meta['meta_scale'] = Quantizer.cuda(self.meta['scale_q'], self.meta['meta_scale'], device_n)

if(self.offload_meta):
if(self.meta['quant_scale'] and self.meta['quant_zero']):
self.meta['zero_scale'] = torch.stack((self.meta['zero_q'], self.meta['scale_q']))
del self.meta['scale_q'], self.meta['zero_q']
else:
self.meta['zero_scale'] = torch.stack((self.meta['zero'], self.meta['scale'])).to(self.compute_dtype)
del self.meta['scale'], self.meta['zero']

self.meta['zero_scale'] = self.meta['zero_scale'].contiguous().cpu()

if(self.bias is not None):
self.bias = self.bias.to(self.compute_dtype).cuda(device_n)

self.W_q = torch.nn.Parameter(self.W_q, requires_grad=False)
self.device = self.W_q.device
self.in_gpu = True

torch.cuda.empty_cache()

def to(self, *args, **kwargs):
pass

Expand All @@ -297,15 +319,18 @@ def load_state_dict(self, state_dict):
self.bias = state_dict['bias'] if ('bias' in state_dict) else None
self.in_gpu = self.W_q.device.type == 'cuda'
if(self.in_gpu):
if('scale' in self.meta):
self.meta['scale'] = self.meta['scale'].to(self.compute_dtype)
if('zero' in self.meta):
self.meta['zero'] = self.meta['zero'].to(self.compute_dtype)

if('scale' in self.meta):
self.meta['scale'] = self.meta['scale'].to(self.compute_dtype)

if(('zero_scale' in self.meta) and (self.meta['quant_scale']==False) and (self.meta['zero_scale']==False)):
self.meta['zero_scale'] = self.meta['zero_scale'].to(self.compute_dtype)
else:
self.cuda(self.device_n)
self.ready = True

#@torch.inference_mode()
def quantize(self, W, weight_quant_params, scale_quant_params, zero_quant_params):
quant_scale = scale_quant_params is not None
quant_zero = zero_quant_params is not None
Expand All @@ -315,13 +340,15 @@ def quantize(self, W, weight_quant_params, scale_quant_params, zero_quant_params
#Quantize
W_q , meta = Quantizer.quantize(W, **weight_quant_params)
meta.update({'quant_scale':quant_scale, 'quant_zero':quant_zero})
if(meta['quant_scale']):
meta['scale_q'] , meta['meta_scale'] = Quantizer.quantize(meta['scale'], **scale_quant_params); del meta['scale']
meta['meta_scale']['compute_dtype'] = self.compute_dtype

if(meta['quant_zero']):
meta['zero_q'], meta['meta_zero'] = Quantizer.quantize(meta['zero'], **zero_quant_params); del meta['zero']
meta['meta_zero']['compute_dtype'] = self.compute_dtype

if(meta['quant_scale']):
meta['scale_q'] , meta['meta_scale'] = Quantizer.quantize(meta['scale'], **scale_quant_params); del meta['scale']
meta['meta_scale']['compute_dtype'] = self.compute_dtype

self.W_q = W_q
self.meta = meta
self.cuda(self.device_n)
Expand All @@ -332,11 +359,23 @@ def dequantize(self):
W_q, meta = self.W_q, self.meta

del_keys = []
if(meta['quant_scale']):
meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale')

if('zero_scale' in meta):
zero_scale = meta['zero_scale'].to(self.W_q.device)

if(zero_scale.dtype==torch.uint8):
meta['zero_q'] = zero_scale[0]; del_keys.append('zero_q');
meta['scale_q'] = zero_scale[1]; del_keys.append('scale_q');
else:
meta['zero'] = zero_scale[0]; del_keys.append('zero');
meta['scale'] = zero_scale[1]; del_keys.append('scale');

if(meta['quant_zero']):
meta['zero'] = Quantizer.dequantize(meta['zero_q'], meta['meta_zero']); del_keys.append('zero')

if(meta['quant_scale']):
meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale')

W_est = Quantizer.dequantize(W_q, meta)

#Cleanup
Expand Down Expand Up @@ -381,11 +420,23 @@ def dequantize_aten(self):
W_q, meta = self.W_q, self.meta

del_keys = []

if('zero_scale' in meta):
zero_scale = meta['zero_scale'].to(self.W_q.device)

if(zero_scale.dtype==torch.uint8):
meta['zero_q'] = zero_scale[0]; del_keys.append('zero_q');
meta['scale_q'] = zero_scale[1]; del_keys.append('scale_q');
else:
meta['zero'] = zero_scale[0]; del_keys.append('zero');
meta['scale'] = zero_scale[1]; del_keys.append('scale');

if(meta['quant_scale']):
if(meta['meta_scale']['group_size']):
meta['scale'] = self.dequantize_Wq_aten(meta['scale_q'], meta['meta_scale']); del_keys.append('scale')
else:
meta['scale'] = Quantizer.dequantize(meta['scale_q'], meta['meta_scale']); del_keys.append('scale')

if(meta['quant_zero']):
if(meta['meta_zero']['group_size']):
meta['zero'] = self.dequantize_Wq_aten(meta['zero_q'], meta['meta_zero']); del_keys.append('zero')
Expand Down Expand Up @@ -450,14 +501,26 @@ def forward_aten_backprop(self, x):
# return hqq_aten.forward_with_quant(*args)


def hqq_base_quant_config(nbits=4, group_size=64, quant_zero=True, quant_scale=False):
def hqq_base_quant_config(nbits=4, group_size=64, quant_zero=True, quant_scale=False, offload_meta=False):
assert nbits in Quantizer.SUPPORTED_BITS, "nbits value not supported. Check Quantizer.SUPPORTED_BITS."
if(group_size is not None):
assert is_divisible(group_size, 8), "Invalid group_size param: the value should be a multiple of 8."
weight_quant_params = {'nbits':nbits,'channel_wise':True, 'group_size':group_size, 'optimize':True, 'round_zero':True if nbits==4 else False}
scale_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_scale) else None
zero_quant_params = {'nbits':8, 'channel_wise':False, 'group_size':None, 'optimize':False} if (quant_zero) else None
return {'weight_quant_params':weight_quant_params, 'scale_quant_params':scale_quant_params, 'zero_quant_params':zero_quant_params}

if(offload_meta):
if((quant_scale!=quant_zero)):
print(colored("quant_zero and quant_scale must be the same when offload_meta is set to True. Setting quant_scale=quant_zero." , 'yellow'))
quant_scale = quant_zero

scale_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_scale) else None
zero_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_zero) else None

else:
scale_quant_params = {'nbits':8, 'channel_wise':True, 'group_size':128, 'optimize':False} if (quant_scale) else None
zero_quant_params = {'nbits':8, 'channel_wise':False, 'group_size':None, 'optimize':False} if (quant_zero) else None


return {'weight_quant_params':weight_quant_params, 'scale_quant_params':scale_quant_params, 'zero_quant_params':zero_quant_params, 'offload_meta':offload_meta}

#Alias: follow similar Auto-GPTQ naming
BaseQuantizeConfig = hqq_base_quant_config
5 changes: 3 additions & 2 deletions hqq/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import Dict
from ..models.base import BaseHQQModel
import torch

#Wrapper that makes it easier to add quantization support to different engines (HF, VLLM, etc.)

Expand Down Expand Up @@ -47,7 +48,7 @@ def _set_quantized(cls, model, quantized):

#####################################################
@classmethod
def quantize_model_(cls, model, quant_config, compute_dtype):
def quantize_model_(cls, model, quant_config, compute_dtype=torch.float16):
if(cls._is_quantizable(model)==False):
cls._make_quantizable(model, quantized=False)
cls._check_arch_support(model)
Expand All @@ -61,7 +62,7 @@ def save_quantized_(cls, model, save_dir):
cls._get_hqq_class(model).save_quantized(model, save_dir=save_dir)

@classmethod
def from_quantized(cls, save_dir_or_hub, compute_dtype, cache_dir=''):
def from_quantized(cls, save_dir_or_hub, compute_dtype=torch.float16, cache_dir=''):
#Both local and hub-support
save_dir = BaseHQQModel.try_snapshot_download(save_dir_or_hub)
arch_key = cls._get_arch_key_from_save_dir(save_dir)
Expand Down
4 changes: 2 additions & 2 deletions hqq/engine/hf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import transformers, json
import transformers, json, torch
from typing import Dict

_HQQ_REGISTRY = {}
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(self, *args, **kwargs):
def _make_quantizable(cls, model, quantized):
model.hqq_quantized = quantized
model.arch_key = model.config.architectures[0]
model.quantize_model = lambda quant_config, compute_dtype: cls.quantize_model_(model=model, quant_config=quant_config, compute_dtype=compute_dtype)
model.quantize_model = lambda quant_config, compute_dtype=torch.float16: cls.quantize_model_(model=model, quant_config=quant_config, compute_dtype=compute_dtype)
model.save_quantized = lambda save_dir: cls.save_quantized_(model=model, save_dir=save_dir)
model.cuda = lambda *args, **kwargs: model if(quantized) else model.cuda
model.to = lambda *args, **kwargs: model if(quantized) else model.to
Expand Down
4 changes: 2 additions & 2 deletions hqq/engine/timm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import timm, json
import timm, json, torch
from typing import Dict
from ..models.base import BaseHQQModel
from ..models.timm.vit_clip import ViTCLIPHQQ
Expand Down Expand Up @@ -26,7 +26,7 @@ def __init__(self, *args, **kwargs):
def _make_quantizable(cls, model, quantized):
model.hqq_quantized = quantized
model.arch_key = model.default_cfg['architecture']
model.quantize_model = lambda quant_config, compute_dtype: cls.quantize_model_(model=model, quant_config=quant_config, compute_dtype=compute_dtype)
model.quantize_model = lambda quant_config, compute_dtype=torch.float16: cls.quantize_model_(model=model, quant_config=quant_config, compute_dtype=compute_dtype)
model.save_quantized = lambda save_dir: cls.save_quantized_(model=model, save_dir=save_dir)
model.cuda = lambda *args, **kwargs: model if(quantized) else model.cuda
model.to = lambda *args, **kwargs: model if(quantized) else model.to
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

setup(
name='hqq',
version='0.1.3',
version='0.1.3.post1',
description='Half-Quadratic Quantization (HQQ)',
url='https://github.com/mobiusml/hqq/',
author='Dr. Hicham Badri',
author_email='[email protected]',
license='Apache 2',
packages=['hqq', 'hqq/core', 'hqq/engine', 'hqq/models', 'hqq/models/hf', 'hqq/models/timm', 'hqq/models/vllm'],
install_requires=['numpy>=1.24.4','tqdm>=4.64.1', 'huggingface_hub', 'accelerate', 'timm', 'transformers>=4.36.1', 'termcolor'], #'torch>=2.1.1', add vllm/langchain?
install_requires=['numpy>=1.24.4','tqdm>=4.64.1', 'huggingface_hub', 'accelerate', 'timm', 'transformers>=4.36.1', 'termcolor'],
)

0 comments on commit 5bc48cb

Please sign in to comment.