Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch compile applied to model forward #2932

Conversation

nikita-savelyevv
Copy link
Collaborator

@nikita-savelyevv nikita-savelyevv commented Aug 29, 2024

Changes

Check whether model is compiled based on "_torchdynamo_orig_callable" property of the model forward.

Reason for changes

torch.compile can be applied not only to the model itself, but to forward() method only. For example:

model.forward = torch.compile(model.forward)

In this case the model itself doesn't change and it won't be an instance of torch._dynamo.OptimizedModule.

Related tickets

143796

Tests

Added test when torch.compile is applied this way. It does not fail without the fix though, because the issue is sporadic.

Relates to

#2665, #2719

@nikita-savelyevv nikita-savelyevv requested a review from a team as a code owner August 29, 2024 15:01
@github-actions github-actions bot added the NNCF PT Pull requests that updates NNCF PyTorch label Aug 29, 2024
@@ -136,7 +135,7 @@ def wrapped(self, *args, **kwargs):
from nncf.torch.dynamic_graph.patch_pytorch import unpatching_module_call

# If called on a model compiled by torch dynamo, we unpatch torch operators and invoke original module call
if isinstance(self, OptimizedModule):
if "_torchdynamo_orig_callable" in self.forward.__dict__:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that in theory any model method can be compiled in isolation. The fix presented here won't work for this case. To deal with such problem, we could do something like this:

for attr in self.__dict__.values():
    if hasattr(attr, "__dict__") and "_torchdynamo_orig_callable" in attr.__dict__:
        return unpatching_module_call(self, *args, **kwargs)

Such logic should protect the inference from being patched in case any of the model's methods is compiled. This however will introduce some overhead during every patched module call. I propose to keep it as is for now until we stumble across an issue where this can actually help.

Copy link
Collaborator

@daniil-lyakhov daniil-lyakhov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, the OptimizedModel always has the _torchdynamo_orig_callable https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/eval_frame.py#L148 method + appears in the compiled model https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/convert_frame.py#L413 . Hope it will stay as it is

@nikita-savelyevv
Copy link
Collaborator Author

LGTM, the OptimizedModel always has the _torchdynamo_orig_callable method + appears in the compiled model https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/convert_frame.py#L413 . Hope it will stay as it is

I've added a check so tests will fail if this changes https://github.com/nikita-savelyevv/nncf/blob/fix-torch-compile-forward/tests/torch/pytorch_patch_isolated.py#L104

@AlexanderDokuchaev AlexanderDokuchaev merged commit eb61347 into openvinotoolkit:develop Aug 30, 2024
13 checks passed
nikita-savelyevv added a commit to KodiaqQ/nncf that referenced this pull request Sep 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
NNCF PT Pull requests that updates NNCF PyTorch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants