-
Notifications
You must be signed in to change notification settings - Fork 435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.compile support for pytorch 2.4 #1690
base: main
Are you sure you want to change the base?
Conversation
@@ -266,7 +266,8 @@ def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = | |||
).int() | |||
|
|||
pos_logits = [] | |||
for i in range(max_length): | |||
i = 0 | |||
while i < max_length: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember there was a issue with while
loops by exporting to onnx so we have to be careful here (needs to be checked)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it because it was some unecessary complication related to breaks in torch.compile. Changing to a while loop and changing the logic a bit helped. Hopefully it works for the onnx also
…on for compatibility with future backends
…rch_backend_available
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Fabioomega 👋
Tests looks already good to me 👍
Docs section missing here: https://github.com/mindee/doctr/blob/main/docs/source/using_doctr/using_model_export.rst
As mentioned a table would be great :)
For the classification
models it would be enough to add both orientation
models to the table (i don't think we should blow up the table by adding all the backbone models)
Todo's:
- comments
- unittests
- docs
Left some comments to revert unrequired parts :)
As mentioned for follow up PR's we can focus on fixes for the models which does not work yet out of the box 👍
@@ -76,6 +75,18 @@ | |||
" is installed and that either USE_TF or USE_TORCH is enabled." | |||
) | |||
|
|||
if _torch_available: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Fabioomega We can remove this.
2 options:
We pin the lower boundary to >= 2.0.0 here
Line 61 in 9045dcf
"torch>=1.12.0,<3.0.0", |
Line 105 in 9045dcf
"torch>=1.12.0,<3.0.0", |
and
torchvision>=0.15.0
or we mention in the docs that this requires >= 2.0.0 for compile and >=2.4.0 for compile + fullgraph
@odulcy-mindee wdyt ?
We are already at 2.4.0 so i would prefer the >=2.0.0 pin (in this case only to mention >=2.4.0 for fullgraph (triton) support)
@@ -104,3 +115,11 @@ def is_torch_available(): | |||
def is_tf_available(): | |||
"""Whether TensorFlow is installed.""" | |||
return _tf_available | |||
|
|||
def does_torch_have_compile_capability(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be reverted complete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hame some questions about that! Wasn't the original ideia to add a new argument to enable compilation? Did I misunderstood?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was the first thought as your code looked like changes to the pipeline/models were needed. However, we then saw that these were not needed.
Which is why we only add tests here and a section on how to use it. The compilation therefore remains on the user side, which is at the same time much more flexible. :)
Additional this avoids to add a arg which at the end only does -> model = torch.compile(model, ..)
and is backend depending (PyTorch).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A full sample would look like then for example:
import requests
import torch
from doctr.models import ocr_predictor, parseq, fast_base
from doctr.io import DocumentFile
bytes_data = requests.get(
"https://i1.rgstatic.net/publication/231831562_Another_Boring_Day_in_Paradise_Rock_and_Roll_and_the_Empowerment_of_Everyday_Life/links/57d02a2408ae601b39a05636/largepreview.png"
).content
doc = DocumentFile.from_images([bytes_data])
rec_model = torch.compile(parseq(pretrained=True))
det_model = torch.compile(fast_base(pretrained=True))
predictor = ocr_predictor(det_arch=det_model, reco_arch=rec_model, pretrained=True)
res = predictor(doc)
res.show()
The only required change here would be to allow also:
torch._dynamo.eval_frame.OptimizedModule
in
doctr/doctr/models/recognition/zoo.py
Line 39 in 9045dcf
arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq) |
doctr/doctr/models/detection/zoo.py
Line 59 in 9045dcf
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)): |
doctr/doctr/models/classification/zoo.py
Line 45 in 9045dcf
if not isinstance(arch, classification.MobileNetV3): |
@@ -186,3 +186,46 @@ def test_models_onnx_export(arch_name, input_shape, output_size): | |||
assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) | |||
except AssertionError: | |||
pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") | |||
|
|||
@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0") | |||
@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the first two skipif
Ci runs always on latest pytorch - same for the other tests 👍
Hi @Fabioomega :) Are you still interested to work on it ? :) |
Added support for torch.compile only for version 2.4 or higher of pytorch. Included support for all the detection models and a recognition model (parseq).
Unfortunately, triton support is only available on linux plataforms. WSL seems to work fine tough, so it may be used that way.
Example use:
USE_TRITON = YES
.