diff --git a/nncf/torch/dynamic_graph/io_handling.py b/nncf/torch/dynamic_graph/io_handling.py index 2e69d22a1ac..47a020a9bcb 100644 --- a/nncf/torch/dynamic_graph/io_handling.py +++ b/nncf/torch/dynamic_graph/io_handling.py @@ -10,6 +10,7 @@ # limitations under the License. import abc from copy import deepcopy +from functools import partial from inspect import Parameter from inspect import Signature from typing import Any, Dict, List, Optional, Protocol, Set, Tuple, Type @@ -227,7 +228,14 @@ def __init__(self, forward_args: Tuple, forward_kwargs: Dict): self._forward_kwargs = forward_kwargs def get_forward_inputs(self, device: str = None) -> Tuple[Tuple, Dict]: - return self._forward_args, self._forward_kwargs + if device is None: + return self._forward_args, self._forward_kwargs + to_device_fn = partial(torch.Tensor.to, device=device) + args_copy = deepcopy(self._forward_args) + kwargs_copy = deepcopy(self._forward_kwargs) + args_at_device = objwalk(args_copy, is_tensor, to_device_fn) + kwargs_at_device = objwalk(kwargs_copy, is_tensor, to_device_fn) + return args_at_device, kwargs_at_device @classmethod def from_nncf_dataset(cls, dataset: Dataset) -> "ExactInputsInfo":