From f6e13077d90ec4c30fd62e7f859227ba387e76b4 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 28 Mar 2024 18:13:00 +0100 Subject: [PATCH] API code moved to a separate PR --- nncf/__init__.py | 2 - nncf/quantization/__init__.py | 2 - nncf/quantization/quantize_model.py | 33 +--------- .../graph/transformations/serialization.py | 19 ++---- nncf/torch/quantization/quantize_model.py | 22 ------- tests/torch/qat/test_qat_classification.py | 64 ------------------- tests/torch/test_serialization.py | 39 +++++++---- 7 files changed, 34 insertions(+), 147 deletions(-) diff --git a/nncf/__init__.py b/nncf/__init__.py index 897444e9450..eaaa755a49e 100644 --- a/nncf/__init__.py +++ b/nncf/__init__.py @@ -39,11 +39,9 @@ from nncf.parameters import SensitivityMetric as SensitivityMetric from nncf.parameters import TargetDevice as TargetDevice from nncf.quantization import QuantizationPreset as QuantizationPreset -from nncf.quantization import apply_serialized_transformations as apply_serialized_transformations from nncf.quantization import compress_weights as compress_weights from nncf.quantization import quantize as quantize from nncf.quantization import quantize_with_accuracy_control as quantize_with_accuracy_control -from nncf.quantization import serialize_transformations as serialize_transformations from nncf.quantization.advanced_parameters import ( AdvancedAccuracyRestorerParameters as AdvancedAccuracyRestorerParameters, ) diff --git a/nncf/quantization/__init__.py b/nncf/quantization/__init__.py index a42f0247f12..a1b78c774e1 100644 --- a/nncf/quantization/__init__.py +++ b/nncf/quantization/__init__.py @@ -10,8 +10,6 @@ # limitations under the License. """Post-training quantization APIs.""" from nncf.common.quantization.structs import QuantizationPreset as QuantizationPreset -from nncf.quantization.quantize_model import apply_serialized_transformations as apply_serialized_transformations from nncf.quantization.quantize_model import compress_weights as compress_weights from nncf.quantization.quantize_model import quantize as quantize from nncf.quantization.quantize_model import quantize_with_accuracy_control as quantize_with_accuracy_control -from nncf.quantization.quantize_model import serialize_transformations as serialize_transformations diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 04837e06912..fe8a69ace20 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union import nncf from nncf.api.compression import TModel @@ -540,34 +540,3 @@ def quantize_with_tune_hyperparams( quantized_model = hyperparameter_tuner.apply(model, validation_dataset) return quantized_model - - -@api(canonical_alias="nncf.apply_serialized_transformations") -def apply_serialized_transformations( - model: TModel, - serialized_transformations, -) -> TModel: - """ - Applies transformation layout to the model. - """ - backend = get_backend(model) - if backend == BackendType.TORCH: - from nncf.torch.quantization.quantize_model import apply_serialized_transformations_impl - - return apply_serialized_transformations_impl(model, serialized_transformations) - raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") - - -@api(canonical_alias="nncf.serialize_transformations") -def serialize_transformations( - model: TModel, -) -> Dict[str, Any]: - """ - Applies transformation layout to the model. - """ - backend = get_backend(model) - if backend == BackendType.TORCH: - from nncf.torch.quantization.quantize_model import serialize_transformations_impl - - return serialize_transformations_impl(model) - raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") diff --git a/nncf/torch/graph/transformations/serialization.py b/nncf/torch/graph/transformations/serialization.py index 9f9a136f12d..616a4c25e52 100644 --- a/nncf/torch/graph/transformations/serialization.py +++ b/nncf/torch/graph/transformations/serialization.py @@ -10,13 +10,10 @@ # limitations under the License. from enum import Enum -from typing import Any, Dict, Tuple, Union - -import torch +from typing import Any, Dict, Union from nncf.common.graph.transformations.commands import TransformationPriority from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.torch.dynamic_graph.io_handling import FillerInputInfo from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand @@ -25,7 +22,6 @@ from nncf.torch.layer_utils import COMPRESSION_MODULES COMPRESSION_STATE_ATTR = "compression_state" -INPUT_INFO_ATTR = "example_input" class CompressionKeys(Enum): @@ -33,28 +29,23 @@ class CompressionKeys(Enum): INSERTION_COMMAND = "INSERTION_COMMAND" -def serialize_transformations(model: torch.nn.Module, transformations_layout: TransformationLayout) -> Dict[str, Any]: - input_info = model.nncf._input_info - if not isinstance(input_info, FillerInputInfo): - raise RuntimeError("Could not serialize model inputs input: {input_info}") - +def serialize_transformations(transformations_layout: TransformationLayout) -> Dict[str, Any]: transformation_commands = [] for command in transformations_layout.transformations: serialized_command = serialize_command(command) if serialized_command: transformation_commands.append(serialized_command) - return {COMPRESSION_STATE_ATTR: transformation_commands, INPUT_INFO_ATTR: input_info.get_state()} + return {COMPRESSION_STATE_ATTR: transformation_commands} -def load_transformations(transformations_state: Dict[str, Any]) -> Tuple[TransformationLayout, FillerInputInfo]: +def load_transformations(transformations_state: Dict[str, Any]) -> TransformationLayout: transformation_layout = TransformationLayout() for serialized_command in transformations_state[COMPRESSION_STATE_ATTR]: command = load_command(serialized_command) transformation_layout.register(command) - input_info = FillerInputInfo.from_state(transformations_state[INPUT_INFO_ATTR]) - return transformation_layout, input_info + return transformation_layout def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]: diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index 81aadc5007e..48f3ddefae2 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -15,7 +15,6 @@ import torch import nncf -from nncf.common.factory import ModelTransformerFactory from nncf.common.factory import NNCFGraphFactory from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset @@ -30,10 +29,7 @@ from nncf.quantization.quantize_model import warning_model_no_batchwise_support from nncf.scopes import IgnoredScope from nncf.torch.graph.operator_metatypes import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS -from nncf.torch.graph.transformations.serialization import load_transformations -from nncf.torch.graph.transformations.serialization import serialize_transformations from nncf.torch.model_creation import wrap_model -from nncf.torch.nncf_network import NNCFNetwork DEFAULT_RANGE_TYPE = "mean_min_max" @@ -104,21 +100,3 @@ def compress_weights_impl( ) graph = NNCFGraphFactory.create(model) return compression_algorithm.apply(model, graph, dataset=dataset) - - -def apply_serialized_transformations_impl(model: torch.nn.Module, serialized_transformations): - transformations_layout, input_info = load_transformations(serialized_transformations) - - nncf_network = NNCFNetwork(deepcopy(model), input_info=input_info) - model_transformer = ModelTransformerFactory.create(nncf_network) - transformed_model = model_transformer.transform(transformations_layout) - - transformed_model.nncf.disable_dynamic_graph_building() - return transformed_model - - -def serialize_transformations_impl( - model: NNCFNetwork, -): - layout = model.nncf.get_applied_transformation_layout() - return serialize_transformations(model, layout) diff --git a/tests/torch/qat/test_qat_classification.py b/tests/torch/qat/test_qat_classification.py index 86c95e2973a..394ebfa0071 100644 --- a/tests/torch/qat/test_qat_classification.py +++ b/tests/torch/qat/test_qat_classification.py @@ -293,67 +293,3 @@ def test_compression_training(quantization_config_path: Path, sota_data_dir): del sample_config["compression"]["initializer"]["range"] start_worker_clean_memory(main_worker, sample_config) - - -def save_load_main_worker(current_gpu: int, config: SampleConfig): - configure_device(current_gpu, config) - if is_main_process(): - configure_logging(logger, config) - else: - config.tb = None - - pretrained = is_pretrained_model_requested(config) - model_name = config["model"] - # create model - logger.info(f"\nCreating model from config: {config.config}") - model = load_model( - model_name, - pretrained=pretrained, - num_classes=config.get("num_classes", 1000), - model_params=config.get("model_params"), - weights_path=config.get("weights"), - ) - model.to(config.device) - - datasets = get_datasets(config) - criterion = nn.CrossEntropyLoss() - criterion = criterion.to(config.device) - - logger.info("Original model validation:") - # original_accuracy, *_ = validate(datasets.val_data_loader, model, criterion, config) - original_accuracy = 100.0 - - logger.info("Apply quantization to the model:") - config_quantization_params = config["compression"] - - preset = get_quantization_preset(config_quantization_params) - advanced_parameters = get_advanced_ptq_parameters(config_quantization_params) - # subset_size = get_num_samples(config_quantization_params) - - quantized_model = nncf.quantize( - model, - datasets.calibration_dataset, - preset=preset, - advanced_parameters=advanced_parameters, - subset_size=1, - ) - - transformations_state = nncf.serialize_transformations(quantized_model) - state_dict = quantized_model.state_dict() - del quantized_model - quantized_model = nncf.apply_serialized_transformations(model, transformations_state) - quantized_model.load_state_dict(state_dict) - - train_criterion_fn = inception_criterion_fn if "inception" in model_name else default_criterion_fn - acc_drop = train( - quantized_model, - config, - criterion, - train_criterion_fn, - datasets, - original_accuracy, - get_mocked_compression_ctrl(), - ) - assert accuracy_drop_is_acceptable(acc_drop) - check_training_correctness(config, model, datasets, criterion, train_criterion_fn) - logger.info("Done!") diff --git a/tests/torch/test_serialization.py b/tests/torch/test_serialization.py index 737e9c1b7a6..d9b125d5540 100644 --- a/tests/torch/test_serialization.py +++ b/tests/torch/test_serialization.py @@ -15,12 +15,12 @@ import pytest import torch -import nncf from nncf.common.factory import ModelTransformerFactory from nncf.common.quantization.structs import QuantizationScheme from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply from nncf.torch.dynamic_graph.io_handling import FillerInputElement from nncf.torch.dynamic_graph.io_handling import FillerInputInfo +from nncf.torch.dynamic_graph.io_handling import ModelInputInfo from nncf.torch.graph.transformations.serialization import load_command from nncf.torch.graph.transformations.serialization import load_transformations from nncf.torch.graph.transformations.serialization import serialize_command @@ -56,39 +56,56 @@ def test_serialize_load_command(target_type, command_builder, priority): _check_commands_after_serialization(command, recovered_command, dummy_op_state) -def test_serialize_transformations(mocker): +def test_serialize_transformations(): layout = TwoConvTestModel.get_all_available_commands() - model = mocker.MagicMock() - input_info_ref = FillerInputInfo([FillerInputElement([1, 1, 4, 4])]) - model.nncf._input_info = input_info_ref - serialized_transformations = serialize_transformations(model, layout) + serialized_transformations = serialize_transformations(layout) # Check serialized transformation are json compatible j_str = json.dumps(serialized_transformations) serialized_transformations = json.loads(j_str) - recovered_layout, input_info = load_transformations(serialized_transformations) - assert input_info == input_info_ref + recovered_layout = load_transformations(serialized_transformations) assert len(layout.transformations) == len(recovered_layout.transformations) # Can zip layouts because the order should not be altered for command, recovered_command in zip(layout.transformations, recovered_layout.transformations): _check_commands_after_serialization(command, recovered_command) +def apply_serialized_transformations_impl( + model: torch.nn.Module, input_info: ModelInputInfo, serialized_transformations +): + transformations_layout = load_transformations(serialized_transformations) + + nncf_network = NNCFNetwork(deepcopy(model), input_info=input_info) + model_transformer = ModelTransformerFactory.create(nncf_network) + transformed_model = model_transformer.transform(transformations_layout) + + transformed_model.nncf.disable_dynamic_graph_building() + return transformed_model + + +def serialize_transformations_impl( + model: NNCFNetwork, +): + layout = model.nncf.get_applied_transformation_layout() + return serialize_transformations(layout) + + def test_get_apply_serialization_from_a_model(): layout = TwoConvTestModel.get_all_available_commands(skip_model_transformer_unsupported=True) model = TwoConvTestModel() - nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + input_info = FillerInputInfo([FillerInputElement([1, 1, 4, 4])]) + nncf_model = NNCFNetwork(deepcopy(model), input_info=input_info) modified_model = ModelTransformerFactory.create(nncf_model).transform(layout) - serialized_transformations = nncf.serialize_transformations(modified_model) + serialized_transformations = serialize_transformations_impl(modified_model) # Check serialized transformation are json compatible j_str = json.dumps(serialized_transformations) serialized_transformations = json.loads(j_str) - recovered_model = nncf.apply_serialized_transformations(model, serialized_transformations) + recovered_model = apply_serialized_transformations_impl(model, input_info, serialized_transformations) for conv, recovered_conv in zip(modified_model.features, recovered_model.features): for hooks_attr in ["pre_ops", "post_ops"]: hooks = getattr(conv[0], hooks_attr)