From 610e8004d46832afe1691351b322942fae604013 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Wed, 15 Nov 2023 07:56:19 +0200 Subject: [PATCH] Remove to_device from PTEngine (#2260) ### Changes Remove logic to set device in `PTEngine`, to support multi-device model https://github.com/openvinotoolkit/nncf/pull/2253 --- .../torch/ssd300_vgg16/main.py | 7 ++++--- nncf/torch/engine.py | 9 --------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/examples/post_training_quantization/torch/ssd300_vgg16/main.py b/examples/post_training_quantization/torch/ssd300_vgg16/main.py index 6d6b9365a34..3bed9cfee45 100644 --- a/examples/post_training_quantization/torch/ssd300_vgg16/main.py +++ b/examples/post_training_quantization/torch/ssd300_vgg16/main.py @@ -29,6 +29,7 @@ from torchvision.models.detection.ssd import SSD from torchvision.models.detection.ssd import GeneralizedRCNNTransform from nncf.common.logging.track_progress import track +from functools import partial ROOT = Path(__file__).parent.resolve() DATASET_URL = "https://ultralytics.com/assets/coco128.zip" @@ -125,10 +126,10 @@ def validate(model: torch.nn.Module, dataset: COCO128Dataset, device: torch.devi return computed_metrics["map_50"] -def transform_fn(data_item: Tuple[torch.Tensor, Dict]) -> torch.Tensor: +def transform_fn(data_item: Tuple[torch.Tensor, Dict], device: torch.device) -> torch.Tensor: # Skip label and add a batch dimension to an image tensor images, _ = data_item - return images[None] + return images[None].to(device) def main(): @@ -149,7 +150,7 @@ def main(): disable_tracing(SSD.postprocess_detections) # Quantize model - calibration_dataset = nncf.Dataset(dataset, transform_fn) + calibration_dataset = nncf.Dataset(dataset, partial(transform_fn, device=device)) quantized_model = nncf.quantize(model, calibration_dataset) # Convert to OpenVINO diff --git a/nncf/torch/engine.py b/nncf/torch/engine.py index 63b4e93f114..44271123d6b 100644 --- a/nncf/torch/engine.py +++ b/nncf/torch/engine.py @@ -15,9 +15,6 @@ from torch import nn from nncf.common.engine import Engine -from nncf.torch.nested_objects_traversal import objwalk -from nncf.torch.utils import get_model_device -from nncf.torch.utils import is_tensor class PTEngine(Engine): @@ -34,7 +31,6 @@ def __init__(self, model: nn.Module): self._model = model self._model.eval() - self._device = get_model_device(model) def infer( self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]] @@ -46,11 +42,6 @@ def infer( :return: Model outputs. """ - def send_to_device(tensor): - return tensor.to(self._device) - - input_data = objwalk(input_data, is_tensor, send_to_device) - if isinstance(input_data, dict): return self._model(**input_data) if isinstance(input_data, tuple):