Skip to content

Commit

Permalink
Correct device assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Nov 8, 2023
1 parent 96de397 commit 77888da
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion nncf/torch/dynamic_graph/io_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.
import abc
from copy import deepcopy
from functools import partial
from inspect import Parameter
from inspect import Signature
from typing import Any, Dict, List, Optional, Protocol, Set, Tuple, Type
Expand Down Expand Up @@ -227,7 +228,14 @@ def __init__(self, forward_args: Tuple, forward_kwargs: Dict):
self._forward_kwargs = forward_kwargs

def get_forward_inputs(self, device: str = None) -> Tuple[Tuple, Dict]:
return self._forward_args, self._forward_kwargs
if device is None:
return self._forward_args, self._forward_kwargs
to_device_fn = partial(torch.Tensor.to, device=device)
args_copy = deepcopy(self._forward_args)
kwargs_copy = deepcopy(self._forward_kwargs)
args_at_device = objwalk(args_copy, is_tensor, to_device_fn)
kwargs_at_device = objwalk(kwargs_copy, is_tensor, to_device_fn)
return args_at_device, kwargs_at_device

@classmethod
def from_nncf_dataset(cls, dataset: Dataset) -> "ExactInputsInfo":
Expand Down

0 comments on commit 77888da

Please sign in to comment.