Skip to content

Commit

Permalink
Fix torch compile applied to model forward (#2932)
Browse files Browse the repository at this point in the history
### Changes

Check whether model is compiled based on `"_torchdynamo_orig_callable"`
property of the model forward.

### Reason for changes

`torch.compile` can be applied not only to the model itself, but to
`forward()` method only. For example:
```
model.forward = torch.compile(model.forward)
```
In this case the model itself doesn't change and it won't be an instance
of `torch._dynamo.OptimizedModule`.

### Related tickets

143796

### Tests

Added test when `torch.compile` is applied this way. It does not fail
without the fix though, because the issue is sporadic.


### Relates to
#2665, #2719
  • Loading branch information
nikita-savelyevv authored Aug 30, 2024
1 parent 2088962 commit eb61347
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
3 changes: 1 addition & 2 deletions nncf/torch/dynamic_graph/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Callable, List, Tuple

import torch
from torch._dynamo import OptimizedModule
from torch.nn import DataParallel

from nncf.common.graph.definitions import MODEL_CONST_OP_NAME
Expand Down Expand Up @@ -136,7 +135,7 @@ def wrapped(self, *args, **kwargs):
from nncf.torch.dynamic_graph.patch_pytorch import unpatching_module_call

# If called on a model compiled by torch dynamo, we unpatch torch operators and invoke original module call
if isinstance(self, OptimizedModule):
if "_torchdynamo_orig_callable" in self.forward.__dict__:
return unpatching_module_call(self, *args, **kwargs)

ctx = get_current_context()
Expand Down
14 changes: 10 additions & 4 deletions tests/torch/pytorch_patch_isolated.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_jit_script_exception_preserves_patching_isolated():
assert "nncf" in torch.nn.Module.__call__.__code__.co_filename


def compile_and_run_test_model() -> torch.Tensor:
def compile_and_run_test_model(compile_forward: bool) -> torch.Tensor:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -96,14 +96,20 @@ def forward(self, x):
state_dict[k] = torch.rand(v.shape)
model.load_state_dict(state_dict)

compiled_model = torch.compile(model)
if compile_forward:
compiled_model = model
compiled_model.forward = torch.compile(model.forward)
else:
compiled_model = torch.compile(model)
assert "_torchdynamo_orig_callable" in compiled_model.forward.__dict__
return compiled_model(torch.rand([1, 3, 5, 5]))


@pytest.mark.skipif(ISOLATION_RUN_ENV_VAR not in os.environ, reason="Should be run via isolation proxy")
def test_compile():
before_nncf = compile_and_run_test_model()
compile_forward = os.environ.get("COMPILE_FORWARD", None) == "1"
before_nncf = compile_and_run_test_model(compile_forward)
import nncf.torch # noqa: F401

after_nncf = compile_and_run_test_model()
after_nncf = compile_and_run_test_model(compile_forward)
assert torch.allclose(before_nncf, after_nncf)
5 changes: 4 additions & 1 deletion tests/torch/test_pytorch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import inspect
import os
from typing import List

import pytest
Expand Down Expand Up @@ -114,8 +115,10 @@ def test_jit_script_exception_preserves_patching():


@pytest.mark.xfail(is_windows(), reason="https://github.com/pytorch/pytorch/issues/122094")
def test_torch_compile():
@pytest.mark.parametrize("compile_forward", [False, True])
def test_torch_compile(compile_forward):
# Run test case in a separate process to track patching of torch by NNCF
os.environ["COMPILE_FORWARD"] = f"{int(compile_forward)}"
run_pytest_case_function_in_separate_process(test_compile)


Expand Down

0 comments on commit eb61347

Please sign in to comment.