Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make static cache compatible with torch.export #32168
Make static cache compatible with torch.export #32168
Changes from all commits
6b732ba
1e44f8d
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will need such a wrapper class because:
For example, if I changed the test to use the gemma-2b model directly:
It won't work because the input type is not supported. So we will see something like this:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker BTW, we don't need to address it in this PR, so it shouldn't block merging this PR. It's just to kick off another static cache related discussion for ExecuTorch since this code snippet is a good example it explain the context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see below that we can use this structure to export a model compatible with the forward pass -- the user has to implement their own generation loop to use the exported model.
In a perfect world, I'm assuming it would be interesting to export the entire
generate
function, which would bundle model, cache, and the generation loop. Is this assumption correct?(at the moment,
generate
is not compatible withtorch.compile
, but we have a PR open)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gante Let me explain more about the
generate
(inference) part, and give the full picture of how I imagine the collaboration/integration.Disclaimer: Note that the ultimate goal of export and lower to ExecuTorch is to run inference on edge devices just like onnx, tflite, etc. via
Optimum
.Why is the adapter forward() needed?
In ExecuTorch we have a c++ runtime for LLMs that could load the exported transform model (a binary format
.pte
) for inference. To utilize that runtime, theforward()
must comply with the same signature, which looks like:So basically in the prototype PR I created this adapter
forward()
for Gemma-2b to make it compatible with that c++ runtime. With the adapter, after the model is exported, it can be loaded for inference by running cli command like this:Please note that the primary goal of creating the adapter forward() is to demonstrate the end-to-end workflow of exporting and lowering a Hugging Face model to ExecuTorch with minimal changes by reusing the c++ runtime. It doesn't mean we need to adapt to all Hugging Face models to it.
How to generate in a more scalable way?
Of course ,it's not scale to add such an adapter
forward()
for all models. For users to be able to inference/generate using the export model, ideally the experience should be similar to eager or torch compiled model. To make it happen, ExecuTorch can provide a dedicated runtime that:forward()
of Hugging Face transformersPreTrainedModel
(With one exception. Explain it in next section)Optimum
(either directly implement in python like this or expose the c++ implementation via pybind)With that dedicated runtime, such an adapter won't be needed.
What is the requirement?
It's important to emphasize that to make the above approach work, there is a technical requirement must be complied. The requirement is: The cache must be statically configured at the time of export.
Today in HF transformers the cache is not configurable via
AutoConfig
(static config used to construct a transformer model). To make it work withtorch.export()
, I have to statically instantiate theStaticCache
in the adapterforward()
as a workaround. That is also the other reason of having this adapter forward() in both the prototype PR and in this unit test. As discussed with @amyeroberts yesterday, I'm proposing to add an option to make it statically configurable at model construct time.I feel maybe it's better to make it a co-design proposal somewhere so we can iterate on it and loop in Optimum team. @amyeroberts @gante @ArthurZucker what would be the recommended place to repost it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Writing here in case I forget, in advance of potentially moving the discussion somewhere else)
@guangy10 This is actually something we want to change! At the moment, the
model
instance holds:config
, which specifies the model architecturegeneration_config
, which specifies generation parameterizationWe were thinking of creating a
cache_config
field withingeneration_config
, which would fully parameterize a cache. I'm assuming this would solve the question, correct? If so, we can (and should) fast-track it 🤗There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cache is actually already configurable via the AutoConfig API, because you can set parameters there that will be passed to the
generation_config
. As @gante mentioned having thecache_config
alwayhs passed to thegeneration_config
would solve the configurable at construct time constraint!We can open a PR for this if it not already the case!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ArthurZucker @gante Glad to hear there is already a plan to change it.
Not exactly. Let me repost it to a GitHub issues, and we can consolidate all discussion there. And will have new PRs to address it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@guangy10 Feel free to open an issue in transformers and link it here (or visa versa). The optimum team will be able to comment and discuss there. cc @michaelbenayoun for the optimum side
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened in #32253. Let's continue the discussion there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the recommended way to export with static cache would be attaching the cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean attaching the cache as a param to forward()? Would you mind elaborating a bit more?