Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantized KV Cache #30483

Merged
merged 21 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ RUN python3 -m pip install --no-cache-dir decord av==9.2.0
# Some slow tests require bnb
RUN python3 -m pip install --no-cache-dir bitsandbytes

# Some tests require quanto
RUN python3 -m pip install --no-cache-dir quanto

# For `dinat` model
# The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent)
RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels
Expand Down
38 changes: 38 additions & 0 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,44 @@ your screen, one word at a time:
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```


## KV Cache Quantization

The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved

ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
(https://arxiv.org/abs/2402.02750) and currently works on `quanto` backend. For more information on the inner workings see the paper.

To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`.
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`QuantizedCacheConfig`] class.

<Tip warning={true}>

Cache quantization can be detrimental if the context length is short and there is enough GPU VRAM available to run without cache quantization.

</Tip>


```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4})
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. It's a great way to express myself and rel

>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20)
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```


## Decoding strategies

Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
Expand Down
12 changes: 11 additions & 1 deletion docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,23 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] Cache
- update

[[autodoc]] CacheConfig
- update

[[autodoc]] QuantizedCacheConfig
- validate

[[autodoc]] DynamicCache
- update
- get_seq_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] QuantoQuantizedCache
- update
- get_seq_length

[[autodoc]] SinkCache
- update
- get_seq_length
Expand All @@ -371,4 +381,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache
- update
- get_seq_length
- reorder_cache
- reset
20 changes: 18 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,15 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"]
_import_structure["cache_utils"] = [
"Cache",
"CacheConfig",
"DynamicCache",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
"SinkCache",
"StaticCache",
]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down Expand Up @@ -5730,7 +5738,15 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from .cache_utils import (
Cache,
CacheConfig,
DynamicCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SinkCache,
StaticCache,
)
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
254 changes: 252 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import copy
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import torch

from .configuration_utils import PretrainedConfig
from .utils import logging
from .utils import is_quanto_available, logging


if is_quanto_available():
from quanto import QBitsTensor, qint2, qint4

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -82,6 +88,161 @@ def seen_tokens(self):
return None


@dataclass
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
class CacheConfig:
"""
Base class for cache configs
"""

cache_implementation: None

@classmethod
def from_dict(cls, config_dict, **kwargs):
"""
Constructs a CacheConfig instance from a dictionary of parameters.
Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.
Returns:
CacheConfig: Instance of CacheConfig constructed from the dictionary.
"""
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
return config

# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.

Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default
`QuantizationConfig()` is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
config_dict = self.to_dict()
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

writer.write(json_string)

# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return copy.deepcopy(self.__dict__)

# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
def __iter__(self):
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
for attr, value in copy.deepcopy(self.__dict__).items():
yield attr, value

# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"

def to_json_string(self):
"""
Serializes this instance to a JSON formatted string.
Returns:
str: JSON formatted string representing the configuration instance.
"""
return json.dumps(self.__dict__, indent=2) + "\n"

# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
returning all the unused kwargs.

Args:
kwargs (`Dict[str, Any]`):
Dictionary of attributes to tentatively update this class.

Returns:
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
"""
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)

# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs


@dataclass
class QuantizedCacheConfig(CacheConfig):
"""
Configuration class for quantized cache settings.

Attributes:
nbits (`Optional[int]`, *optional*, defaults to 4):
Number of bits, can be 2 or 4. Defaults to 2.
q_group_size (`Optional[int]`, *optional*, defaults to 64):
Size of the quantization group, should be a divisor of the model's hidden dimension.
Defaults to 64.
residual_length (`Optional[int]`, *optional*, defaults to 128):
Length of the residual cache which will always be stored in original presicion.
Defaults to 128.
"""

def __init__(
self,
nbits: Optional[int] = 4,
q_group_size: Optional[int] = 64,
residual_length: Optional[int] = 128,
):
self.nbits = nbits
self.q_group_size = q_group_size
self.residual_length = residual_length

def validate(self):
"""Validates if the arguments passed are correct"""

incorrect_arg_msg = (
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)
if self.nbits not in [2, 4]:
raise ValueError(
incorrect_arg_msg.format(
key="nbits",
correct_value="2 or 4",
found_value=self.nbits,
),
)
if self.q_group_size <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="q_group_size",
correct_value="a positive integer",
found_value=self.q_group_size,
),
)
if self.residual_length < 0:
raise ValueError(
incorrect_arg_msg.format(
key="residual_length",
correct_value="a positive integer",
found_value=self.residual_length,
),
)


class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
Expand Down Expand Up @@ -186,6 +347,95 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens
return cache


class QuantoQuantizedCache(DynamicCache):
"""
A cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.

The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. Current implementation
supports `int2` and `int4` dtypes from `quanto` cache.

Cache stores the original precision Key and Value states as a list of tensors, one for each layer. The maximum expected shape for each tensor is
`[batch_size, num_heads, residual_length, head_dim]`. Quantized Key and Value are stored separately as a list of quantized tensors, one for each layer.
The size of each tensor is `[batch_size, num_heads, seq_len - residual_length, head_dim]`

Parameters:
nbits (`Optional[int]`, *optional*, defaults to 4):
Number of bits, can be 2 or 4. Defaults to 2.
q_group_size (`Optional[int]`, *optional*, defaults to 64):
Size of the quantization group, should be a divisor of the model's hidden dimension.
Defaults to 64.
residual_length (`Optional[int]`, *optional*, defaults to 128):
Length of the residual cache which will always be stored in original presicion.
Defaults to 128.
"""

def __init__(self, nbits: int = 4, q_group_size: int = 64, residual_length: int = 128) -> None:
if nbits not in [2, 4]:
raise ValueError(f"`nbits` has to be one of [`2`, `4`] but got {nbits}")

self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = []

self.seen_token = 0
self.residual_length = residual_length
self.qtype = qint4 if nbits == 4 else qint2
self.q_group_size = q_group_size

super().__init__()

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]

if len(self.key_cache) <= layer_idx:
self._quantized_key_cache.append(self._quantize(key_states.contiguous()))
self._quantized_value_cache.append(self._quantize(value_states.contiguous()))
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
keys_to_return, values_to_return = key_states, value_states
else:
dequant_key = self._quantized_key_cache[layer_idx].dequantize()
dequant_value = self._quantized_value_cache[layer_idx].dequantize()
keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]

keys_to_return = torch.cat(keys_to_return, dim=-2)
values_to_return = torch.cat(values_to_return, dim=-2)
if (
self.key_cache[layer_idx].dim() == 4
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
):
self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous())
self._quantized_value_cache[layer_idx] = self._quantize(values_to_return.contiguous())
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

return keys_to_return, values_to_return
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
return self._seen_tokens

def _quantize(self, tensor):
qtensor = QBitsTensor.quantize(tensor, axis=0, qtype=self.qtype, group_size=self.q_group_size)
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
return qtensor


class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
Expand Down
Loading
Loading