From 1493149ae35987442a13a60de151e7842159955f Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Fri, 10 Nov 2023 11:35:42 +0100 Subject: [PATCH] [Torch][PTQ] Examples are updated for the new PTQ TORCH backend (#2246) ### Changes - Do not filter constant nodes for torch backend in the inference graph - Fix version in requarements.txt for examples of post_training_quantization - for ssd300_vgg16 is not available to use torch 2.1.0 (failed on export to onnx Unsupported: ONNX export of operator get_pool_ceil_padding, tracing is not supporting too) - Update metrics - Add to PTEngine convert inputs to model's device to sync behavior with `create_compress_model` - Mobilenet_v2 example converting PyTorch model to IR by tracing (without onnx). - nncf.quantize for PyTorch works with copy of the target model ### Reason for changes To make PTQ work properly with disconnected graphs (like in [example](https://github.com/openvinotoolkit/nncf/blob/develop/examples/post_training_quantization/torch/ssd300_vgg16/main.py)) ### Related tickets 124417 ### Tests test_examples build 128 --------- Co-authored-by: Alexander Dokuchaev --- .../torch/mobilenet_v2/main.py | 39 ++++++------------- .../torch/mobilenet_v2/requirements.txt | 9 ++--- .../torch/ssd300_vgg16/main.py | 21 +++++----- .../torch/ssd300_vgg16/requirements.txt | 10 ++--- .../algorithms/min_max/algorithm.py | 1 + nncf/quantization/passes.py | 5 ++- nncf/torch/engine.py | 10 +++++ nncf/torch/quantization/quantize_model.py | 4 +- .../test_constant_filtering_model_after.dot | 12 ++++++ .../test_constant_filtering_model_before.dot | 18 +++++++++ tests/common/quantization/test_passes.py | 29 +++++++++++--- tests/cross_fw/examples/example_scope.json | 8 ++-- tests/cross_fw/examples/test_examples.py | 6 +-- tests/post_training/test_templates/models.py | 27 +++++++++++++ tests/torch/test_transform_fn.py | 27 +++++++++++-- 15 files changed, 160 insertions(+), 66 deletions(-) create mode 100644 tests/common/data/reference_graphs/passes/test_constant_filtering_model_after.dot create mode 100644 tests/common/data/reference_graphs/passes/test_constant_filtering_model_before.dot diff --git a/examples/post_training_quantization/torch/mobilenet_v2/main.py b/examples/post_training_quantization/torch/mobilenet_v2/main.py index 35f9b35c06c..152ffc7e31f 100644 --- a/examples/post_training_quantization/torch/mobilenet_v2/main.py +++ b/examples/post_training_quantization/torch/mobilenet_v2/main.py @@ -13,7 +13,7 @@ import re import subprocess from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Tuple import numpy as np import openvino as ov @@ -24,9 +24,9 @@ from torchvision import datasets from torchvision import models from torchvision import transforms -from tqdm import tqdm import nncf +from nncf.common.logging.track_progress import track ROOT = Path(__file__).parent.resolve() CHECKPOINT_URL = "https://huggingface.co/alexsu52/mobilenet_v2_imagenette/resolve/main/pytorch_model.bin" @@ -53,7 +53,7 @@ def validate(model: ov.Model, val_loader: torch.utils.data.DataLoader) -> float: compiled_model = ov.compile_model(model) output = compiled_model.outputs[0] - for images, target in tqdm(val_loader): + for images, target in track(val_loader, description="Validating"): pred = compiled_model(images)[output] predictions.append(np.argmax(pred, axis=1)) references.append(target) @@ -84,9 +84,9 @@ def get_model_size(ir_path: str, m_type: str = "Mb", verbose: bool = True) -> fl bin_size /= 1024 model_size = xml_size + bin_size if verbose: - print(f"Model graph (xml): {xml_size:.3f} Mb") - print(f"Model weights (bin): {bin_size:.3f} Mb") - print(f"Model size: {model_size:.3f} Mb") + print(f"Model graph (xml): {xml_size:.3f} {m_type}") + print(f"Model weights (bin): {bin_size:.3f} {m_type}") + print(f"Model size: {model_size:.3f} {m_type}") return model_size @@ -123,7 +123,7 @@ def get_model_size(ir_path: str, m_type: str = "Mb", verbose: bool = True) -> fl # >> model(transform_fn(data_item)) -def transform_fn(data_item): +def transform_fn(data_item: Tuple[torch.Tensor, int]) -> torch.Tensor: images, _ = data_item return images @@ -149,28 +149,11 @@ def transform_fn(data_item): # Benchmark performance, calculate compression rate and validate accuracy dummy_input = torch.randn(1, 3, 224, 224) - -fp32_onnx_path = f"{ROOT}/mobilenet_v2_fp32.onnx" -torch.onnx.export( - torch_model.cpu(), - dummy_input, - fp32_onnx_path, - input_names=["input"], - output_names=["output"], - dynamic_axes={"input": {0: "-1"}}, -) -ov_model = mo.convert_model(fp32_onnx_path) - -int8_onnx_path = f"{ROOT}/mobilenet_v2_int8.onnx" -torch.onnx.export( - torch_quantized_model.cpu(), - dummy_input, - int8_onnx_path, - input_names=["input"], - output_names=["output"], - dynamic_axes={"input": {0: "-1"}}, +ov_input_shape = (-1, 3, 224, 224) +ov_model = mo.convert_model(torch_model.cpu(), example_input=dummy_input, input_shape=ov_input_shape) +ov_quantized_model = mo.convert_model( + torch_quantized_model.cpu(), example_input=dummy_input, input_shape=ov_input_shape ) -ov_quantized_model = mo.convert_model(int8_onnx_path) fp32_ir_path = f"{ROOT}/mobilenet_v2_fp32.xml" ov.save_model(ov_model, fp32_ir_path, compress_to_fp16=False) diff --git a/examples/post_training_quantization/torch/mobilenet_v2/requirements.txt b/examples/post_training_quantization/torch/mobilenet_v2/requirements.txt index 189882bdf42..7f0bf782ccf 100644 --- a/examples/post_training_quantization/torch/mobilenet_v2/requirements.txt +++ b/examples/post_training_quantization/torch/mobilenet_v2/requirements.txt @@ -1,6 +1,5 @@ -torchvision>=0.10.1,<0.16 -tqdm -scikit-learn -fastdownload +fastdownload==0.0.7 openvino-dev==2023.1 -onnx +scikit-learn +torch==2.1.0 +torchvision==0.16.0 diff --git a/examples/post_training_quantization/torch/ssd300_vgg16/main.py b/examples/post_training_quantization/torch/ssd300_vgg16/main.py index 6c495ec03ce..6d6b9365a34 100644 --- a/examples/post_training_quantization/torch/ssd300_vgg16/main.py +++ b/examples/post_training_quantization/torch/ssd300_vgg16/main.py @@ -13,6 +13,7 @@ import re import subprocess from pathlib import Path +from typing import Callable, Tuple, Dict # nncf.torch must be imported before torchvision import nncf @@ -27,7 +28,7 @@ from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.models.detection.ssd import SSD from torchvision.models.detection.ssd import GeneralizedRCNNTransform -from tqdm import tqdm +from nncf.common.logging.track_progress import track ROOT = Path(__file__).parent.resolve() DATASET_URL = "https://ultralytics.com/assets/coco128.zip" @@ -49,9 +50,9 @@ def get_model_size(ir_path: str, m_type: str = "Mb", verbose: bool = True) -> fl bin_size /= 1024 model_size = xml_size + bin_size if verbose: - print(f"Model graph (xml): {xml_size:.3f} Mb") - print(f"Model weights (bin): {bin_size:.3f} Mb") - print(f"Model size: {model_size:.3f} Mb") + print(f"Model graph (xml): {xml_size:.3f} {m_type}") + print(f"Model weights (bin): {bin_size:.3f} {m_type}") + print(f"Model size: {model_size:.3f} {m_type}") return model_size @@ -73,7 +74,7 @@ class COCO128Dataset(torch.utils.data.Dataset): 61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90 ] # fmt: skip - def __init__(self, data_path, transform): + def __init__(self, data_path: str, transform: Callable): super().__init__() self.transform = transform self.data_path = Path(data_path) @@ -81,7 +82,7 @@ def __init__(self, data_path, transform): self.labels_path = self.data_path / "labels" / "train2017" self.image_ids = sorted(map(lambda p: int(p.stem), self.images_path.glob("*.jpg"))) - def __getitem__(self, item): + def __getitem__(self, item: int) -> Tuple[torch.Tensor, Dict]: image_id = self.image_ids[item] img = Image.open(self.images_path / f"{image_id:012d}.jpg") @@ -106,16 +107,16 @@ def __getitem__(self, item): img, target = self.transform(img, target) return img, target - def __len__(self): + def __len__(self) -> int: return len(self.image_ids) -def validate(model, dataset, device): +def validate(model: torch.nn.Module, dataset: COCO128Dataset, device: torch.device): model.to(device) model.eval() metric = MeanAveragePrecision() with torch.no_grad(): - for img, target in tqdm(dataset, desc="Validating"): + for img, target in track(dataset, description="Validating"): prediction = model(img.to(device)[None])[0] for k in prediction.keys(): prediction[k] = prediction[k].to(torch.device("cpu")) @@ -124,7 +125,7 @@ def validate(model, dataset, device): return computed_metrics["map_50"] -def transform_fn(data_item): +def transform_fn(data_item: Tuple[torch.Tensor, Dict]) -> torch.Tensor: # Skip label and add a batch dimension to an image tensor images, _ = data_item return images[None] diff --git a/examples/post_training_quantization/torch/ssd300_vgg16/requirements.txt b/examples/post_training_quantization/torch/ssd300_vgg16/requirements.txt index b3f92061890..952c35e0a94 100644 --- a/examples/post_training_quantization/torch/ssd300_vgg16/requirements.txt +++ b/examples/post_training_quantization/torch/ssd300_vgg16/requirements.txt @@ -1,7 +1,7 @@ -fastdownload +fastdownload==0.0.7 +onnx==1.13.1 openvino-dev==2023.1 +pycocotools==2.0.7 +torch==2.0.1 # ssd300_vgg16 can not be exported with 2.1.0, reference: https://github.com/pytorch/pytorch/issues/113155 torchmetrics==1.0.1 -pycocotools -torchvision~=0.15.1 -tqdm -onnx +torchvision==0.15.2 diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 5f6ad5f0376..e698c77946b 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -540,6 +540,7 @@ def _get_quantization_target_points( self._backend_entity.shapeof_metatypes, self._backend_entity.dropout_metatypes, self._backend_entity.read_variable_metatypes, + nncf_graph_contains_constants=backend != BackendType.TORCH, ) quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns) diff --git a/nncf/quantization/passes.py b/nncf/quantization/passes.py index b3eb27c18c7..055f1f27a5b 100644 --- a/nncf/quantization/passes.py +++ b/nncf/quantization/passes.py @@ -23,6 +23,7 @@ def transform_to_inference_graph( shapeof_metatypes: List[OperatorMetatype], dropout_metatypes: List[OperatorMetatype], read_variable_metatypes: Optional[List[OperatorMetatype]] = None, + nncf_graph_contains_constants: bool = True, ) -> NNCFGraph: """ This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows. @@ -32,11 +33,13 @@ def transform_to_inference_graph( :param dropout_metatypes: List of backend-specific Dropout metatypes. :param read_variable_metatypes: List of backend-specific metatypes that also can be interpreted as inputs (ReadValue). + :param nncf_graph_contains_constants: Whether NNCFGraph contains constant nodes or not. :return: NNCFGraph in the inference style. """ remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, read_variable_metatypes) remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes) - filter_constant_nodes(nncf_graph, read_variable_metatypes) + if nncf_graph_contains_constants: + filter_constant_nodes(nncf_graph, read_variable_metatypes) return nncf_graph diff --git a/nncf/torch/engine.py b/nncf/torch/engine.py index b30ba4dbef1..63b4e93f114 100644 --- a/nncf/torch/engine.py +++ b/nncf/torch/engine.py @@ -15,6 +15,9 @@ from torch import nn from nncf.common.engine import Engine +from nncf.torch.nested_objects_traversal import objwalk +from nncf.torch.utils import get_model_device +from nncf.torch.utils import is_tensor class PTEngine(Engine): @@ -31,6 +34,7 @@ def __init__(self, model: nn.Module): self._model = model self._model.eval() + self._device = get_model_device(model) def infer( self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]] @@ -41,6 +45,12 @@ def infer( :param input_data: Inputs for the model. :return: Model outputs. """ + + def send_to_device(tensor): + return tensor.to(self._device) + + input_data = objwalk(input_data, is_tensor, send_to_device) + if isinstance(input_data, dict): return self._model(**input_data) if isinstance(input_data, tuple): diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index 205b2e187a6..152c7d59dc9 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from typing import Optional, Union import torch @@ -68,7 +69,8 @@ def quantize_impl( if target_device == TargetDevice.CPU_SPR: raise RuntimeError("target_device == CPU_SPR is not supported") - nncf_network = create_nncf_network_ptq(model.eval(), calibration_dataset) + copied_model = deepcopy(model) + nncf_network = create_nncf_network_ptq(copied_model.eval(), calibration_dataset) quantization_algorithm = PostTrainingQuantization( preset=preset, diff --git a/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after.dot b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after.dot new file mode 100644 index 00000000000..3dc12b5ad16 --- /dev/null +++ b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after.dot @@ -0,0 +1,12 @@ +strict digraph { +"0 /Input_1_0" [id=0, type=Input_1]; +"1 /ReadVariable_0" [id=1, type=ReadVariable]; +"4 /Conv_0" [id=4, type=Conv]; +"6 /Conv2_0" [id=6, type=Conv2]; +"7 /Add_0" [id=7, type=Add]; +"8 /Final_node_0" [id=8, type=Final_node]; +"0 /Input_1_0" -> "4 /Conv_0"; +"1 /ReadVariable_0" -> "7 /Add_0"; +"6 /Conv2_0" -> "7 /Add_0"; +"7 /Add_0" -> "8 /Final_node_0"; +} diff --git a/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before.dot b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before.dot new file mode 100644 index 00000000000..b4590b212a3 --- /dev/null +++ b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before.dot @@ -0,0 +1,18 @@ +strict digraph { +"0 /Input_1_0" [id=0, type=Input_1]; +"1 /ReadVariable_0" [id=1, type=ReadVariable]; +"2 /Weights_0" [id=2, type=Weights]; +"3 /AnyNodeBetweenWeightAndConv_0" [id=3, type=AnyNodeBetweenWeightAndConv]; +"4 /Conv_0" [id=4, type=Conv]; +"5 /Weights2_0" [id=5, type=Weights2]; +"6 /Conv2_0" [id=6, type=Conv2]; +"7 /Add_0" [id=7, type=Add]; +"8 /Final_node_0" [id=8, type=Final_node]; +"0 /Input_1_0" -> "4 /Conv_0"; +"1 /ReadVariable_0" -> "7 /Add_0"; +"2 /Weights_0" -> "3 /AnyNodeBetweenWeightAndConv_0"; +"3 /AnyNodeBetweenWeightAndConv_0" -> "4 /Conv_0"; +"5 /Weights2_0" -> "6 /Conv2_0"; +"6 /Conv2_0" -> "7 /Add_0"; +"7 /Add_0" -> "8 /Final_node_0"; +} diff --git a/tests/common/quantization/test_passes.py b/tests/common/quantization/test_passes.py index c744c2297a4..87d38d09805 100644 --- a/tests/common/quantization/test_passes.py +++ b/tests/common/quantization/test_passes.py @@ -14,8 +14,10 @@ import pytest +from nncf.quantization.passes import filter_constant_nodes from nncf.quantization.passes import remove_nodes_and_reconnect_graph from tests.post_training.test_templates.models import NNCFGraphDropoutRemovingCase +from tests.post_training.test_templates.models import NNCFGraphToTestConstantFiltering from tests.shared.nx_graph import compare_nx_graph_with_reference from tests.shared.paths import TEST_ROOT @@ -28,13 +30,14 @@ class TestModes(Enum): WRONG_PARALLEL_EDGES = "wrong_parallel_edges" +def _check_graphs(dot_file_name, nncf_graph) -> None: + nx_graph = nncf_graph.get_graph_for_structure_analysis() + path_to_dot = DATA_ROOT / dot_file_name + compare_nx_graph_with_reference(nx_graph, path_to_dot, check_edge_attrs=True) + + @pytest.mark.parametrize("mode", [TestModes.VALID, TestModes.WRONG_TENSOR_SHAPE, TestModes.WRONG_PARALLEL_EDGES]) def test_remove_nodes_and_reconnect_graph(mode: TestModes): - def _check_graphs(dot_file_name, nncf_graph) -> None: - nx_graph = nncf_graph.get_graph_for_structure_analysis() - path_to_dot = DATA_ROOT / dot_file_name - compare_nx_graph_with_reference(nx_graph, path_to_dot, check_edge_attrs=True) - dot_reference_path_before = Path("passes") / "dropout_synthetic_model_before.dot" dot_reference_path_after = Path("passes") / "dropout_synthetic_model_after.dot" dropout_metatype = "DROPOUT_METATYPE" @@ -52,3 +55,19 @@ def _check_graphs(dot_file_name, nncf_graph) -> None: _check_graphs(dot_reference_path_before, nncf_graph) remove_nodes_and_reconnect_graph(nncf_graph, [dropout_metatype]) _check_graphs(dot_reference_path_after, nncf_graph) + + +@pytest.mark.xfail +def test_filter_constant_nodes(): + dot_reference_path_before = Path("passes") / "test_constant_filtering_model_before.dot" + dot_reference_path_after = Path("passes") / "test_constant_filtering_model_after.dot" + + constant_metatype = "CONSTANT_METATYPE" + read_variable_metatype = "READ_VARIABLE_METATYPE" + + nncf_graph = NNCFGraphToTestConstantFiltering(constant_metatype, read_variable_metatype).nncf_graph + _check_graphs(dot_reference_path_before, nncf_graph) + filter_constant_nodes( + nncf_graph, read_variable_metatypes=[read_variable_metatype], constant_nodes_metatypes=[constant_metatype] + ) + _check_graphs(dot_reference_path_after, nncf_graph) diff --git a/tests/cross_fw/examples/example_scope.json b/tests/cross_fw/examples/example_scope.json index a215b096679..1bb9f43f0b0 100644 --- a/tests/cross_fw/examples/example_scope.json +++ b/tests/cross_fw/examples/example_scope.json @@ -110,7 +110,7 @@ "cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz", "accuracy_metrics": { "fp32_top1": 0.9864968152866243, - "int8_top1": 0.9829299363057324, + "int8_top1": 0.9836942675159236, "accuracy_drop": 0.0035668789808918078 }, "performance_metrics": { @@ -129,8 +129,8 @@ "requirements": "examples/post_training_quantization/torch/ssd300_vgg16/requirements.txt", "cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz", "accuracy_metrics": { - "fp32_mAP": 0.5232756733894348, - "int8_mAP": 0.5140125155448914, + "fp32_mAP": 0.5228869318962097, + "int8_mAP": 0.5148677825927734, "accuracy_drop": 0.009263157844543457 }, "performance_metrics": { @@ -144,4 +144,4 @@ "model_compression_rate": 3.8631822183889652 } } -} \ No newline at end of file +} diff --git a/tests/cross_fw/examples/test_examples.py b/tests/cross_fw/examples/test_examples.py index 8f98961da15..e9708d7daaa 100644 --- a/tests/cross_fw/examples/test_examples.py +++ b/tests/cross_fw/examples/test_examples.py @@ -34,7 +34,7 @@ ACCURACY_METRICS = "accuracy_metrics" MODEL_SIZE_METRICS = "model_size_metrics" -PERFORMNACE_METRICS = "performance_metrics" +PERFORMANCE_METRICS = "performance_metrics" def example_test_cases(): @@ -85,6 +85,6 @@ def test_examples( for name, value in example_params[MODEL_SIZE_METRICS].items(): assert measured_metrics[name] == pytest.approx(value, rel=MODEL_SIZE_RELATIVE_TOLERANCE) - if is_check_performance and PERFORMNACE_METRICS in example_params: - for name, value in example_params[PERFORMNACE_METRICS].items(): + if is_check_performance and PERFORMANCE_METRICS in example_params: + for name, value in example_params[PERFORMANCE_METRICS].items(): assert measured_metrics[name] == pytest.approx(value, rel=PERFORMANCE_RELATIVE_TOLERANCE) diff --git a/tests/post_training/test_templates/models.py b/tests/post_training/test_templates/models.py index 5d83db95117..fd0c4e773aa 100644 --- a/tests/post_training/test_templates/models.py +++ b/tests/post_training/test_templates/models.py @@ -297,3 +297,30 @@ def __init__( dtype=Dtype.FLOAT, parallel_input_port_ids=list(range(1, 10)), ) + + +class NNCFGraphToTestConstantFiltering: + def __init__(self, constant_metatype, read_variable_metatype, nncf_graph_cls=NNCFGraph) -> None: + nodes = [ + NodeWithType("Input_1", InputNoopMetatype), + NodeWithType("Conv", None), + NodeWithType("Weights", constant_metatype), + NodeWithType("AnyNodeBetweenWeightAndConv", None), + NodeWithType("Weights2", constant_metatype), + NodeWithType("Conv2", None), + NodeWithType("ReadVariable", read_variable_metatype), + NodeWithType("Add", None), + NodeWithType("Final_node", None), + ] + + edges = [ + ("Input_1", "Conv"), + ("Weights", "AnyNodeBetweenWeightAndConv"), + ("AnyNodeBetweenWeightAndConv", "Conv"), + ("Weights2", "Conv2"), + ("Conv2", "Add"), + ("ReadVariable", "Add"), + ("Add", "Final_node"), + ] + original_mock_graph = create_mock_graph(nodes, edges) + self.nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph, nncf_graph_cls) diff --git a/tests/torch/test_transform_fn.py b/tests/torch/test_transform_fn.py index f0a19438cbb..3d3d80acbb0 100644 --- a/tests/torch/test_transform_fn.py +++ b/tests/torch/test_transform_fn.py @@ -14,6 +14,7 @@ from torch import nn import nncf +from nncf.torch.nested_objects_traversal import objwalk from tests.torch.test_models.alexnet import AlexNet as ModelWithSingleInput @@ -43,11 +44,18 @@ def single_input_transform_fn(data_item): return data_item[0] -def test_transform_fn_single_input(): +def test_transform_fn_single_input(use_cuda): + if use_cuda and not torch.cuda.is_available(): + pytest.skip("There are no available CUDA devices") + model = ModelWithSingleInput() + input_data = single_input_transform_fn(next(iter(dataloader))) + if use_cuda: + model = model.cuda() + input_data = input_data.cuda() # Check the transformation function - model(single_input_transform_fn(next(iter(dataloader)))) + model(input_data) # Start quantization calibration_dataset = nncf.Dataset(dataloader, single_input_transform_fn) nncf.quantize(model, calibration_dataset) @@ -64,15 +72,26 @@ def multiple_inputs_transform_dict_fn(data_item): @pytest.mark.parametrize( "transform_fn", (multiple_inputs_transform_tuple_fn, multiple_inputs_transform_dict_fn), ids=["tuple", "dict"] ) -def test_transform_fn_multiple_inputs(transform_fn): +def test_transform_fn_multiple_inputs(transform_fn, use_cuda): + if use_cuda and not torch.cuda.is_available(): + pytest.skip("There are no available CUDA devices") + model = ModelWithMultipleInputs() + input_data = transform_fn(next(iter(dataloader))) + if use_cuda: + model = model.cuda() + + def send_to_cuda(tensor): + return tensor.cuda() + + input_data = objwalk(input_data, lambda _: True, send_to_cuda) # Check the transformation function - input_data = transform_fn(next(iter(dataloader))) if isinstance(input_data, tuple): model(*input_data) if isinstance(input_data, dict): model(**input_data) + # Start quantization calibration_dataset = nncf.Dataset(dataloader, transform_fn) nncf.quantize(model, calibration_dataset)