Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add onnx support for deberta and debertav2 #17617

Merged
merged 31 commits into from
Jun 21, 2022

Conversation

sam-h-bean
Copy link
Contributor

@sam-h-bean sam-h-bean commented Jun 9, 2022

What does this PR do?

Details: Add ONNX Support for DeBERTa and DeBERTaV2.

Issue: huggingface/optimum#207

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@JingyaHuang @chainyo

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 9, 2022

The documentation is not available anymore as the PR was closed or merged.

@chainyo
Copy link
Contributor

chainyo commented Jun 9, 2022

Hello @sam-h-bean, thanks for the PR; it looks nice!

Copy link
Contributor

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sam-h-bean, thanks for contributing! The OnnxConfig of DeBERTa-V2 looks good to me. Can you also add it to the test tests/onnx/test_onnx_v2.py and run

RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py -s -k "deberta_v2"

to confirm that the export goes well?

src/transformers/onnx/features.py Outdated Show resolved Hide resolved
@sam-h-bean
Copy link
Contributor Author

sam-h-bean commented Jun 9, 2022

Hi @sam-h-bean, thanks for contributing! The OnnxConfig of DeBERTa-V2 looks good to me. Can you also add it to the test tests/onnx/test_onnx_v2.py and run

RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py -s -k "deberta_v2"

to confirm that the export goes well?

@JingyaHuang I am seeing a segmentation fault locally but tests pass in CI/CD?

10461 ± RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py -s -k "deberta_v2"
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
=========================================================================================================== test session starts ============================================================================================================
platform darwin -- Python 3.9.10, pytest-7.1.2, pluggy-1.0.0
rootdir: /Users/marklar/workspace/transformers, configfile: setup.cfg
plugins: xdist-2.5.0, forked-1.4.0, timeout-2.1.0, hypothesis-6.47.0, dash-2.5.0
collected 341 items / 329 deselected / 12 selected                                                                                                                                                                                         

Downloading: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 633/633 [00:00<00:00, 130kB/s]
Downloading: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52.0/52.0 [00:00<00:00, 53.1kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.33M/2.33M [00:00<00:00, 9.58MB/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
FSpecial tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
FSpecial tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Fatal Python error: Segmentation fault

Thread 0x000070000d4e3000 (most recent call first):

UPDATE:

I realized if you follow the documentation to the letter it only has you install the dev deps which does not include onnxruntime which feels a bit weird. I am now getting a non-segfault error

======================================================================================================================================= short test summary info ========================================================================================================================================
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_041_deberta_v2_default - AttributeError: 'torch._C.Value' object has no attribute 'dtype'
FAILED tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_041_deberta_v2_default - AttributeError: 'torch._C.Value' object has no attribute 'dtype'
====================================================================================================================== 2 failed, 340 deselected, 35 warnings in 89.38s (0:01:29) =======================================================================================================================

SECOND UPDATE:

Still getting some seg faults and some of this dtype error. I was able to run the distilbert tests OK. I see now that CI/CD does not run the slow tests so this just seems to be a serialization issue. I was able to attach a debugger and get into the core torch.export function but the failure is in there so I'm not sure if it makes sense to go deeper in the debugger or if this is just something obvious I'm missing.

@sam-h-bean
Copy link
Contributor Author

sam-h-bean commented Jun 10, 2022

I am seeing that deberta v2 is leveraging some symbolic library using onnx already. Line 135 is where the crash is happening now

UPDATE:

I tried deleting the symbolic API for the XSoftmax class and rerunning the test, however I think the code is coming from the hub which unfortunately still has that method in it and is crashing the onnx export with

E   AssertionError: deberta-v2, default -> Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice is a deprecated experimental op. Please use statically allocated variables or export to a higher opset version.

This is coming directly from the opset9 library. I also tried setting the default opset to 9 in my onnx config object and got the same error. Does it make sense to delete that method from the code in the model hub since the method is unused?

@JingyaHuang
Copy link
Contributor

Hi @sam-h-bean, the symbolic function was added intentionally(check #14013) to support ONNX export as XSoftmax is not a natively supported onnx op. Can you try to export Deberta-V2 with an upper opset, eg. with opset=15(onnx>=1.10.0)

onnx_inputs, onnx_outputs = export(preprocessor, model, onnx_config, 15, Path(output.name), device=device)

And maybe find the minimal opset if possible(currently we set default it to 11).

Keep me posted if it works, then we can make the necessary change on the export.

@JingyaHuang
Copy link
Contributor

@sam-h-bean To confirm, can you try this snippet:

from collections import OrderedDict
from typing import Mapping
from pathlib import Path
from transformers.onnx import export
from transformers.onnx import OnnxConfig
from transformers import AutoTokenizer, AutoModel, AutoConfig

class DebertaV2OnnxConfig(OnnxConfig):
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
        else:
            dynamic_axis = {0: "batch", 1: "sequence"}
        return OrderedDict(
            [
                ("input_ids", dynamic_axis),
                ("attention_mask", dynamic_axis),
            ]
        )
    @property
    def default_onnx_opset(self) -> int:
        return 15

config = AutoConfig.from_pretrained("microsoft/deberta-v3-large")
base_model = AutoModel.from_pretrained("microsoft/deberta-v3-large")
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
onnx_config = DebertaV2OnnxConfig(config)
onnx_path = Path("deberta.onnx")
onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
print(onnx_config.default_onnx_opset)

@sam-h-bean
Copy link
Contributor Author

sam-h-bean commented Jun 10, 2022

@JingyaHuang I tried the export in a try/except as follows

try:
     onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
 except (RuntimeError, ValueError) as e:
     print(onnx_config.default_onnx_opset)

And it prints 15 with no onnx model exported 😢

The exception is

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/marklar/workspace/transformers/src/transformers/onnx/convert.py", line 335, in export
    return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device)
  File "/Users/marklar/workspace/transformers/src/transformers/onnx/convert.py", line 190, in export_pytorch
    onnx_export(
  File "/Users/marklar/workspace/transformers/venv/lib/python3.9/site-packages/torch/onnx/__init__.py", line 305, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/Users/marklar/workspace/transformers/venv/lib/python3.9/site-packages/torch/onnx/utils.py", line 118, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/Users/marklar/workspace/transformers/venv/lib/python3.9/site-packages/torch/onnx/utils.py", line 738, in _export
    proto, export_map, val_use_external_data_format = graph._export_onnx(
RuntimeError: ONNX export failed: Couldn't export Python operator XSoftmax

UPDATE: I had made some modification of the XSoftmax, once I reverted them I got this exception running your snippet

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/marklar/workspace/transformers/src/transformers/onnx/convert.py", line 335, in export
    return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device)
  File "/Users/marklar/workspace/transformers/src/transformers/onnx/convert.py", line 190, in export_pytorch
    onnx_export(
  File "/Users/marklar/workspace/transformers/venv/lib/python3.9/site-packages/torch/onnx/__init__.py", line 305, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/Users/marklar/workspace/transformers/venv/lib/python3.9/site-packages/torch/onnx/utils.py", line 118, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/Users/marklar/workspace/transformers/venv/lib/python3.9/site-packages/torch/onnx/utils.py", line 719, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/Users/marklar/workspace/transformers/venv/lib/python3.9/site-packages/torch/onnx/utils.py", line 503, in _model_to_graph
    graph = _optimize_graph(graph, operator_export_type,
  File "/Users/marklar/workspace/transformers/venv/lib/python3.9/site-packages/torch/onnx/utils.py", line 232, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/Users/marklar/workspace/transformers/venv/lib/python3.9/site-packages/torch/onnx/__init__.py", line 359, in _run_symbolic_method
    return utils._run_symbolic_method(*args, **kwargs)
  File "/Users/marklar/workspace/transformers/venv/lib/python3.9/site-packages/torch/onnx/utils.py", line 846, in _run_symbolic_method
    return symbolic_fn(g, *args)
  File "/Users/marklar/workspace/transformers/src/transformers/models/deberta_v2/modeling_deberta_v2.py", line 135, in symbolic
    output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
AttributeError: 'torch._C.Value' object has no attribute 'dtype'

@michaelbenayoun
Copy link
Member

Hi @sam-h-bean,

I merged a PR a few days ago changing deberta symbolic function #17539 (to follow what is planned to be done here), that seems to break the ONNX export.

I think that the culprit here is the torch.finfo(some_tensor.dtype) part. Let me try to find a workaround.

Pinging @ydshieh to make him aware of the issue.

@sam-h-bean
Copy link
Contributor Author

Hi @sam-h-bean,

I merged a PR a few days ago changing deberta symbolic function #17539 (to follow what is planned to be done here), that seems to break the ONNX export.

I think that the culprit here is the torch.finfo(some_tensor.dtype) part. Let me try to find a workaround.

Pinging @ydshieh to make him aware of the issue.

@michaelbenayoun Yeah it definitely does break. I have changed that line to

output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(torch.float32).min)))

and it seems to work. I am now facing

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:token_type_ids

@sam-h-bean
Copy link
Contributor Author

@michaelbenayoun @ydshieh this PR now contains a fix for the symbolic function

Comment on lines 267 to 268
if model_name == "microsoft/deberta-v2-xlarge":
config.type_vocab_size = 2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JingyaHuang A weird thing about the DeBERTa model is that it is the only model that by default takes token_type_ids in its forward method (the tokenizer returns this in the dictionary by default) which the model config by default sets the vocab size to 0 (this basically ignores that value in the forward method).

This has a few downstream side-effects, the ONNX graph perceives this as meaning the model by default only takes 2 inputs which causes the test to fail because we attempt to pass all 3 items in the ONNX input. What I have done is in the ONNX config I override the value to be 2 so that the graph is constructed taking 3 inputs (the default for the tokenizer but not the model) and had to hack in this statement to initialize the model config to expect token_type_ids as well.

This feels pretty bad. I think that the tokenizer should by default only return 2 items unless the model config has type_vocab_size > 0. However, this would be a larger change to the DeBERTa classes and I was unsure if that was desired.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sam-h-bean whether the exported ONNX model needs 2 or 3 inputs depends on whether you've included token_type_ids in the custom OnnxConfig, as the dummy inputs used for tracing will look into the custom OnnxConfig to decide what generated inputs will be considered for export.

As the tokenizer of DeBERTa will return token_type_ids by default, we suggest including token_type_ids into its OnnxConfig even if it is optional(it will be ignored if type_vocab_size is set to 0)

Copy link
Contributor Author

@sam-h-bean sam-h-bean Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can do that but here we still need to set the AutoConfig value some way because the token type embedding is initialized on creation here. So if I just have it overwritten in the onnx config it fails for there being no token embedding layer since by default there is none when we instantiate the model here. So how should I change the config here dynamically?

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 10, 2022

I have very little knowledge on this onnx thing. But what I can add here is: we need to understand what self actually is in

def symbolic(g, self, mask, dim):

My quick search (in a previous PR) shows it should be some input to the model. And if it is a torch.Tensor, it should have dtype.

  • Maybe it is not input to model?
  • Maybe it is, but for some reason, it is not torch.Tensor (for example some symbolic tensor etc - I have no idea)

I suggest to set some break points, launch a code that will call this method, and investigate what self is.
Please don't merge this PR with float32.dtype without any further investigate first, thank you 🙏

@sam-h-bean
Copy link
Contributor Author

sam-h-bean commented Jun 10, 2022

I have very little knowledge on this onnx thing. But what I can add here is: we need to understand what self actually is in

def symbolic(g, self, mask, dim):

My quick search (in a previous PR) shows it should be some input to the model. And if it is a torch.Tensor, it should have dtype.

* Maybe it is not input to model?

* Maybe it is, but for some reason, it is not `torch.Tensor` (for example some symbolic tensor etc - I have no idea)

I suggest to set some break points, launch a code that will call this method, and investigate what self is. Please don't merge this PR with float32.dtype without any further investigate first, thank you 🙏

This code previously was

output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(float("-inf"))))

@ydshieh Then it was changed to this self reference which broke the code. This is now using the built-in torch -inf so the type will be the same. I can also just change it back to torch.tensor(float("-inf")) if you would prefer but I imagine this won't make a great deal of difference in practice

See screenshot of the inspection:
Screen Shot 2022-06-10 at 5 42 27 PM

It is some object that is clearly not a tensor.

Here is the value of the value as it is now:
Screen Shot 2022-06-10 at 5 43 32 PM

And the value it was before the broken code was merged:
Screen Shot 2022-06-10 at 5 44 08 PM

@sam-h-bean sam-h-bean changed the title add onnx support for debertav2 add onnx support for deberta and debertav2 Jun 17, 2022
@sam-h-bean
Copy link
Contributor Author

sam-h-bean commented Jun 17, 2022

Here is the full test suite passing. @lewtun @sgugger if I could get a review on this it would unblock getting the support into optimum which I sorely need for my production DeBERTa microservice at you.com

10571 ± RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py -k "deberta" -v                                                                                                                          ⏎
===================================================================================== test session starts ======================================================================================
platform darwin -- Python 3.9.10, pytest-7.1.2, pluggy-1.0.0 -- /Users/marklar/workspace/transformers/venv/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/Users/marklar/workspace/transformers/.hypothesis/examples')
rootdir: /Users/marklar/workspace/transformers, configfile: setup.cfg
plugins: xdist-2.5.0, forked-1.4.0, timeout-2.1.0, hypothesis-6.47.0, dash-2.5.0
collected 367 items / 345 deselected / 22 selected                                                                                                                                             

tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_043_deberta_v2_default PASSED                                                                                      [  4%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_044_deberta_v2_masked_lm PASSED                                                                                    [  9%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_045_deberta_v2_multiple_choice PASSED                                                                              [ 13%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_046_deberta_v2_question_answering PASSED                                                                           [ 18%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_047_deberta_v2_sequence_classification PASSED                                                                      [ 22%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_048_deberta_v2_token_classification PASSED                                                                         [ 27%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_049_deberta_default PASSED                                                                                         [ 31%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_050_deberta_masked_lm PASSED                                                                                       [ 36%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_051_deberta_question_answering PASSED                                                                              [ 40%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_052_deberta_sequence_classification PASSED                                                                         [ 45%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_053_deberta_token_classification PASSED                                                                            [ 50%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_043_deberta_v2_default PASSED                                                                              [ 54%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_044_deberta_v2_masked_lm PASSED                                                                            [ 59%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_045_deberta_v2_multiple_choice PASSED                                                                      [ 63%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_046_deberta_v2_question_answering PASSED                                                                   [ 68%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_047_deberta_v2_sequence_classification PASSED                                                              [ 72%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_048_deberta_v2_token_classification PASSED                                                                 [ 77%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_049_deberta_default PASSED                                                                                 [ 81%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_050_deberta_masked_lm PASSED                                                                               [ 86%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_051_deberta_question_answering PASSED                                                                      [ 90%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_052_deberta_sequence_classification PASSED                                                                 [ 95%]
tests/onnx/test_onnx_v2.py::OnnxExportTestCaseV2::test_pytorch_export_on_cuda_053_deberta_token_classification PASSED                                                                    [100%]

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, thanks a lot!!

Copy link
Contributor

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for iterating it!

It probably makes sense to filter out the invalid inputs and I opened a PR on Optimum. But for the moment please remove token_type_ids from the inputs of InferenceSession if it doesn't exist in the exported ONNX.

@sam-h-bean
Copy link
Contributor Author

sam-h-bean commented Jun 21, 2022

Looks great, thanks for iterating it!

It probably makes sense to filter out the invalid inputs and I opened a PR on Optimum. But for the moment please remove token_type_ids from the inputs of InferenceSession if it doesn't exist in the exported ONNX.

@JingyaHuang I'm not sure what you mean by this. Do you want something beyond removing the inputs in generate dummy inputs? Or is this comment strictly about my personal use of this functionality?

@JingyaHuang
Copy link
Contributor

Looks great, thanks for iterating it!
It probably makes sense to filter out the invalid inputs and I opened a PR on Optimum. But for the moment please remove token_type_ids from the inputs of InferenceSession if it doesn't exist in the exported ONNX.

@JingyaHuang I'm not sure what you mean by this. Do you want something beyond removing the inputs in generate dummy inputs? Or is this comment strictly about my personal use of this functionality?

Hi @sam-h-bean, nothing that you should worry about. Here I mention it for another API(ORTModelForXXX) in Optimum. I will merge this PR, and then by building transformers from source, you shall be able to leverage Quantization and Graph Optimization features in Optimum. Thank you again for the contribution.

@JingyaHuang JingyaHuang merged commit eb16be4 into huggingface:main Jun 21, 2022
@chainyo
Copy link
Contributor

chainyo commented Jun 21, 2022

Congratz @sam-h-bean, excellent work. Thanks for adding these configs!!

@sam-h-bean
Copy link
Contributor Author

Can I get a t-shirt 😏 ?

younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 25, 2022
* add onnx support for debertav2

* debertav2 -> deberta-v2 in onnx features file

* remove causal lm

* add deberta-v2-xlarge to onnx tests

* use self.type().dtype() in xsoftmax

Co-authored-by: Jingya HUANG <[email protected]>

* remove hack for deberta

* remove unused imports

* Update src/transformers/models/deberta_v2/configuration_deberta_v2.py

Co-authored-by: Jingya HUANG <[email protected]>

* use generate dummy inputs

* linter

* add imports

* add support for deberta v1 as well

* deberta does not support multiple choice

* Update src/transformers/models/deberta/configuration_deberta.py

Co-authored-by: Jingya HUANG <[email protected]>

* Update src/transformers/models/deberta_v2/configuration_deberta_v2.py

Co-authored-by: Jingya HUANG <[email protected]>

* one line ordered dict

* fire build

Co-authored-by: Jingya HUANG <[email protected]>
younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 29, 2022
* add onnx support for debertav2

* debertav2 -> deberta-v2 in onnx features file

* remove causal lm

* add deberta-v2-xlarge to onnx tests

* use self.type().dtype() in xsoftmax

Co-authored-by: Jingya HUANG <[email protected]>

* remove hack for deberta

* remove unused imports

* Update src/transformers/models/deberta_v2/configuration_deberta_v2.py

Co-authored-by: Jingya HUANG <[email protected]>

* use generate dummy inputs

* linter

* add imports

* add support for deberta v1 as well

* deberta does not support multiple choice

* Update src/transformers/models/deberta/configuration_deberta.py

Co-authored-by: Jingya HUANG <[email protected]>

* Update src/transformers/models/deberta_v2/configuration_deberta_v2.py

Co-authored-by: Jingya HUANG <[email protected]>

* one line ordered dict

* fire build

Co-authored-by: Jingya HUANG <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants