Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS #6036

Open
wants to merge 73 commits into
base: main
Choose a base branch
from

Conversation

LeiWang1999
Copy link

@LeiWang1999 LeiWang1999 commented Jul 1, 2024

Hi all, this PR introduces support for the Microsoft Runtime Kernel Library to enhance our low precision computation capabilities.

Brief Introduction of BitBLAS

BitBLAS is a library to support mixed-precision BLAS operations on GPUs, for example, the $W_{wdtype}A_{adtype}$ mixed-precision matrix multiplication where $C_{cdtype}[M, N] = A_{adtype}[M, K] \times W_{wdtype}[N, K]$.
BitBLAS aims to support efficient mixed-precision DNN model deployment, especially the $W_{wdtype}A_{adtype}$ quantization in large language models (LLMs), for example, the $W_{UINT4}A_{FP16}$ in GPTQ, the $W_{INT2}A_{FP16}$ in BitDistiller, the $W_{INT2}A_{INT8}$ in BitNet-b1.58.

PR Overview

This PR integrates BitBLAS into vLLM by adding examples of its usage. We provide two forms:

  1. Load from GPTQ Checkpoints: This allows the loading of models from GPTQ format checkpoints.
  2. Load from GPTQ CKPT with BitBLAS Format: This enables the loading of models using the BitBLAS format for further optimized performance.

Below are the benchmarking results that we evaluated several months ago:

TODO ITEMS

  • Update and provide the latest benchmarking results.
  • 1.58Bits Model
  • Provide Benchmark/Test Scripts

Any feedback and suggestions to improve this integration are appreciated.

@robertgshaw2-redhat
Copy link
Collaborator

Nice!

@LeiWang1999
Copy link
Author

LeiWang1999 commented Jul 1, 2024

BTW, are there any tools available that can automatically resolve lint issues?

vllm/model_executor/layers/quantization/gptq_bitblas.py:28:1: E402 Module level import not at top of file
vllm/model_executor/layers/quantization/gptq_bitblas.py:28:8: F811 Redefinition of unused `bitblas` from line 21
vllm/model_executor/layers/quantization/gptq_bitblas.py:29:1: E402 Module level import not at top of file
vllm/model_executor/layers/quantization/gptq_bitblas.py:66:81: E501 Line too long (107 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:172:81: E501 Line too long (85 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:222:81: E501 Line too long (105 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:230:81: E501 Line too long (89 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:233:81: E501 Line too long (110 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:236:81: E501 Line too long (99 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:242:81: E501 Line too long (84 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:253:81: E501 Line too long (94 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:414:81: E501 Line too long (86 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:417:29: G004 Logging statement uses f-string
vllm/model_executor/layers/quantization/gptq_bitblas.py:420:17: G004 Logging statement uses f-string
vllm/model_executor/layers/quantization/gptq_bitblas.py:427:81: E501 Line too long (103 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:433:81: E501 Line too long (116 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:454:81: E501 Line too long (82 > 80)

@robertgshaw2-redhat
Copy link
Collaborator

BTW, are there any tools available that can automatically resolve lint issues?

vllm/model_executor/layers/quantization/gptq_bitblas.py:28:1: E402 Module level import not at top of file
vllm/model_executor/layers/quantization/gptq_bitblas.py:28:8: F811 Redefinition of unused `bitblas` from line 21
vllm/model_executor/layers/quantization/gptq_bitblas.py:29:1: E402 Module level import not at top of file
vllm/model_executor/layers/quantization/gptq_bitblas.py:66:81: E501 Line too long (107 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:172:81: E501 Line too long (85 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:222:81: E501 Line too long (105 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:230:81: E501 Line too long (89 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:233:81: E501 Line too long (110 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:236:81: E501 Line too long (99 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:242:81: E501 Line too long (84 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:253:81: E501 Line too long (94 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:414:81: E501 Line too long (86 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:417:29: G004 Logging statement uses f-string
vllm/model_executor/layers/quantization/gptq_bitblas.py:420:17: G004 Logging statement uses f-string
vllm/model_executor/layers/quantization/gptq_bitblas.py:427:81: E501 Line too long (103 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:433:81: E501 Line too long (116 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:454:81: E501 Line too long (82 > 80)

./format.sh fixes whatever it can, but not everything is automated for fixing (esp line length)

@mgoin
Copy link
Member

mgoin commented Jul 1, 2024

@LeiWang1999 thanks for the WIP, very cool interface with bitblas as a package. Can you explain if the GPTQ benchmarking results in vLLM were run with the base "gptq" kernels or using the "gptq_marlin" interface to take advantage of Marlin kernels? This will be important to compare with the current baseline we consider for GPTQ models in vLLM

@LeiWang1999
Copy link
Author

@LeiWang1999 thanks for the WIP, very cool interface with bitblas as a package. Can you explain if the GPTQ benchmarking results in vLLM were run with the base "gptq" kernels or using the "gptq_marlin" interface to take advantage of Marlin kernels? This will be important to compare with the current baseline we consider for GPTQ models in vLLM

Thanks, it utilized exllamav2 during our benchmarking at that time; we will examine the comparison with the Marlin kernel.

@LeiWang1999
Copy link
Author

LeiWang1999 commented Jul 19, 2024

Hi all, I recently update the the supports for 1.58bits model and related bitblas inference kernel for vllm.

    Token Per Second(tok/s)    
model framework BS16IN32OUT128 BS1IN512OUT1024 B32IN32OUT128
bitnet-3b-1.58bits pytorch 106.83 49.34 209.03
bitnet-3b-1.58bits pytorch-bitblas 240.33 103.09 493.31
bitnet-3b-1.58bits vllm-bitblas 379.25 117.43 752.55
bitnet-3b-1.58bits vllm-bitblas-cuda-graph 2543.58 1621.08 2731.79

@LeiWang1999 LeiWang1999 marked this pull request as ready for review July 19, 2024 04:23
@LeiWang1999
Copy link
Author

We will soon do benchmarking with marlin, and looks like the docs build failed because of the dependency for bitblas, do you have any ideas to fix this issue? should we put the bitblas requirements to the doc/requirements or is there some options to skip this dependency? @mgoin

@LeiWang1999
Copy link
Author

@mgoin , @LucasWilkinson , thanks for your detail and valuable review messages, modifications and improvements have been applied, please take a look :)

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this essentially looks good to go with these last fixes @LeiWang1999

docs/source/quantization/bitblas.rst Outdated Show resolved Hide resolved
docs/source/quantization/bitblas.rst Outdated Show resolved Hide resolved
docs/source/quantization/bitblas.rst Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/kernels/bitblas.py Outdated Show resolved Hide resolved
Copy link

mergify bot commented Nov 19, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LeiWang1999.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label Dec 19, 2024
@LeiWang1999
Copy link
Author

@mgoin apologies for the delayed response. In the latest update, we double-checked the correctness, optimized INT8 GEMM performance with DP4A on V100, and added support for high-performance GEMM on MI300 within BitBLAS. We’ve also updated recent benchmark results, which you can find at bitblas-benchmark.

Additionally, we released version 0.1.0 as part of this pull request, let us work together to get this pull request in, and start planning the next PR for BitNet :)

@mgoin mgoin changed the title [Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation [Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS Jan 2, 2025
Copy link

mergify bot commented Jan 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LeiWang1999.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 2, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy new year and thanks for your patience over the holidays! This looks good to me to land, just a few nits and help with your merge conflicts

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to update your docs changes to work with the new .md format

Comment on lines +685 to +701
def find_flash_attn_supported_head_dims(self, head_dim: int) -> int:
"""
Find the closest head dimension to the given head dimension that
is supported by Flash Attention.
"""
from vllm.attention.backends.flash_attn import FlashAttentionBackend

FLASHATTN_SUPPORTED_HEAD_DIMS = (
FlashAttentionBackend.get_supported_head_sizes())

for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS:
if head_dim <= supported_head_dim:
return supported_head_dim
raise ValueError(
f"Head dimension {head_dim} is not supported by Flash Attention."
f"Supported head dimensions are {FLASHATTN_SUPPORTED_HEAD_DIMS}.")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an unrelated and unused change? There are a few other changes in this file as well, but I could understand if this is just the formatter switching up

BITBLAS_SUPPORTED_SYM = [False, True]


# For binary size and compile time, we don't support the same types for with and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment doesn't look finished

Comment on lines +129 to +138
A_dtype,
W_dtype,
out_dtype,
accum_dtype,
layout,
with_bias,
group_size,
with_scaling,
with_zeros,
zeros_mode,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you could make all these configs more readable if you put all of these args in a list and unzipped them such as (1, 16384, 16384, *shared_args) where shared_args = [A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation needs-rebase
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants