Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Support MoE Fp8 Checkpoints for Mixtral (Static Weights with Dynamic/Static Activations) #4527

Merged
merged 6 commits into from
May 4, 2024

Conversation

mgoin
Copy link
Collaborator

@mgoin mgoin commented May 1, 2024

Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436.

This PR enables the following checkpoint loading features for Mixtral:

Notes:

  • The Expert Gate/Router always runs at half / full precision for now.
  • If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.

Future work:

  • cutlass kernels for separate QKV weight scales
  • support memory compression from loading fp16 checkpoints and dynamically quantizing to fp8 (blocked on weight loader refactor)
  • generalize MoE implementation to apply to other MoE models

Smoke test output:

python test-mixtral-fp8.py 
WARNING 05-03 01:42:29 config.py:187] fp8 quantization is not fully optimized yet. The speed can be slower than non-quantized models.
INFO 05-03 01:42:29 llm_engine.py:100] Initializing an LLM engine (v0.4.1) with config: model='nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8', speculative_config=None, tokenizer='nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0)
INFO 05-03 01:42:29 utils.py:623] Found nccl from library /home/paperspace/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 05-03 01:42:30 selector.py:75] Cannot use FlashAttention-2 backend because the flash_attn package is not found. Please install it for better performance.
INFO 05-03 01:42:30 selector.py:31] Using XFormers backend.
WARNING 05-03 01:42:31 fp8.py:29] Detected fp8 checkpoint. Please note that the format is experimental and subject to change.
INFO 05-03 01:42:31 weight_utils.py:199] Using model weights format ['*.safetensors']
WARNING 05-03 01:42:41 utils.py:428] Found act_scales that are not equal for fp8 MoE layer. Using the maximum across experts for each layer. 
INFO 05-03 01:42:42 model_runner.py:172] Loading model weights took 43.7487 GB
INFO 05-03 01:42:51 gpu_executor.py:114] # GPU blocks: 9689, # CPU blocks: 2048
INFO 05-03 01:42:53 model_runner.py:872] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-03 01:42:53 model_runner.py:876] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 05-03 01:43:00 model_runner.py:953] Graph capturing finished in 7 secs.
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.30it/s]
Prompt: 'Hello, my name is', Generated text: ' Alyssa and I am a 17-year-old girl'
Prompt: 'The president of the United States is', Generated text: ' the head of the executive branch of the United States government and is the highest political'
Prompt: 'The capital of France is', Generated text: " a beautiful and historic city that is home to some of the world's most"
Prompt: 'The future of AI is', Generated text: ' a rapidly evolving field, with new developments and innovations happening all the time'

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

vllm/model_executor/models/mixtral.py Outdated Show resolved Hide resolved

# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
Copy link
Collaborator

@pcmoritz pcmoritz May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be removed -- we do support activation scales for FP16 checkpoints too (same as kv store scales going forward)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah never mind, I misunderstood -- FP16 checkpoints with "quantization": "fp8" are also considered fp8 serialized (this is pretty confusing)

@pcmoritz pcmoritz merged commit 2a05201 into vllm-project:main May 4, 2024
49 checks passed
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
… Dynamic/Static Activations) (vllm-project#4527)

Follow on to vllm-project#4332 to enable FP8 checkpoint loading for Mixtral and supersedes vllm-project#4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
… Dynamic/Static Activations) (vllm-project#4527)

Follow on to vllm-project#4332 to enable FP8 checkpoint loading for Mixtral and supersedes vllm-project#4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 7, 2024
… Dynamic/Static Activations) (vllm-project#4527)

Follow on to vllm-project#4332 to enable FP8 checkpoint loading for Mixtral and supersedes vllm-project#4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
… Dynamic/Static Activations) (vllm-project#4527)

Follow on to vllm-project#4332 to enable FP8 checkpoint loading for Mixtral and supersedes vllm-project#4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants