Skip to content

Commit

Permalink
GPTQ export w/a (huggingface#451)
Browse files Browse the repository at this point in the history
* Patch the modules in order to export GPTQ models on CPU

* Style
  • Loading branch information
AlexKoff88 authored Oct 18, 2023
1 parent 65b20f5 commit 8273e7f
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,34 @@ def _from_transformers(
if use_cache:
task = task + "-with-past"

# Patch the modules to export of GPTQ models w/o GPU
do_gptq_patching = False
config_dict = config.to_dict()
quantization_config = config_dict.get("quantization_config", None)
do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq"
if do_gptq_patching:
torch.set_default_dtype(torch.float32)
orig_cuda_check = torch.cuda.is_available
torch.cuda.is_available = lambda: True

from optimum.gptq import GPTQQuantizer

orig_post_init_model = GPTQQuantizer.post_init_model

def post_init_model(self, model):
from auto_gptq import exllama_set_max_input_length

class StoreAttr(object):
pass

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
if self.desc_act and not self.disable_exllama and self.max_input_length is not None:
model = exllama_set_max_input_length(model, self.max_input_length)
return model

GPTQQuantizer.post_init_model = post_init_model

main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand All @@ -238,10 +266,14 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
model_kwargs=kwargs,
int8=load_in_8bit,
)

# Unpatch modules after GPTQ export
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
GPTQQuantizer.post_init_model = orig_post_init_model

config.is_decoder = True
config.is_encoder_decoder = False
config.save_pretrained(save_dir_path)
Expand Down

0 comments on commit 8273e7f

Please sign in to comment.