Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export to ExecuTorch #32253

Open
13 of 26 tasks
guangy10 opened this issue Jul 26, 2024 · 11 comments
Open
13 of 26 tasks

Export to ExecuTorch #32253

guangy10 opened this issue Jul 26, 2024 · 11 comments
Labels
ExecuTorch Feature request Request for a new feature

Comments

@guangy10
Copy link
Contributor

guangy10 commented Jul 26, 2024

Feature request

Unlock a new workflow for on-device use-cases via torch.export and ExecuTorch.

So ideally the users can have an e2e experience by loading a pretrained transformer model from HuggingFace, export and lower it to ExecuTorch and get reasonable performance out-of-the-box.

For example:

  1. Load a model with StaticCache:
model = AutoModelForCausalLM.from_pretrained(
    hf_model_repo,
    config=config,
    attn_implementation="sdpa",
    cache_config={
        "use_cache": True, 
        "cache_implementation": "static", 
        "max_cache_length": 128,
    },  # Mandatory field to set ONLY for "Export to ExecuTorch" workflow, optional in other use-cases
)
  1. Then export the model with StaticCache.
exported_program = convert_and_export_with_cache(
    model, 
    args=(model_inputs,), 
    kwargs={"position_ids": <val>, "inputs_embeds": <val>, "cache_position": <val>}

and then further lower the exported program to ExecuTorch with delegates for performance:

executorch_m = lower_to_executorch(
    model, 
    recipes="xnnpack_fp32",  # Delegate to XNNPACK backend
)

# The lowered artifact can be saved into a `.pte` binary format for integration and distribution.

With that you may get a model for on-device with reasonable performance to start with.

From there and still within ExecuTorch stack, you can easily tailor the experience for your use-cases, of course, with better performance! Note that ExecuTorch supports delegatation to XNNPACK backend, Apple Core ML and MPS, Qualcomm QNN, ARM Ethos-U, Vulkan GPU and more. You can learn more by reading our tutorial.

  1. Use the exported/lowered artifact for inference:

# The lowered artifact can run on a local device in the ExecuTorch runtime in c++ or via pybind, providing the same experience as how users run inference with the eager model on server.

generate(model=executorch_m, prompt="Hello world")  # Will generate up to the maximal sequence length/cache length 

The example workflow above shows direct integration between ExecuTorch and HF transformers models. Eventually this workflow could be accessible via optimum exporters-et, Transformers.js or in ExecuTorch and torchchat.

Motivation

Unlock a whole new on-device experience of using HuggingFace models w/o leaving the PyTorch ecosystem (ExecuTorch is native PyTorch!)

Issues Tracker

Cache

E2E workflow

Optimization

Models

Your contribution

  1. Co-design the "Export to ExecuTorch" workflow.
  2. Co-design the generate for exported model and the integration in Optimum

Here is how ExecuTorch implements the generate() for llama2/3 in eager python and c++.

cc: @amyeroberts @gante @ArthurZucker @michaelbenayoun

@guangy10 guangy10 added the Feature request Request for a new feature label Jul 26, 2024
@guangy10 guangy10 changed the title Make Cache statically configurable at model construction time Export to ExecuTorch Jul 26, 2024
@gante
Copy link
Member

gante commented Jul 27, 2024

Thank you for detailing Executorch's goals 🤗

Two follow-up questions:

  1. In the snippet you shared at the top, you explicitly load the model config before loading the model with .from_pretrained. However, .from_pretrained handles loading the config and modification of the base config -- for instance, model = AutoModelForCausalLM.from_pretrained("distilgpt2", use_cache=False) will change use_cache in model.config from the default True to False. Am I correct in saying that we don't need to manually load the config then?
  2. We have been separating the parameterization of everything specifically related to auto-regressive generation to generation_config (i.e. it is not only for generate, just like config is not only for our model classes). As such, we want to place the cache config in generation_config, as KV caching only exists in auto-regressive generation. generation_config is also loaded in .from_pretrained, just as config. However, updating parameters through .from_pretrained is not yet supported (e.g. .from_pretrained(model_repo, cache_config={...}) would't work). If the answer to 1. is yes: would this API [passing the cache config to .from_pretrained] be useful to you?

@guangy10
Copy link
Contributor Author

@gante Thanks for the great follow-up questions:

For #1, yes if we can pass/override the config while loading the pretrained model, e.g. model = AutoModelForCausalLM.from_pretrained("distilgpt2", use_cache=True, cache_implementation="static", max_seq_lenght=128, attn_implementation="sdpa", ...), or even better to consolidate all cache related configs into one field cache_config (something you and Arthur are suggesting?)

For #2, yes I understand there are use-cases where make cache config closer to auto-regressive generation is cleaner. KV cache config can still be passed through generation_config in my proposal, for any use case, and no conflict. It's required to passed to .from_pretrained(cache_config={"use_cache": True, "cache_implementation": "static", "max_cache_length": 128, ...})like this only when "Export to ExecuTorch". It's because ExecuTorch is handling the memory planning ahead-of-time (during export and lowering to ExecuTorch) so the Runtime doesn't deal with the dynamic memory allocation, that's where the fast inference comes from. And yes, the API [passing the cache config to .from_pretrained] will be useful to ExecuTorch use-case!

@zucchini-nlp
Copy link
Member

Hey, saw your comments from another PR and wanted to share that I was thinking to make cache-config savable/loadable same way as generation config. It will hold all the needed args for all cache types, and loading a model from_pretrained should also load cache config and assign self.cache_config=cache_config. @gante WDYT about it?

@gante
Copy link
Member

gante commented Jul 30, 2024

I'm quite biased towards keeping the cache config inside generation_config:

  1. A separate file to hold a handful of fields seems overkill
  2. Caching exists because of generation, it is not a stand-alone feature (contrarily to e.g. quantization, where it is not part of any existing configuration file)

But happy to reconsider if there are strong arguments to keep them separate :)

@zucchini-nlp
Copy link
Member

Wait, i just realized that we will save the cache config even if it's inside generation config. So it will be loadable from hub. Oke, that makes sense, thanks!

@guangy10
Copy link
Contributor Author

I'm quite biased towards keeping the cache config inside generation_config:

  1. A separate file to hold a handful of fields seems overkill
  2. Caching exists because of generation, it is not a stand-alone feature (contrarily to e.g. quantization, where it is not part of any existing configuration file)

But happy to reconsider if there are strong arguments to keep them separate :)

@gante, I see there are two orthogonal things from your and @zucchini-nlp 's comments. Let's get more clarify on it:

  1. Enable the ability to pass/override the cache config via PreTrainedModel.from_pretrained()

It would take the cache config to construct the model. This is a new feature needed in order to support torch.export and ExecuTorch. I think we're on the same page on this?

  1. Decide on where to load/read the cache config from

So this is about whether PreTrainedModel.from_pretrained() will load cache config from a separate config, generation config, or other config. I won't have strong option where it should go. To me, there may be a use case to quantize the cache at some point, and for torch.export and ExecuTorch, the quantization process is independent from generation.

@gante
Copy link
Member

gante commented Aug 1, 2024

It would take the cache config to construct the model. This is a new feature needed in order to support torch.export and ExecuTorch. I think we're on the same page on this?

Yes :)

To me, there may be a use case to quantize the cache at some point

We have indeed support for quantized caches! Their quantization configuration is set at initialization time, so it will belong in the cache config as well :) (we can have, e.g. a FP16 model and a quantized cache, to support very long generation)

@bhack
Copy link
Contributor

bhack commented Nov 19, 2024

This could eventually enable also AOTI compilation:
https://pytorch.org/tutorials/recipes/torch_export_aoti_python.html

@guangy10
Copy link
Contributor Author

This could eventually enable also AOTI compilation: https://pytorch.org/tutorials/recipes/torch_export_aoti_python.html

Yeah, we are having a proof-of-concept where making AOTI a backend of ExecuTorch enabling users can utilize both desktop GPU, HTP, and CPU altogether on a desktop where all these accelerators are available.

@cptspacemanspiff
Copy link

@guangy10
I have been trying to get whisper implemented but wanting to check my understanding. Would you expect there to be 2 artifacts? The encoder that runs independently, and a decoder that has an additional input containing the input embeddings?

Or am I missing something with how encoder/decoder architectures should be best implemented in executorch?

@guangy10
Copy link
Contributor Author

guangy10 commented Nov 22, 2024

@guangy10 I have been trying to get whisper implemented but wanting to check my understanding. Would you expect there to be 2 artifacts? The encoder that runs independently, and a decoder that has an additional input containing the input embeddings?

Or am I missing something with how encoder/decoder architectures should be best implemented in executorch?

Yeah, you can export the model to multiple artifacts. Here is an example of how another encoder-decoder model (t5) is supported: https://github.com/huggingface/transformers/blob/d9e6f307e71b5108a7882ec00ffcc0d0eb316cb7/tests/models/t5/test_modeling_t5.py#L1650-L1706https://github.com/huggingface/transformers/blob/d9e6f307e71b5108a7882ec00ffcc0d0eb316cb7/tests/models/t5/test_modeling_t5.py#L1650-L1706. The example is using torch.compile, the idea would be same for using torch.export to ExecuTorch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ExecuTorch Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

6 participants