Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make static cache compatible with torch.export #32168

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
logger = logging.get_logger(__name__)


@dataclass
class Cache:
class Cache(torch.nn.Module):
guangy10 marked this conversation as resolved.
Show resolved Hide resolved
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""

def __init__(self):
super().__init__()

def update(
self,
key_states: torch.Tensor,
Expand Down Expand Up @@ -299,6 +301,7 @@ class DynamicCache(Cache):
"""

def __init__(self) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
Expand Down Expand Up @@ -461,6 +464,7 @@ class QuantizedCache(DynamicCache):
"""

def __init__(self, cache_config: QuantizedCacheConfig) -> None:
super().__init__()
self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = []

Expand Down Expand Up @@ -634,6 +638,7 @@ class SinkCache(Cache):
"""

def __init__(self, window_length: int, num_sink_tokens: int) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.window_length = window_length
Expand Down Expand Up @@ -786,7 +791,7 @@ def update(

class StaticCache(Cache):
"""
Static Cache class to be used with `torch.compile(model)`.
Static Cache class to be used with `torch.compile(model)` and `torch.export()`.

Parameters:
config (`PretrainedConfig):
Expand Down Expand Up @@ -817,18 +822,22 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for _ in range(config.num_hidden_layers):
for idx in range(config.num_hidden_layers):
# Note: `torch.export()`` requires mutations to be registered as buffers.
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
key_cache = getattr(self, f"key_cache_{idx}")
value_cache = getattr(self, f"value_cache_{idx}")
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
torch._dynamo.mark_static_address(key_cache)
torch._dynamo.mark_static_address(value_cache)
self.key_cache.append(key_cache)
self.value_cache.append(value_cache)

def update(
self,
Expand Down Expand Up @@ -928,6 +937,7 @@ class SlidingWindowCache(StaticCache):
"""

def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
Expand Down Expand Up @@ -1005,6 +1015,7 @@ class EncoderDecoderCache(Cache):
"""

def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
super().__init__()
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache

Expand Down Expand Up @@ -1148,6 +1159,7 @@ def batch_select_indices(self, indices: torch.Tensor):

class HybridCache(Cache):
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
Expand Down
58 changes: 58 additions & 0 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

import unittest

from packaging import version
from parameterized import parameterized

from transformers import set_seed
from transformers.testing_utils import (
is_torch_available,
require_auto_gptq,
require_read_token,
require_torch,
require_torch_gpu,
slow,
Expand All @@ -32,6 +34,7 @@
import torch

from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
Expand Down Expand Up @@ -164,6 +167,61 @@ def _random_kvs(config):
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))

@slow
@require_read_token
guangy10 marked this conversation as resolved.
Show resolved Hide resolved
def test_static_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`
"""
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")

device = "cpu"
dtype = torch.float32
max_batch_size = 1

config = AutoConfig.from_pretrained(
"google/gemma-2b",
torch_dtype=dtype,
use_cache=True,
)
m = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
config=config,
torch_dtype=dtype,
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]

class ExportatibleModelWithStaticCache(torch.nn.Module):
Copy link
Contributor Author

@guangy10 guangy10 Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will need such a wrapper class because:

  1. The model needs to be self-contained so that the runtime can just load as a binary and run as-is. Therefore, the cache config needs to be determined at the model construct time and being part of the exported binary.
  2. The export may not be able to support passing the static cache instance as a param to forward().

For example, if I changed the test to use the gemma-2b model directly:

# This is the forward() signature for GemmaPreTrainedModel
# 
    # def forward(
    #     self,
    #     input_ids: torch.LongTensor = None,
    #     attention_mask: Optional[torch.Tensor] = None,
    #     position_ids: Optional[torch.LongTensor] = None,
    #     past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
    #     inputs_embeds: Optional[torch.FloatTensor] = None,
    #     labels: Optional[torch.LongTensor] = None,
    #     use_cache: Optional[bool] = None,
    #     output_attentions: Optional[bool] = None,
    #     output_hidden_states: Optional[bool] = None,
    #     return_dict: Optional[bool] = None,
    #     cache_position: Optional[torch.LongTensor] = None,
    # ) -> Union[Tuple, CausalLMOutputWithPast]:

m = AutoModelForCausalLM.from_pretrained("google/gemma-2b", attn_implementation="sdpa")
with torch.no_grad():
  static_kv_cache = StaticCache(config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device)
  export(m, args=(inputs,), kwargs={"past_key_values": static_kv_cache, "use_cache": True, "cache_position": torch.arange(1)})

It won't work because the input type is not supported. So we will see something like this:

E    torch._dynamo.exc.UserError: It looks like one of the inputs with type `<class 'transformers.cache_utils.StaticCache'>` is not supported or pytree-flattenable.
E    Exported graphs inputs can only contain the following supported types: [<class 'torch.Tensor'>, <class 'torch.SymInt'>, <class 'torch.SymFloat'>, <class 'torch.SymBool'>, <class 'torch.ScriptObject'>, <class 'NoneType'>, <class 'complex'>, <class 'torch.dtype'>, <class 'str'>, <class 'bool'>, <class 'ellipsis'>, <class 'int'>, <class 'torch.layout'>, <class 'code'>, <class 'torch.memory_format'>, <class 'bytes'>, <class 'float'>, <class 'torch.device'>].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker BTW, we don't need to address it in this PR, so it shouldn't block merging this PR. It's just to kick off another static cache related discussion for ExecuTorch since this code snippet is a good example it explain the context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see below that we can use this structure to export a model compatible with the forward pass -- the user has to implement their own generation loop to use the exported model.

In a perfect world, I'm assuming it would be interesting to export the entire generate function, which would bundle model, cache, and the generation loop. Is this assumption correct?

(at the moment, generate is not compatible with torch.compile, but we have a PR open)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante Let me explain more about the generate (inference) part, and give the full picture of how I imagine the collaboration/integration.

Disclaimer: Note that the ultimate goal of export and lower to ExecuTorch is to run inference on edge devices just like onnx, tflite, etc. via Optimum.

Why is the adapter forward() needed?

In ExecuTorch we have a c++ runtime for LLMs that could load the exported transform model (a binary format .pte) for inference. To utilize that runtime, the forward() must comply with the same signature, which looks like:

    def forward(
        token: torch.Tensor,
        input_pos: Optional[torch.Tensor],
    ) -> torch.Tensor:

So basically in the prototype PR I created this adapter forward() for Gemma-2b to make it compatible with that c++ runtime. With the adapter, after the model is exported, it can be loaded for inference by running cli command like this:

cmake-out/examples/models/llama2/llama_main --model_path=gemma.pte --tokenizer_path=gemma_tokenizer.bin --prompt="Hello world!"

Please note that the primary goal of creating the adapter forward() is to demonstrate the end-to-end workflow of exporting and lowering a Hugging Face model to ExecuTorch with minimal changes by reusing the c++ runtime. It doesn't mean we need to adapt to all Hugging Face models to it.

How to generate in a more scalable way?

Of course ,it's not scale to add such an adapter forward() for all models. For users to be able to inference/generate using the export model, ideally the experience should be similar to eager or torch compiled model. To make it happen, ExecuTorch can provide a dedicated runtime that:

  1. comply with forward() of Hugging Face transformers PreTrainedModel (With one exception. Explain it in next section)
  2. can be used in eager python directly or via Optimum (either directly implement in python like this or expose the c++ implementation via pybind)
    With that dedicated runtime, such an adapter won't be needed.

What is the requirement?

It's important to emphasize that to make the above approach work, there is a technical requirement must be complied. The requirement is: The cache must be statically configured at the time of export.

Today in HF transformers the cache is not configurable via AutoConfig (static config used to construct a transformer model). To make it work with torch.export(), I have to statically instantiate the StaticCache in the adapter forward() as a workaround. That is also the other reason of having this adapter forward() in both the prototype PR and in this unit test. As discussed with @amyeroberts yesterday, I'm proposing to add an option to make it statically configurable at model construct time.

I feel maybe it's better to make it a co-design proposal somewhere so we can iterate on it and loop in Optimum team. @amyeroberts @gante @ArthurZucker what would be the recommended place to repost it?

Copy link
Member

@gante gante Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Writing here in case I forget, in advance of potentially moving the discussion somewhere else)

Today in HF transformers the cache is not configurable via AutoConfig

@guangy10 This is actually something we want to change! At the moment, the model instance holds:

  1. config, which specifies the model architecture
  2. generation_config, which specifies generation parameterization

We were thinking of creating a cache_config field within generation_config, which would fully parameterize a cache. I'm assuming this would solve the question, correct? If so, we can (and should) fast-track it 🤗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache is actually already configurable via the AutoConfig API, because you can set parameters there that will be passed to the generation_config. As @gante mentioned having the cache_config alwayhs passed to the generation_config would solve the configurable at construct time constraint!

We can open a PR for this if it not already the case!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ArthurZucker @gante Glad to hear there is already a plan to change it.

We were thinking of creating a cache_config field within generation_config, which would fully parameterize a cache. I'm assuming this would solve the question, correct?

Not exactly. Let me repost it to a GitHub issues, and we can consolidate all discussion there. And will have new PRs to address it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guangy10 Feel free to open an issue in transformers and link it here (or visa versa). The optimum team will be able to comment and discuss there. cc @michaelbenayoun for the optimum side

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened in #32253. Let's continue the discussion there.

def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
self.static_cache = StaticCache(
config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device
)
Comment on lines +202 to +204
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the recommended way to export with static cache would be attaching the cache

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean attaching the cache as a param to forward()? Would you mind elaborating a bit more?


def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
outs = self.model(
input_ids=tokens,
attention_mask=None,
position_ids=input_pos.unsqueeze(0),
cache_position=input_pos,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits

set_seed(0)
with torch.no_grad():
from torch.export import ExportedProgram, export

model = ExportatibleModelWithStaticCache(config, m)
exported_program = export(model, args=(inputs,), kwargs={"input_pos": torch.arange(1)})
self.assertTrue(isinstance(exported_program, ExportedProgram))


@require_torch_gpu
guangy10 marked this conversation as resolved.
Show resolved Hide resolved
@slow
Expand Down