From f89019b92b84b8a57d9dfbd26318c7f67ebd89db Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 27 Mar 2024 16:35:13 +0100 Subject: [PATCH] [Torch] Save/load NNCFNetwork state --- nncf/__init__.py | 2 + nncf/quantization/__init__.py | 2 + .../algorithms/min_max/torch_backend.py | 10 +- .../algorithms/smooth_quant/torch_backend.py | 37 ++- nncf/quantization/quantize_model.py | 33 ++- nncf/torch/dynamic_graph/io_handling.py | 20 ++ nncf/torch/graph/graph.py | 7 + .../graph/transformations/serialization.py | 113 +++++++ nncf/torch/layer_utils.py | 34 +++ nncf/torch/model_transformer.py | 1 + nncf/torch/nncf_network.py | 130 +++++++++ nncf/torch/pruning/filter_pruning/layers.py | 27 +- nncf/torch/quantization/layers.py | 12 +- nncf/torch/quantization/quantize_model.py | 22 ++ nncf/torch/sparsity/layers.py | 18 +- nncf/torch/sparsity/rb/layers.py | 31 +- tests/torch/helpers.py | 207 +++++++++++++ tests/torch/qat/test_qat_classification.py | 64 ++++ tests/torch/test_nncf_network.py | 84 ++++++ tests/torch/test_serialization.py | 275 ++++++++++++++++++ 20 files changed, 1104 insertions(+), 25 deletions(-) create mode 100644 nncf/torch/graph/transformations/serialization.py create mode 100644 tests/torch/test_serialization.py diff --git a/nncf/__init__.py b/nncf/__init__.py index eaaa755a49e..897444e9450 100644 --- a/nncf/__init__.py +++ b/nncf/__init__.py @@ -39,9 +39,11 @@ from nncf.parameters import SensitivityMetric as SensitivityMetric from nncf.parameters import TargetDevice as TargetDevice from nncf.quantization import QuantizationPreset as QuantizationPreset +from nncf.quantization import apply_serialized_transformations as apply_serialized_transformations from nncf.quantization import compress_weights as compress_weights from nncf.quantization import quantize as quantize from nncf.quantization import quantize_with_accuracy_control as quantize_with_accuracy_control +from nncf.quantization import serialize_transformations as serialize_transformations from nncf.quantization.advanced_parameters import ( AdvancedAccuracyRestorerParameters as AdvancedAccuracyRestorerParameters, ) diff --git a/nncf/quantization/__init__.py b/nncf/quantization/__init__.py index a1b78c774e1..a42f0247f12 100644 --- a/nncf/quantization/__init__.py +++ b/nncf/quantization/__init__.py @@ -10,6 +10,8 @@ # limitations under the License. """Post-training quantization APIs.""" from nncf.common.quantization.structs import QuantizationPreset as QuantizationPreset +from nncf.quantization.quantize_model import apply_serialized_transformations as apply_serialized_transformations from nncf.quantization.quantize_model import compress_weights as compress_weights from nncf.quantization.quantize_model import quantize as quantize from nncf.quantization.quantize_model import quantize_with_accuracy_control as quantize_with_accuracy_control +from nncf.quantization.quantize_model import serialize_transformations as serialize_transformations diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index d5a37ddcbe5..d10b1575ba3 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -263,22 +263,22 @@ def _create_quantizer( quantizer = quantizer_cls(quantizer_spec) # Fill it with minmax - PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters) + PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) return quantizer @staticmethod - def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters) -> None: + def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None: if isinstance(quantizer, AsymmetricQuantizer): - quantizer.input_low = torch.nn.Parameter(parameters.input_low.data) + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape)) input_range = parameters.input_high - parameters.input_low # Subtract eps from the input_range to make quantizer parameters equal to # original parameters on the forward call. - quantizer.input_range = torch.nn.Parameter(input_range.data - quantizer.eps) + quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape)) else: quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) # Subtract eps from the scale to make quantizer parameters equal to # original parameters on the forward call. - quantizer.scale = torch.nn.Parameter(parameters.input_high.data - quantizer.eps) + quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape)) @staticmethod def create_quantizer_insertion_command( diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index 275f9a2523e..5d66449f60d 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Tuple +from typing import Any, Callable, Dict, List, Tuple import numpy as np import torch @@ -30,20 +30,41 @@ from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.layer_utils import COMPRESSION_MODULES +from nncf.torch.layer_utils import CompressionParameter +from nncf.torch.layer_utils import StatefullTorchModuleInterface from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor -class SQMultiply(torch.nn.Module): - def __init__(self, scale_value): +@COMPRESSION_MODULES.register() +class SQMultiply(torch.nn.Module, StatefullTorchModuleInterface): + SCALE_SHAPE_KEY = "scale_shape" + + def __init__(self, scale_shape: Tuple[int, ...]): super().__init__() - self._scale_value = scale_value + self._scale_value = CompressionParameter(torch.empty(scale_shape)) + + @property + def scale(self) -> torch.nn.Parameter: + return self._scale_value - def forward(self, x): + @scale.setter + def scale(self, value: torch.tensor): + self._scale_value.data = value + + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.mul(x, self._scale_value) + def get_state(self) -> Dict[str, Any]: + return {self.SCALE_SHAPE_KEY: list(self._scale_value.shape)} + + @classmethod + def from_state(cls, state) -> "SQMultiply": + return SQMultiply(state[cls.SCALE_SHAPE_KEY]) + PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK @@ -117,7 +138,7 @@ def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) @staticmethod def scale_insertion_command( source_node: NNCFNode, - scale_value: np.ndarray, + scale_value: torch.Tensor, source_output_port_id: int, nodes: List[NNCFNode], scale_node_name: str, @@ -127,7 +148,9 @@ def scale_insertion_command( for node in nodes: target_points.append(PTTargetPoint(PT_PRE_LAYER_TARGET_TYPE, node.node_name, input_port_id=input_port_id)) - return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name) + sq_multiply = SQMultiply(scale_value.shape) + sq_multiply.scale = scale_value + return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name) @staticmethod def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index fe8a69ace20..04837e06912 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union import nncf from nncf.api.compression import TModel @@ -540,3 +540,34 @@ def quantize_with_tune_hyperparams( quantized_model = hyperparameter_tuner.apply(model, validation_dataset) return quantized_model + + +@api(canonical_alias="nncf.apply_serialized_transformations") +def apply_serialized_transformations( + model: TModel, + serialized_transformations, +) -> TModel: + """ + Applies transformation layout to the model. + """ + backend = get_backend(model) + if backend == BackendType.TORCH: + from nncf.torch.quantization.quantize_model import apply_serialized_transformations_impl + + return apply_serialized_transformations_impl(model, serialized_transformations) + raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") + + +@api(canonical_alias="nncf.serialize_transformations") +def serialize_transformations( + model: TModel, +) -> Dict[str, Any]: + """ + Applies transformation layout to the model. + """ + backend = get_backend(model) + if backend == BackendType.TORCH: + from nncf.torch.quantization.quantize_model import serialize_transformations_impl + + return serialize_transformations_impl(model) + raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") diff --git a/nncf/torch/dynamic_graph/io_handling.py b/nncf/torch/dynamic_graph/io_handling.py index c0f3aeb92e5..2884033e6fb 100644 --- a/nncf/torch/dynamic_graph/io_handling.py +++ b/nncf/torch/dynamic_graph/io_handling.py @@ -122,6 +122,7 @@ def __init__(self, shape: List[int], type_str: str = "float", keyword: str = Non """ self.shape = shape self.type = self._string_to_torch_type(type_str) + self._type_str = type_str self.keyword = keyword if filler is None: self.filler = self.FILLER_TYPE_ONES @@ -157,6 +158,15 @@ def get_tensor_for_input(self) -> torch.Tensor: return torch.rand(size=self.shape, dtype=self.type) raise NotImplementedError + def get_state(self) -> Dict[str, Any]: + return {"shape": self.shape, "type_str": self._type_str, "keyword": self.keyword, "filler": self.filler} + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "FillerInputElement": + return FillerInputElement( + shape=state["shape"], type_str=state["type_str"], keyword=state["keyword"], filler=state["filler"] + ) + class FillerInputInfo(ModelInputInfo): """ @@ -220,6 +230,16 @@ def get_forward_inputs( kwargs[fe.keyword] = tensor return tuple(args_list), kwargs + def get_state(self) -> Dict[str, Any]: + return {"elements": [elem.get_state() for elem in self.elements]} + + @classmethod + def from_state(cls, state) -> "FillerInputInfo": + return FillerInputInfo([FillerInputElement.from_state(s) for s in state["elements"]]) + + def __eq__(self, other: "FillerInputInfo") -> bool: + return self.elements == other.elements + class ExactInputsInfo(ModelInputInfo): """ diff --git a/nncf/torch/graph/graph.py b/nncf/torch/graph/graph.py index 1a651f9a363..63637f759b0 100644 --- a/nncf/torch/graph/graph.py +++ b/nncf/torch/graph/graph.py @@ -60,6 +60,13 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]: matching_graph_op_nodes.extend(nodes_in_module) return matching_graph_op_nodes + def get_op_node_in_scope(self, scope: Scope) -> List[NNCFNode]: + for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items(): + module_scope = Scope.from_str(scope_str) + if module_scope == scope: + return nodes_in_module + return [] + def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope: matches = [] for node_id, scope_str in self._node_ids_vs_layer_names.items(): diff --git a/nncf/torch/graph/transformations/serialization.py b/nncf/torch/graph/transformations/serialization.py new file mode 100644 index 00000000000..9f9a136f12d --- /dev/null +++ b/nncf/torch/graph/transformations/serialization.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Any, Dict, Tuple, Union + +import torch + +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.torch.dynamic_graph.io_handling import FillerInputInfo +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTTransformationCommand +from nncf.torch.layer_utils import COMPRESSION_MODULES + +COMPRESSION_STATE_ATTR = "compression_state" +INPUT_INFO_ATTR = "example_input" + + +class CompressionKeys(Enum): + SHARED_INSERTION_COMMAND = "SHARED_INSERTION_COMMAND" + INSERTION_COMMAND = "INSERTION_COMMAND" + + +def serialize_transformations(model: torch.nn.Module, transformations_layout: TransformationLayout) -> Dict[str, Any]: + input_info = model.nncf._input_info + if not isinstance(input_info, FillerInputInfo): + raise RuntimeError("Could not serialize model inputs input: {input_info}") + + transformation_commands = [] + for command in transformations_layout.transformations: + serialized_command = serialize_command(command) + if serialized_command: + transformation_commands.append(serialized_command) + + return {COMPRESSION_STATE_ATTR: transformation_commands, INPUT_INFO_ATTR: input_info.get_state()} + + +def load_transformations(transformations_state: Dict[str, Any]) -> Tuple[TransformationLayout, FillerInputInfo]: + transformation_layout = TransformationLayout() + for serialized_command in transformations_state[COMPRESSION_STATE_ATTR]: + command = load_command(serialized_command) + transformation_layout.register(command) + + input_info = FillerInputInfo.from_state(transformations_state[INPUT_INFO_ATTR]) + return transformation_layout, input_info + + +def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]: + if not isinstance(command, (PTSharedFnInsertionCommand, PTInsertionCommand)): + return {} + + serialized_transformation = dict() + if isinstance(command, PTSharedFnInsertionCommand): + serialized_transformation["type"] = CompressionKeys.SHARED_INSERTION_COMMAND.value + serialized_transformation["target_points"] = [point.get_state() for point in command.target_points] + serialized_transformation["op_name"] = command.op_name + serialized_transformation["compression_module_type"] = command.compression_module_type.value + + elif isinstance(command, PTInsertionCommand): + serialized_transformation["type"] = CompressionKeys.INSERTION_COMMAND.value + serialized_transformation["target_point"] = command.target_point.get_state() + + # Check compression module is registered + compression_module_name = command.fn.__class__.__name__ + if compression_module_name not in COMPRESSION_MODULES.registry_dict: + raise RuntimeError( + f"Could not serialize compression module with name {compression_module_name}." + " Please register your module in the COMPRESSION_MODULES registry." + ) + serialized_transformation["compression_module_name"] = compression_module_name + serialized_transformation["fn_state"] = command.fn.get_state() + serialized_transformation["hooks_group_name"] = command.hooks_group_name + priority = command.priority + serialized_transformation["priority"] = priority.value if isinstance(priority, Enum) else priority + return serialized_transformation + + +def load_command(serialized_command: Dict[str, Any]) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]: + module_cls = COMPRESSION_MODULES.get(serialized_command["compression_module_name"]) + fn = module_cls.from_state(serialized_command["fn_state"]) + priority = serialized_command["priority"] + if priority in iter(TransformationPriority): + priority = TransformationPriority(priority) + + if serialized_command["type"] == CompressionKeys.INSERTION_COMMAND.value: + target_point = PTTargetPoint.from_state(serialized_command["target_point"]) + return PTInsertionCommand( + point=target_point, fn=fn, priority=priority, hooks_group_name=serialized_command["hooks_group_name"] + ) + + if serialized_command["type"] == CompressionKeys.SHARED_INSERTION_COMMAND.value: + target_points = [PTTargetPoint.from_state(state) for state in serialized_command["target_points"]] + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + op_unique_name=serialized_command["op_name"], + compression_module_type=ExtraCompressionModuleType(serialized_command["compression_module_type"]), + priority=priority, + hooks_group_name=serialized_command["hooks_group_name"], + ) + raise RuntimeError(f"Command type {serialized_command['type']} is not supported.") diff --git a/nncf/torch/layer_utils.py b/nncf/torch/layer_utils.py index 85756f13dcc..7a4c7975813 100644 --- a/nncf/torch/layer_utils.py +++ b/nncf/torch/layer_utils.py @@ -9,6 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC +from abc import abstractclassmethod +from abc import abstractmethod +from typing import Any, Dict + import torch from torch import nn @@ -19,6 +24,30 @@ COMPRESSION_MODULES = Registry("compression modules") +class StatefullTorchModuleInterface(ABC): + """ + Interface that should be implemented for every registered compression module to make it possible + to save an compression modules state and create an compression module from the saved state. + State of the module should be json serializable, no python objects except + standart (str, list and etc.) should be present in a compression module state. + The state for attributes with type torch.nn.Parameter + is recovered from the model `state_dict`, so there is no need to keep them in the module state. + + """ + + @abstractmethod + def get_state(self) -> Dict[str, Any]: + """ + Returns the compression module state. + """ + + @abstractclassmethod + def from_state(cls, state: Dict[str, Any]) -> object: + """ + Creates a compression module instance from the given state. + """ + + class ProxyModule: def __init__(self, module): self._module = module @@ -117,7 +146,12 @@ def __init__(self, data: torch.Tensor = None, requires_grad: bool = True, compre """ super().__init__() + self._compression_lr_multiplier = compression_lr_multiplier if compression_lr_multiplier is not None and self.dtype.is_floating_point: self.requires_grad = True self.register_hook(lambda grad: compression_lr_multiplier * grad) self.requires_grad = requires_grad + + @property + def compression_lr_multiplier(self): + return self._compression_lr_multiplier diff --git a/nncf/torch/model_transformer.py b/nncf/torch/model_transformer.py index 19c2c647b05..dc5b0a48522 100644 --- a/nncf/torch/model_transformer.py +++ b/nncf/torch/model_transformer.py @@ -61,6 +61,7 @@ def __init__(self, model: NNCFNetwork): ] def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwork: + # self._model.nncf.record_commands(transformation_layout.transformations) transformations = transformation_layout.transformations aggregated_transformations = defaultdict(list) requires_graph_rebuild = False diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index a27d338a77a..d02f4038d31 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -41,6 +41,7 @@ from nncf.common.utils.debug import is_debug from nncf.torch.debug import CombinedDebugInterface from nncf.torch.debug import debuggable_forward +from nncf.torch.dynamic_graph.context import PreHookId from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.graph import ShapeIgnoringTensorMetaComparator @@ -60,6 +61,7 @@ from nncf.torch.dynamic_graph.wrappers import wrap_module_call from nncf.torch.dynamic_graph.wrappers import wrap_parameters from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME +from nncf.torch.external_hook import ExternalOpCallHook from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph_builder import GraphBuilder from nncf.torch.graph.graph_builder import GraphConverter @@ -67,9 +69,13 @@ from nncf.torch.graph.operator_metatypes import PTSplitMetatype from nncf.torch.graph.transformations.commands import DEFAULT_HOOKS_GROUP_NAME from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler from nncf.torch.layer_utils import _NNCFModuleMixin +from nncf.torch.module_operations import UpdateWeight from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from nncf.torch.utils import compute_FLOPs_hook @@ -778,6 +784,119 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable) result.append(scope_in_model) return result + def get_applied_transformation_layout(self) -> PTTransformationLayout: + """ + Collects all hooks applied to the NNCFNetwork, converts them to insertion commands + and returns in PTTransformationLayout format. Default hooks group name is used in + recovered commands, so hooks group names specified diring the model modification + become outdated. + + :return: Transformation layout with all commands applied to the NNCFNetwork. + """ + + def _create_pt_insert_command(module, target_type, target_node_name, priority, input_port_id): + target_point = PTTargetPoint( + target_type=target_type, target_node_name=target_node_name, input_port_id=input_port_id + ) + return PTInsertionCommand(point=target_point, fn=module, priority=priority) + + def _check_external_call_hook_is_valid(hook: ExternalOpCallHook, info: str): + assert hasattr( + self, hook._storage_name + ), f"Storage name {hook._storage_name} is not registered. Info: {info}" + assert hook._storage_key in getattr( + self, hook._storage_name + ), f"Storage key {hook._storage_key} is not registered. Info: {info}" + + context_hooks = defaultdict(lambda: defaultdict(list)) + transformation_layout = PTTransformationLayout() + nncf_graph = self.get_graph() + nncf_node_names_map = self.get_op_address_to_op_name_map() + + # Collect pre/post layer and op with weights insertion commands + for nncf_module, module_scope in self.get_nncf_modules().items(): + for ops, target_type in ( + (nncf_module.pre_ops, TargetType.PRE_LAYER_OPERATION), + (nncf_module.post_ops, TargetType.POST_LAYER_OPERATION), + ): + for priority, module in enumerate(ops.values()): + nodes_in_scope = nncf_graph.get_op_node_in_scope(module_scope) + assert len(nodes_in_scope) == 1 + nncf_node = nodes_in_scope[0] + command_target_type = target_type + if isinstance(module, UpdateWeight): + command_target_type = TargetType.OPERATION_WITH_WEIGHTS + module = module.op + if not isinstance(module, ExternalOpCallHook): + command = _create_pt_insert_command( + module, command_target_type, nncf_node.node_name, priority, None + ) + transformation_layout.register(command) + continue + + info = f"TargetType: {command_target_type}, nncf node name: {nncf_node.node_name}," + f" priority: {priority}, fn: {module}" + _check_external_call_hook_is_valid(module, info) + + context_hooks[module._storage_name][module._storage_key].append( + (command_target_type, nncf_node.node_name, priority, module, None) + ) + + # Collect all pre/post hooks commands + for ops, target_type in ( + (self._compressed_context._pre_hooks, TargetType.OPERATOR_PRE_HOOK), + (self._compressed_context._post_hooks, TargetType.OPERATOR_POST_HOOK), + ): + for op_address, hooks in ops.items(): + if isinstance(op_address, PreHookId): + input_port_id = op_address.input_port_id + op_address = op_address.op_address + else: + input_port_id = None + for priority, fn in enumerate(hooks.values()): + target_node_names = nncf_node_names_map[op_address] + assert len(target_node_names) == 1 + target_node_name = target_node_names[0] + + if not isinstance(fn, ExternalOpCallHook): + command = _create_pt_insert_command(fn, target_type, target_node_name, priority, input_port_id) + transformation_layout.register(command) + continue + + info = f"TargetType: {target_type}, op_address: {op_address}, priority: {priority}, fn: {fn}" + _check_external_call_hook_is_valid(fn, info) + + context_hooks[fn._storage_name][fn._storage_key].append( + (target_type, target_node_name, priority, fn, input_port_id) + ) + + # Create shared fn insertion commands according to external hooks collected from + # pre/post layer, pre/post hooks and op with weights target points. + for module_type_name, storage in context_hooks.items(): + for storage_key, call_hook_list_info in storage.items(): + compression_module = getattr(self, module_type_name)[storage_key] + target_points = [] + for target_type, target_node_name, priority, fn, input_port_id in call_hook_list_info: + target_points.append(PTTargetPoint(target_type, target_node_name, input_port_id=input_port_id)) + + if module_type_name == EXTERNAL_QUANTIZERS_STORAGE_NAME: + module_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER + elif module_type_name == EXTERNAL_OP_STORAGE_NAME: + module_type = ExtraCompressionModuleType.EXTERNAL_OP + else: + raise RuntimeError(f"Module type {module_type_name} is not supported") + + command = PTSharedFnInsertionCommand( + target_points=target_points, + fn=compression_module, + op_unique_name=storage_key, + compression_module_type=module_type, + priority=priority, + ) + transformation_layout.register(command) + + return transformation_layout + def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]: """ Returns map of NNCFGraph node names vs DynamicGraph operation addresses. @@ -796,6 +915,17 @@ def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress] retval[nncf_node.node_name] = op_address return retval + def get_op_address_to_op_name_map(self) -> Dict[OperationAddress, NNCFNodeName]: + """ + Returns map of DynamicGraph operation addresses vs NNCFGraph node names. + + :return: DynamicGraph operation addresses vs NNCFGraph node names. + """ + retval = defaultdict(list) + for nncf_node_name, op_address in self.get_node_to_op_address_mapping().items(): + retval[op_address].append(nncf_node_name) + return retval + def set_compression_controller(self, ctrl: CompressionAlgorithmController): self.compression_controller = ctrl diff --git a/nncf/torch/pruning/filter_pruning/layers.py b/nncf/torch/pruning/filter_pruning/layers.py index 2d95a9d982f..872ce31ca24 100644 --- a/nncf/torch/pruning/filter_pruning/layers.py +++ b/nncf/torch/pruning/filter_pruning/layers.py @@ -8,6 +8,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import Any, Dict + import numpy as np import torch from torch import nn @@ -15,15 +18,20 @@ import nncf from nncf.common.graph import NNCFNodeName from nncf.torch.layer_utils import COMPRESSION_MODULES +from nncf.torch.layer_utils import StatefullTorchModuleInterface @COMPRESSION_MODULES.register() -class FilterPruningMask(nn.Module): +class FilterPruningMask(nn.Module, StatefullTorchModuleInterface): """ A module contains the mask for pruning. On forward pass applying the mask to weight and bias of the module. """ + MASK_APPLYING_DIM_KEY = "dim" + NODE_NAME_KEY = "node_name" + SIZE_KEY = "size_key" + def __init__(self, size, node_name, dim=0): super().__init__() self.register_buffer("_binary_filter_pruning_mask", torch.ones(size)) @@ -31,11 +39,11 @@ def __init__(self, size, node_name, dim=0): self.node_name = node_name @property - def binary_filter_pruning_mask(self): + def binary_filter_pruning_mask(self) -> torch.Tensor: return self._binary_filter_pruning_mask @binary_filter_pruning_mask.setter - def binary_filter_pruning_mask(self, mask): + def binary_filter_pruning_mask(self, mask: torch.Tensor): with torch.no_grad(): self._binary_filter_pruning_mask.set_(mask) @@ -56,6 +64,19 @@ def forward(self, **params): ) return new_params + def get_state(self) -> Dict[str, Any]: + return { + self.MASK_APPLYING_DIM_KEY: self.mask_applying_dim, + self.NODE_NAME_KEY: self.node_name, + self.SIZE_KEY: list(self.binary_filter_pruning_mask.size()), + } + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "FilterPruningMask": + return FilterPruningMask( + size=state[cls.SIZE_KEY], node_name=state[cls.NODE_NAME_KEY], dim=state[cls.MASK_APPLYING_DIM_KEY] + ) + def broadcast_filter_mask(filter_mask, shape, dim=0): broadcasted_shape = np.ones(len(shape), dtype=np.int64) diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index cb8906b1ff0..92ccf89d010 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -41,6 +41,7 @@ from nncf.torch.graph.transformations.commands import TargetType from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import CompressionParameter +from nncf.torch.layer_utils import StatefullTorchModuleInterface from nncf.torch.quantization.quantize_functions import ExportQuantizeToFakeQuantize from nncf.torch.quantization.quantize_functions import ExportQuantizeToONNXQuantDequant from nncf.torch.quantization.quantize_functions import TuneRange @@ -283,9 +284,10 @@ def add_quantization_point(self, qp_id: QuantizationPointId, qp: PTQuantizationP self.quantization_points[qp_id] = qp -class BaseQuantizer(nn.Module, ABC): +class BaseQuantizer(nn.Module, StatefullTorchModuleInterface, ABC): def __init__(self, qspec: PTQuantizerSpec): super().__init__() + self._qspec = qspec self._narrow_range = qspec.narrow_range self._signedness_to_force = qspec.signedness_to_force self._is_using_log_scale_storage = qspec.logarithm_scale @@ -563,6 +565,14 @@ def get_parameters_for_torch_fq(self) -> Tuple[int, int, torch.Tensor, torch.Ten zero_point - Quantizer zero point. """ + def get_state(self): + return self._qspec.get_state() + + @classmethod + def from_state(cls, state) -> "BaseQuantizer": + qsetup = PTQuantizerSpec.from_state(state) + return cls(qsetup) + class QuantizersSwitcher: """Enables/disables quantizers with saving and restoring original state""" diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index 48f3ddefae2..81aadc5007e 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -15,6 +15,7 @@ import torch import nncf +from nncf.common.factory import ModelTransformerFactory from nncf.common.factory import NNCFGraphFactory from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset @@ -29,7 +30,10 @@ from nncf.quantization.quantize_model import warning_model_no_batchwise_support from nncf.scopes import IgnoredScope from nncf.torch.graph.operator_metatypes import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS +from nncf.torch.graph.transformations.serialization import load_transformations +from nncf.torch.graph.transformations.serialization import serialize_transformations from nncf.torch.model_creation import wrap_model +from nncf.torch.nncf_network import NNCFNetwork DEFAULT_RANGE_TYPE = "mean_min_max" @@ -100,3 +104,21 @@ def compress_weights_impl( ) graph = NNCFGraphFactory.create(model) return compression_algorithm.apply(model, graph, dataset=dataset) + + +def apply_serialized_transformations_impl(model: torch.nn.Module, serialized_transformations): + transformations_layout, input_info = load_transformations(serialized_transformations) + + nncf_network = NNCFNetwork(deepcopy(model), input_info=input_info) + model_transformer = ModelTransformerFactory.create(nncf_network) + transformed_model = model_transformer.transform(transformations_layout) + + transformed_model.nncf.disable_dynamic_graph_building() + return transformed_model + + +def serialize_transformations_impl( + model: NNCFNetwork, +): + layout = model.nncf.get_applied_transformation_layout() + return serialize_transformations(model, layout) diff --git a/nncf/torch/sparsity/layers.py b/nncf/torch/sparsity/layers.py index 5a506b87a10..f6c695895dc 100644 --- a/nncf/torch/sparsity/layers.py +++ b/nncf/torch/sparsity/layers.py @@ -8,29 +8,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Any, Dict, List import torch from torch import nn from nncf.torch.layer_utils import COMPRESSION_MODULES +from nncf.torch.layer_utils import StatefullTorchModuleInterface from nncf.torch.sparsity.functions import apply_binary_mask as apply_binary_mask_impl from nncf.torch.utils import is_tracing_state @COMPRESSION_MODULES.register() -class BinaryMask(nn.Module): +class BinaryMask(nn.Module, StatefullTorchModuleInterface): + SHAPE_KEY = "shape" + def __init__(self, shape: List[int]): super().__init__() self.register_buffer("_binary_mask", torch.ones(shape)) self.frozen = False @property - def binary_mask(self): + def binary_mask(self) -> torch.Tensor: return self._binary_mask @binary_mask.setter - def binary_mask(self, tensor): + def binary_mask(self, tensor: torch.Tensor): with torch.no_grad(): self._binary_mask.set_(tensor) @@ -45,3 +48,10 @@ def _calc_training_binary_mask(self, weight): def apply_binary_mask(self, weight): return apply_binary_mask_impl(self.binary_mask, weight) + + def get_state(self) -> Dict[str, Any]: + return {self.SHAPE_KEY: list(self.binary_mask.shape)} + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "BinaryMask": + return BinaryMask(state[cls.SHAPE_KEY]) diff --git a/nncf/torch/sparsity/rb/layers.py b/nncf/torch/sparsity/rb/layers.py index c1df48ad563..39b1013e5ea 100644 --- a/nncf/torch/sparsity/rb/layers.py +++ b/nncf/torch/sparsity/rb/layers.py @@ -8,20 +8,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Any, Dict, List import torch from nncf.torch.functions import logit from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import CompressionParameter +from nncf.torch.layer_utils import StatefullTorchModuleInterface from nncf.torch.sparsity.layers import BinaryMask from nncf.torch.sparsity.rb.functions import binary_mask from nncf.torch.sparsity.rb.functions import calc_rb_binary_mask @COMPRESSION_MODULES.register() -class RBSparsifyingWeight(BinaryMask): +class RBSparsifyingWeight(BinaryMask, StatefullTorchModuleInterface): + WEIGHTS_SHAPE_KEY = "weight_shape" + FROZEN_KEY = "frozen" + COMPRESSION_LR_MULTIPLIER_KEY = "compression_lr_multiplier" + EPS_KEY = "eps" + def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multiplier=None, eps=1e-6): super().__init__(weight_shape) self.frozen = frozen @@ -36,11 +42,11 @@ def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multipli self.mask_calculation_hook = MaskCalculationHook(self) @property - def mask(self): + def mask(self) -> torch.nn.Parameter: return self._mask @mask.setter - def mask(self, tensor): + def mask(self, tensor: torch.Tensor): self._mask.data = tensor self.binary_mask = binary_mask(self._mask) @@ -51,6 +57,23 @@ def _calc_training_binary_mask(self, weight): def loss(self): return binary_mask(self._mask) + def get_state(self) -> Dict[str, Any]: + return { + self.WEIGHTS_SHAPE_KEY: list(self.mask.shape), + self.FROZEN_KEY: self.frozen, + self.COMPRESSION_LR_MULTIPLIER_KEY: self.mask.compression_lr_multiplier, + self.EPS_KEY: self.eps, + } + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "RBSparsifyingWeight": + return RBSparsifyingWeight( + weight_shape=state[cls.WEIGHTS_SHAPE_KEY], + frozen=state[cls.FROZEN_KEY], + compression_lr_multiplier=state[cls.COMPRESSION_LR_MULTIPLIER_KEY], + eps=state[cls.EPS_KEY], + ) + class MaskCalculationHook: def __init__(self, module): diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index 3dfe3a3df7e..ca42278f9a3 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -8,7 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import contextlib +import functools +import itertools import numbers from abc import ABC from abc import abstractmethod @@ -29,6 +32,8 @@ import nncf from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout from nncf.config import NNCFConfig from nncf.config.extractors import extract_algorithm_names from nncf.config.structures import BNAdaptationInitArgs @@ -38,8 +43,13 @@ from nncf.torch.dynamic_graph.io_handling import FillerInputInfo from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.dynamic_graph.scope import Scope +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args +from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layers import NNCF_MODULES_MAP from nncf.torch.model_creation import create_compressed_model from nncf.torch.module_operations import UpdateWeight @@ -172,6 +182,12 @@ def nz_bias_num(self): class TwoConvTestModel(nn.Module): + INPUT_SHAPE = [1, 1, 4, 4] + NNCF_CONV_NODES_NAMES = [ + "TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", + "TwoConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", + ] + def __init__(self): super().__init__() self.features = [] @@ -198,6 +214,114 @@ def nz_weights_num(self): def nz_bias_num(self): return 2 + @staticmethod + def create_pt_insertion_command( + target_type: TargetType, priority: TransformationPriority, fn=None, group: str = "default_group" + ): + target_point = PTTargetPoint( + target_type=target_type, target_node_name=TwoConvTestModel.NNCF_CONV_NODES_NAMES[0], input_port_id=0 + ) + if fn is None: + fn = DummyOpWithState("DUMMY_STATE") + return PTInsertionCommand(point=target_point, fn=fn, priority=priority, hooks_group_name=group) + + @staticmethod + def create_pt_shared_fn_insertion_command( + target_type: TargetType, + priority: TransformationPriority, + compression_module_type: ExtraCompressionModuleType, + fn=None, + group: str = "default_group", + op_unique_name: str = "UNIQUE_NAME", + ): + target_points = [] + + for node_name in TwoConvTestModel.NNCF_CONV_NODES_NAMES: + target_points.append(PTTargetPoint(target_type=target_type, target_node_name=node_name, input_port_id=0)) + if fn is None: + fn = DummyOpWithState("DUMMY_STATE") + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + compression_module_type=compression_module_type, + op_unique_name=op_unique_name, + priority=priority, + hooks_group_name=group, + ) + + AVAILABLE_TARGET_TYPES = ( + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.OPERATOR_PRE_HOOK, + TargetType.OPERATOR_POST_HOOK, + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ) + + @staticmethod + def get_command_builders(): + return ( + TwoConvTestModel.create_pt_insertion_command, + functools.partial( + TwoConvTestModel.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_OP, + ), + functools.partial( + TwoConvTestModel.create_pt_shared_fn_insertion_command, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + ), + ) + + COMMAND_TYPES = [PTInsertionCommand, PTSharedFnInsertionCommand, PTSharedFnInsertionCommand] + PRIORITIES = (TransformationPriority.QUANTIZATION_PRIORITY, TransformationPriority.QUANTIZATION_PRIORITY.value + 1) + + @classmethod + def get_all_available_commands(cls, skip_model_transformer_unsupported=False) -> TransformationLayout: + """ + Returns all possible commands to insert: + all target types x all command class x all compression module types x different priorities. + DummyOpWithState is used as insertion module. DummyOpWithState states are unique for each + unique command. + """ + layout = TransformationLayout() + for idx, (target_type, (command_builder, command_type), priority) in enumerate( + itertools.product( + cls.AVAILABLE_TARGET_TYPES, zip(cls.get_command_builders(), cls.COMMAND_TYPES), cls.PRIORITIES + ) + ): + dummy_op_state = f"DUMMY_OP_STATE_{idx}" + if command_type is PTSharedFnInsertionCommand: + if skip_model_transformer_unsupported and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + continue + command = cls.create_one_command( + command_builder, target_type, priority, dummy_op_state, op_unique_name=f"UNIQUE_NAME_{idx}" + ) + else: + command = cls.create_one_command(command_builder, target_type, priority, dummy_op_state) + + layout.register(command) + return layout + + @staticmethod + def create_one_command(command_builder, target_type, priority, dummy_op_state, op_unique_name=None): + group_name = "CUSTOM_HOOKS_GROUP_NAME" + + if DummyOpWithState.__name__ not in COMPRESSION_MODULES.registry_dict: + registered_dummy_op_cls = COMPRESSION_MODULES.register()(DummyOpWithState) + else: + registered_dummy_op_cls = DummyOpWithState + dummy_op = registered_dummy_op_cls(dummy_op_state) + if op_unique_name is None: + command = command_builder(target_type, priority, fn=dummy_op, group=group_name) + else: + command = command_builder( + target_type, priority, fn=dummy_op, group=group_name, op_unique_name=op_unique_name + ) + + return command + class LeNet(nn.Module): INPUT_SIZE = 1, 32, 32 @@ -228,6 +352,37 @@ def num_flat_features(self, x): return num_features +def are_commands_equal( + command, applied_command, check_priority: bool = True, check_hooks_group_name: bool = True, check_fn_ref=True +): + if type(applied_command) is not type(command): + return False + + # Check reference to functions are equal. + if check_fn_ref and applied_command.fn is not command.fn: + return False + if check_hooks_group_name and applied_command.hooks_group_name != command.hooks_group_name: + return False + if check_priority and applied_command.priority != command.priority: + return False + + if isinstance(applied_command, PTInsertionCommand): + if not target_points_are_equal(command.target_point, applied_command.target_point): + return False + elif isinstance(applied_command, PTSharedFnInsertionCommand): + if not all(target_points_are_equal(a, b) for a, b in zip(command.target_points, applied_command.target_points)): + return False + if ( + applied_command.target_points != command.target_points + or applied_command.op_name != command.op_name + or applied_command.compression_module_type != command.compression_module_type + ): + return False + else: + raise RuntimeError() + return True + + class SharedConv(nn.Module): INPUT_SIZE = [1, 1, 4, 4] @@ -254,6 +409,58 @@ def forward(self, x): return a + b +class DummyOpWithState(torch.nn.Module): + def __init__(self, state: str): + super().__init__() + self._state = state + + def __call__(self, *args): + if len(args) == 1: + return args[0] + # To work correctly with + # TargetType.PRE_LAYER_OPERATION + # TargetType.POST_LAYER_OPERATION + return None + + def get_state(self): + return self._state + + @classmethod + def from_state(cls, state: str): + return cls(state) + + +def target_points_are_equal(tp_original: PTTargetPoint, tp_recovered: PTTargetPoint): + if tp_original != tp_recovered: + return False + if tp_original.target_type == TargetType.OPERATOR_PRE_HOOK: + return tp_original.input_port_id == tp_recovered.input_port_id + return True + + +def check_commands_are_equal( + command, applied_command, check_priority: bool = True, check_hooks_group_name: bool = True, check_fn_ref=True +): + assert type(applied_command) is type(command) + # Check reference to functions are equal. + if check_fn_ref: + assert applied_command.fn is command.fn + if check_hooks_group_name: + assert applied_command.hooks_group_name == command.hooks_group_name + if check_priority: + assert applied_command.priority == command.priority + + if isinstance(applied_command, PTInsertionCommand): + assert target_points_are_equal(command.target_point, applied_command.target_point) + elif isinstance(applied_command, PTSharedFnInsertionCommand): + all(target_points_are_equal(a, b) for a, b in zip(command.target_points, applied_command.target_points)) + assert applied_command.target_points == command.target_points + assert applied_command.op_name == command.op_name + assert applied_command.compression_module_type == command.compression_module_type + else: + raise RuntimeError() + + def get_empty_config( model_size=4, input_sample_sizes: Union[Tuple[List[int]], List[int]] = None, input_info: Dict = None ) -> NNCFConfig: diff --git a/tests/torch/qat/test_qat_classification.py b/tests/torch/qat/test_qat_classification.py index 394ebfa0071..86c95e2973a 100644 --- a/tests/torch/qat/test_qat_classification.py +++ b/tests/torch/qat/test_qat_classification.py @@ -293,3 +293,67 @@ def test_compression_training(quantization_config_path: Path, sota_data_dir): del sample_config["compression"]["initializer"]["range"] start_worker_clean_memory(main_worker, sample_config) + + +def save_load_main_worker(current_gpu: int, config: SampleConfig): + configure_device(current_gpu, config) + if is_main_process(): + configure_logging(logger, config) + else: + config.tb = None + + pretrained = is_pretrained_model_requested(config) + model_name = config["model"] + # create model + logger.info(f"\nCreating model from config: {config.config}") + model = load_model( + model_name, + pretrained=pretrained, + num_classes=config.get("num_classes", 1000), + model_params=config.get("model_params"), + weights_path=config.get("weights"), + ) + model.to(config.device) + + datasets = get_datasets(config) + criterion = nn.CrossEntropyLoss() + criterion = criterion.to(config.device) + + logger.info("Original model validation:") + # original_accuracy, *_ = validate(datasets.val_data_loader, model, criterion, config) + original_accuracy = 100.0 + + logger.info("Apply quantization to the model:") + config_quantization_params = config["compression"] + + preset = get_quantization_preset(config_quantization_params) + advanced_parameters = get_advanced_ptq_parameters(config_quantization_params) + # subset_size = get_num_samples(config_quantization_params) + + quantized_model = nncf.quantize( + model, + datasets.calibration_dataset, + preset=preset, + advanced_parameters=advanced_parameters, + subset_size=1, + ) + + transformations_state = nncf.serialize_transformations(quantized_model) + state_dict = quantized_model.state_dict() + del quantized_model + quantized_model = nncf.apply_serialized_transformations(model, transformations_state) + quantized_model.load_state_dict(state_dict) + + train_criterion_fn = inception_criterion_fn if "inception" in model_name else default_criterion_fn + acc_drop = train( + quantized_model, + config, + criterion, + train_criterion_fn, + datasets, + original_accuracy, + get_mocked_compression_ctrl(), + ) + assert accuracy_drop_is_acceptable(acc_drop) + check_training_correctness(config, model, datasets, criterion, train_criterion_fn) + logger.info("Done!") diff --git a/tests/torch/test_nncf_network.py b/tests/torch/test_nncf_network.py index c4da4be8c82..795fc704699 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/test_nncf_network.py @@ -27,6 +27,7 @@ from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import UnknownMetatype from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority from nncf.common.hook_handle import HookHandle from nncf.torch import register_module from nncf.torch.dynamic_graph.io_handling import ExampleInputInfo @@ -40,9 +41,12 @@ from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.layers import NNCFConv2d from nncf.torch.model_creation import wrap_model +from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint @@ -52,6 +56,7 @@ from tests.torch.helpers import BasicConvTestModel from tests.torch.helpers import HookChecker from tests.torch.helpers import TwoConvTestModel +from tests.torch.helpers import are_commands_equal from tests.torch.helpers import check_correct_nncf_modules_replacement from tests.torch.helpers import create_compressed_model_and_algo_for_test from tests.torch.helpers import register_bn_adaptation_init_args @@ -1024,3 +1029,82 @@ def test_insert_at_point_hook_handles(self, target_type: TargetType, target_node del ref_hooks[-2] _check(ref_hooks) + + +@pytest.mark.parametrize("target_type", TwoConvTestModel.AVAILABLE_TARGET_TYPES) +@pytest.mark.parametrize("command_builder", TwoConvTestModel.get_command_builders()) +def test_get_applied_modification_commands(command_builder, target_type): + command = command_builder(target_type, TransformationPriority.DEFAULT_PRIORITY) + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + model_tranformer = PTModelTransformer(nncf_model) + + layout = PTTransformationLayout() + layout.register(command) + model_tranformer.transform(layout) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + + assert len(applied_commands.transformations) == 1 + applied_command = applied_commands.transformations[0] + are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + + +@pytest.mark.parametrize("target_type", TwoConvTestModel.AVAILABLE_TARGET_TYPES) +@pytest.mark.parametrize( + "command_builder,command_type", tuple(zip(TwoConvTestModel.get_command_builders(), TwoConvTestModel.COMMAND_TYPES)) +) +def test_priority_of_get_applied_modification_commands(command_builder, target_type, command_type): + layout = PTTransformationLayout() + commands = dict() + for priority in (0, 3, 2, 4, 1): + if command_type is PTSharedFnInsertionCommand: + command = command_builder(target_type, priority, op_unique_name=f"UNIQUE_NAME_{priority}") + else: + command = command_builder(target_type, priority) + layout.register(command) + commands[priority] = command + else: + if isinstance(command, PTSharedFnInsertionCommand) and target_type in [ + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: + pytest.skip(f"PTSharedFnInsertionCommand is not supporting target type {target_type}") + + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + model_tranformer = PTModelTransformer(nncf_model) + + model_tranformer.transform(layout) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + assert len(applied_commands.transformations) == len(commands) + for applied_command in applied_commands.transformations: + command = commands[applied_command.priority] + are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + + +def test_all_possible_combinations_of_commands_for_get_applied_commands(): + commands = TwoConvTestModel.get_all_available_commands(skip_model_transformer_unsupported=True) + + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + model_tranformer = PTModelTransformer(nncf_model) + + model_tranformer.transform(commands) + + applied_commands = nncf_model.nncf.get_applied_transformation_layout() + assert len(applied_commands.transformations) == len(commands.transformations) + for command in commands.transformations: + eq_commands = ( + are_commands_equal(command, applied_command, check_priority=False, check_hooks_group_name=False) + for applied_command in applied_commands.transformations + ) + if sum(map(int, eq_commands)) != 1: + raise RuntimeError(f"Command {command} has no pair in recovered commands") diff --git a/tests/torch/test_serialization.py b/tests/torch/test_serialization.py new file mode 100644 index 00000000000..737e9c1b7a6 --- /dev/null +++ b/tests/torch/test_serialization.py @@ -0,0 +1,275 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from copy import deepcopy + +import pytest +import torch + +import nncf +from nncf.common.factory import ModelTransformerFactory +from nncf.common.quantization.structs import QuantizationScheme +from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply +from nncf.torch.dynamic_graph.io_handling import FillerInputElement +from nncf.torch.dynamic_graph.io_handling import FillerInputInfo +from nncf.torch.graph.transformations.serialization import load_command +from nncf.torch.graph.transformations.serialization import load_transformations +from nncf.torch.graph.transformations.serialization import serialize_command +from nncf.torch.graph.transformations.serialization import serialize_transformations +from nncf.torch.module_operations import UpdateWeight +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.pruning.filter_pruning.layers import FilterPruningMask +from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import SymmetricQuantizer +from nncf.torch.sparsity.layers import BinaryMask +from nncf.torch.sparsity.rb.layers import RBSparsifyingWeight +from tests.torch.helpers import DummyOpWithState +from tests.torch.helpers import TwoConvTestModel +from tests.torch.helpers import check_commands_are_equal + + +@pytest.mark.parametrize("target_type", TwoConvTestModel.AVAILABLE_TARGET_TYPES) +@pytest.mark.parametrize("command_builder", TwoConvTestModel.get_command_builders()) +@pytest.mark.parametrize("priority", TwoConvTestModel.PRIORITIES) +def test_serialize_load_command(target_type, command_builder, priority): + dummy_op_state = "DUMMY_OP_STATE" + command = TwoConvTestModel.create_one_command(command_builder, target_type, priority, dummy_op_state) + + serialized_command = serialize_command(command) + + # Check serialized transformation are json compatible + j_str = json.dumps(serialized_command) + serialized_command = json.loads(j_str) + + recovered_command = load_command(serialized_command) + _check_commands_after_serialization(command, recovered_command, dummy_op_state) + + +def test_serialize_transformations(mocker): + layout = TwoConvTestModel.get_all_available_commands() + model = mocker.MagicMock() + input_info_ref = FillerInputInfo([FillerInputElement([1, 1, 4, 4])]) + model.nncf._input_info = input_info_ref + + serialized_transformations = serialize_transformations(model, layout) + + # Check serialized transformation are json compatible + j_str = json.dumps(serialized_transformations) + serialized_transformations = json.loads(j_str) + + recovered_layout, input_info = load_transformations(serialized_transformations) + assert input_info == input_info_ref + assert len(layout.transformations) == len(recovered_layout.transformations) + # Can zip layouts because the order should not be altered + for command, recovered_command in zip(layout.transformations, recovered_layout.transformations): + _check_commands_after_serialization(command, recovered_command) + + +def test_get_apply_serialization_from_a_model(): + layout = TwoConvTestModel.get_all_available_commands(skip_model_transformer_unsupported=True) + model = TwoConvTestModel() + nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])])) + modified_model = ModelTransformerFactory.create(nncf_model).transform(layout) + + serialized_transformations = nncf.serialize_transformations(modified_model) + + # Check serialized transformation are json compatible + j_str = json.dumps(serialized_transformations) + serialized_transformations = json.loads(j_str) + + recovered_model = nncf.apply_serialized_transformations(model, serialized_transformations) + for conv, recovered_conv in zip(modified_model.features, recovered_model.features): + for hooks_attr in ["pre_ops", "post_ops"]: + hooks = getattr(conv[0], hooks_attr) + recovered_hooks = getattr(recovered_conv[0], hooks_attr) + assert len(hooks) == len(recovered_hooks) + for k, hook in hooks.items(): + recovered_hook = recovered_hooks[k] + if isinstance(hook, UpdateWeight): + assert isinstance(recovered_hook, UpdateWeight) + hook = hook.op + recovered_hook = recovered_hook.op + _check_hook_are_equal(hook, recovered_hook) + + context = modified_model.nncf._compressed_context + recovered_context = recovered_model.nncf._compressed_context + for hooks_attr in ["_pre_hooks", "_post_hooks"]: + container = getattr(context, hooks_attr) + recovered_container = getattr(recovered_context, hooks_attr) + assert len(container) == len(recovered_container) + for op_address, hooks in container.items(): + recovered_hooks = recovered_container[op_address] + for k, hook in hooks.items(): + recovered_hook = recovered_hooks[k] + _check_hook_are_equal(hook, recovered_hook) + + for attr_name in ["external_quantizers", "external_op"]: + container = getattr(modified_model.nncf, attr_name) + recovered_container = getattr(recovered_model.nncf, attr_name) + assert len(container) == len(recovered_container) + for k, module in container.items(): + recovered_module = recovered_container[k] + _check_hook_are_equal(module, recovered_module) + + +def _check_hook_are_equal(hook, recovered_hook): + assert type(hook) == type(recovered_hook) + if isinstance(hook, DummyOpWithState): + assert hook.get_state() == recovered_hook.get_state() + return + # Hook is external op call hook then + assert hook._storage_name == recovered_hook._storage_name + assert hook._storage_key == recovered_hook._storage_key + + +def _check_commands_after_serialization(command, recovered_command, dummy_op_state=None): + check_commands_are_equal(recovered_command, command, check_fn_ref=False) + assert isinstance(command.fn, DummyOpWithState) + assert command.fn.get_state() == recovered_command.fn.get_state() + if dummy_op_state is not None: + assert command.fn.get_state() == dummy_op_state + + +@pytest.mark.parametrize("size", (4, [3, 4])) +def test_pruning_mask_serialization(size): + node_name = "dummy_node_name" + dim = 2 + mask = FilterPruningMask(size=size, node_name=node_name, dim=dim) + mask.binary_filter_pruning_mask = torch.fill(torch.empty(size), 5) + state_dict = mask.state_dict() + + state = mask.get_state() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_mask = FilterPruningMask.from_state(state) + recovered_mask.load_state_dict(state_dict) + + ref_size = size if isinstance(size, list) else [size] + assert list(recovered_mask.binary_filter_pruning_mask.size()) == ref_size + assert recovered_mask.node_name == node_name + assert recovered_mask.mask_applying_dim == dim + + assert torch.all(mask.binary_filter_pruning_mask == recovered_mask.binary_filter_pruning_mask) + + +@pytest.mark.parametrize("quantizer_class", (SymmetricQuantizer, AsymmetricQuantizer)) +def test_quantizer_serialization(quantizer_class: BaseQuantizer): + scale_shape = [1, 3, 1, 1] + ref_qspec = PTQuantizerSpec( + num_bits=4, + mode=QuantizationScheme.ASYMMETRIC, + signedness_to_force=False, + narrow_range=True, + half_range=False, + scale_shape=scale_shape, + logarithm_scale=False, + is_quantized_on_export=False, + compression_lr_multiplier=2.0, + ) + quantizer = quantizer_class(ref_qspec) + if isinstance(quantizer, SymmetricQuantizer): + quantizer.scale = torch.nn.Parameter(torch.fill(torch.empty(scale_shape), 5)) + elif isinstance(quantizer, AsymmetricQuantizer): + quantizer.input_low = torch.nn.Parameter(torch.fill(torch.empty(scale_shape), 6)) + quantizer.input_range = torch.nn.Parameter(torch.fill(torch.empty(scale_shape), 7)) + + state_dict = quantizer.state_dict() + + state = quantizer.get_state() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_quantizer = quantizer_class.from_state(state) + recovered_quantizer.load_state_dict(state_dict) + + assert recovered_quantizer._qspec == ref_qspec + + assert torch.all(quantizer._num_bits == recovered_quantizer._num_bits) + assert torch.all(quantizer.enabled == recovered_quantizer.enabled) + if isinstance(quantizer, SymmetricQuantizer): + assert torch.all(quantizer.signed_tensor == recovered_quantizer.signed_tensor) + assert torch.all(quantizer.scale == recovered_quantizer.scale) + elif isinstance(quantizer, AsymmetricQuantizer): + assert torch.all(quantizer.input_low == recovered_quantizer.input_low) + assert torch.all(quantizer.input_range == recovered_quantizer.input_range) + else: + raise RuntimeError() + + +def test_sparsity_binary_mask_serialization(): + ref_shape = [4, 2, 1, 3] + mask = BinaryMask(ref_shape) + mask.binary_mask = torch.zeros(ref_shape) + state_dict = mask.state_dict() + + state = mask.get_state() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_mask = BinaryMask.from_state(state) + recovered_mask.load_state_dict(state_dict) + + assert list(recovered_mask.binary_mask.shape) == ref_shape + assert torch.all(mask.binary_mask == recovered_mask.binary_mask) + + +def test_rb_sparsity_mask_serialization(): + ref_weights_shape = [3, 2, 4, 1] + ref_frozen = False + ref_compression_lr_multiplier = 2.0 + ref_eps = 0.3 + mask = RBSparsifyingWeight( + weight_shape=ref_weights_shape, + frozen=ref_frozen, + compression_lr_multiplier=ref_compression_lr_multiplier, + eps=ref_eps, + ) + mask.binary_mask = torch.zeros(ref_weights_shape) + mask.mask = torch.fill(torch.empty(ref_weights_shape), 5) + state_dict = mask.state_dict() + + state = mask.get_state() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_mask = RBSparsifyingWeight.from_state(state) + recovered_mask.load_state_dict(state_dict) + + assert list(recovered_mask.mask.shape) == ref_weights_shape + assert recovered_mask.frozen == ref_frozen + assert recovered_mask.mask.compression_lr_multiplier == ref_compression_lr_multiplier + assert recovered_mask.eps == ref_eps + + assert torch.all(mask.mask == recovered_mask.mask) + assert torch.all(mask.binary_mask == recovered_mask.binary_mask) + assert torch.all(mask.uniform == recovered_mask.uniform) + + +def test_sq_multiply_serialization(): + tensor_shape = [1, 3, 5] + tensor_value = torch.fill(torch.empty(tensor_shape, dtype=torch.float16), 5) + sq_multiply = SQMultiply(tensor_shape) + sq_multiply.scale = tensor_value + state_dict = sq_multiply.state_dict() + + state = sq_multiply.get_state() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_sq_multiply = SQMultiply.from_state(state) + recovered_sq_multiply.load_state_dict(state_dict) + + assert torch.all(sq_multiply.scale == recovered_sq_multiply.scale) + assert sq_multiply.scale.shape == recovered_sq_multiply.scale.shape