diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 6f01e44f49b..78a38bcd8df 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -12,7 +12,7 @@ import functools import inspect from contextlib import contextmanager -from typing import List +from typing import Callable, List, Union import torch import torch.utils.cpp_extension @@ -243,6 +243,24 @@ def wrapper(*args, **kwargs): return wrapper +def get_torch_compile_wrapper(): + """ + Wrapper for torch.compile() that disables NNCF patching when called for vanilla PyTorch model and + raises an exception when called for an NNCF-optimized model. + """ + + @functools.wraps(_ORIG_TORCH_COMPILE) + def wrapper(model, *args, **kwargs): + from nncf.torch.nncf_network import NNCFNetwork + + if isinstance(model, NNCFNetwork): + raise TypeError("At the moment torch.compile() is not supported for models optimized by NNCF.") + with disable_patching(): + return _ORIG_TORCH_COMPILE(model, *args, **kwargs) + + return wrapper + + class OriginalOpInfo: def __init__(self, name: str, namespace, op): self.name = name @@ -256,6 +274,15 @@ def __init__(self, name: str, namespace, op): _OPERATORS_ALREADY_WRAPPED = False _ORIG_JIT_SCRIPT = None _ORIG_JIT_TRACE_MAKE_MODULE = None +_COMPILE_ALREADY_WRAPPED = False +_ORIG_TORCH_COMPILE: Union[Callable, None] = None + + +@functools.wraps(ORIGINAL_CALL) +def unpatching_module_call(*args, **kwargs): + # Wrapper for module.__call__ that unpatches torch operators during model forward + with disable_patching(): + return ORIGINAL_CALL(*args, **kwargs) def patch_torch_jit(): @@ -330,6 +357,14 @@ def patch_torch_operators(): patch_torch_jit() _JIT_ALREADY_WRAPPED = True + # Unpatch torch operators during model compilation. + global _COMPILE_ALREADY_WRAPPED + if not _COMPILE_ALREADY_WRAPPED: + global _ORIG_TORCH_COMPILE + _ORIG_TORCH_COMPILE = torch.compile + setattr(torch, "compile", get_torch_compile_wrapper()) + _COMPILE_ALREADY_WRAPPED = True + # Do not patch operators twice as well global _OPERATORS_ALREADY_WRAPPED if _OPERATORS_ALREADY_WRAPPED: diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index 6728d5cd904..d59b2fa6c5a 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -13,6 +13,7 @@ from typing import Callable, List, Tuple import torch +from torch._dynamo import OptimizedModule from torch.nn import DataParallel from nncf.common.graph.definitions import MODEL_CONST_OP_NAME @@ -127,6 +128,12 @@ def wrap_module_call(module_call): @functools.wraps(module_call) def wrapped(self, *args, **kwargs): + from nncf.torch.dynamic_graph.patch_pytorch import unpatching_module_call + + # If called on a model compiled by torch dynamo, we unpatch torch operators and invoke original module call + if isinstance(self, OptimizedModule): + return unpatching_module_call(self, *args, **kwargs) + ctx = get_current_context() if not ctx or self.__class__ in _IGNORED_SCOPES: if isinstance(self, DataParallel): diff --git a/tests/torch/pytorch_patch_isolated.py b/tests/torch/pytorch_patch_isolated.py index 6724fbd130f..ce26f67eea3 100644 --- a/tests/torch/pytorch_patch_isolated.py +++ b/tests/torch/pytorch_patch_isolated.py @@ -77,3 +77,23 @@ def test_jit_script_exception_preserves_patching_isolated(): # torch.nn.Module.__call__ is one of the fundamental patched functions, if the code object points to NNCF code, # then it means patching is still present assert "nncf" in torch.nn.Module.__call__.__code__.co_filename + + +def compile_and_run_test_model() -> torch.Tensor: + from tests.torch.helpers import BasicConvTestModel + + model = BasicConvTestModel() + state_dict = {"conv.weight": model.default_weight(), "conv.bias": model.default_bias()} + model.load_state_dict(state_dict) + + compiled_model = torch.compile(model) + return compiled_model(torch.ones(model.INPUT_SIZE)) + + +@pytest.mark.skipif(ISOLATION_RUN_ENV_VAR not in os.environ, reason="Should be run via isolation proxy") +def test_compile(): + before_nncf = compile_and_run_test_model() + import nncf.torch # noqa: F401 + + after_nncf = compile_and_run_test_model() + assert torch.allclose(before_nncf, after_nncf) diff --git a/tests/torch/test_pytorch_patch.py b/tests/torch/test_pytorch_patch.py index 07d6a8bf82f..c800b74ce0a 100644 --- a/tests/torch/test_pytorch_patch.py +++ b/tests/torch/test_pytorch_patch.py @@ -15,6 +15,7 @@ import pytest import torch +import nncf from nncf.config import NNCFConfig from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.dynamic_graph.patch_pytorch import _ORIG_JIT_SCRIPT @@ -27,6 +28,7 @@ from tests.torch.helpers import BasicConvTestModel from tests.torch.helpers import create_compressed_model_and_algo_for_test from tests.torch.helpers import register_bn_adaptation_init_args +from tests.torch.pytorch_patch_isolated import test_compile from tests.torch.pytorch_patch_isolated import test_jit_if_tracing_script_source_equals from tests.torch.pytorch_patch_isolated import test_jit_script_exception_preserves_patching_isolated @@ -106,6 +108,45 @@ def test_jit_script_exception_preserves_patching(): run_pytest_case_function_in_separate_process(test_jit_script_exception_preserves_patching_isolated) +def test_torch_compile(): + # Run test case in a separate process to track patching of torch by NNCF + run_pytest_case_function_in_separate_process(test_compile) + + +def test_torch_compile_on_nncf_model(): + # Calling torch.compile on a regular torch model should work fine + model = BasicConvTestModel() + compiled_model = torch.compile(model) + compiled_model(torch.ones(model.INPUT_SIZE)) + + model = BasicConvTestModel() + quantized_model = nncf.quantize(model, nncf.Dataset([torch.rand(model.INPUT_SIZE)])) + with pytest.raises( + TypeError, match="At the moment torch\\.compile\\(\\) is not supported for models optimized by NNCF\\." + ): + torch.compile(quantized_model) + + model = BasicConvTestModel() + config = get_test_quantization_config(model) + compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) + with pytest.raises( + TypeError, match="At the moment torch\\.compile\\(\\) is not supported for models optimized by NNCF\\." + ): + torch.compile(compressed_model) + + stripped_model = compression_ctrl.strip() + with pytest.raises( + TypeError, match="At the moment torch\\.compile\\(\\) is not supported for models optimized by NNCF\\." + ): + torch.compile(stripped_model) + + with pytest.raises( + TypeError, match="At the moment torch\\.compile\\(\\) is not supported for models optimized by NNCF\\." + ): + # Compiling this model would actually work, but inference of the compiled model will fail + torch.compile(model) + + def test_jit_script_signature(): # Check that torch.jit.script has the same signature as the wrapper was designed for signature = inspect.signature(_ORIG_JIT_SCRIPT) @@ -127,6 +168,16 @@ def class_method(self, x): def test_jit_trace_model(): model = BasicConvTestModel() + config = get_test_quantization_config(model) + + compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) + torch.jit.trace(compressed_model, example_inputs=torch.rand(model.INPUT_SIZE)) + + model = compression_ctrl.strip() + torch.jit.trace(model, example_inputs=torch.rand(model.INPUT_SIZE)) + + +def get_test_quantization_config(model): config = NNCFConfig() config.update( { @@ -136,9 +187,4 @@ def test_jit_trace_model(): } ) register_bn_adaptation_init_args(config) - - compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) - torch.jit.trace(compressed_model, example_inputs=torch.rand(model.INPUT_SIZE)) - - model = compression_ctrl.strip() - torch.jit.trace(model, example_inputs=torch.rand(model.INPUT_SIZE)) + return config