Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: DeepSeek-Coder-V2-Instruct-FP8 on 8xA100 #7322

Closed
halexan opened this issue Aug 9, 2024 · 8 comments
Closed

[Feature]: DeepSeek-Coder-V2-Instruct-FP8 on 8xA100 #7322

halexan opened this issue Aug 9, 2024 · 8 comments

Comments

@halexan
Copy link

halexan commented Aug 9, 2024

🚀 The feature, motivation and pitch

VLLM has announced support for running llama3.1-405b-fp8 on 8xA100. This is the blog

Does vllm support running DeepSeek-Coder-V2-Instruct-FP8 on 8xA100?

However, I notice that vLLM uses Triton for its FusedMoE kernel, which doesn't support the FP8 Marlin mixed-precision. See sgl-project/sglang#989 (comment)

Is there any work around?

Alternatives

No response

Additional context

No response

@robertgshaw2-neuralmagic
Copy link
Collaborator

There is not currently a workaround for this. We have been working on extending Marlin to support FusedMoE and will likely extend this to fp8 at some point. But this will take some time

see: #7079 for progress of marlin fused_moe

@robertgshaw2-neuralmagic
Copy link
Collaborator

Closing for now.

@jon-chuang
Copy link
Contributor

Hello @robertgshaw2-neuralmagic , may I ask why an FP8 quantized model would used an FP16XINT4 mm kernel? Could you point to some resources or blog post about this? Thank you.

@robertgshaw2-neuralmagic
Copy link
Collaborator

Marlin is a mixed precision inference kernel. It supports int4 weights, int8 weights, and fp8 weights with 16 bit activations (for dense models)

we started by extending marlin to support fused moe with int4 and int8 weights and fp16 activations (the pr I linked). A follow up to this will be extending to support fp8 weights as well

@jon-chuang
Copy link
Contributor

At what batch size does Marlin become optimal (I.e. roofline) for FP8?

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Aug 12, 2024

I’m not sure I follow the question.

The Roofline analysis shows the latency of the kernel as a function of batch size. Marlin GEMM is a highly optimized kernel that was designed to address performance issues with the prior generation of mixed precision kernels which did not perform well in the batch 8-64 range even though the computation is memory bound.

So, marlin follows the roofline plot very well. But, you should not expect marlin to accelerate compute bound workloads over fp16. For compute bound workloads we recommend using activation quantization

@robertgshaw2-neuralmagic
Copy link
Collaborator

One follow up - If you’re running on Hopper, I don’t think it makes sense to use marlin for fp8 since we can use dyanmic activation quantization with high accuracy. The only use of marlin fp8 IMO should be for devices which do not support fp8 compute (I.e a100)

@jon-chuang
Copy link
Contributor

I see, thank you for the detailed response!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants