Skip to content

Commit

Permalink
remove to_device from PTEngine
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Nov 11, 2023
1 parent 898faf1 commit 638737d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torchvision.models.detection.ssd import SSD
from torchvision.models.detection.ssd import GeneralizedRCNNTransform
from nncf.common.logging.track_progress import track
from functools import partial

ROOT = Path(__file__).parent.resolve()
DATASET_URL = "https://ultralytics.com/assets/coco128.zip"
Expand Down Expand Up @@ -125,10 +126,10 @@ def validate(model: torch.nn.Module, dataset: COCO128Dataset, device: torch.devi
return computed_metrics["map_50"]


def transform_fn(data_item: Tuple[torch.Tensor, Dict]) -> torch.Tensor:
def transform_fn(data_item: Tuple[torch.Tensor, Dict], device: torch.device) -> torch.Tensor:
# Skip label and add a batch dimension to an image tensor
images, _ = data_item
return images[None]
return images[None].to(device)


def main():
Expand All @@ -149,7 +150,7 @@ def main():
disable_tracing(SSD.postprocess_detections)

# Quantize model
calibration_dataset = nncf.Dataset(dataset, transform_fn)
calibration_dataset = nncf.Dataset(dataset, partial(transform_fn, device=device))
quantized_model = nncf.quantize(model, calibration_dataset)

# Convert to OpenVINO
Expand Down
6 changes: 0 additions & 6 deletions nncf/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(self, model: nn.Module):

self._model = model
self._model.eval()
self._device = get_model_device(model)

def infer(
self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]]
Expand All @@ -46,11 +45,6 @@ def infer(
:return: Model outputs.
"""

def send_to_device(tensor):
return tensor.to(self._device)

input_data = objwalk(input_data, is_tensor, send_to_device)

if isinstance(input_data, dict):
return self._model(**input_data)
if isinstance(input_data, tuple):
Expand Down
38 changes: 16 additions & 22 deletions tests/torch/test_transform_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import pytest
import torch
from torch import nn
Expand Down Expand Up @@ -40,33 +42,31 @@ def forward(self, input_0, input_1):
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)


def single_input_transform_fn(data_item):
return data_item[0]
def single_input_transform_fn(data_item, device):
return data_item[0].to(device)


def test_transform_fn_single_input(use_cuda):
if use_cuda and not torch.cuda.is_available():
pytest.skip("There are no available CUDA devices")

model = ModelWithSingleInput()
input_data = single_input_transform_fn(next(iter(dataloader)))
if use_cuda:
model = model.cuda()
input_data = input_data.cuda()
device = torch.device("cuda" if use_cuda else "cpu")
model = ModelWithSingleInput().to(device)
input_data = single_input_transform_fn(next(iter(dataloader)), device)

# Check the transformation function
model(input_data)
# Start quantization
calibration_dataset = nncf.Dataset(dataloader, single_input_transform_fn)
calibration_dataset = nncf.Dataset(dataloader, partial(single_input_transform_fn, device=device))
nncf.quantize(model, calibration_dataset)


def multiple_inputs_transform_tuple_fn(data_item):
return data_item[0], data_item[1]
def multiple_inputs_transform_tuple_fn(data_item, device):
return data_item[0].to(device), data_item[1].to(device)


def multiple_inputs_transform_dict_fn(data_item):
return {"input_0": data_item[0], "input_1": data_item[1]}
def multiple_inputs_transform_dict_fn(data_item, device):
return {"input_0": data_item[0].to(device), "input_1": data_item[1].to(device)}


@pytest.mark.parametrize(
Expand All @@ -76,15 +76,9 @@ def test_transform_fn_multiple_inputs(transform_fn, use_cuda):
if use_cuda and not torch.cuda.is_available():
pytest.skip("There are no available CUDA devices")

model = ModelWithMultipleInputs()
input_data = transform_fn(next(iter(dataloader)))
if use_cuda:
model = model.cuda()

def send_to_cuda(tensor):
return tensor.cuda()

input_data = objwalk(input_data, lambda _: True, send_to_cuda)
device = torch.device("cuda" if use_cuda else "cpu")
model = ModelWithMultipleInputs().to(device)
input_data = transform_fn(next(iter(dataloader)), device)

# Check the transformation function
if isinstance(input_data, tuple):
Expand All @@ -93,5 +87,5 @@ def send_to_cuda(tensor):
model(**input_data)

# Start quantization
calibration_dataset = nncf.Dataset(dataloader, transform_fn)
calibration_dataset = nncf.Dataset(dataloader, partial(transform_fn, device=device))
nncf.quantize(model, calibration_dataset)

0 comments on commit 638737d

Please sign in to comment.