Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Attention Sinks #1357

Closed
kmn1024 opened this issue Nov 30, 2023 · 14 comments
Closed

Add Attention Sinks #1357

kmn1024 opened this issue Nov 30, 2023 · 14 comments
Labels
feature request New feature or request

Comments

@kmn1024
Copy link

kmn1024 commented Nov 30, 2023

🚀 Feature

Add Attention Sinks (https://arxiv.org/pdf/2309.17453.pdf, https://github.com/tomaarsen/attention_sinks/) to MLC.

Motivation

mlc_chat_cli gets noticeably slower as the conversation progresses. I tried this on Orange Pi 5, with two setups: as exactly described in https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi, and then compiling my own StableLM 3B (q4f16_1, OpenCL). You can see from these screenshots that toks/sec gradually decreases with progress (I waited in between each generation to ensure it wasn't due to thermal throttling).
Screenshot from 2023-11-30 15-52-48
Screenshot from 2023-11-30 16-26-43

Such slow down is unavoidable given the nature of Attention, but maybe we can reduce the latency hit without affecting decoding quality too much by using Attention Sinks (figure 1 from the paper). The default cache settings for most models is window attention with window size = sequence length; with Attention Sinks maybe we can use something smaller.

Alternatives

It seems there will always be a latency vs quality trade off for any type of cache, but perhaps Attention Sinks currently offers the best trade off.

Additional context

I would love to work on this, if I can get a mentor to point out which files should be changed for a tidy implementation!

@kmn1024 kmn1024 added the feature request New feature or request label Nov 30, 2023
@CharlieFRuan
Copy link
Contributor

Thanks for the request! Attention sink is definitely something on our minds and will probably support it soon! cc @junrushao @davidpissarra

@kmn1024
Copy link
Author

kmn1024 commented Dec 1, 2023

Thanks for your quick response Charlie!

I would like to offer my help on this one, since it looks relatively beginner-friendly, given the example implementation in https://github.com/tomaarsen/attention_sinks. Would it be OK for me to try to work on this over the next few days, and send you a PR for one model (whilst lighting a path to enable it for all other models)?

If experts such as yourself, Junru, or David have extra cycles, may I suggest something much more complicated and impactful: https://github.com/FasterDecoding/Medusa. Medusa is probably one of the most scalable speculative-decoding implementations, since it doesn't require a separate draft model. The claimed gains in toks/sec is impressive. It seems to fit well with MLC's focus on universality and high performance.

@CharlieFRuan
Copy link
Contributor

Thank you so much for offering help! We really appreciate it.

However, implementing attention sink may not be the most beginner-friendly task as we handle most kv cache logic in a lower-level stack called TVM. For instance, when we introduced sliding-window-attention, some work needed to be done in TVM: apache/tvm#15963.

Regarding speculative decoding, it is definitely something we are considering as well. We are currently working on another front for model serving (in contrast with simple chatting), which will probably include speculative decoding.

@kmn1024
Copy link
Author

kmn1024 commented Dec 2, 2023

Thanks for the heads up! Looking forwards to an implementation of this and speculative decoding too.

@kmn1024
Copy link
Author

kmn1024 commented Dec 6, 2023

@CharlieFRuan @davidpissarra So I didn't heed Charlie's warning, and attempted to implement attention sink. Ended up pulling most of my hair out, but I have something that seems to work. The changes are in the two repos that Charlie pointed out:
mlc-ai/relax@mlc...kmn1024:relax_attention_sinks:main
main...kmn1024:llm_attention_sinks:main (please ignore conv_templates.cc)

I will try to get some /stats tomorrow. Is this something that I can send to you two for review? No worries at all if you prefer not, I can see that you guys are super busy =)

@CharlieFRuan
Copy link
Contributor

@kmn1024 Wow this is really impressive, thank you for the hard work!

We are in the process of migrating from the relax_model folder to SLIM, essentially a new workflow for compiling models on the mlc-llm layer. We are still wrapping it up and making documentation for it.

Therefore, the changes in lm_support.cc, llm_chat.cc would not be affected; but those in relax_model and mlc_llm/core.py may need to be migrated later when the new workflow is up.

With that being said, once you are ready, feel free to open a PR for both the TVM side and the mlc-llm side (old workflow is fine), then @davidpissarra and/or I will help review. We can later pick the changes to the new workflow.

Really appreciate the contribution!

@kmn1024
Copy link
Author

kmn1024 commented Dec 7, 2023

Thanks Charlie! I'll begin sending PRs.

Here's a screenshot of the code in action, with added logs to show when the cache trimming happens:
Screenshot from 2023-12-07 13-27-14

@CharlieFRuan
Copy link
Contributor

@kmn1024 Looks great, thank you so much! We'll look at the PRs.

@davidpissarra
Copy link
Member

davidpissarra commented Dec 17, 2023

Hi, @kmn1024! Regarding sinks, since most of the SW logic was already implemented, we were able to reuse the WindowOverride function from SWA to implement it (three-line change, see apache/tvm#16240). As of now, mistral is the only architecture that supports SWA and sinks (#1435). Part of our effort now is to bring sinks to the other models.

@kmn1024
Copy link
Author

kmn1024 commented Dec 18, 2023

Thanks, really appreciate your work! I will close out my PRs.

I want to take this opportunity to ask you (@davidpissarra) and Charlie (@CharlieFRuan) if there's more information on speculative decoding. Even with shortened context window, the steady state throughput is still too low, and extrapolating from Medusa numbers (my board has smaller warp size), I think something similar would give me a much needed ~1.3x boost.

I would be happy to sponsor a bounty, but 1. the TVM/MLC group seems better funded than myself; 2. I'm not sure if bounties are a source of motivation in academia. Any how, laying this out there, please let me know =)

@CharlieFRuan
Copy link
Contributor

Thanks for proposing it! Unfortunately, there isn't too much update in speculative decoding as the team is occupied on various other things (e.g. serving, multi GPU, etc.). We would probably look into that after https://github.com/mlc-ai/mlc-llm/tree/serving lands.

@junrushao
Copy link
Member

Si-ze and I are on speculative decoding on the serving branch

@kmn1024
Copy link
Author

kmn1024 commented Dec 19, 2023

Thanks for the replies! @junrushao do you plan on implementing it "bring your own draft model" style, or integrated like Medusa, or something else?

@jpf888
Copy link

jpf888 commented Apr 10, 2024

@junrushao hi, When will medusa be supported on the serve branch? Do you have any plans?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants