-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPT-Neo ONNX export #12911
GPT-Neo ONNX export #12911
Conversation
@sgugger @LysandreJik What do you think would be the best way to approach this exporting features for downstream task? I think we have the two possible ways:
|
I think using a |
…jects (draft version with lots of printing and comments, comitted to have them available if need be)
…entionMixin._get_block_length_and_num_blocks in a graph friendly fashion
@michaelbenayoun is the PR ready for review? 🥰 |
I also implemented a "factory" called From what @sgugger said, I went with the "task argument" approach. Basically, a feature is the combination of a task and the potential use of past keys and values, for instance:
Any feature containing "-with-past" will be mapped by the factory to an OnnxConfig instantiated using the @mfuntowicz any comments on the changes I have made? |
@@ -1121,7 +1121,7 @@ def forward( | |||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | |||
) | |||
|
|||
pooled_logits = logits[range(batch_size), sequence_lengths] | |||
pooled_logits = logits[torch.arange(batch_size), sequence_lengths] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can use torch.take_along_dim(logits, sequence_lengths, dim=1)
. You might need to match shape of logits and sequence_lengths.
=> It will remove the need to gather from the shape object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I applied the following changes:
if isinstance(sequence_lengths, torch.Tensor):
pooled_logits = torch.take_along_dim(
logits, indices=sequence_lengths.unsqueeze(1).unsqueeze(1), dim=1
).squeeze()
else:
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
The forward pass works and give the same output as original version, but ONNX conversion fails with:
RuntimeError: Exporting the operator take_along_dim to ONNX opset version 11 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks a lot for the PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! It requires a bit more reading to understand what to add when manually adding a new configuration. It would be nice to add the ONNX part to the existing templates so that when adding new models, users would automatically make them compatible with ONNX (not in this PR though).
LGTM!
GPT-Neo ONNX export and task / feature refactoring Authored-by: Michael Benayoun <[email protected]>
What does this PR do?
This PR enables the export of GPT-Neo to ONNX by extending the new module transformers.onnx.
It also provides a possible way of implementing the export for specific tasks: the task can be specified when instantiating an OnnxConfig. It is a nice approach because it makes factoring most of the code for the inputs / outputs very easy, but it is less aligned with transformers DNA than having subclasses (such as OnnxConfigForSequenceClassification, etc) taking care of that.
The issue with having many subclasses is that it would have to be done everytime one wants to add the support for a model.
What do you think?