Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot use StaticCache with Phi3 #32338

Closed
2 of 4 tasks
helunwencser opened this issue Jul 30, 2024 · 0 comments
Closed
2 of 4 tasks

Cannot use StaticCache with Phi3 #32338

helunwencser opened this issue Jul 30, 2024 · 0 comments
Labels

Comments

@helunwencser
Copy link
Contributor

System Info

Name: torch
Version: 2.5.0.dev20240716

Name: transformers
Version: 4.44.0.dev0

Who can help?

@ArthurZucker
@zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm trying to run phi3 model on edge device via ExecuTorch, where I can only use StaticCache. However, the current phi3 model fails to work with StaticCache.

To reproduce this issue, please run the following script:

import torch
from transformers import Phi3ForCausalLM, StaticCache, AutoTokenizer

end_of_text_token = 32000

class Phi3Mini(torch.nn.Module):

    def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
        super().__init__()
        self.model = model
        self.cache = StaticCache(
            config=model.config,
            max_batch_size=max_batch_size,
            max_cache_len=max_seq_len,
            device=self.model.device,
            dtype=self.model.dtype,
        )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        cache_position: torch.LongTensor = None,
    ) -> torch.FloatTensor:
        return self.model.forward(
            input_ids=input_ids,
            use_cache=True,
            return_dict=True,
            past_key_values=self.cache,
            cache_position=cache_position,
        ).logits

def _generate_token_with_kv_cache(seq_len, model, prompt_tokens):
    print("Generating tokens:", end="", flush=True)

    model = Phi3Mini(model, 1, seq_len + prompt_tokens.shape[-1])

    for input_pos in range(prompt_tokens.shape[-1]):
        result = model.forward(
            input_ids=prompt_tokens[:, input_pos : input_pos + 1],
            cache_position=torch.arange(0, input_pos, device=model.model.device),
        )

    current_token = torch.argmax(result[:, -1, :], dim=-1).item()
    print(f" {current_token}", end="", flush=True)
    generated_tokens = [current_token]

    while current_token != end_of_text_token and len(generated_tokens) < seq_len:
        result = model.forward(
            input_ids=torch.tensor([[current_token]], dtype=torch.long),
            cache_position=torch.arange(
                0,
                prompt_tokens.shape[-1] + len(generated_tokens),
                device=model.model.device,
            ),
        )
        current_token = torch.argmax(result[:, -1, :], dim=-1).item()
        print(f" {current_token}", end="", flush=True)
        generated_tokens.append(current_token)

    print("", flush=True)

    return generated_tokens


def main(
        prompt,
        seq_len,
):
    seed = 42
    torch.manual_seed(seed)
    model_name = "microsoft/Phi-3-mini-4k-instruct"
    model = Phi3ForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    tokens = tokenizer.encode(prompt, return_tensors="pt")

    generated_tokens = _generate_token_with_kv_cache(seq_len, model, tokens)

    print(
        "Generated response: \n {}".format(
            tokenizer.decode(
                generated_tokens,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
        ),
        flush=True,
    )


if __name__ == "__main__":
    main(
        prompt="Tell me a story",
        seq_len=128
    )

It fails with the following error:

/opt/anaconda3/envs/executorch/bin/python /Users/lunwenh/executorch/examples/models/phi-3-mini/test.py 
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.03it/s]
Generating tokens:You are not running the flash-attention implementation, expect numerical differences.
Traceback (most recent call last):
  File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 92, in <module>
    main(
  File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 77, in main
    generated_tokens = _generate_token_with_kv_cache(seq_len, model, tokens)
  File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 38, in _generate_token_with_kv_cache
    result = model.forward(
  File "/Users/lunwenh/executorch/examples/models/phi-3-mini/test.py", line 24, in forward
    return self.model.forward(
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 1207, in forward
    outputs = self.model(
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 1002, in forward
    layer_outputs = decoder_layer(
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 739, in forward
    attn_outputs, self_attn_weights, present_key_value = self.self_attn(
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1727, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/anaconda3/envs/executorch/lib/python3.10/site-packages/transformers/models/phi3/modeling_phi3.py", line 405, in forward
    raise ValueError(
ValueError: Attention weights should be of size (1, 32, 1, 1), but is torch.Size([1, 32, 1, 132])

Process finished with exit code 1

This happens because the current StaticCache implementation does not slice the k_out, v_out upon update and it returns the whole cache up to max_cache_len.

In the long term, #31421 and #30862 should resolve this problem by supporting StaticCache and dynamic length.

For now, removing this size check should make phi3 work with StaticCache.

Expected behavior

After removing the size check, the above mentioned script works well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant