Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for GPTQ #1580

Closed
wants to merge 63 commits into from
Closed

Add initial support for GPTQ #1580

wants to merge 63 commits into from

Conversation

WoosukKwon
Copy link
Collaborator

This PR is a simplified version of the great PR #916.
The main difference is that, this PR does not use the exllama kernels while #916 does.
The purpose of this PR is to minimize the code change in a PR, and avoid possible conflicts from @zhuohan123's ongoing refactoring effort.

@WoosukKwon WoosukKwon marked this pull request as ready for review November 8, 2023 20:33
@WoosukKwon
Copy link
Collaborator Author

@zhuohan123 The PR is ready for review. Please take a look!

@WoosukKwon
Copy link
Collaborator Author

@zhaoyang-star The kernels used in this PR is not optimized, so actually you cannot get any speedup for now. We will optimize the quantized GEMM kernels in the next PR.

@LimpidEarth The kernel we are using now only supports 4 bit quantization, but we can extend it in the next PR. BTW, it seems most of the GPTQ models found in HF model hub are using 4 bit quantization. Could you provide us with any pointer to 8-bit GPTQ model?

@bash99
Copy link

bash99 commented Nov 10, 2023

@zhaoyang-star The kernels used in this PR is not optimized, so actually you cannot get any speedup for now. We will optimize the quantized GEMM kernels in the next PR.

@LimpidEarth The kernel we are using now only supports 4 bit quantization, but we can extend it in the next PR. BTW, it seems most of the GPTQ models found in HF model hub are using 4 bit quantization. Could you provide us with any pointer to 8-bit GPTQ model?

Most new GPTQ model released by TheBloke now has 8bit version (-1g 128g 32g),
https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GPTQ/tree/gptq-8bit-128g-actorder_True

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! The changes in general LGTM. However, there are some places that will conflict with #1622. Mainly because of the auxiliary variables like g_idx and shifter. What do you think about the merging plan? Should we merge this first or #1622 first? No matter which plan we go, I can help on the merge.

Comment on lines +128 to +130
logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean all quantization methods are not optimized yet?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, yes. For the three quantization methods we support, we are using the original authors' kernels, which can be further optimized. Particularly, the squeezellm and GPTQ kernels are slow if batch size > 1. As for AWQ, I think its kernel is much better than the two, but is still slow for large batch size and does not support bfloat16.

Comment on lines +90 to +94
self.shifter = torch.tensor(
[0, 4, 8, 12, 16, 20, 24, 28],
device="cuda",
dtype=torch.int32,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this shifter a must-have? The problem here is that we created a "non-parameter" tensor. We will need to modify the weight creation code in #1622 to make this creation work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How difficult will it be to add support for such tensors? While this tensor is only used for the pytorch-based GPTQ matmul implementation and thus will be eventually unused once we develop a more optimized kernel, such non-parameter buffers can be used for other quantization methods. I believe we should take this into account in the new design.

out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
num_tokens = x.shape[:-1].numel()
if num_tokens <= 32:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have an if here? Is the CUDA kernel slower when num_tokens > 32 or it's just the CUDA kernel will not work at all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Actually, the current GPTQ kernel is designed for batch size 1 and performs extremely poorly when the batch size is large, often taking 10+ minutes for the initial memory profiling. As a workaround, I implemented a simple PyTorch-based GPTQ matmul faster than the original kernel for large batch size. Still, the two implementations are very bad and probably much slower than the optimized implementations like exllama.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment on this.

)
# Initialize g_idx to be sequential.
# This is required because old GPTQ models may not have g_idx.
start_idx = self.tp_rank * self.input_size_per_partition
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leave as a note: This line will also require some modification in #1622.

@LimpidEarth
Copy link

LimpidEarth commented Nov 13, 2023

@LimpidEarth The kernel we are using now only supports 4 bit quantization, but we can extend it in the next PR. BTW, it seems most of the GPTQ models found in HF model hub are using 4 bit quantization. Could you provide us with any pointer to 8-bit GPTQ model?

@WoosukKwon Got it and looking forward to the PR of 8 bit supporting! The main reason for 8bit GPTQ model is that we found the evaluation results are better than 4 bit model regarding our domain tasks.

zhuohan123 added a commit that referenced this pull request Nov 16, 2023
…inear logic and extend quantization support to all models (#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](#1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
@WoosukKwon
Copy link
Collaborator Author

Closed as some of the changes are already merged and #916 will be merged instead.

@WoosukKwon WoosukKwon closed this Dec 5, 2023
jasonreyes9 added a commit to jasonreyes9/vllm-release-python that referenced this pull request Feb 8, 2024
…inear logic and extend quantization support to all models (#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](vllm-project/vllm#1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
…inear logic and extend quantization support to all models (vllm-project#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](vllm-project#1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
@WoosukKwon WoosukKwon deleted the minimal-gptq branch March 12, 2024 06:14
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
…inear logic and extend quantization support to all models (vllm-project#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](vllm-project#1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
coolCatalyst added a commit to coolCatalyst/vllm that referenced this pull request Jun 1, 2024
…inear logic and extend quantization support to all models (#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](vllm-project/vllm#1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
shaojiewang pushed a commit to shaojiewang/vllm-rocm that referenced this pull request Jul 3, 2024
…inear logic and extend quantization support to all models (#1622)

Refactor the tensor parallelism, quantization, and weight-loading codes.

Summary of the new features enabled by this PR:
- **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](vllm-project/vllm#1580).
- Model loading code became much simpler.
- Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants