Skip to content

Commit

Permalink
PR suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed May 10, 2024
1 parent 6549943 commit 5c8b25c
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 73 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ KV Cache quantization in `transformers` is largely inspired by the paper [KIVI:
(https://arxiv.org/abs/2402.02750). For more information on the inner workings see the paper.

To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`.
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a `CacheConfig` class.
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`QuantizedCacheConfig`] class.

<Tip warning={true}>

Expand Down
4 changes: 3 additions & 1 deletion docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] CacheConfig
- update

[[autodoc]] QuantizedCacheConfig
- validate

[[autodoc]] DynamicCache
Expand All @@ -367,7 +369,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- to_legacy_cache
- from_legacy_cache

[[autodoc]] QuantCache
[[autodoc]] QuantoQuantizedCache
- update
- get_seq_length

Expand Down
13 changes: 11 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,8 @@
"Cache",
"CacheConfig",
"DynamicCache",
"QuantCache",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
"SinkCache",
"StaticCache",
]
Expand Down Expand Up @@ -5737,7 +5738,15 @@
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, CacheConfig, DynamicCache, QuantCache, SinkCache, StaticCache
from .cache_utils import (
Cache,
CacheConfig,
DynamicCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SinkCache,
StaticCache,
)
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
Expand Down
89 changes: 35 additions & 54 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,28 +91,10 @@ def seen_tokens(self):
@dataclass
class CacheConfig:
"""
Configuration class for quantized cache settings.
Attributes:
nbits (`Optional[int]`, *optional*, defaults to 2):
Number of bits, can be 2 or 4. Defaults to 2.
q_group_size (`Optional[int]`, *optional*, defaults to 64):
Size of the quantization group, should be a divisor of the model's hidden dimension.
Defaults to 64.
residual_length (`Optional[int]`, *optional*, defaults to 128):
Length of the residual cache which will always be stored in original presicion.
Defaults to 128.
Base class for cache configs
"""

def __init__(
self,
nbits: Optional[int] = 4,
q_group_size: Optional[int] = 64,
residual_length: Optional[int] = 128,
):
self.nbits = nbits
self.q_group_size = q_group_size
self.residual_length = residual_length
cache_implementation: None

@classmethod
def from_dict(cls, config_dict, **kwargs):
Expand Down Expand Up @@ -180,7 +162,36 @@ def update(self, **kwargs):
if hasattr(self, key):
setattr(self, key, value)


@dataclass
class QuantizedCacheConfig(CacheConfig):
"""
Configuration class for quantized cache settings.
Attributes:
nbits (`Optional[int]`, *optional*, defaults to 4):
Number of bits, can be 2 or 4. Defaults to 2.
q_group_size (`Optional[int]`, *optional*, defaults to 64):
Size of the quantization group, should be a divisor of the model's hidden dimension.
Defaults to 64.
residual_length (`Optional[int]`, *optional*, defaults to 128):
Length of the residual cache which will always be stored in original presicion.
Defaults to 128.
"""

def __init__(
self,
nbits: Optional[int] = 4,
q_group_size: Optional[int] = 64,
residual_length: Optional[int] = 128,
):
self.nbits = nbits
self.q_group_size = q_group_size
self.residual_length = residual_length

def validate(self):
"""Validates if the arguments passed are correct"""

incorrect_arg_msg = (
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
Expand Down Expand Up @@ -315,22 +326,22 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens
return cache


class QuantCache(Cache):
class QuantoQuantizedCache(DynamicCache):
"""
A cache similar to what described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. Current implementation
supports `int2` and `int4` cache.
supports `int2` and `int4` dtypes from `quanto` cache.
Cache stores the original precision Key and Value states as a list of tensors, one for each layer. The maximum expected shape for each tensor is
`[batch_size, num_heads, residual_length, head_dim]`. Quantized Key and Value are stored separately as a list of quantized tensors, one for each layer.
The size of each tensor is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
Parameters:
nbits (`Optional[int]`, *optional*, defaults to 2):
nbits (`Optional[int]`, *optional*, defaults to 4):
Number of bits, can be 2 or 4. Defaults to 2.
q_group_size (`Optional[int]`, *optional*, defaults to 64):
Size of the quantization group, should be a divisor of the model's hidden dimension.
Expand All @@ -344,40 +355,14 @@ def __init__(self, nbits: int = 4, q_group_size: int = 64, residual_length: int
if nbits not in [2, 4]:
raise ValueError(f"`nbits` has to be one of [`2`, `4`] but got {nbits}")

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self._key_cache_quant: List[torch.Tensor] = []
self._value_cache_quant: List[torch.Tensor] = []
self.seen_token = 0

self.residual_length = residual_length
self.qtype = qint4 if nbits == 4 else qint2
self.q_group_size = q_group_size

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in range(len(self)):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])

def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)
super.__init__()

def update(
self,
Expand Down Expand Up @@ -435,10 +420,6 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return 0
return self.seen_token

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None

def _quantize(self, tensor):
qtensor = QBitsTensor.quantize(tensor, axis=0, qtype=self.qtype, group_size=self.q_group_size)
return qtensor
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


if is_torch_available():
from ..cache_utils import CacheConfig
from ..cache_utils import QuantizedCacheConfig


if TYPE_CHECKING:
Expand Down Expand Up @@ -287,9 +287,9 @@ class GenerationConfig(PushToHubMixin):
cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
cache_config (`Union[CacheConfig, dict]`, *optional*, default to `None`):
cache_config (`Union[QuantizedCacheConfig, dict]`, *optional*, default to `None`):
Arguments used for quantized cache that stores keys and values in lower precision for memory efficiency.
If passed as `Dict`, it will be converted to a `CacheConfig` internally.
If passed as `Dict`, it will be converted to a `QuantizedCacheConfig` internally.
Accepts the following keys:
- nbits (`int`, *optional*, defaults to 2):
Number of bits, can be 2 or 4. Defaults to 2.
Expand Down Expand Up @@ -378,9 +378,9 @@ def __init__(self, **kwargs):
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation == "quantized":
if self.cache_config is None:
self.cache_config = CacheConfig()
self.cache_config = QuantizedCacheConfig()
elif isinstance(self.cache_config, dict):
self.cache_config = CacheConfig.from_dict(self.cache_config)
self.cache_config = QuantizedCacheConfig.from_dict(self.cache_config)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
Expand Down Expand Up @@ -634,8 +634,8 @@ def validate(self, is_init=False):

# 5. check `cache_config`
if self.cache_config is not None:
if not isinstance(self.cache_config, CacheConfig):
self.cache_config = CacheConfig.from_dict(self.cache_config)
if not isinstance(self.cache_config, QuantizedCacheConfig):
self.cache_config = QuantizedCacheConfig.from_dict(self.cache_config)
self.cache_config.validate()

# 6. check common issue: passing `generate` arguments inside the generation config
Expand Down
12 changes: 7 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache, QuantCache, StaticCache
from ..cache_utils import Cache, DynamicCache, QuantizedCacheConfig, QuantoQuantizedCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand All @@ -45,7 +45,7 @@
_prepare_attention_mask,
_prepare_token_type_ids,
)
from .configuration_utils import CacheConfig, GenerationConfig, GenerationMode
from .configuration_utils import GenerationConfig, GenerationMode
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
Expand Down Expand Up @@ -95,7 +95,7 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "quantized": QuantCache}
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "quantized": QuantoQuantizedCache}


@dataclass
Expand Down Expand Up @@ -1600,9 +1600,11 @@ def generate(
)

cache_config = (
generation_config.cache_config if generation_config.cache_config is not None else CacheConfig()
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
)
model_kwargs["past_key_values"] = QuantCache(
model_kwargs["past_key_values"] = QuantoQuantizedCache(
nbits=cache_config.nbits,
q_group_size=cache_config.q_group_size,
residual_length=cache_config.residual_length,
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class QuantCache(metaclass=DummyObject):
class QuantizedCacheConfig(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class QuantoQuantizedCache(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
)
from transformers.cache_utils import DynamicCache, QuantCache
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
Expand Down Expand Up @@ -1670,7 +1670,7 @@ def test_generate_with_quant_cache(self):
}

results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self.assertTrue(isinstance(results.past_key_values, QuantCache))
self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache))

# passing past key values of different type should raise Error
with self.assertRaises(ValueError):
Expand Down

0 comments on commit 5c8b25c

Please sign in to comment.