From 7ef226278d816dc9d8a6d68a9874c47adfcf6ad2 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 3 May 2024 15:18:08 +0200 Subject: [PATCH 01/18] Do not inherit torch patching context from threading.local --- nncf/torch/dynamic_graph/context.py | 4 ++-- tests/torch/pytorch_patch_isolated.py | 21 +++++++++++++++++++++ tests/torch/test_pytorch_patch.py | 6 ++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/nncf/torch/dynamic_graph/context.py b/nncf/torch/dynamic_graph/context.py index 05b89843740..64d7f129263 100644 --- a/nncf/torch/dynamic_graph/context.py +++ b/nncf/torch/dynamic_graph/context.py @@ -37,13 +37,13 @@ from nncf.torch.dynamic_graph.trace_tensor import TracedTensorMixin -class ThreadLocalGlobalContext(threading.local): +class GlobalContext: def __init__(self): super().__init__() self.context = None -_CURRENT_CONTEXT = ThreadLocalGlobalContext() +_CURRENT_CONTEXT = GlobalContext() class PreHookId: diff --git a/tests/torch/pytorch_patch_isolated.py b/tests/torch/pytorch_patch_isolated.py index 6724fbd130f..43717218ef8 100644 --- a/tests/torch/pytorch_patch_isolated.py +++ b/tests/torch/pytorch_patch_isolated.py @@ -18,6 +18,7 @@ import torch from tests.shared.isolation_runner import ISOLATION_RUN_ENV_VAR +from tests.torch.test_models.lenet import LeNet def clean_source_code(code_source): @@ -77,3 +78,23 @@ def test_jit_script_exception_preserves_patching_isolated(): # torch.nn.Module.__call__ is one of the fundamental patched functions, if the code object points to NNCF code, # then it means patching is still present assert "nncf" in torch.nn.Module.__call__.__code__.co_filename + + +def _compile_and_run_lenet() -> torch.Tensor: + model = LeNet() + + torch.manual_seed(0) + state_dict = {} + for k, v in model.state_dict().items(): + state_dict[k] = torch.rand(v.shape) + model.load_state_dict(state_dict) + + compiled_model = torch.compile(model) + return compiled_model(torch.ones([1, 3, 32, 32])) + + +def test_compile(): + before_nncf = _compile_and_run_lenet() + import nncf.torch + after_nncf = _compile_and_run_lenet() + assert torch.allclose(before_nncf, after_nncf) diff --git a/tests/torch/test_pytorch_patch.py b/tests/torch/test_pytorch_patch.py index 07d6a8bf82f..9a3c15efd0c 100644 --- a/tests/torch/test_pytorch_patch.py +++ b/tests/torch/test_pytorch_patch.py @@ -27,6 +27,7 @@ from tests.torch.helpers import BasicConvTestModel from tests.torch.helpers import create_compressed_model_and_algo_for_test from tests.torch.helpers import register_bn_adaptation_init_args +from tests.torch.pytorch_patch_isolated import test_compile from tests.torch.pytorch_patch_isolated import test_jit_if_tracing_script_source_equals from tests.torch.pytorch_patch_isolated import test_jit_script_exception_preserves_patching_isolated @@ -106,6 +107,11 @@ def test_jit_script_exception_preserves_patching(): run_pytest_case_function_in_separate_process(test_jit_script_exception_preserves_patching_isolated) +def test_torch_compile(): + # Run test case in a separate process to track patching of torch by NNCF + run_pytest_case_function_in_separate_process(test_compile) + + def test_jit_script_signature(): # Check that torch.jit.script has the same signature as the wrapper was designed for signature = inspect.signature(_ORIG_JIT_SCRIPT) From c24ddad63e5fac1a139e7d1c270ce701beffe31c Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 3 May 2024 15:18:33 +0200 Subject: [PATCH 02/18] Linters --- tests/torch/pytorch_patch_isolated.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/torch/pytorch_patch_isolated.py b/tests/torch/pytorch_patch_isolated.py index 43717218ef8..92a006e7004 100644 --- a/tests/torch/pytorch_patch_isolated.py +++ b/tests/torch/pytorch_patch_isolated.py @@ -96,5 +96,6 @@ def _compile_and_run_lenet() -> torch.Tensor: def test_compile(): before_nncf = _compile_and_run_lenet() import nncf.torch + after_nncf = _compile_and_run_lenet() assert torch.allclose(before_nncf, after_nncf) From 450838a0b50029b536aafba4256972280eabcfa7 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 3 May 2024 15:31:05 +0200 Subject: [PATCH 03/18] Ignore an import line by ruff --- tests/torch/pytorch_patch_isolated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torch/pytorch_patch_isolated.py b/tests/torch/pytorch_patch_isolated.py index 92a006e7004..cb5374b3080 100644 --- a/tests/torch/pytorch_patch_isolated.py +++ b/tests/torch/pytorch_patch_isolated.py @@ -95,7 +95,7 @@ def _compile_and_run_lenet() -> torch.Tensor: def test_compile(): before_nncf = _compile_and_run_lenet() - import nncf.torch + import nncf.torch # noqa: F401 after_nncf = _compile_and_run_lenet() assert torch.allclose(before_nncf, after_nncf) From 5a5be12d724eaac52050801cd329648aeafb06c0 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 3 May 2024 17:16:43 +0200 Subject: [PATCH 04/18] Test tweaks --- tests/torch/pytorch_patch_isolated.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/torch/pytorch_patch_isolated.py b/tests/torch/pytorch_patch_isolated.py index cb5374b3080..aa842657b98 100644 --- a/tests/torch/pytorch_patch_isolated.py +++ b/tests/torch/pytorch_patch_isolated.py @@ -18,7 +18,6 @@ import torch from tests.shared.isolation_runner import ISOLATION_RUN_ENV_VAR -from tests.torch.test_models.lenet import LeNet def clean_source_code(code_source): @@ -80,7 +79,9 @@ def test_jit_script_exception_preserves_patching_isolated(): assert "nncf" in torch.nn.Module.__call__.__code__.co_filename -def _compile_and_run_lenet() -> torch.Tensor: +def compile_and_run_lenet() -> torch.Tensor: + from tests.torch.test_models.lenet import LeNet + model = LeNet() torch.manual_seed(0) @@ -93,9 +94,10 @@ def _compile_and_run_lenet() -> torch.Tensor: return compiled_model(torch.ones([1, 3, 32, 32])) +@pytest.mark.skipif(ISOLATION_RUN_ENV_VAR not in os.environ, reason="Should be run via isolation proxy") def test_compile(): - before_nncf = _compile_and_run_lenet() + before_nncf = compile_and_run_lenet() import nncf.torch # noqa: F401 - after_nncf = _compile_and_run_lenet() + after_nncf = compile_and_run_lenet() assert torch.allclose(before_nncf, after_nncf) From 9f03b1a7b832128a851a5a7c6ebdcd2b0b43da8c Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Tue, 7 May 2024 10:31:02 +0200 Subject: [PATCH 05/18] Add thread safe lock --- nncf/torch/dynamic_graph/context.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/nncf/torch/dynamic_graph/context.py b/nncf/torch/dynamic_graph/context.py index 64d7f129263..7a44eb1f286 100644 --- a/nncf/torch/dynamic_graph/context.py +++ b/nncf/torch/dynamic_graph/context.py @@ -39,8 +39,18 @@ class GlobalContext: def __init__(self): - super().__init__() - self.context = None + self._context = None + self.lock = threading.Lock() + + @property + def context(self): + with self.lock: + return self._context + + @context.setter + def context(self, value): + with self.lock: + self._context = value _CURRENT_CONTEXT = GlobalContext() From f49258a57ba7a14c3c187654cdb74283f6a21c7d Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Wed, 8 May 2024 14:23:33 +0200 Subject: [PATCH 06/18] WIP --- nncf/torch/dynamic_graph/context.py | 37 ++++++++++++++--------- nncf/torch/dynamic_graph/patch_pytorch.py | 8 +++++ nncf/torch/dynamic_graph/wrappers.py | 11 +++++++ 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/nncf/torch/dynamic_graph/context.py b/nncf/torch/dynamic_graph/context.py index 7a44eb1f286..fcc34c44cd3 100644 --- a/nncf/torch/dynamic_graph/context.py +++ b/nncf/torch/dynamic_graph/context.py @@ -37,23 +37,32 @@ from nncf.torch.dynamic_graph.trace_tensor import TracedTensorMixin -class GlobalContext: +# class GlobalContext: +# def __init__(self): +# self._context = None +# self.lock = threading.Lock() +# +# @property +# def context(self): +# with self.lock: +# return self._context +# +# @context.setter +# def context(self, value): +# with self.lock: +# self._context = value +# +# +# _CURRENT_CONTEXT = GlobalContext() + + +class ThreadLocalGlobalContext(threading.local): def __init__(self): - self._context = None - self.lock = threading.Lock() - - @property - def context(self): - with self.lock: - return self._context - - @context.setter - def context(self, value): - with self.lock: - self._context = value + super().__init__() + self.context = None -_CURRENT_CONTEXT = GlobalContext() +_CURRENT_CONTEXT = ThreadLocalGlobalContext() class PreHookId: diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 6f01e44f49b..67fe143d8ca 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -243,6 +243,14 @@ def wrapper(*args, **kwargs): return wrapper +def get_disable_patching_wrapper(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + with disable_patching(): + return fn(*args, **kwargs) + return wrapper + + class OriginalOpInfo: def __init__(self, name: str, namespace, op): self.name = name diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index 6728d5cd904..139ad1380bb 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -128,10 +128,21 @@ def wrap_module_call(module_call): @functools.wraps(module_call) def wrapped(self, *args, **kwargs): ctx = get_current_context() + # if ctx is not None: + # # import threading + # # from nncf.torch.dynamic_graph.context import _CURRENT_CONTEXT + # print(ctx, type(ctx), ctx is None, ctx is not None) + # # print(threading.get_ident(), ctx, type(ctx), ctx is None, ctx is not None) + # # print(_CURRENT_CONTEXT) + if "_torchdynamo_orig_callable" in self.forward.__dict__: + from nncf.torch.dynamic_graph.patch_pytorch import disable_patching + with disable_patching(): + return module_call(self, *args, **kwargs) if not ctx or self.__class__ in _IGNORED_SCOPES: if isinstance(self, DataParallel): _warn_data_parallel() return module_call(self, *args, **kwargs) + print(self.__class__) ctx.push_scope(self) is_nncf_layer = isinstance(self, _NNCFModuleMixin) if is_nncf_layer: From 5e92034632817dc2d3ee2f36beb52c7d2d6437bc Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Wed, 8 May 2024 15:10:12 +0200 Subject: [PATCH 07/18] Subset test --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 4d9b8419a37..f1a8aada913 100644 --- a/Makefile +++ b/Makefile @@ -151,7 +151,7 @@ test-torch-cuda: pytest ${COVERAGE_ARGS} tests/torch -ra -m "cuda and not weekly and not nightly and not models_hub" test-torch-nightly: - pytest ${COVERAGE_ARGS} tests/torch -m nightly --junitxml ${JUNITXML_PATH} $(DATA_ARG) + pytest ${COVERAGE_ARGS} tests/torch -m nightly -k "quantization.test_sanity_sample" --junitxml ${JUNITXML_PATH} $(DATA_ARG) test-torch-weekly: pytest ${COVERAGE_ARGS} tests/torch -m weekly \ From c3ce084053e66cec03d22206f7224553b32fa41d Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Wed, 8 May 2024 15:16:23 +0200 Subject: [PATCH 08/18] Subset test --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index f1a8aada913..36ae7573d92 100644 --- a/Makefile +++ b/Makefile @@ -151,7 +151,7 @@ test-torch-cuda: pytest ${COVERAGE_ARGS} tests/torch -ra -m "cuda and not weekly and not nightly and not models_hub" test-torch-nightly: - pytest ${COVERAGE_ARGS} tests/torch -m nightly -k "quantization.test_sanity_sample" --junitxml ${JUNITXML_PATH} $(DATA_ARG) + pytest ${COVERAGE_ARGS} tests/torch -m nightly -k "test_sanity_sample" --junitxml ${JUNITXML_PATH} $(DATA_ARG) test-torch-weekly: pytest ${COVERAGE_ARGS} tests/torch -m weekly \ From d422d11c925b1eb9ee38f90d7d4cf21b1fcee4f1 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Wed, 8 May 2024 18:11:22 +0200 Subject: [PATCH 09/18] WIP --- Makefile | 2 +- nncf/torch/dynamic_graph/context.py | 19 ---------- nncf/torch/dynamic_graph/patch_pytorch.py | 45 ++++++++++++++++++++--- nncf/torch/dynamic_graph/wrappers.py | 16 +++----- reproducer.py | 27 ++++++++++++++ 5 files changed, 73 insertions(+), 36 deletions(-) create mode 100644 reproducer.py diff --git a/Makefile b/Makefile index 36ae7573d92..4d9b8419a37 100644 --- a/Makefile +++ b/Makefile @@ -151,7 +151,7 @@ test-torch-cuda: pytest ${COVERAGE_ARGS} tests/torch -ra -m "cuda and not weekly and not nightly and not models_hub" test-torch-nightly: - pytest ${COVERAGE_ARGS} tests/torch -m nightly -k "test_sanity_sample" --junitxml ${JUNITXML_PATH} $(DATA_ARG) + pytest ${COVERAGE_ARGS} tests/torch -m nightly --junitxml ${JUNITXML_PATH} $(DATA_ARG) test-torch-weekly: pytest ${COVERAGE_ARGS} tests/torch -m weekly \ diff --git a/nncf/torch/dynamic_graph/context.py b/nncf/torch/dynamic_graph/context.py index fcc34c44cd3..05b89843740 100644 --- a/nncf/torch/dynamic_graph/context.py +++ b/nncf/torch/dynamic_graph/context.py @@ -37,25 +37,6 @@ from nncf.torch.dynamic_graph.trace_tensor import TracedTensorMixin -# class GlobalContext: -# def __init__(self): -# self._context = None -# self.lock = threading.Lock() -# -# @property -# def context(self): -# with self.lock: -# return self._context -# -# @context.setter -# def context(self, value): -# with self.lock: -# self._context = value -# -# -# _CURRENT_CONTEXT = GlobalContext() - - class ThreadLocalGlobalContext(threading.local): def __init__(self): super().__init__() diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 67fe143d8ca..6d0e9d0f57d 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -243,12 +243,39 @@ def wrapper(*args, **kwargs): return wrapper -def get_disable_patching_wrapper(fn): - @functools.wraps(fn) - def wrapper(*args, **kwargs): - with disable_patching(): - return fn(*args, **kwargs) - return wrapper +# def get_disable_patching_wrapper(fn): +# @functools.wraps(fn) +# def wrapper(*args, **kwargs): +# with disable_patching(): +# wrapped_call = fn.__dict__["__wrapped__"] +# original_call = wrapped_call.__dict__["__wrapped__"] +# # fn.__dict__["__wrapped__"] = original_call +# result = fn(*args, **kwargs) +# fn.__dict__["__wrapped__"] = wrapped_call +# return result +# return wrapper +# +# +# def get_dynamo_optimize_context_wrapper(fn): +# @functools.wraps(fn) +# def wrapper(self, target): +# result = fn(self, target) +# if hasattr(target, "__self__"): +# return get_disable_patching_wrapper(result) +# return result +# return wrapper +# +# +# def get_dynamo_optimized_module_forward_wrapper(fn): +# @functools.wraps(fn) +# def wrapper(*args, **kwargs): +# return fn(*args, **kwargs) +# return wrapper +# +# +# def patch_dynamo(): +# from torch._dynamo.eval_frame import OptimizeContext +# OptimizeContext.__call__ = get_dynamo_optimize_context_wrapper(OptimizeContext.__call__) class OriginalOpInfo: @@ -286,6 +313,8 @@ def patch_torch_jit(): # unpatched torch.jit.script and the patching above does not affect it setattr(torch.jit, "_script_if_tracing", torch_jit_script_if_tracing) + # patch_dynamo() + def patch_namespace_opname(namespace, op_info: PatchedOperatorInfo): op_name = op_info.name @@ -431,3 +460,7 @@ def disable_patching(): # Need to restore the previous state of patching in this case before continuing to the exception handling. if was_patched: patch_torch_operators() + + +def is_patched_by_dynamo(model: torch.nn.Module): + return hasattr(model, "forward") and "_torchdynamo_orig_callable" in model.forward.__dict__ diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index 139ad1380bb..f9f15460eb9 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -127,22 +127,18 @@ def wrap_module_call(module_call): @functools.wraps(module_call) def wrapped(self, *args, **kwargs): - ctx = get_current_context() - # if ctx is not None: - # # import threading - # # from nncf.torch.dynamic_graph.context import _CURRENT_CONTEXT - # print(ctx, type(ctx), ctx is None, ctx is not None) - # # print(threading.get_ident(), ctx, type(ctx), ctx is None, ctx is not None) - # # print(_CURRENT_CONTEXT) - if "_torchdynamo_orig_callable" in self.forward.__dict__: - from nncf.torch.dynamic_graph.patch_pytorch import disable_patching + from nncf.torch.dynamic_graph.patch_pytorch import disable_patching + from nncf.torch.dynamic_graph.patch_pytorch import is_patched_by_dynamo + + if is_patched_by_dynamo(self): with disable_patching(): return module_call(self, *args, **kwargs) + + ctx = get_current_context() if not ctx or self.__class__ in _IGNORED_SCOPES: if isinstance(self, DataParallel): _warn_data_parallel() return module_call(self, *args, **kwargs) - print(self.__class__) ctx.push_scope(self) is_nncf_layer = isinstance(self, _NNCFModuleMixin) if is_nncf_layer: diff --git a/reproducer.py b/reproducer.py new file mode 100644 index 00000000000..a16105465f4 --- /dev/null +++ b/reproducer.py @@ -0,0 +1,27 @@ +import time + +import nncf.torch +# from optimum.intel import OVModelForFeatureExtraction + +import openvino.torch +from transformers import AutoModel, AutoTokenizer +import torch + +# import nncf.torch + +# import torch._dynamo +# torch._dynamo.config.suppress_errors = True + +tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5") +model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5") +model = torch.compile(model, backend="openvino") +# model = torch.compile(model) + +encoded_input = tokenizer( + ["hello world"], padding=True, truncation=True, return_tensors="pt" +) +with torch.no_grad(): + time.sleep(1) + model_output = model(**encoded_input) + sentence_embeddings = model_output[0][:, 0] +sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) From f7eea5e05816585a8fd8a19fd327bfcef47e2d29 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 17 May 2024 14:00:30 +0200 Subject: [PATCH 10/18] WIP --- nncf/torch/dynamic_graph/patch_pytorch.py | 48 +++++++---------------- reproducer.py | 4 ++ 2 files changed, 19 insertions(+), 33 deletions(-) diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 6d0e9d0f57d..b44bc33974b 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -243,39 +243,12 @@ def wrapper(*args, **kwargs): return wrapper -# def get_disable_patching_wrapper(fn): -# @functools.wraps(fn) -# def wrapper(*args, **kwargs): -# with disable_patching(): -# wrapped_call = fn.__dict__["__wrapped__"] -# original_call = wrapped_call.__dict__["__wrapped__"] -# # fn.__dict__["__wrapped__"] = original_call -# result = fn(*args, **kwargs) -# fn.__dict__["__wrapped__"] = wrapped_call -# return result -# return wrapper -# -# -# def get_dynamo_optimize_context_wrapper(fn): -# @functools.wraps(fn) -# def wrapper(self, target): -# result = fn(self, target) -# if hasattr(target, "__self__"): -# return get_disable_patching_wrapper(result) -# return result -# return wrapper -# -# -# def get_dynamo_optimized_module_forward_wrapper(fn): -# @functools.wraps(fn) -# def wrapper(*args, **kwargs): -# return fn(*args, **kwargs) -# return wrapper -# -# -# def patch_dynamo(): -# from torch._dynamo.eval_frame import OptimizeContext -# OptimizeContext.__call__ = get_dynamo_optimize_context_wrapper(OptimizeContext.__call__) +def get_disable_patching_wrapper(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + with disable_patching(): + return f(*args, **kwargs) + return wrapper class OriginalOpInfo: @@ -291,6 +264,8 @@ def __init__(self, name: str, namespace, op): _OPERATORS_ALREADY_WRAPPED = False _ORIG_JIT_SCRIPT = None _ORIG_JIT_TRACE_MAKE_MODULE = None +_COMPILE_ALREADY_WRAPPED = False +_ORIG_TORCH_COMPILE = None def patch_torch_jit(): @@ -367,6 +342,13 @@ def patch_torch_operators(): patch_torch_jit() _JIT_ALREADY_WRAPPED = True + global _COMPILE_ALREADY_WRAPPED + if not _COMPILE_ALREADY_WRAPPED: + global _ORIG_TORCH_COMPILE + _ORIG_TORCH_COMPILE = torch.compile + setattr(torch, "compile", get_disable_patching_wrapper(_ORIG_TORCH_COMPILE)) + _COMPILE_ALREADY_WRAPPED = True + # Do not patch operators twice as well global _OPERATORS_ALREADY_WRAPPED if _OPERATORS_ALREADY_WRAPPED: diff --git a/reproducer.py b/reproducer.py index a16105465f4..178345c8d7d 100644 --- a/reproducer.py +++ b/reproducer.py @@ -1,12 +1,14 @@ import time import nncf.torch +from nncf.torch.dynamic_graph.patch_pytorch import unpatch_torch_operators, patch_torch_operators # from optimum.intel import OVModelForFeatureExtraction import openvino.torch from transformers import AutoModel, AutoTokenizer import torch + # import nncf.torch # import torch._dynamo @@ -14,7 +16,9 @@ tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5") model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5") +# unpatch_torch_operators() model = torch.compile(model, backend="openvino") +# patch_torch_operators() # model = torch.compile(model) encoded_input = tokenizer( From 88859ec4ea70389afc971f9e20258b77f651817c Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Fri, 17 May 2024 14:59:17 +0200 Subject: [PATCH 11/18] Unpatch torch operators during torch compile call and during compiled model forward call --- nncf/torch/dynamic_graph/patch_pytorch.py | 30 +++++++++++++++++----- nncf/torch/dynamic_graph/wrappers.py | 7 ----- reproducer.py | 31 ----------------------- 3 files changed, 23 insertions(+), 45 deletions(-) delete mode 100644 reproducer.py diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index b44bc33974b..35434497669 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -244,10 +244,31 @@ def wrapper(*args, **kwargs): def get_disable_patching_wrapper(f): + """ + :param: f: A callable object to be wrapped. + :return: A wrapper of a function that unpatches torch operators before the + function call and patches them back after the call. + """ + @functools.wraps(f) def wrapper(*args, **kwargs): with disable_patching(): return f(*args, **kwargs) + + return wrapper + + +def module_call_wrapper(module_call): + wrapped_module_call = wrap_module_call(module_call) + unpatched_module_call = get_disable_patching_wrapper(module_call) + + @functools.wraps(module_call) + def wrapper(self, *args, **kwargs): + # Check if model was patched by torch dynamo during compilation + if "_torchdynamo_orig_callable" in self.forward.__dict__: + unpatched_module_call(self, *args, **kwargs) + return wrapped_module_call(self, *args, **kwargs) + return wrapper @@ -288,8 +309,6 @@ def patch_torch_jit(): # unpatched torch.jit.script and the patching above does not affect it setattr(torch.jit, "_script_if_tracing", torch_jit_script_if_tracing) - # patch_dynamo() - def patch_namespace_opname(namespace, op_info: PatchedOperatorInfo): op_name = op_info.name @@ -342,6 +361,7 @@ def patch_torch_operators(): patch_torch_jit() _JIT_ALREADY_WRAPPED = True + # Unpatch torch operators during model compilation. global _COMPILE_ALREADY_WRAPPED if not _COMPILE_ALREADY_WRAPPED: global _ORIG_TORCH_COMPILE @@ -415,7 +435,7 @@ def patch_torch_operators(): patch_namespace_opname(TracedTensor, op_info) ORIGINAL_OPERATORS.append(OriginalOpInfo("__call__", torch.nn.Module, torch.nn.Module.__call__)) - torch.nn.Module.__call__ = wrap_module_call(torch.nn.Module.__call__) + torch.nn.Module.__call__ = module_call_wrapper(torch.nn.Module.__call__) ignore_scope(DataParallel) ignore_scope(DistributedDataParallel) @@ -442,7 +462,3 @@ def disable_patching(): # Need to restore the previous state of patching in this case before continuing to the exception handling. if was_patched: patch_torch_operators() - - -def is_patched_by_dynamo(model: torch.nn.Module): - return hasattr(model, "forward") and "_torchdynamo_orig_callable" in model.forward.__dict__ diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index f9f15460eb9..6728d5cd904 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -127,13 +127,6 @@ def wrap_module_call(module_call): @functools.wraps(module_call) def wrapped(self, *args, **kwargs): - from nncf.torch.dynamic_graph.patch_pytorch import disable_patching - from nncf.torch.dynamic_graph.patch_pytorch import is_patched_by_dynamo - - if is_patched_by_dynamo(self): - with disable_patching(): - return module_call(self, *args, **kwargs) - ctx = get_current_context() if not ctx or self.__class__ in _IGNORED_SCOPES: if isinstance(self, DataParallel): diff --git a/reproducer.py b/reproducer.py deleted file mode 100644 index 178345c8d7d..00000000000 --- a/reproducer.py +++ /dev/null @@ -1,31 +0,0 @@ -import time - -import nncf.torch -from nncf.torch.dynamic_graph.patch_pytorch import unpatch_torch_operators, patch_torch_operators -# from optimum.intel import OVModelForFeatureExtraction - -import openvino.torch -from transformers import AutoModel, AutoTokenizer -import torch - - -# import nncf.torch - -# import torch._dynamo -# torch._dynamo.config.suppress_errors = True - -tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5") -model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5") -# unpatch_torch_operators() -model = torch.compile(model, backend="openvino") -# patch_torch_operators() -# model = torch.compile(model) - -encoded_input = tokenizer( - ["hello world"], padding=True, truncation=True, return_tensors="pt" -) -with torch.no_grad(): - time.sleep(1) - model_output = model(**encoded_input) - sentence_embeddings = model_output[0][:, 0] -sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) From 133e0f7dc6fa91fb73d784c307296cec306b8f61 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Tue, 21 May 2024 15:31:43 +0200 Subject: [PATCH 12/18] Added docstring; improved test --- nncf/torch/dynamic_graph/patch_pytorch.py | 9 +++++++-- tests/torch/pytorch_patch_isolated.py | 4 ++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 35434497669..09eeb636ba4 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -258,7 +258,12 @@ def wrapper(*args, **kwargs): return wrapper -def module_call_wrapper(module_call): +def module_call_wrapper(): + """ + An additional wrapper over wrap_module_call() that disables torch patching if + model is compiled by torch dynamo. + """ + module_call = torch.nn.Module.__call__ wrapped_module_call = wrap_module_call(module_call) unpatched_module_call = get_disable_patching_wrapper(module_call) @@ -435,7 +440,7 @@ def patch_torch_operators(): patch_namespace_opname(TracedTensor, op_info) ORIGINAL_OPERATORS.append(OriginalOpInfo("__call__", torch.nn.Module, torch.nn.Module.__call__)) - torch.nn.Module.__call__ = module_call_wrapper(torch.nn.Module.__call__) + torch.nn.Module.__call__ = module_call_wrapper() ignore_scope(DataParallel) ignore_scope(DistributedDataParallel) diff --git a/tests/torch/pytorch_patch_isolated.py b/tests/torch/pytorch_patch_isolated.py index aa842657b98..98cdc485824 100644 --- a/tests/torch/pytorch_patch_isolated.py +++ b/tests/torch/pytorch_patch_isolated.py @@ -91,6 +91,10 @@ def compile_and_run_lenet() -> torch.Tensor: model.load_state_dict(state_dict) compiled_model = torch.compile(model) + + # This key is used to check if model is compiled at patch_pytorch.py + assert "_torchdynamo_orig_callable" in compiled_model.forward.__dict__ + return compiled_model(torch.ones([1, 3, 32, 32])) From a97ef09956f5ae4fe211c26579a5683752cc25b1 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Tue, 21 May 2024 15:35:41 +0200 Subject: [PATCH 13/18] Replaced key check by type check --- nncf/torch/dynamic_graph/patch_pytorch.py | 5 +++-- tests/torch/pytorch_patch_isolated.py | 4 ---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 09eeb636ba4..1dbe92954d7 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -16,6 +16,7 @@ import torch import torch.utils.cpp_extension +from torch._dynamo import OptimizedModule from torch._jit_internal import createResolutionCallbackFromFrame from torch.jit import is_tracing from torch.nn import DataParallel @@ -269,8 +270,8 @@ def module_call_wrapper(): @functools.wraps(module_call) def wrapper(self, *args, **kwargs): - # Check if model was patched by torch dynamo during compilation - if "_torchdynamo_orig_callable" in self.forward.__dict__: + # Check if model was compiled by torch dynamo + if isinstance(self, OptimizedModule): unpatched_module_call(self, *args, **kwargs) return wrapped_module_call(self, *args, **kwargs) diff --git a/tests/torch/pytorch_patch_isolated.py b/tests/torch/pytorch_patch_isolated.py index 98cdc485824..aa842657b98 100644 --- a/tests/torch/pytorch_patch_isolated.py +++ b/tests/torch/pytorch_patch_isolated.py @@ -91,10 +91,6 @@ def compile_and_run_lenet() -> torch.Tensor: model.load_state_dict(state_dict) compiled_model = torch.compile(model) - - # This key is used to check if model is compiled at patch_pytorch.py - assert "_torchdynamo_orig_callable" in compiled_model.forward.__dict__ - return compiled_model(torch.ones([1, 3, 32, 32])) From 42752b53fb5e14d3e7496322beab0011ceb57637 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Tue, 21 May 2024 15:37:52 +0200 Subject: [PATCH 14/18] Rename --- nncf/torch/dynamic_graph/patch_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 1dbe92954d7..779030cf7be 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -259,7 +259,7 @@ def wrapper(*args, **kwargs): return wrapper -def module_call_wrapper(): +def get_module_call_wrapper(): """ An additional wrapper over wrap_module_call() that disables torch patching if model is compiled by torch dynamo. @@ -441,7 +441,7 @@ def patch_torch_operators(): patch_namespace_opname(TracedTensor, op_info) ORIGINAL_OPERATORS.append(OriginalOpInfo("__call__", torch.nn.Module, torch.nn.Module.__call__)) - torch.nn.Module.__call__ = module_call_wrapper() + torch.nn.Module.__call__ = get_module_call_wrapper() ignore_scope(DataParallel) ignore_scope(DistributedDataParallel) From b36a727320b9c7e52bbc0e318a516308725f977a Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Tue, 21 May 2024 15:53:43 +0200 Subject: [PATCH 15/18] Addressed changes --- nncf/torch/dynamic_graph/patch_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 779030cf7be..4e80753a74e 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -246,7 +246,7 @@ def wrapper(*args, **kwargs): def get_disable_patching_wrapper(f): """ - :param: f: A callable object to be wrapped. + :param f: A callable object to be wrapped. :return: A wrapper of a function that unpatches torch operators before the function call and patches them back after the call. """ @@ -272,7 +272,7 @@ def get_module_call_wrapper(): def wrapper(self, *args, **kwargs): # Check if model was compiled by torch dynamo if isinstance(self, OptimizedModule): - unpatched_module_call(self, *args, **kwargs) + return unpatched_module_call(self, *args, **kwargs) return wrapped_module_call(self, *args, **kwargs) return wrapper From 169e803f59b9018701aa5a1918844aa283048e3e Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Tue, 21 May 2024 18:49:41 +0200 Subject: [PATCH 16/18] Add exception for compilation of optimized model --- nncf/torch/dynamic_graph/patch_pytorch.py | 22 +++++++++-- tests/torch/test_pytorch_patch.py | 45 ++++++++++++++++++++--- 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 4e80753a74e..9376344a912 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -12,7 +12,7 @@ import functools import inspect from contextlib import contextmanager -from typing import List +from typing import Callable, List, Union import torch import torch.utils.cpp_extension @@ -259,6 +259,22 @@ def wrapper(*args, **kwargs): return wrapper +def get_torch_compile_wrapper(): + """ + Wrapper for torch.compile() that disables NNCF patching when called for vanilla PyTorch model and + raises an exception when called for an NNCF-optimized model. + """ + + @functools.wraps(_ORIG_TORCH_COMPILE) + def wrapper(model, *args, **kwargs): + if hasattr(model, "nncf"): + raise ValueError("At the moment torch.compile() is not supported for models optimized by NNCF.") + with disable_patching(): + return _ORIG_TORCH_COMPILE(model, *args, **kwargs) + + return wrapper + + def get_module_call_wrapper(): """ An additional wrapper over wrap_module_call() that disables torch patching if @@ -292,7 +308,7 @@ def __init__(self, name: str, namespace, op): _ORIG_JIT_SCRIPT = None _ORIG_JIT_TRACE_MAKE_MODULE = None _COMPILE_ALREADY_WRAPPED = False -_ORIG_TORCH_COMPILE = None +_ORIG_TORCH_COMPILE: Union[Callable, None] = None def patch_torch_jit(): @@ -372,7 +388,7 @@ def patch_torch_operators(): if not _COMPILE_ALREADY_WRAPPED: global _ORIG_TORCH_COMPILE _ORIG_TORCH_COMPILE = torch.compile - setattr(torch, "compile", get_disable_patching_wrapper(_ORIG_TORCH_COMPILE)) + setattr(torch, "compile", get_torch_compile_wrapper()) _COMPILE_ALREADY_WRAPPED = True # Do not patch operators twice as well diff --git a/tests/torch/test_pytorch_patch.py b/tests/torch/test_pytorch_patch.py index 9a3c15efd0c..4966da0c43c 100644 --- a/tests/torch/test_pytorch_patch.py +++ b/tests/torch/test_pytorch_patch.py @@ -15,6 +15,7 @@ import pytest import torch +import nncf from nncf.config import NNCFConfig from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.dynamic_graph.patch_pytorch import _ORIG_JIT_SCRIPT @@ -112,6 +113,31 @@ def test_torch_compile(): run_pytest_case_function_in_separate_process(test_compile) +def test_torch_compile_on_nncf_model(): + model = BasicConvTestModel() + quantized_model = nncf.quantize(model, nncf.Dataset([torch.rand(model.INPUT_SIZE)])) + with pytest.raises(ValueError) as e: + torch.compile(quantized_model) + assert "At the moment torch.compile() is not supported for models optimized by NNCF." in str(e.value) + + model = BasicConvTestModel() + config = get_test_quantization_config(model) + compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) + with pytest.raises(ValueError) as e: + torch.compile(compressed_model) + assert "At the moment torch.compile() is not supported for models optimized by NNCF." in str(e.value) + + stripped_model = compression_ctrl.strip() + with pytest.raises(ValueError) as e: + torch.compile(stripped_model) + assert "At the moment torch.compile() is not supported for models optimized by NNCF." in str(e.value) + + with pytest.raises(ValueError) as e: + # Compiling this model would actually work, but inference of the compiled model will fail + torch.compile(model) + assert "At the moment torch.compile() is not supported for models optimized by NNCF." in str(e.value) + + def test_jit_script_signature(): # Check that torch.jit.script has the same signature as the wrapper was designed for signature = inspect.signature(_ORIG_JIT_SCRIPT) @@ -133,6 +159,18 @@ def class_method(self, x): def test_jit_trace_model(): model = BasicConvTestModel() + config = get_test_quantization_config(model) + + compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) + torch.jit.trace(compressed_model, example_inputs=torch.rand(model.INPUT_SIZE)) + + model = compression_ctrl.strip() + torch.jit.trace(model, example_inputs=torch.rand(model.INPUT_SIZE)) + + +def get_test_quantization_config( + model, +): config = NNCFConfig() config.update( { @@ -142,9 +180,4 @@ def test_jit_trace_model(): } ) register_bn_adaptation_init_args(config) - - compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) - torch.jit.trace(compressed_model, example_inputs=torch.rand(model.INPUT_SIZE)) - - model = compression_ctrl.strip() - torch.jit.trace(model, example_inputs=torch.rand(model.INPUT_SIZE)) + return config From 4adc0c415c1ca24f89be723cbfef652bb3b52b95 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Thu, 23 May 2024 10:34:38 +0200 Subject: [PATCH 17/18] Apply review suggestions --- nncf/torch/dynamic_graph/patch_pytorch.py | 50 ++++++----------------- nncf/torch/dynamic_graph/wrappers.py | 7 ++++ tests/torch/pytorch_patch_isolated.py | 18 ++++---- tests/torch/test_pytorch_patch.py | 29 ++++++++----- 4 files changed, 44 insertions(+), 60 deletions(-) diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 9376344a912..78a38bcd8df 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -16,7 +16,6 @@ import torch import torch.utils.cpp_extension -from torch._dynamo import OptimizedModule from torch._jit_internal import createResolutionCallbackFromFrame from torch.jit import is_tracing from torch.nn import DataParallel @@ -244,21 +243,6 @@ def wrapper(*args, **kwargs): return wrapper -def get_disable_patching_wrapper(f): - """ - :param f: A callable object to be wrapped. - :return: A wrapper of a function that unpatches torch operators before the - function call and patches them back after the call. - """ - - @functools.wraps(f) - def wrapper(*args, **kwargs): - with disable_patching(): - return f(*args, **kwargs) - - return wrapper - - def get_torch_compile_wrapper(): """ Wrapper for torch.compile() that disables NNCF patching when called for vanilla PyTorch model and @@ -267,33 +251,16 @@ def get_torch_compile_wrapper(): @functools.wraps(_ORIG_TORCH_COMPILE) def wrapper(model, *args, **kwargs): - if hasattr(model, "nncf"): - raise ValueError("At the moment torch.compile() is not supported for models optimized by NNCF.") + from nncf.torch.nncf_network import NNCFNetwork + + if isinstance(model, NNCFNetwork): + raise TypeError("At the moment torch.compile() is not supported for models optimized by NNCF.") with disable_patching(): return _ORIG_TORCH_COMPILE(model, *args, **kwargs) return wrapper -def get_module_call_wrapper(): - """ - An additional wrapper over wrap_module_call() that disables torch patching if - model is compiled by torch dynamo. - """ - module_call = torch.nn.Module.__call__ - wrapped_module_call = wrap_module_call(module_call) - unpatched_module_call = get_disable_patching_wrapper(module_call) - - @functools.wraps(module_call) - def wrapper(self, *args, **kwargs): - # Check if model was compiled by torch dynamo - if isinstance(self, OptimizedModule): - return unpatched_module_call(self, *args, **kwargs) - return wrapped_module_call(self, *args, **kwargs) - - return wrapper - - class OriginalOpInfo: def __init__(self, name: str, namespace, op): self.name = name @@ -311,6 +278,13 @@ def __init__(self, name: str, namespace, op): _ORIG_TORCH_COMPILE: Union[Callable, None] = None +@functools.wraps(ORIGINAL_CALL) +def unpatching_module_call(*args, **kwargs): + # Wrapper for module.__call__ that unpatches torch operators during model forward + with disable_patching(): + return ORIGINAL_CALL(*args, **kwargs) + + def patch_torch_jit(): # This import statement is required, otherwise we get a # "RuntimeError: undefined value torch" inside the real torch.jit.script @@ -457,7 +431,7 @@ def patch_torch_operators(): patch_namespace_opname(TracedTensor, op_info) ORIGINAL_OPERATORS.append(OriginalOpInfo("__call__", torch.nn.Module, torch.nn.Module.__call__)) - torch.nn.Module.__call__ = get_module_call_wrapper() + torch.nn.Module.__call__ = wrap_module_call(torch.nn.Module.__call__) ignore_scope(DataParallel) ignore_scope(DistributedDataParallel) diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index 6728d5cd904..d59b2fa6c5a 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -13,6 +13,7 @@ from typing import Callable, List, Tuple import torch +from torch._dynamo import OptimizedModule from torch.nn import DataParallel from nncf.common.graph.definitions import MODEL_CONST_OP_NAME @@ -127,6 +128,12 @@ def wrap_module_call(module_call): @functools.wraps(module_call) def wrapped(self, *args, **kwargs): + from nncf.torch.dynamic_graph.patch_pytorch import unpatching_module_call + + # If called on a model compiled by torch dynamo, we unpatch torch operators and invoke original module call + if isinstance(self, OptimizedModule): + return unpatching_module_call(self, *args, **kwargs) + ctx = get_current_context() if not ctx or self.__class__ in _IGNORED_SCOPES: if isinstance(self, DataParallel): diff --git a/tests/torch/pytorch_patch_isolated.py b/tests/torch/pytorch_patch_isolated.py index aa842657b98..ce26f67eea3 100644 --- a/tests/torch/pytorch_patch_isolated.py +++ b/tests/torch/pytorch_patch_isolated.py @@ -79,25 +79,21 @@ def test_jit_script_exception_preserves_patching_isolated(): assert "nncf" in torch.nn.Module.__call__.__code__.co_filename -def compile_and_run_lenet() -> torch.Tensor: - from tests.torch.test_models.lenet import LeNet +def compile_and_run_test_model() -> torch.Tensor: + from tests.torch.helpers import BasicConvTestModel - model = LeNet() - - torch.manual_seed(0) - state_dict = {} - for k, v in model.state_dict().items(): - state_dict[k] = torch.rand(v.shape) + model = BasicConvTestModel() + state_dict = {"conv.weight": model.default_weight(), "conv.bias": model.default_bias()} model.load_state_dict(state_dict) compiled_model = torch.compile(model) - return compiled_model(torch.ones([1, 3, 32, 32])) + return compiled_model(torch.ones(model.INPUT_SIZE)) @pytest.mark.skipif(ISOLATION_RUN_ENV_VAR not in os.environ, reason="Should be run via isolation proxy") def test_compile(): - before_nncf = compile_and_run_lenet() + before_nncf = compile_and_run_test_model() import nncf.torch # noqa: F401 - after_nncf = compile_and_run_lenet() + after_nncf = compile_and_run_test_model() assert torch.allclose(before_nncf, after_nncf) diff --git a/tests/torch/test_pytorch_patch.py b/tests/torch/test_pytorch_patch.py index 4966da0c43c..c800b74ce0a 100644 --- a/tests/torch/test_pytorch_patch.py +++ b/tests/torch/test_pytorch_patch.py @@ -114,28 +114,37 @@ def test_torch_compile(): def test_torch_compile_on_nncf_model(): + # Calling torch.compile on a regular torch model should work fine + model = BasicConvTestModel() + compiled_model = torch.compile(model) + compiled_model(torch.ones(model.INPUT_SIZE)) + model = BasicConvTestModel() quantized_model = nncf.quantize(model, nncf.Dataset([torch.rand(model.INPUT_SIZE)])) - with pytest.raises(ValueError) as e: + with pytest.raises( + TypeError, match="At the moment torch\\.compile\\(\\) is not supported for models optimized by NNCF\\." + ): torch.compile(quantized_model) - assert "At the moment torch.compile() is not supported for models optimized by NNCF." in str(e.value) model = BasicConvTestModel() config = get_test_quantization_config(model) compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) - with pytest.raises(ValueError) as e: + with pytest.raises( + TypeError, match="At the moment torch\\.compile\\(\\) is not supported for models optimized by NNCF\\." + ): torch.compile(compressed_model) - assert "At the moment torch.compile() is not supported for models optimized by NNCF." in str(e.value) stripped_model = compression_ctrl.strip() - with pytest.raises(ValueError) as e: + with pytest.raises( + TypeError, match="At the moment torch\\.compile\\(\\) is not supported for models optimized by NNCF\\." + ): torch.compile(stripped_model) - assert "At the moment torch.compile() is not supported for models optimized by NNCF." in str(e.value) - with pytest.raises(ValueError) as e: + with pytest.raises( + TypeError, match="At the moment torch\\.compile\\(\\) is not supported for models optimized by NNCF\\." + ): # Compiling this model would actually work, but inference of the compiled model will fail torch.compile(model) - assert "At the moment torch.compile() is not supported for models optimized by NNCF." in str(e.value) def test_jit_script_signature(): @@ -168,9 +177,7 @@ def test_jit_trace_model(): torch.jit.trace(model, example_inputs=torch.rand(model.INPUT_SIZE)) -def get_test_quantization_config( - model, -): +def get_test_quantization_config(model): config = NNCFConfig() config.update( { From d91c650cf0250c7624ad85c2699a983855838617 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Thu, 23 May 2024 11:24:17 +0200 Subject: [PATCH 18/18] Trigger checks