Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantized KV Cache #30483

Merged
merged 21 commits into from
May 23, 2024
Merged

Quantized KV Cache #30483

merged 21 commits into from
May 23, 2024

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Apr 25, 2024

What does this PR do?

An implementation of quantized cache with quanto library. Introduces a new CacheConfig to store cache related arguments and a new cache class QuantoQuantizedCache. The implementation is based partially on the KIVI paper, but in this case we do a per-token quantization for both: keys and values.

PR for HF blogpost here

Example usage:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, attn_implementation="eager").to("cuda:0")

inputs = tokenizer("Hello, how are you?", truncation=True, return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized")
out_fp16 = model.generate(**inputs, do_sample=False, max_new_tokens=20)

print(f"text with quant cache: {tokenizer.batch_decode(out)}")
print(f"text with fp16 cache: {tokenizer.batch_decode(out_fp16)}")
Perplexity plots Here the results are different from what we got earlier because I was calculating perplexity in one forward pass, by quantizing and then dequantizing all keys and values. The new script uses cache object and calculates pplx per new token. Perplexity Latency
Eval on LongBench (scripts taken from LongBench repo) This is to compare with the KIVI method, since they did the same evals on all datasets from LongBench.
Dataset KIVI 16fp KIVI int2 Our fp16 Our int4 Our int2
TREC 63.0 67.5 63.0 63.0 55.0
SAMSum 41.12 42.18 41.12 41.3 14.04

I cannot find KIVI results on all of the LongBench, so here will be only transformers version.

Dataset fp16 int4 int2
TriviaQA 84.28 84.76 63.64
HotPotQA 30.08 30.04 17.3
Passage_retrieval_en 8.5 9.5 4.82
Memory vs Latency plots Same old plots showing memory consumption and latency for differeny cache types: Latency as a function of batch size Memory consumption as a function of batch size Memory consumption as a function of max new tokens

UPDATE:
Latest commit has added possibility to choose HQQ or quanto as backend. Usage:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import QuantoQuantizedCache, QuantizedCacheConfig, HQQQuantizedCache

os.environ["TOKENIZERS_PARALLELISM"] = "0"

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, attn_implementation="eager", device_map="auto")
tokenizer.pad_token_id = tokenizer.eos_token_id

inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

# now users can group keys and values over different axis (e.g quanto can do better on int2 if axis_key=0 and axis_value=-1)
cache_config = QuantizedCacheConfig(
    backend="HQQ",
    nbits=4,
    axis_key=0,
    axis_value=1,
    compute_dtype=torch.float16,
    device=model.device
)

out = model.generate(**inputs, do_sample=False, max_new_tokens=30, cache_implementation="quantized", cache_config=cache_config)
print(tokenizer.batch_decode(out, skip_special_tokens=True))

@zucchini-nlp
Copy link
Member Author

As we discussed quantized cache can be started to be integrated to the library, given the results we got so far. All the possible speed optimizations/pre-fill stage optimizations can be done further, as we will be getting feedback from the community.

So, I would like to get a review on the PR :)

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API wise looks really great ! I did not spotted anything critical here that needs to be addressed (and I will let joao give a deeper review on the cache file changes) - except for guarding quanto imports (also I would say safer to make local imports whenever possible - e.g. at QuantCache init)
You raised a concern about switching between cache implementations - I made an attempt while ago: #29030 that got stale (😅 ) maybe that PR might solve your concern?
Maybe we could also track models that support quant cache with a private attribute _supports_quant_cache in xxxPreTrainedModel - what do you think?

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

Thanks for the comments!

except for guarding quanto imports (also I would say safer to make local imports whenever possible - e.g. at QuantCache init)

Okey noted!

You raised a concern about switching between cache implementations - I made an attempt while ago: #29030 that got stale (😅 ) maybe that PR might solve your concern?

I love the generalized cache implementation idea. Not sure how this will work on overall API level, given that Joao and Arthur are working on changing cache thing. I'll let Joao to decide about that

Maybe we could also track models that support quant cache with a private attribute _supports_quant_cache in xxxPreTrainedModel - what do you think?

Hmm, Actually quant cache should be supported abywhere dynamicCache is, that means everything except for old models like bart/t5. Yeah I think we can add it for explicitness, until the cache API is refactored to be same everywhere

@younesbelkada
Copy link
Contributor

Thanks !

Hmm, Actually quant cache should be supported abywhere dynamicCache is, that means everything except for old models like bart/t5. Yeah I think we can add it for explicitness, until the cache API is refactored to be same everywhere

Ok that's great if that's the case then, i would say no need for that !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need a rebase due to #30476, but I love this POC -- in fact, I've reviewed it as if it was not a POC 😉

After removing the extra .py files and adding some docs, I believe it is ready to be launched! And I also think it deserves a blog post :D

src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
@zucchini-nlp zucchini-nlp marked this pull request as ready for review May 2, 2024 10:33
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments for you to work on, but let's gather the benchmarks first :)

docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @zucchini-nlp ! 🚀 I only left nits and one open question with respect to tests, otherwise it looks really great !

src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
tests/models/llama/test_modeling_llama.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

zucchini-nlp commented May 3, 2024

@gante added benchmark results on the PR description. Right now int4 has almost same performance as fp16, sometimes a bit better. Also added some comparison with the KIVI paper.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🙌 Thank you for iterating on this very cool project!

@gante
Copy link
Member

gante commented May 8, 2024

(CI needs fixing -- possibly a simple make fix-copies)

@gante gante requested a review from ArthurZucker May 8, 2024 14:20
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting work! Having both cache and quantizing on the fly when needed is very interesting!

docs/source/en/generation_strategies.md Show resolved Hide resolved
docs/source/en/generation_strategies.md Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
@zucchini-nlp zucchini-nlp changed the title [POC] Quantized KV Cache Quantized KV Cache May 9, 2024
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work ! Left one nit about tests !

tests/quantization/quanto_integration/test_quanto.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Last few nits and should be good to go!

src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
docs/source/en/generation_strategies.md Show resolved Hide resolved
src/transformers/cache_utils.py Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
@gante
Copy link
Member

gante commented May 13, 2024

@ArthurZucker @ydshieh: "torch.compile with quanto is only supported for 8 bits quantization for now" (from @SunMarc, on a related conversation on slack)

@zucchini-nlp
Copy link
Member Author

I made the KV cache work with HQQ as a backend. It can be simply plugged in if a user writes their own "CacheClass". I am not planning to add it now as it needs more evaluation and experiments, but wanted to show how anyone can add more backends. Do you think I should continue experimenting with HQQ or we can simply put the below code as example for users?

BTW, if we were to actually support more cache quant classes in the library, maybe we'll need to change the current QuantCache API a bit to be more versatile.

from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from hqq.core.quantize import Quantizer as HQQQuantizer


class HQQQuantizedCache(DynamicCache):
    def __init__(
        self,
        nbits: int = 4,
        axis: int = 0,
        q_group_size: int = 64,
        residual_length: int = 128,
        compute_dtype: torch.dtype = torch.float16,
        device: str = "cpu",
    ) -> None:
        if nbits not in [2, 4, 8]:
            raise ValueError(f"`nbits` has to be one of [`2`, `4`, `8`] but got {nbits}")

        if axis not in [0, 1]:
            raise ValueError(f"`axis` has to be one of [`1`, `2`] but got {axis}")

        self._quantized_key_cache: List[Tuple[torch.Tensor, Dict]] = []
        self._quantized_value_cache: List[Tuple[torch.Tensor, Dict]] = []
        self.nbits = nbits
        self.axis = axis

        self.residual_length = residual_length
        self.q_group_size = q_group_size
        self.compute_dtype = compute_dtype
        self.quantizer = HQQQuantizer
        self.device = device

        super().__init__()

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        if len(self.key_cache) <= layer_idx:
            q_key, meta_key = self._quantize(key_states.contiguous())
            self._quantized_key_cache.append((q_key, meta_key))

            q_value, meta_value = self._quantize(value_states.contiguous())
            self._quantized_value_cache.append((q_value, meta_value))

            self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
            self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
            keys_to_return, values_to_return = key_states, value_states
        else:
            quant_key, meta_key = self._quantized_key_cache[layer_idx]
            dequant_key = self.quantizer.dequantize(quant_key, meta_key)

            quant_value, meta_value = self._quantized_value_cache[layer_idx]
            dequant_value = self.quantizer.dequantize(quant_value, meta_value)

            keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
            values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]

            keys_to_return = torch.cat(keys_to_return, dim=-2)
            values_to_return = torch.cat(values_to_return, dim=-2)
            if (
                self.key_cache[layer_idx].dim() == 4
                and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
            ):
                q_key, meta_key = self._quantize(keys_to_return.contiguous())
                self._quantized_key_cache[layer_idx] = (q_key, meta_key)

                q_value, meta_value = self._quantize(values_to_return.contiguous())
                self._quantized_key_cache[layer_idx] = (q_value, meta_value)

                self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
                self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        return keys_to_return, values_to_return

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        if len(self.key_cache) <= layer_idx:
            return 0
        # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
        # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
        # this part of code otherwise fails when used to verify attn_weight shape in some models
        return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1

    def _quantize(self, tensor):
        qtensor, meta = self.quantizer.quantize(
            tensor,
            axis=self.axis,
            device=self.device,
            compute_dtype=self.compute_dtype,
            nbits=self.nbits,
            group_size=self.q_group_size,
        )
        meta["compute_dtype"] = self.compute_dtype
        return qtensor, meta


tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, attn_implementation="eager", device_map = "auto")

inputs = tokenizer("I like rock music because" return_tensors="pt").to(model.device)

out = model.generate(
    **inputs,
    do_sample=False,
    max_new_tokens=50,
    past_key_values=HQQQuantizedCache(
        nbits=2,
        axis=1, # 2bit with axis=0 generates garbage
        compute_dtype=torch.float16,
        device=model.device
    ),
)


print(f"text with HQQ backend: {tokenizer.batch_decode(out)}")

@ArthurZucker
Copy link
Collaborator

I think that making the cache class versatile is great to have people build on top of it, without necessarily including anythinig in transformers! But this can comme in a follow up PR

@zucchini-nlp
Copy link
Member Author

@ArthurZucker yes, making a versatile cache class will go on another PR. In that case we can leave quanto as the only choice available, and the rest can be implemented by users themselves

@ArthurZucker
Copy link
Collaborator

sounds good

@zucchini-nlp
Copy link
Member Author

@ArthurZucker @gante I made a few changes from the last review:

  1. Now we support HQQ and quanto (quanto by default as it is a bit faster, we'll work on using optimized kernels later). For that we have a base "QuantizedCacheClass" and all quantization methods can make their own class from it by overriding the _quantize and _dequantize methods.
  2. Added more kwargs to the config, so the users can indicate axis to quantize for keys and values separately, and have more control over the process
  3. Added _supports_quantized_cache mainly because of Jamba. Jamba comes out to _supports_cache_class but in "modeling" it checks for attribute that is not available for all cache classes (here)

I added a new usage ex in the description and will rework a bit the blogpost, given that now support HQQ. This PR is ready for the second review!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💛 💛 💛

@zucchini-nlp
Copy link
Member Author

zucchini-nlp commented May 23, 2024

Cool, merging 🤞🏻

Ran slow tests in quantization and generation locally, everything is passing.

@zucchini-nlp zucchini-nlp merged commit d583f13 into huggingface:main May 23, 2024
23 checks passed
@ydshieh
Copy link
Collaborator

ydshieh commented May 23, 2024

I am wondering if we can have this works together #30862. If so, we can probably get further more speedup!

@zucchini-nlp Could you share the simplest code snippet that you use for this PR to measure the runtime (latency)? I can try to incorporate this with #30862 🙏

@zucchini-nlp
Copy link
Member Author

@ydshieh This PR actually results in slow-down because of quantization 😅 But we can check the memory usage probably. Here is a script I used, but you'd have to replace QuantCache with QuantoQuantizedCache because the evaluation was done on an older commit

@ydshieh
Copy link
Collaborator

ydshieh commented May 24, 2024

OK. Thanks for sharing, so this PR is more about memory instead of speed.

itazap pushed a commit that referenced this pull request May 24, 2024
* clean-up

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <[email protected]>

* fixup

* Update tests/quantization/quanto_integration/test_quanto.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Arthur <[email protected]>

* more suggestions

* mapping if torch available

* run tests & add 'support_quantized' flag

* fix jamba test

* revert, will be fixed by another PR

* codestyle

* HQQ and versatile cache classes

* final update

* typo

* make tests happy

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
zucchini-nlp added a commit to zucchini-nlp/transformers that referenced this pull request Jun 11, 2024
* clean-up

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/cache_utils.py

Co-authored-by: Arthur <[email protected]>

* fixup

* Update tests/quantization/quanto_integration/test_quanto.py

Co-authored-by: Younes Belkada <[email protected]>

* Update src/transformers/generation/configuration_utils.py

Co-authored-by: Arthur <[email protected]>

* more suggestions

* mapping if torch available

* run tests & add 'support_quantized' flag

* fix jamba test

* revert, will be fixed by another PR

* codestyle

* HQQ and versatile cache classes

* final update

* typo

* make tests happy

---------

Co-authored-by: Arthur <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants