diff --git a/examples/torch/common/export.py b/examples/torch/common/export.py index 7d98bcbf70d..0eb72071b80 100644 --- a/examples/torch/common/export.py +++ b/examples/torch/common/export.py @@ -29,7 +29,7 @@ def export_model(ctrl: CompressionAlgorithmController, save_path: str, no_strip_ model = model.eval().cpu() - export_args = get_export_args(model) + export_args = get_export_args(model, device="cpu") input_names = generate_input_names_list(count_tensors(export_args)) with torch.no_grad(): diff --git a/nncf/torch/dynamic_graph/graph_tracer.py b/nncf/torch/dynamic_graph/graph_tracer.py index ee5c84e4e19..89880beda81 100644 --- a/nncf/torch/dynamic_graph/graph_tracer.py +++ b/nncf/torch/dynamic_graph/graph_tracer.py @@ -13,10 +13,14 @@ import torch +from nncf.common.logging import nncf_logger from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.dynamic_graph.graph import DynamicGraph +from nncf.torch.dynamic_graph.io_handling import FillerInputInfo +from nncf.torch.dynamic_graph.io_handling import LoaderInputInfo from nncf.torch.dynamic_graph.io_handling import ModelInputInfo from nncf.torch.utils import get_model_device +from nncf.torch.utils import is_multidevice class GraphTracer: @@ -65,8 +69,16 @@ def default_dummy_forward_fn(model): from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_outputs_with_objwalk - device = get_model_device(model) - args, kwargs = input_info.get_forward_inputs(device=str(device)) + device = None + if isinstance(input_info, (FillerInputInfo, LoaderInputInfo)): + if is_multidevice(model): + nncf_logger.warning( + "Multidevice model detected when tracing the model's dynamic graph - will pass example " + "inputs to the model as-is without changing their device." + ) + else: + device = get_model_device(model) + args, kwargs = input_info.get_forward_inputs(device) if with_input_tracing: if wrap_inputs_fn is None: diff --git a/nncf/torch/dynamic_graph/io_handling.py b/nncf/torch/dynamic_graph/io_handling.py index 639ba0784f4..9d0b7b8e721 100644 --- a/nncf/torch/dynamic_graph/io_handling.py +++ b/nncf/torch/dynamic_graph/io_handling.py @@ -13,7 +13,7 @@ from functools import partial from inspect import Parameter from inspect import Signature -from typing import Any, Dict, List, Optional, Protocol, Set, Tuple, Type +from typing import Any, Dict, List, Optional, Protocol, Set, Tuple, Type, Union import torch @@ -87,7 +87,7 @@ class ModelInputInfo(abc.ABC): """ @abc.abstractmethod - def get_forward_inputs(self, device: str = None) -> Tuple[Tuple, Dict]: + def get_forward_inputs(self, device: Optional[Union[str, torch.device]] = None) -> Tuple[Tuple, Dict]: """ Returns the tuple of (args, kwargs) for passing into the compressed model's forward method when necessary. The returned arguments should be such that the model's forward with these arguments executes the main @@ -204,7 +204,9 @@ def from_nncf_config(cls, config: NNCFConfig): return FillerInputInfo(elements) raise RuntimeError("Invalid input_infos specified in config - should be either dict or list of dicts") - def get_forward_inputs(self, device: str = None) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]: + def get_forward_inputs( + self, device: Optional[Union[str, torch.device]] = None + ) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]: args_list = [] kwargs = {} for fe in self.elements: @@ -227,7 +229,7 @@ def __init__(self, forward_args: Tuple, forward_kwargs: Dict): self._forward_args = forward_args self._forward_kwargs = forward_kwargs - def get_forward_inputs(self, device: str = None) -> Tuple[Tuple, Dict]: + def get_forward_inputs(self, device: Optional[Union[str, torch.device]] = None) -> Tuple[Tuple, Dict]: if device is None: return self._forward_args, self._forward_kwargs to_device_fn = partial(torch.Tensor.to, device=device) @@ -237,34 +239,38 @@ def get_forward_inputs(self, device: str = None) -> Tuple[Tuple, Dict]: kwargs_at_device = objwalk(kwargs_copy, is_tensor, to_device_fn) return args_at_device, kwargs_at_device + +class ExampleInputInfo(ExactInputsInfo): @classmethod - def from_example_input(cls, example_input: Any) -> "ExactInputsInfo": + def from_example_input(cls, example_input: Any) -> "ExampleInputInfo": """ - Builds an ExactInputsInfo object based on the example input. + Builds an ExampleInputInfo object based on the example input. :param dataset: An nncf.Dataset whose first element will be used as an example model input - :return: An initialized ExactInputsInfo object. + :return: An initialized ExampleInputInfo object. """ if isinstance(example_input, tuple): - return ExactInputsInfo(example_input, {}) + return ExampleInputInfo(example_input, {}) if isinstance(example_input, dict): - return ExactInputsInfo(tuple(), example_input) - return ExactInputsInfo((example_input,), {}) + return ExampleInputInfo(tuple(), example_input) + return ExampleInputInfo((example_input,), {}) @classmethod - def from_nncf_dataset(cls, dataset: Dataset) -> "ExactInputsInfo": + def from_nncf_dataset(cls, dataset: Dataset) -> "ExampleInputInfo": """ - Checks the first element of the provided nncf.Dataset and builds an ExactInputsInfo object that would + Checks the first element of the provided nncf.Dataset and builds an ExampleInputInfo object that would provide the same input to the model at corresponding compression stages. :param dataset: An nncf.Dataset whose first element will be used as an example model input - :return: An initialized ExactInputsInfo object. + :return: An initialized ExampleInputInfo object. """ example_input = next(iter(dataset.get_inference_data())) return cls.from_example_input(example_input) + +class LoaderInputInfo(ExactInputsInfo): @classmethod - def from_nncf_config_dataloaders(cls, config: NNCFConfig) -> Optional["ExactInputsInfo"]: + def from_nncf_config_dataloaders(cls, config: NNCFConfig) -> Optional["LoaderInputInfo"]: """ Examines the user-provided structures registered with the NNCFConfig instance used for compression to find structures that contain a dataloader. The dataloader's first element is used to provide an example input to @@ -273,7 +279,7 @@ def from_nncf_config_dataloaders(cls, config: NNCFConfig) -> Optional["ExactInpu :param config: An nncf.NNCFConfig instance. Must have at least one NNCFExtraConfigStruct attached that can provide a dataloader (these are listed in nncf.torch.dynamic_graph.io_handling.EXTRA_STRUCTS_WITH_DATALOADERS) - :return: An initialized ExactInputsInfo object. + :return: An initialized LoaderInputInfo object. """ extra_structs = config.get_all_extra_structs() for extra_struct in extra_structs: @@ -283,7 +289,7 @@ def from_nncf_config_dataloaders(cls, config: NNCFConfig) -> Optional["ExactInpu wrapped_dataloader = wrap_dataloader_for_init(dataloader) dataloader_output = next(iter(wrapped_dataloader)) args, kwargs = wrapped_dataloader.get_inputs(dataloader_output) - return ExactInputsInfo(args, kwargs) + return LoaderInputInfo(args, kwargs) # config extra structs had no suitable dataloaders return None diff --git a/nncf/torch/exporter.py b/nncf/torch/exporter.py index 0a9a8d45bb4..4aa35e1bd9e 100644 --- a/nncf/torch/exporter.py +++ b/nncf/torch/exporter.py @@ -45,8 +45,10 @@ def counter_fn(x: torch.Tensor) -> torch.Tensor: return count -def get_export_args(model: NNCFNetwork, model_args: Optional[Tuple[Any, ...]] = None) -> Tuple: - args, kwargs = model.nncf.input_infos.get_forward_inputs() +def get_export_args( + model: NNCFNetwork, model_args: Optional[Tuple[Any, ...]] = None, device: Optional[str] = None +) -> Tuple: + args, kwargs = model.nncf.input_infos.get_forward_inputs(device) if model_args is not None: args = tuple(list(args) + list(model_args[:-1])) @@ -142,7 +144,7 @@ def _export_to_onnx(self, save_path: str, opset_version: int) -> None: original_device = get_model_device(self._model) model = self._model.eval().cpu() - export_args = get_export_args(self._model, model_args=self._model_args) + export_args = get_export_args(self._model, model_args=self._model_args, device="cpu") if self._input_names is not None: input_names = self._input_names diff --git a/nncf/torch/model_creation.py b/nncf/torch/model_creation.py index f2e6b0ce9b0..0e04c804f3e 100644 --- a/nncf/torch/model_creation.py +++ b/nncf/torch/model_creation.py @@ -35,8 +35,9 @@ from nncf.torch.dynamic_graph.graph_tracer import WrapInputsFnType from nncf.torch.dynamic_graph.graph_tracer import WrapOutputsFnType from nncf.torch.dynamic_graph.io_handling import EXTRA_STRUCTS_WITH_DATALOADERS -from nncf.torch.dynamic_graph.io_handling import ExactInputsInfo +from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo from nncf.torch.dynamic_graph.io_handling import FillerInputInfo +from nncf.torch.dynamic_graph.io_handling import LoaderInputInfo from nncf.torch.dynamic_graph.io_handling import ModelInputInfo from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.utils import is_dist_avail_and_initialized @@ -167,7 +168,7 @@ def get_input_info_from_config(config: NNCFConfig) -> ModelInputInfo: nncf_logger.debug( "Config has no 'input_info' section, trying to use dataloader output as model inputs " "for graph building." ) - exact_info = ExactInputsInfo.from_nncf_config_dataloaders(config) + exact_info = LoaderInputInfo.from_nncf_config_dataloaders(config) if exact_info is not None: return exact_info raise RuntimeError( @@ -323,7 +324,7 @@ def wrap_model(model: torch.nn.Module, example_input: Any) -> NNCFNetwork: :return: A model wrapped by NNCFNetwork. """ - input_info = ExactInputsInfo.from_example_input(example_input) + input_info = ExampleInputInfo.from_example_input(example_input) with training_mode_switcher(model, is_training=False): nncf_network = NNCFNetwork(model, input_info=input_info) diff --git a/nncf/torch/model_transformer.py b/nncf/torch/model_transformer.py index db4bf2fec72..087ea85df6c 100644 --- a/nncf/torch/model_transformer.py +++ b/nncf/torch/model_transformer.py @@ -32,6 +32,8 @@ from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook +from nncf.torch.utils import get_model_device +from nncf.torch.utils import is_multidevice class PTModelTransformer(ModelTransformer): @@ -119,15 +121,21 @@ def _apply_quantizer_insertion_transformations( model.nncf.register_compression_module_type(compression_model_type) insertion_commands: List[PTInsertionCommand] = [] + device = None + if not is_multidevice(model): + device = get_model_device(model) for transformation_command in transformations: target_point: PTTargetPoint = transformation_command.target_point - fn = transformation_command.quantizer + quantizer_module = transformation_command.quantizer + if device is not None: + quantizer_module = quantizer_module.to(device) + fn = quantizer_module if target_point.type is not TargetType.OPERATION_WITH_WEIGHTS: quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id) storage_key = str(quantizer_id) - model.nncf.add_compression_module(storage_key, transformation_command.quantizer, compression_model_type) + model.nncf.add_compression_module(storage_key, quantizer_module, compression_model_type) fn = ExternalQuantizerCallHook(model.nncf.get_tracing_context(), storage_key) insertion_commands.append( diff --git a/nncf/torch/nncf_module_replacement.py b/nncf/torch/nncf_module_replacement.py index fdaaaf0d8bd..38d6a3c5254 100644 --- a/nncf/torch/nncf_module_replacement.py +++ b/nncf/torch/nncf_module_replacement.py @@ -25,6 +25,7 @@ from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT from nncf.torch.layers import UNWRAPPED_USER_MODULES from nncf.torch.layers import add_nncf_functionality_to_user_module +from nncf.torch.utils import get_model_device def is_nncf_module(module: nn.Module) -> bool: @@ -181,11 +182,14 @@ def replace_modules_by_nncf_modules( scope_set, ignored_scopes, target_scopes, eval_op_scopes ) and not _is_module_only_in_user_module(scope_set) if should_process: + device = get_model_device(module) + if custom_replacer is not None: replaced_module = custom_replacer(module) else: replaced_module = nncf_module_from(module) + replaced_module.to(device) inter_dict[replaced_module] = scope_set new_scope_set = set() diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index 0cc9eda854a..29eb4b63d03 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -145,7 +145,7 @@ def get_original_forward(self) -> Callable: have its 0-th implicit `self` argument bound to the model object. """ if self._original_instance_forward is not None: - return functools.partial(self._original_instance_forward, self._model_ref) + return self._original_instance_forward return functools.partial(self._original_unbound_forward, self._model_ref) @contextmanager @@ -224,8 +224,6 @@ def __init__( self._user_dummy_forward_fn = dummy_forward_fn self._kd_loss_handler = None - device = get_model_device(model) - if wrap_inputs_fn is not None: self._wrap_inputs_fn = wrap_inputs_fn elif self._input_info is not None: @@ -254,7 +252,7 @@ def __init__( eval_op_scopes = self._collect_eval_op_scopes(model, _orig_graph_build_forward_fn) # all modules called in eval mode should be replaced prior to graph building - self._replace_modules_by_nncf_modules(model, device, eval_op_scopes) + self._replace_modules_by_nncf_modules(model, eval_op_scopes) _orig_context = TracingContext() @@ -458,13 +456,10 @@ def wrapped_user_dummy_forward_fn(*args, **kwargs): return wrapped_user_dummy_forward_fn - def _replace_modules_by_nncf_modules( - self, model: torch.nn.Module, device: torch.device, eval_op_scopes: List[Scope] = None - ): + def _replace_modules_by_nncf_modules(self, model: torch.nn.Module, eval_op_scopes: List[Scope] = None): _, self._nncf_replaced_modules = replace_modules_by_nncf_modules( model, ignored_scopes=self._ignored_scopes, target_scopes=self._target_scopes, eval_op_scopes=eval_op_scopes ) - model.to(device) return model def get_nncf_module_scopes(self) -> List[List[Scope]]: @@ -836,10 +831,12 @@ def __call__( new_class.forward = _get_nncf_forward_function_with_signature(inspect.signature(original_class.forward)) # In case of overriding forward by code like `model.forward = wrapper(model.forward)` - forward_inst_attr_fn = original_model.__dict__.get("forward") - if forward_inst_attr_fn is not None: - new_inst_forward = _get_nncf_forward_function_with_signature(inspect.signature(forward_inst_attr_fn)) - original_model.__dict__["forward"] = functools.partial(new_inst_forward, original_model) + original_instance_forward = original_model.__dict__.get("forward") + if original_instance_forward is not None: + bound_new_instance_forward = _get_nncf_forward_function_with_signature( + inspect.signature(original_instance_forward), bind_self_to=original_model + ) + original_model.__dict__["forward"] = bound_new_instance_forward # Make resulting class keep __module__ attributes of the original class, # otherwise these will point to NNCF @@ -884,15 +881,22 @@ def __eq__(cls, other): return other is NNCFNetwork -def _get_nncf_forward_function_with_signature(signature: inspect.Signature): +def _get_nncf_forward_function_with_signature( + signature: inspect.Signature, bind_self_to: torch.nn.Module = None +) -> Callable: """ - Create forward function with copy signature of forward function. + Creates a function that executes code from NNCFNetwork.forward, but with a final signature equal to the provided + one. :param signature: Signature of function that will used for forward function. + :param bind_self_to: If provided, will bind the `self` argument of the returned function to the provided model + object. This should be the model object that we are currently constructing the NNCFNetwork with. :return: New copy of function NNCFNetwork.forward with specified signature. """ fn = NNCFNetwork.forward new_forward = types.FunctionType(fn.__code__, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__) new_forward.__dict__.update(fn.__dict__) + if bind_self_to is not None: + new_forward = functools.partial(new_forward, bind_self_to) new_forward.__signature__ = signature if is_debug(): new_forward = debuggable_forward(new_forward) diff --git a/nncf/torch/utils.py b/nncf/torch/utils.py index 8e3a735d6be..546e0172f70 100644 --- a/nncf/torch/utils.py +++ b/nncf/torch/utils.py @@ -11,7 +11,7 @@ import random from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Dict, List +from typing import Any, Dict, Generator, List import numpy as np import torch @@ -420,6 +420,24 @@ def get_model_device(model: torch.nn.Module) -> torch.device: return device +def get_all_model_devices_generator(model: torch.nn.Module) -> Generator[torch.device, None, None]: + for p in model.parameters(): + yield p.device + + +def is_multidevice(model: torch.nn.Module) -> bool: + device_generator = get_all_model_devices_generator(model) + try: + curr_device = next(device_generator) + except StopIteration: # no parameters + return False + + for d in device_generator: + if d != curr_device: + return True + return False + + def get_model_dtype(model: torch.nn.Module) -> torch.dtype: try: dtype = next(model.parameters()).dtype diff --git a/tests/torch/test_graph_building.py b/tests/torch/test_graph_building.py index 84815d1af8c..cf08db48644 100644 --- a/tests/torch/test_graph_building.py +++ b/tests/torch/test_graph_building.py @@ -42,9 +42,10 @@ from nncf.torch.dynamic_graph.graph_tracer import GraphTracer from nncf.torch.dynamic_graph.graph_tracer import create_dummy_forward_fn from nncf.torch.dynamic_graph.io_handling import EXTRA_STRUCTS_WITH_DATALOADERS -from nncf.torch.dynamic_graph.io_handling import ExactInputsInfo +from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo from nncf.torch.dynamic_graph.io_handling import FillerInputElement from nncf.torch.dynamic_graph.io_handling import FillerInputInfo +from nncf.torch.dynamic_graph.io_handling import LoaderInputInfo from nncf.torch.dynamic_graph.io_handling import ModelInputInfo from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_outputs_with_objwalk from nncf.torch.dynamic_graph.trace_tensor import trace_tensors @@ -646,9 +647,10 @@ def test_filler_input_info_arg_generation(filler_gen_test_struct: FillerInputInf "input_info", [ FillerInputInfo([FillerInputElement([1, 3, 3, 3])]), - ExactInputsInfo((torch.Tensor([1]), torch.Tensor([1])), {"a": torch.Tensor([1]), "b": torch.Tensor([1])}), + ExampleInputInfo((torch.Tensor([1]), torch.Tensor([1])), {"a": torch.Tensor([1]), "b": torch.Tensor([1])}), + LoaderInputInfo((torch.Tensor([1]), torch.Tensor([1])), {"a": torch.Tensor([1]), "b": torch.Tensor([1])}), ], - ids=["filler", "exact"], + ids=["filler", "example", "loader"], ) @pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_input_infos_respect_device_setting(input_info: ModelInputInfo, device: str): @@ -704,7 +706,7 @@ def test_compressed_model_creation_can_build_exact_input_infos_from_dataloader_i _ = create_compressed_model(mock_model_with_stub_forward, config) input_info_received_by_nncf_network_init = nncf_network_init_spy.call_args.kwargs["input_info"] # input_info - assert isinstance(input_info_received_by_nncf_network_init, ExactInputsInfo) + assert isinstance(input_info_received_by_nncf_network_init, LoaderInputInfo) test_args, test_kwargs = input_info_received_by_nncf_network_init.get_forward_inputs() for idx, arg in enumerate(test_args): diff --git a/tests/torch/test_nncf_network.py b/tests/torch/test_nncf_network.py index 96eceba0941..54a8fe767a5 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/test_nncf_network.py @@ -8,11 +8,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import functools import inspect from abc import ABCMeta from abc import abstractmethod from copy import deepcopy +from typing import Callable, Type import pytest import torch @@ -25,6 +27,7 @@ from nncf.common.graph.operator_metatypes import UnknownMetatype from nncf.common.graph.transformations.commands import TargetType from nncf.torch import register_module +from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo from nncf.torch.dynamic_graph.io_handling import FillerInputElement from nncf.torch.dynamic_graph.io_handling import FillerInputInfo from nncf.torch.dynamic_graph.operation_address import OperationAddress @@ -405,29 +408,74 @@ def test_replacing_forward_with_another_own_method(_nncf_caplog): assert "set_original_unbound_forward" in _nncf_caplog.text -def test_replacing_forward_of_original_model(): - def decorator(func): - def wrap(*args): - return func(*args) +class ReplacementForwardProvider(abc.ABC): + @abstractmethod + def get_replacement_forward(self, model: torch.nn.Module) -> Callable: + pass + + +class ArgDecoratorForward(ReplacementForwardProvider): + def get_replacement_forward(self, model: torch.nn.Module) -> Callable: + def decorator(func): + def wrap(*args): + return func(*args) + + return wrap + + return decorator(model.forward) + + +class ArgAndKwargWrapsForward(ReplacementForwardProvider): + def get_replacement_forward(self, model: torch.nn.Module) -> Callable: + old_forward = model.forward + + @functools.wraps(old_forward) + def new_forward(*args, **kwargs): + return old_forward(*args, **kwargs) + + return new_forward + + +class EvilSelfForwardProvider(ReplacementForwardProvider): + def get_replacement_forward(self, model: torch.nn.Module) -> Callable: + old_forward = model.forward + + def evil_forward(self): + # since `self` is just a name, and not a reserved word, + # in this function `self` will refer just to the 0-th actual (tensor) arg of the forward function + return old_forward(self) - return wrap + return evil_forward + +@pytest.mark.parametrize( + "replacement_forward_provider_cls", [ArgDecoratorForward, ArgAndKwargWrapsForward, EvilSelfForwardProvider] +) +def test_replacing_forward_of_original_model(replacement_forward_provider_cls: Type[ReplacementForwardProvider]): model = BasicConvTestModel() - model.forward = decorator(model.forward) + provider = replacement_forward_provider_cls() + replacement_forward = provider.get_replacement_forward(model) + model.forward = replacement_forward + input_info = FillerInputInfo([FillerInputElement(model.INPUT_SIZE)]) + input_args, input_kwargs = input_info.get_forward_inputs() + original_output = model.forward(*input_args, **input_kwargs) fn_id = id(model.__dict__["forward"]) fn_sign = inspect.signature(model.forward) # type of current - assert isinstance(model.__dict__["forward"], type(decorator)) + assert isinstance(model.__dict__["forward"], type(replacement_forward)) - nncf_net = NNCFNetwork(model, FillerInputInfo([FillerInputElement(model.INPUT_SIZE)])) - nncf_net.forward(torch.ones(model.INPUT_SIZE)) + nncf_net = NNCFNetwork(model, input_info) # Check that forward was updated assert fn_id != id(nncf_net.__dict__["forward"]) assert fn_sign == inspect.signature(nncf_net.forward) assert isinstance(nncf_net.forward, functools.partial) + # Check that the functional outputs are the same + new_output = nncf_net.forward(torch.ones(model.INPUT_SIZE)) + assert torch.equal(new_output, original_output) + def test_temporary_clean_view(): model = TwoConvTestModelWithUserModule() @@ -767,3 +815,27 @@ def test_is_compression_module_registered(compression_module_type, is_registered assert nncf_model.nncf.is_compression_module_registered(compression_module_type) else: assert not nncf_model.nncf.is_compression_module_registered(compression_module_type) + + +class MultideviceModel(torch.nn.Module): + def __init__(self, linear_0, linear_1): + super().__init__() + self.linear_cpu = torch.nn.Linear(linear_0[0], linear_0[1], device="cpu") + self.linear_gpu = torch.nn.Linear(linear_1[0], linear_1[1], device="cuda") + + def forward(self, x, y): + x1 = self.linear_cpu(x) + y1 = self.linear_gpu(y) + res = x1.to(y1.device) + y1 + return res + + +def test_multidevice_model(): + if not torch.cuda.is_available(): + pytest.skip("GPU required") + + model = MultideviceModel((3, 3), (2, 3)) + example_input = (torch.ones(3, 3, device="cpu"), torch.ones(3, 2, device="cuda")) + input_info = ExampleInputInfo.from_example_input(example_input) + nncf_model = NNCFNetwork(model, input_info) + nncf_model(*example_input) diff --git a/tests/torch/test_transform_fn.py b/tests/torch/test_transform_fn.py index 3d3d80acbb0..f5954189c24 100644 --- a/tests/torch/test_transform_fn.py +++ b/tests/torch/test_transform_fn.py @@ -9,12 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + import pytest import torch from torch import nn import nncf -from nncf.torch.nested_objects_traversal import objwalk from tests.torch.test_models.alexnet import AlexNet as ModelWithSingleInput @@ -40,7 +41,9 @@ def forward(self, input_0, input_1): dataloader = torch.utils.data.DataLoader(dataset, batch_size=1) -def single_input_transform_fn(data_item): +def single_input_transform_fn(data_item, use_cuda): + if use_cuda: + return data_item[0].cuda() return data_item[0] @@ -49,23 +52,26 @@ def test_transform_fn_single_input(use_cuda): pytest.skip("There are no available CUDA devices") model = ModelWithSingleInput() - input_data = single_input_transform_fn(next(iter(dataloader))) + input_data = single_input_transform_fn(next(iter(dataloader)), use_cuda) if use_cuda: model = model.cuda() - input_data = input_data.cuda() # Check the transformation function model(input_data) # Start quantization - calibration_dataset = nncf.Dataset(dataloader, single_input_transform_fn) + calibration_dataset = nncf.Dataset(dataloader, partial(single_input_transform_fn, use_cuda=use_cuda)) nncf.quantize(model, calibration_dataset) -def multiple_inputs_transform_tuple_fn(data_item): +def multiple_inputs_transform_tuple_fn(data_item, use_cuda): + if use_cuda: + return data_item[0].cuda(), data_item[1].cuda() return data_item[0], data_item[1] -def multiple_inputs_transform_dict_fn(data_item): +def multiple_inputs_transform_dict_fn(data_item, use_cuda): + if use_cuda: + return {"input_0": data_item[0].cuda(), "input_1": data_item[1].cuda()} return {"input_0": data_item[0], "input_1": data_item[1]} @@ -77,15 +83,10 @@ def test_transform_fn_multiple_inputs(transform_fn, use_cuda): pytest.skip("There are no available CUDA devices") model = ModelWithMultipleInputs() - input_data = transform_fn(next(iter(dataloader))) + input_data = transform_fn(next(iter(dataloader)), use_cuda) if use_cuda: model = model.cuda() - def send_to_cuda(tensor): - return tensor.cuda() - - input_data = objwalk(input_data, lambda _: True, send_to_cuda) - # Check the transformation function if isinstance(input_data, tuple): model(*input_data) @@ -93,5 +94,5 @@ def send_to_cuda(tensor): model(**input_data) # Start quantization - calibration_dataset = nncf.Dataset(dataloader, transform_fn) + calibration_dataset = nncf.Dataset(dataloader, partial(transform_fn, use_cuda=use_cuda)) nncf.quantize(model, calibration_dataset)