Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NNCF compatibility with torch.compile #2665

Merged
merged 18 commits into from
May 24, 2024

Conversation

nikita-savelyevv
Copy link
Collaborator

@nikita-savelyevv nikita-savelyevv commented May 3, 2024

Changes

  • Unpatch torch during torch.compile() call for vanilla PyTorch model
  • Raise ValueError if torch.compile() is called for NNCF-optimized model
  • Unpatch torch during forward call of compiled model

Reason for changes

PyTorch dynamo compilation conflicts with nncf patching of PyTorch. This results in errors during compiled model forward, even if the model was not quantized, i.e. just import of nncf.torch results in failure.

Related tickets

140265

Tests

Added test for torch.compile compatibility with nncf

@nikita-savelyevv nikita-savelyevv requested a review from a team as a code owner May 3, 2024 13:26
@nikita-savelyevv nikita-savelyevv requested a review from vshampor May 3, 2024 13:26
@github-actions github-actions bot added the NNCF PT Pull requests that updates NNCF PyTorch label May 3, 2024
Copy link

codecov bot commented May 3, 2024

Codecov Report

Attention: Patch coverage is 87.50000% with 3 lines in your changes are missing coverage. Please review.

Project coverage is 34.78%. Comparing base (3586e58) to head (42752b5).
Report is 18 commits behind head on develop.

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff              @@
##           develop    #2665       +/-   ##
============================================
- Coverage    46.85%   34.78%   -12.07%     
============================================
  Files          493      495        +2     
  Lines        45551    46007      +456     
============================================
- Hits         21341    16003     -5338     
- Misses       24210    30004     +5794     
Files Coverage Δ
nncf/torch/dynamic_graph/patch_pytorch.py 79.26% <87.50%> (-6.31%) ⬇️

... and 113 files with indirect coverage changes

Flag Coverage Δ
COMMON ?
ONNX 34.78% <87.50%> (-0.05%) ⬇️
OPENVINO ?

Flags with carried forward coverage won't be shown. Click here to find out more.

Components Coverage Δ
common 58.18% <ø> (-11.49%) ⬇️
torch 32.96% <87.50%> (+0.10%) ⬆️
tensorflow 0.00% <ø> (ø)
onnx 93.06% <ø> (ø)
openvino 0.00% <ø> (-94.21%) ⬇️
ptq 40.41% <ø> (-38.93%) ⬇️

@alexsu52 alexsu52 requested a review from daniil-lyakhov May 6, 2024 09:04
@nikita-savelyevv
Copy link
Collaborator Author

The torch nightly build 207 has failed

@nikita-savelyevv nikita-savelyevv marked this pull request as draft May 6, 2024 10:57
@daniil-lyakhov
Copy link
Collaborator

@nikita-savelyevv, nice PR!

One question: why the context was a thread-local in the first place?

@nikita-savelyevv
Copy link
Collaborator Author

@nikita-savelyevv, nice PR!

One question: why the context was a thread-local in the first place?

Thanks! That's a good question. @vshampor could you please comment on that?

@nikita-savelyevv nikita-savelyevv marked this pull request as ready for review May 17, 2024 14:59
@nikita-savelyevv
Copy link
Collaborator Author

torch nightly build 234 is green, PR is ready for review

@nikita-savelyevv
Copy link
Collaborator Author

@AlexanderDokuchaev @daniil-lyakhov @vshampor please take a look

tests/torch/test_pytorch_patch.py Outdated Show resolved Hide resolved
nncf/torch/dynamic_graph/patch_pytorch.py Show resolved Hide resolved
tests/torch/pytorch_patch_isolated.py Outdated Show resolved Hide resolved
nncf/torch/dynamic_graph/patch_pytorch.py Outdated Show resolved Hide resolved
nncf/torch/dynamic_graph/patch_pytorch.py Outdated Show resolved Hide resolved
tests/torch/test_pytorch_patch.py Outdated Show resolved Hide resolved
tests/torch/test_pytorch_patch.py Outdated Show resolved Hide resolved
nncf/torch/dynamic_graph/patch_pytorch.py Outdated Show resolved Hide resolved
@nikita-savelyevv
Copy link
Collaborator Author

@AlexanderDokuchaev thank you for your suggestions! Applied all except one

@alexsu52 alexsu52 merged commit 1468a9b into openvinotoolkit:develop May 24, 2024
12 checks passed
This was referenced May 29, 2024
AlexanderDokuchaev pushed a commit that referenced this pull request Jun 11, 2024
### Changes

- Add condition to skip operator wrapping if torch is in unpatched state
- Fix test from a previous PR #2665 (remove dependency of the test on
`nncf.torch`)

### Reason for changes

Fix inference of compiled torch models when `nncf.torch` is imported

### Related tickets

140265

### Tests

Added `test_operator_unpatching`
AlexanderDokuchaev pushed a commit that referenced this pull request Aug 30, 2024
### Changes

Check whether model is compiled based on `"_torchdynamo_orig_callable"`
property of the model forward.

### Reason for changes

`torch.compile` can be applied not only to the model itself, but to
`forward()` method only. For example:
```
model.forward = torch.compile(model.forward)
```
In this case the model itself doesn't change and it won't be an instance
of `torch._dynamo.OptimizedModule`.

### Related tickets

143796

### Tests

Added test when `torch.compile` is applied this way. It does not fail
without the fix though, because the issue is sporadic.


### Relates to
#2665, #2719
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
NNCF PT Pull requests that updates NNCF PyTorch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants