-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add HQQ quantization support #29637
Add HQQ quantization support #29637
Conversation
Hi @mobicham, thanks for opening this PR! Is there an associated issue/ feature request for this? Is so, could you add to the PR description? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mobicham, thanks for working on this ! This looks very good already. I left a few comments. Let me know you have any questions !
if type(module) is not torch.nn.Linear: | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we do this in check_quantized_param
? If not, could you add a comment about the reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in check_quantized_param
if type(module) is not torch.nn.Linear: | |
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to keep this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mobicham for this huge work ! I had a look at the PR with @SunMarc and left some comments, the serialization / loading HQQ weights logic seems quite involved so far, maybe we could first go with a v1 with just on-the-fly quantization then I will need to refactor the HF Quantizers to be able to incorporate some of the new logic there, wdyt?
@mobicham Tried running this branch on my machine (merged into the 4.39.3 Transformers release branch), got this error:
Code I was running:
|
@mobicham Tried saving and loading a model on the
Adding a cache_dir argument to the from_pretrained() call fixed it, but ideally I think you should be able to omit that? |
@mobicham Tried using this to do an HQQ quantization of the new Mixtral-8x22B model on my desktop. It worked!, and then the model saved to disk fine. But when I tried to reload it from disk, this part appeared to work:
while this part caused it to OOM (even though it had successfully loaded before):
here's the stack trace:
the cause seems like it's likely a mismatch between the device map in the |
@mobicham Used a monkey-patch to solve the however, it now errors out when you try to do generation:
|
@mobicham Tried generating from a multi-device HQQ model that had been quantized live (vs. being quantized, saved and then re-loaded), got an interesting looking stack trace:
here's the code I used:
|
@rationalism I haven't tested it since the pull request, things have changed since then so it is likely to break with a merge. |
@mobicham Thank you very much! Very excited |
Updated PR. You can test it with this: https://gist.github.com/mobicham/cb07c1eff443ad0918c49ab7bb03e269 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this integration ! I left some minor refactoring comments, please take a look at them before we merge the PR, let me know if you need any help !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again ! Just left two tiny comments while doing a small review !
EDIT: False alarm
cc @amyeroberts this is ready for a final review 🙏 btw I don't know why a llama test is failing on main, seems unrelated to this PR though ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work iterating on this - looks great!
r""" | ||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. | ||
""" | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as the arguments are correct
That's what this method is supposed to check :)
Serializes this instance to a Python dictionary. Returns: | ||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. | ||
""" | ||
return self.quant_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see, HQQBaseQuantizeConfig
is a dict in the hqq library 👍
@younesbelkada Yes - llama tests are unrelated and because of an upstream commit. I think we're free to merge! |
@mobicham Fabulous work!!! |
Thank you @danielhanchen ! |
@mobicham My Code: import logging
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, BitsAndBytesConfig, HqqConfig
hqq_config = HqqConfig(
nbits=1,
group_size=64,
quant_zero=False,
quant_scale=False, axis=0) #axis=0 is used by default
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class SpeechToText:
"""Class for converting audio to text using a pre-trained speech recognition model."""
def __init__(self, model_id: str = "openai/whisper-large-v3", quant_config=None):
self.model = None
self.device = None
if self.model is None:
self.load_model(model_id)
else:
logging.info("Model already loaded.")
def load_model(self, model_id: str = "openai/whisper-large-v3", quant_config=None):
"""
Loads the pre-trained speech recognition model and moves it to the specified device.
Args:
model_id (str): Identifier of the pre-trained model to be loaded.
"""
logging.info("Loading model...")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=quant_config,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2",
device_map="auto")
logging.info("Model loaded successfully.")
processor = AutoProcessor.from_pretrained(model_id)
self.processor = processor
self.model = model
def __call__(
self,
chunk_length_s: int = 30,
stride_length_s: int = 5,
audio_path: str = "test.mp3",
max_new_tokens: int = 128,
batch_size: int = 100,
language: str = "turkish"):
"""
Converts audio to text using the pre-trained speech recognition model.
Args:
audio_path (str): Path to the audio file to be transcribed.
Returns:
str: Transcribed text from the audio.
"""
pipe = pipeline(
"automatic-speech-recognition",
model=self.model,
chunk_length_s=chunk_length_s,
stride_length_s=stride_length_s,
max_new_tokens=max_new_tokens,
batch_size=100,
device_map="auto",
return_timestamps=True,
tokenizer=self.processor.tokenizer,
feature_extractor=self.processor.feature_extractor,
model_kwargs={"use_flash_attention_2": True},
generate_kwargs={"language": language},
)
logging.info("Transcribing audio...")
result = pipe(audio_path)
return result
output = SpeechToText(model_id="distil-whisper/distil-large-v3", quant_config=hqq_config) # or bnb_config
transcript = output(
audio_path = "testv0.mp3",
chunk_length_s = 30,
stride_length_s = 5,
max_new_tokens = 128,
batch_size = 100,
language = "english",
) Error Message: File [/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:51](https://mu4lijajurse00-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py#line=50), in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
52 q,
53 k,
54 v,
55 None,
56 alibi_slopes,
57 dropout_p,
58 softmax_scale,
59 causal,
60 window_size[0],
61 window_size[1],
62 return_softmax,
63 None,
64 )
65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
RuntimeError: FlashAttention only support fp16 and bf16 data type Cli-env: - `transformers` version: 4.41.0.dev0
- Platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.23.0
- Safetensors version: 0.4.3
- Accelerate version: 0.29.3
- Accelerate config: not found
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>- |
It works when you set the parameter to "sdpa". attn_implementation="flash_attention_2" |
@kadirnar that's not related to hqq. Flash attention only works with fp16/bfp16 as the error says, try |
@mobicham hqq_config = HqqConfig(
nbits=1,
group_size=64,
quant_zero=False,
quant_scale=False, axis=0) #axis=0 is used by default
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
) Load-Model: def load_model(self, model_id: str = "openai/whisper-large-v3", quant_config=None):
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=quant_config,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map='auto',
max_memory={0: "24GiB"}
)
logging.info("Model loaded successfully.")
processor = AutoProcessor.from_pretrained(model_id)
self.processor = processor
self.model = model |
Well, that's already good :) ! For faster auto-regressive generation, you need: from hqq.utils.patching import prepare_for_inference
prepare_for_inference(model, backend="torchao_int4")
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) You need to change your quant settings to use I am not sure how torch.compile would work with Please refer to the repo for the documentation. |
Thank you very much for your help ❤️ I added Hqq optimization to the WhisperPlus library and now it runs 2 seconds faster. I will share today. Torch.compile didn't run. AttributeError: 'WhisperForConditionalGeneration' object has no attribute 'base_class'. Did you mean: '_auto_class'? |
@sanchit-gandhi , |
@kadirnar if you got that error it means it's not working properly. I will check on Monday, the whole pipeline was only tested on AutoModelForCausalLM. Please open an issue here: https://github.com/mobiusml/hqq/ |
@kadirnar Can you also try the 4-bit version with torchao_int4, since it will be much faster ? So you should get faster execution time but hopefully with much lower WER. |
I tried it with 4bit, speed and word count were the same. I don't know the accuracy of the words. |
I can recommend https://github.com/huggingface/open_asr_leaderboard/tree/main as a good potential evaluation. It is English only. But it might be a good relative comparison for the effect of bitrate.
… On 4. May 2024, at 19:48, Kadir Nar ***@***.***> wrote:
@kadirnar <https://github.com/kadirnar> Can you also try the 4-bit version with torchao_int4, since it will be much faster ? So you should get faster execution time but hopefully with much lower WER.
I tried it with 4bit, speed and word count were the same. I don't know the accuracy of the words.
—
Reply to this email directly, view it on GitHub <#29637 (comment)>, or unsubscribe <https://github.com/notifications/unsubscribe-auth/AAJL5SKUUMFHJNCPSU7W3MTZAUNNTAVCNFSM6AAAAABEUO2JOKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOJUGMZDINZXGA>.
You are receiving this because you commented.
|
I tested for 1 bit and 4 bit. Accuracy loss is very low. I will just create a detailed doc in the 0.66 WhisperPlus repo. bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
) Wer: 120.88 (hqq) https://github.com/kadirnar/whisper-plus/blob/main/benckmarks.md |
It is great to see similar performance to base. But 120% WER ( assuming that is what you used) is quite poor (since it translates to 1.2 errors per word it transcribes). Not sure if it is because of general quality of model or some mismatch in eval script. |
@kadirnar , this is my gist to convert whisper model into fully HQQ, https://gist.github.com/huseinzol05/70daae3a4557616f315e7744ba3fcc93, but seems the speed is not faster than flash attention 2 on 30 second examples, but simple matmul is faster, https://gist.github.com/huseinzol05/ff59996034604d17c1e53074e9adc03f, @mobicham any thought? |
@huseinzol05 @kadirnar taking a look at this, let's move this conversation to here please: mobiusml/hqq#68 |
* update HQQ transformers integration * push import_utils.py * add force_hooks check in modeling_utils.py * fix | with Optional * force bias as param * check bias is Tensor * force forward for multi-gpu * review fixes pass * remove torch grad() * if any key in linear_tags fix * add cpu/disk check * isinstance return * add multigpu test + refactor tests * clean hqq_utils imports in hqq.py * clean hqq_utils imports in quantizer_hqq.py * delete hqq_utils.py * Delete src/transformers/utils/hqq_utils.py * ruff init * remove torch.float16 from __init__ in test * refactor test * isinstance -> type in quantizer_hqq.py * cpu/disk device_map check in quantizer_hqq.py * remove type(module) nn.linear check in quantizer_hqq.py * add BaseQuantizeConfig import inside HqqConfig init * remove hqq import in hqq.py * remove accelerate import from test_hqq.py * quant config.py doc update * add hqqconfig to main_classes doc * make style * __init__ fix * ruff __init__ * skip_modules list * hqqconfig format fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * hqqconfig doc fix * test_hqq.py remove mistral comment * remove self.using_multi_gpu is False * torch_dtype default val set and logger.info * hqq.py isinstance fix * remove torch=None * torch_device test_hqq * rename test_hqq * MODEL_ID in test_hqq * quantizer_hqq setattr fix * quantizer_hqq typo fix * imports quantizer_hqq.py * isinstance quantizer_hqq * hqq_layer.bias reformat quantizer_hqq * Step 2 as comment in quantizer_hqq * prepare_for_hqq_linear() comment * keep_in_fp32_modules fix * HqqHfQuantizer reformat * quantization.md hqqconfig * quantization.md model example reformat * quantization.md # space * quantization.md space }) * quantization.md space }) * quantization_config fix doc Co-authored-by: amyeroberts <[email protected]> * axis value check in quantization_config * format * dynamic config explanation * quant config method in quantization.md * remove shard-level progress * .cuda fix modeling_utils * test_hqq fixes * make fix-copies --------- Co-authored-by: amyeroberts <[email protected]>
This PR is intended to add support for Half-Quadratic Quantization (HQQ) to the transformers library as requested in #28328
HQQ has been gaining popularity lately since it's fast to quantize and produces good quality models without using any calibration data. More details here: https://github.com/mobiusml/hqq/
Since quantization requires a cuda device, the quantization step is happening in
create_quantized_param()
.The tricky part is the serialization: the current logic unfortunately is not compatible with
HQQLinear
'sstate_dict
structure. For now, I am using the same logic from the hqq package, which storesstate_dicts
of the modules and usestorch.save
for saving the weights. Until we figure out a better way of doing it, this is the best solution I found so far.I added a progress-bar (inside the model loading progress-bar :D) to keep track of the quantization step. I noticed that the quantization step is slower than doing it with the hqq package.
I wrote some basic tests to check if the quantization is done properly on a Mistral model with different settings.
Full example here: hqq_transformers_llama_example.py
Let me know if you have any questions or requests!
Thank you in advance!