Skip to content

Commit

Permalink
[Fix] Patching nncf_model_inputs (openvinotoolkit#2278)
Browse files Browse the repository at this point in the history
### Changes
Strip traced tensors when exit from the tracing context

### Reason for changes

`nncf_model_input` patched `torch.Tensor` with `TracedTensor` which does
not support deepcopy.

### Related tickets

125357

### Tests
test_wrap_model_with_example_input
e2e 506
test_examples 154
  • Loading branch information
alexsu52 authored Nov 28, 2023
1 parent 2362c02 commit ee78ab6
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 64 deletions.
34 changes: 19 additions & 15 deletions nncf/torch/dynamic_graph/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def reset(self):
self.operator_counters = {}
self.node_call_tracker = {}
self.traced_tensor_weakrefs = []
self.nested_contexts_stack = []


class CopySafeThreadingVars:
Expand All @@ -90,7 +91,6 @@ class TracingContext:
def __init__(self):
self.graph = DynamicGraph()

self._save_context = None
self._post_hooks = {}
self._pre_hooks: Dict[PreHookId, List[Callable]] = {}
self._num_nested_hooks = 0
Expand Down Expand Up @@ -122,25 +122,29 @@ def __enter__(self):
# all replicas. Otherwise we will have data races on setting and reading the global _CURRENT_CONTEXT
# variable, which will in turn lead to DP-specific runtime errors such as
# "'_thread._local' object has no attribute 'scopes'"
self._save_context = get_current_context()
self._threading.thread_local.nested_contexts_stack.append(get_current_context())
set_current_context(self)
self._reset_thread_local()
if is_debug():
self.reset_node_call_counters()

return self

def __exit__(self, *args):
if self._save_context is not self: # NNCFNetwork.rebuild_graph() uses the compressed context nested in self
for traced_tensor_weakref in self._threading.thread_local.traced_tensor_weakrefs:
tt = traced_tensor_weakref()
if tt is not None:
tt.nncf_expire()

self._reset_thread_local()

set_current_context(self._save_context)
self._save_context = None
previous_context = self._threading.thread_local.nested_contexts_stack.pop(-1)
for traced_tensor_weakref in self._threading.thread_local.traced_tensor_weakrefs:
tt = traced_tensor_weakref()
if tt is None or not isinstance(tt, TracedTensor):
continue
if previous_context is None:
tt.strip()
elif previous_context is not self:
previous_context.register_traced_tensor(tt)

if previous_context is not self:
self._reset_thread_local()

if is_debug():
self.reset_node_call_counters()

set_current_context(previous_context)

def find_operator_node(
self, tensor_metas: List[Optional[TensorMeta]], op_address: OperationAddress
Expand Down
49 changes: 32 additions & 17 deletions nncf/torch/dynamic_graph/trace_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union

import numpy as np
import torch

from nncf import nncf_logger
from nncf.common.graph.layer_attributes import Dtype
from nncf.torch.dynamic_graph.op_input_processing import OperatorInput
from nncf.torch.nested_objects_traversal import objwalk


class TensorMeta:
Expand Down Expand Up @@ -57,23 +58,28 @@ class TracedTensor(torch.Tensor):
"""

@staticmethod
def from_torch_tensor(tensor, tensor_meta: TensorMeta):
def from_torch_tensor(tensor: torch.Tensor, tensor_meta: TensorMeta) -> "TracedTensor":
"""
Creates a TracedTensor by patching a given torch.Tensor, associating it with the provided tensor_meta.
:param tensor: The input torch.Tensor.
:param tensor_meta: The metadata associated with the tensor.
:return: The resulting TracedTensor.
"""
tensor.tensor_meta = tensor_meta
tensor.__class__ = TracedTensor
if not isinstance(tensor, TracedTensor):
tensor.original_class = tensor.__class__
tensor.__class__ = TracedTensor

tensor._nncf_expired = False
return tensor

def nncf_expire(self):
def strip(self) -> None:
"""
Mark the traced tensor as "expired". The tensor's metainformation should
then be considered outdated/invalid.
Reverts the tensor to its original class by removing tracing attributes.
"""
self._nncf_expired = True

@property
def nncf_expired(self) -> bool:
return self._nncf_expired
self.__class__ = self.original_class
delattr(self, "tensor_meta")
delattr(self, "original_class")

def as_subclass(self, cls: "TracedTensor") -> "TracedTensor":
"""
Expand Down Expand Up @@ -166,14 +172,23 @@ def make_tensor_metas(inputs: OperatorInput) -> List[Optional[TensorMeta]]:
for i, node_input_index_entry in enumerate(inputs):
node_input = node_input_index_entry.getter()
if isinstance(node_input, TracedTensor):
if not node_input.nncf_expired:
tensor_metas.append(node_input.tensor_meta)
else:
meta = TensorMeta(None, i, node_input.shape)
tensor_metas.append(meta)
tensor_metas.append(node_input.tensor_meta)
elif isinstance(node_input, torch.Tensor) and not isinstance(node_input, TracedTensor):
meta = TensorMeta(None, i, node_input.shape)
tensor_metas.append(meta)
else:
tensor_metas.append(None)
return tensor_metas


def strip_traced_tensors(args: Tuple, kwargs: Dict) -> Tuple[Tuple, Dict]:
"""
Required to guard against new forward calls on tensors that have already passed
through NNCF's forward once and got turned into TracedTensors by reference access.
"""
is_traced_tensor_predicate = lambda x: isinstance(x, TracedTensor)
strip_traced_tensor = lambda x: x.strip()

args = objwalk(args, is_traced_tensor_predicate, strip_traced_tensor)
kwargs = objwalk(kwargs, is_traced_tensor_predicate, strip_traced_tensor)
return args, kwargs
37 changes: 11 additions & 26 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from nncf.torch.dynamic_graph.patch_pytorch import ORIGINAL_CALL
from nncf.torch.dynamic_graph.scope import Scope
from nncf.torch.dynamic_graph.scope_access import get_module_by_scope
from nncf.torch.dynamic_graph.trace_tensor import TracedTensor
from nncf.torch.dynamic_graph.trace_tensor import strip_traced_tensors
from nncf.torch.dynamic_graph.wrappers import wrap_module_call
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph_builder import GraphBuilder
Expand All @@ -61,7 +61,6 @@
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler
from nncf.torch.layer_utils import _NNCFModuleMixin
from nncf.torch.nested_objects_traversal import objwalk
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME
from nncf.torch.utils import compute_FLOPs_hook
Expand Down Expand Up @@ -245,14 +244,17 @@ def __init__(
self._extra_module_types: List[ExtraCompressionModuleType] = []
self._insertions_into_original_graph: Dict[PTTargetPoint, List[Tuple[Callable, TransformationPriority]]] = {}

_orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building(
with_input_tracing=True, with_output_tracing=True
)
if isinstance(model, NNCFNetwork):
self._nncf_replaced_modules = model.nncf._nncf_replaced_modules
else:
_orig_graph_build_forward_fn = self._get_dummy_forward_fn_for_graph_building(
with_input_tracing=True, with_output_tracing=True
)

eval_op_scopes = self._collect_eval_op_scopes(model, _orig_graph_build_forward_fn)
eval_op_scopes = self._collect_eval_op_scopes(model, _orig_graph_build_forward_fn)

# all modules called in eval mode should be replaced prior to graph building
self._replace_modules_by_nncf_modules(model, eval_op_scopes)
# all modules called in eval mode should be replaced prior to graph building
self._replace_modules_by_nncf_modules(model, eval_op_scopes)

_orig_context = TracingContext()

Expand Down Expand Up @@ -314,23 +316,6 @@ def _model_ref(self) -> "NNCFNetwork":
def input_infos(self) -> ModelInputInfo:
return deepcopy(self._input_info)

def _strip_traced_tensors(self, args: Tuple, kwargs: Dict) -> Tuple[Tuple, Dict]:
"""
Required to guard against new forward calls on tensors that have already passed
through NNCF's forward once and got turned into TracedTensors by reference access.
"""
is_traced_tensor_predicate = lambda x: isinstance(x, TracedTensor)

def strip_fn(tensor: TracedTensor) -> torch.Tensor:
if hasattr(torch.Tensor, "as_subclass"):
return torch.Tensor.as_subclass(tensor, torch.Tensor)
# Torch < 1.7.0 fallback
return torch.tensor(tensor, device=tensor.device, requires_grad=tensor.requires_grad)

args = objwalk(args, is_traced_tensor_predicate, strip_fn)
kwargs = objwalk(kwargs, is_traced_tensor_predicate, strip_fn)
return args, kwargs

def create_knowledge_distillation_loss_handler(
self, kd_original_model: nn.Module, calculate_fn
) -> KnowledgeDistillationLossHandler:
Expand Down Expand Up @@ -935,7 +920,7 @@ def forward(self, *args, **kwargs):
if not self.nncf._in_user_dummy_forward:
# If a user supplies own dummy forward, he is responsible for
# correctly wrapping inputs inside it as well.
args, kwargs = self.nncf._strip_traced_tensors(args, kwargs)
args, kwargs = strip_traced_tensors(args, kwargs)
args, kwargs = self.nncf._wrap_inputs_fn(args, kwargs)

# For purposes of scope tracking, need the original forward call to occur as if it were
Expand Down
9 changes: 8 additions & 1 deletion tests/torch/ptq/test_wrap_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from nncf.torch.dynamic_graph.context import no_nncf_trace
from nncf.torch.model_creation import wrap_model
from nncf.torch.nested_objects_traversal import objwalk


class ArgumentModel(nn.Module):
Expand Down Expand Up @@ -64,9 +65,15 @@ def forward(self, *, x, y):
def test_wrap_model_with_example_input(example_input, model_cls):
model = model_cls(example_input)
nncf_network = wrap_model(model, example_input)

def check_type(x):
assert type(x) == torch.Tensor
return x

objwalk(example_input, lambda x: True, check_type)

nncf_graph = nncf_network.nncf.get_original_graph()
all_nodes = nncf_graph.get_all_nodes()

num_nodes = 2
if isinstance(example_input, (tuple, dict)):
num_nodes *= len(example_input)
Expand Down
8 changes: 8 additions & 0 deletions tests/torch/test_nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,3 +839,11 @@ def test_multidevice_model():
input_info = ExampleInputInfo.from_example_input(example_input)
nncf_model = NNCFNetwork(model, input_info)
nncf_model(*example_input)


def test_access_to_input_info():
model = SimplestModel()
example_input = torch.ones(SimplestModel.INPUT_SIZE)
input_info = ExampleInputInfo.from_example_input(example_input)
nncf_model = NNCFNetwork(model, input_info)
nncf_model.nncf.input_infos
38 changes: 33 additions & 5 deletions tests/torch/test_tracing_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,16 @@ def forward(self, x):
return self.conv2d(x)


def test_traced_tensors_are_expired_on_context_exit():
def test_traced_tensors_are_stripped_on_context_exit():
module = ModuleForTest()
module.train()
tensor = torch.ones([1, 1, 1, 1])
with TracingContext():
result = module(tensor)
assert isinstance(module.cached_tensor, TracedTensor)
assert module.cached_tensor.nncf_expired
assert isinstance(result, TracedTensor)
assert result.nncf_expired
assert isinstance(module.cached_tensor, TracedTensor)
assert isinstance(result, TracedTensor)
assert isinstance(module.cached_tensor, torch.Tensor)
assert isinstance(result, torch.Tensor)


def test_no_cross_forward_run_dependency():
Expand All @@ -118,3 +118,31 @@ def test_no_cross_forward_run_dependency():
ctx.enable_trace_dynamic_graph()
_ = module(tensor)
ctx.disable_trace_dynamic_graph()


@pytest.mark.parametrize(
"contexts",
[3 * [TracingContext()], [TracingContext(), TracingContext(), TracingContext()]],
ids=["same", "different"],
)
def test_nested_contexts(contexts):
module = ModuleForTest()
module.train()
tensor = torch.ones([1, 1, 1, 1])
nesting_count = [1]
with contexts[0]:
with contexts[1]:
count = contexts[:2].count(contexts[1])
nesting_count.append(count)
with contexts[2]:
count = contexts.count(contexts[2])
nesting_count.append(count)
module(tensor)
assert len(contexts[2]._threading.thread_local.nested_contexts_stack) == nesting_count[2]
assert len(contexts[2]._threading.thread_local.traced_tensor_weakrefs) > 0
assert len(contexts[1]._threading.thread_local.nested_contexts_stack) == nesting_count[1]
assert len(contexts[1]._threading.thread_local.traced_tensor_weakrefs) > 0
assert len(contexts[0]._threading.thread_local.nested_contexts_stack) == nesting_count[0]
assert contexts[0]._threading.thread_local.nested_contexts_stack[0] is None
assert len(contexts[0]._threading.thread_local.traced_tensor_weakrefs) > 0
assert len(contexts[0]._threading.thread_local.traced_tensor_weakrefs) == 0

0 comments on commit ee78ab6

Please sign in to comment.