Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AQLM CUDA support #3287

Merged
merged 116 commits into from
Apr 23, 2024
Merged

Conversation

jaemzfleming
Copy link
Contributor

@jaemzfleming jaemzfleming commented Mar 8, 2024

SUMMARY:
Supports AQLM compressed inference, see

https://github.com/Vahe1994/AQLM
https://arxiv.org/pdf/2401.06118.pdf

Optimized supported formats are 1x16 and 2x8. Tensor parallelism is supported. Only CUDA kernels are provided. Formats other than 1x16 and 2x8 will run but at lower performance.

Also adds underlying support for all quantization schemes that require a separate fixed size codebook per layer.

The only trickiness was that QKVParallelLinear concatenates the Q, K, and V tensors, whose sizes and offsets are determined by by the number of heads, kv heads, and tensor parallelism. The corresponding codebooks all need to be present and concatenated for apply_weights. To support this we add the is_metadata attribute, which if present, will concatenate the Q,K, and V tensors along the zeroth dimension, just using the size of the loaded tensor.

Here's a benchmark server graph comparing 2bit 1x16 and 2x8 compared to FP16, plotting mean TPOT vs queries per second. At low query rates, you can see that the 1x16 is 1.36x faster and the 2x8 is 2.12x faster than FP16. By 15 queries per second, the 1x16 is 1.56x slower and the 2x8 is 1.16 slower. So either format is a good choice if memory is limited, especially if are serving low QPS. But 2x8 is best if you can afford the slightly lower accuracy.

aqlm_benchmark

Tested on several models:

  • ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf
  • ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf
  • ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf
  • ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf
  • BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf

Including with single or multiple GPUS and associated tensor parallelism.

@remiconnesson
Copy link

remiconnesson commented Apr 14, 2024

Hello trying to run it but there seems to be an issue with

File "/workspace/nm-vllm/vllm/model_executor/weight_utils.py", line 63, in get_lock
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
TypeError: BaseFileLock.__init__() got an unexpected keyword argument 'mode'

https://github.com/neuralmagic/nm-vllm/blob/22f7faeee16f63548b33ad6ebcc78e256de93524/vllm/model_executor/weight_utils.py#L62-L64

    # mode 0o666 is required for the filelock to be shared across users
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
                             mode=0o666)

Removing mode=0o666 seems to clear the problem


Traceback

root@86c9ebba321d:/workspace/nm-vllm/examples# python aqlm_example.py 
config.json: 100%|███████████████████████████████████████████████████████████████████| 968/968 [00:00<00:00, 9.93MB/s]
WARNING 04-14 02:07:14 config.py:222] aqlm quantization is not fully optimized yet. The speed can be slower than non-quantized models.
INFO 04-14 02:07:14 llm_engine.py:81] Initializing an LLM engine (v0.4.0.post1) with config: model='ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf', speculative_config=None, tokenizer='ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=True, quantization=aqlm, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, seed=0)
tokenizer_config.json: 100%|█████████████████████████████████████████████████████████| 776/776 [00:00<00:00, 7.82MB/s]
tokenizer.model: 100%|█████████████████████████████████████████████████████████████| 500k/500k [00:00<00:00, 3.24MB/s]
tokenizer.json: 100%|████████████████████████████████████████████████████████████| 1.84M/1.84M [00:00<00:00, 9.20MB/s]
special_tokens_map.json: 100%|███████████████████████████████████████████████████████| 414/414 [00:00<00:00, 4.37MB/s]
INFO 04-14 02:07:16 pynccl.py:58] Loading nccl from library /root/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 04-14 02:07:16 selector.py:77] Cannot use FlashAttention backend because the flash_attn package is not found. Please install it for better performance.
INFO 04-14 02:07:16 selector.py:33] Using XFormers backend.
INFO 04-14 02:07:18 weight_utils.py:194] Using model weights format ['*.safetensors']
Exception ignored in: <function BaseFileLock.__del__ at 0x7efd01ce3f40>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/filelock/_api.py", line 240, in __del__
    self.release(force=True)
  File "/usr/local/lib/python3.10/dist-packages/filelock/_api.py", line 201, in release
    with self._thread_lock:
AttributeError: 'UnixFileLock' object has no attribute '_thread_lock'
Traceback (most recent call last):
  File "/workspace/nm-vllm/examples/aqlm_example.py", line 46, in <module>
    main()
  File "/workspace/nm-vllm/examples/aqlm_example.py", line 36, in main
    model = LLM(args.model if args.model is not None else models[args.choice],
  File "/workspace/nm-vllm/vllm/entrypoints/llm.py", line 112, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
  File "/workspace/nm-vllm/vllm/engine/llm_engine.py", line 231, in from_engine_args
    engine = cls(
  File "/workspace/nm-vllm/vllm/engine/llm_engine.py", line 119, in __init__
    self.model_executor = executor_class(
  File "/workspace/nm-vllm/vllm/executor/gpu_executor.py", line 41, in __init__
    self._init_worker()
  File "/workspace/nm-vllm/vllm/executor/gpu_executor.py", line 67, in _init_worker
    self.driver_worker.load_model()
  File "/workspace/nm-vllm/vllm/worker/worker.py", line 108, in load_model
    self.model_runner.load_model()
  File "/workspace/nm-vllm/vllm/worker/model_runner.py", line 155, in load_model
    self.model = get_model(
  File "/workspace/nm-vllm/vllm/model_executor/model_loader.py", line 101, in get_model
    model.load_weights(model_config.model, model_config.download_dir,
  File "/workspace/nm-vllm/vllm/model_executor/models/llama.py", line 393, in load_weights
    for name, loaded_weight in hf_model_weights_iterator(
  File "/workspace/nm-vllm/vllm/model_executor/weight_utils.py", line 241, in hf_model_weights_iterator
    hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
  File "/workspace/nm-vllm/vllm/model_executor/weight_utils.py", line 197, in prepare_hf_model_weights
    with get_lock(model_name_or_path, cache_dir):
  File "/workspace/nm-vllm/vllm/model_executor/weight_utils.py", line 63, in get_lock
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
TypeError: BaseFileLock.__init__() got an unexpected keyword argument 'mode'

@mgoin
Copy link
Member

mgoin commented Apr 16, 2024

Hey @remiconnesson that filelock issue seems to be unrelated to this PR and addressed on main. Please try the updated version of this branch or main

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a Apache 2 LICENSE at the root, should the attribution (github link) goes to the top of the aqlm_cuda_entry.cpp and aqlm_cuda_kernel.cu?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh i see it is already there i think this can be removed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks removing now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see awq and others for namespacing into vllm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to

    "csrc/quantization/aqlm/cuda_entry.cpp"
    "csrc/quantization/aqlm/gemm_kernels.cu"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about namespacing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto done :)

vllm/config.py Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this to use the registry #4098?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated with main

Comment on lines 404 to 408
params_dtype, linear_method, [
self.num_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size
])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move this to a variable for readability

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@simon-mo
Copy link
Collaborator

sorry i should have clarified the comment about namespace

the exposed torch binding should be not be in cpp namespace

torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters)

all other helpers should include in namespace vllm ...

namespace vllm {
namespace awq {

@mgoin
Copy link
Member

mgoin commented Apr 18, 2024

@simon-mo thanks for the clarification, I was thinking about the wrong namespace! I got rid of the cpp file and wrapped everything except the two external functions in vllm::aqlm::

@jaemzfleming
Copy link
Contributor Author

Thanks for doing all this work @mgoin, much appreciated.

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit 2b7949c into vllm-project:main Apr 23, 2024
47 checks passed
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic deleted the jf/aqlm branch April 23, 2024 17:59
xjpang pushed a commit to xjpang/vllm that referenced this pull request Apr 25, 2024
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 26, 2024
alexeykondrat pushed a commit to alexeykondrat/ci-vllm that referenced this pull request May 1, 2024
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants