-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix torch compile applied to model forward #2932
Fix torch compile applied to model forward #2932
Conversation
@@ -136,7 +135,7 @@ def wrapped(self, *args, **kwargs): | |||
from nncf.torch.dynamic_graph.patch_pytorch import unpatching_module_call | |||
|
|||
# If called on a model compiled by torch dynamo, we unpatch torch operators and invoke original module call | |||
if isinstance(self, OptimizedModule): | |||
if "_torchdynamo_orig_callable" in self.forward.__dict__: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that in theory any model method can be compiled in isolation. The fix presented here won't work for this case. To deal with such problem, we could do something like this:
for attr in self.__dict__.values():
if hasattr(attr, "__dict__") and "_torchdynamo_orig_callable" in attr.__dict__:
return unpatching_module_call(self, *args, **kwargs)
Such logic should protect the inference from being patched in case any of the model's methods is compiled. This however will introduce some overhead during every patched module call. I propose to keep it as is for now until we stumble across an issue where this can actually help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, the OptimizedModel
always has the _torchdynamo_orig_callable
https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/eval_frame.py#L148 method + appears in the compiled model https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/convert_frame.py#L413 . Hope it will stay as it is
I've added a check so tests will fail if this changes https://github.com/nikita-savelyevv/nncf/blob/fix-torch-compile-forward/tests/torch/pytorch_patch_isolated.py#L104 |
Changes
Check whether model is compiled based on
"_torchdynamo_orig_callable"
property of the model forward.Reason for changes
torch.compile
can be applied not only to the model itself, but toforward()
method only. For example:In this case the model itself doesn't change and it won't be an instance of
torch._dynamo.OptimizedModule
.Related tickets
143796
Tests
Added test when
torch.compile
is applied this way. It does not fail without the fix though, because the issue is sporadic.Relates to
#2665, #2719