From 6a66f0198a3ec238c739e41eaeaac37fd643c42b Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 27 Mar 2024 16:35:13 +0100 Subject: [PATCH 1/6] [Torch] Save/load NNCFNetwork state API code moved to a separate PR --- .../algorithms/min_max/torch_backend.py | 10 +- .../algorithms/smooth_quant/torch_backend.py | 37 ++- nncf/torch/dynamic_graph/io_handling.py | 20 ++ .../graph/transformations/serialization.py | 104 ++++++ nncf/torch/layer_utils.py | 34 ++ nncf/torch/model_transformer.py | 1 + nncf/torch/pruning/filter_pruning/layers.py | 27 +- nncf/torch/quantization/layers.py | 12 +- nncf/torch/sparsity/layers.py | 18 +- nncf/torch/sparsity/rb/layers.py | 31 +- tests/torch/helpers.py | 16 +- tests/torch/nncf_network/helpers.py | 8 +- tests/torch/test_serialization.py | 308 ++++++++++++++++++ 13 files changed, 591 insertions(+), 35 deletions(-) create mode 100644 nncf/torch/graph/transformations/serialization.py create mode 100644 tests/torch/test_serialization.py diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index a735ad59cb9..eae40dbffe7 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -268,22 +268,22 @@ def _create_quantizer( quantizer = quantizer_cls(quantizer_spec) # Fill it with minmax - PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters) + PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) return quantizer @staticmethod - def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters) -> None: + def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None: if isinstance(quantizer, AsymmetricQuantizer): - quantizer.input_low = torch.nn.Parameter(parameters.input_low.data) + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape)) input_range = parameters.input_high - parameters.input_low # Subtract eps from the input_range to make quantizer parameters equal to # original parameters on the forward call. - quantizer.input_range = torch.nn.Parameter(input_range.data - quantizer.eps) + quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape)) else: quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) # Subtract eps from the scale to make quantizer parameters equal to # original parameters on the forward call. - quantizer.scale = torch.nn.Parameter(parameters.input_high.data - quantizer.eps) + quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape)) @staticmethod def create_quantizer_insertion_command( diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index 19154231dbc..775eadeffe1 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Tuple +from typing import Any, Callable, Dict, List, Tuple import numpy as np import torch @@ -30,6 +30,9 @@ from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.layer_utils import COMPRESSION_MODULES +from nncf.torch.layer_utils import CompressionParameter +from nncf.torch.layer_utils import StatefullTorchModuleInterface from nncf.torch.model_graph_manager import get_const_data from nncf.torch.model_graph_manager import get_const_node from nncf.torch.nncf_network import NNCFNetwork @@ -38,14 +41,32 @@ from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor -class SQMultiply(torch.nn.Module): - def __init__(self, scale_value): +@COMPRESSION_MODULES.register() +class SQMultiply(torch.nn.Module, StatefullTorchModuleInterface): + SCALE_SHAPE_KEY = "scale_shape" + + def __init__(self, scale_shape: Tuple[int, ...]): super().__init__() - self._scale_value = scale_value + self._scale_value = CompressionParameter(torch.empty(scale_shape)) + + @property + def scale(self) -> torch.nn.Parameter: + return self._scale_value - def forward(self, x): + @scale.setter + def scale(self, value: torch.tensor): + self._scale_value.data = value + + def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.mul(x, self._scale_value) + def get_state(self) -> Dict[str, Any]: + return {self.SCALE_SHAPE_KEY: list(self._scale_value.shape)} + + @classmethod + def from_state(cls, state) -> "SQMultiply": + return SQMultiply(state[cls.SCALE_SHAPE_KEY]) + PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK @@ -122,7 +143,7 @@ def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) @staticmethod def scale_insertion_command( source_node: NNCFNode, - scale_value: np.ndarray, + scale_value: torch.Tensor, source_output_port_id: int, nodes: List[NNCFNode], scale_node_name: str, @@ -132,7 +153,9 @@ def scale_insertion_command( for node in nodes: target_points.append(PTTargetPoint(PT_PRE_LAYER_TARGET_TYPE, node.node_name, input_port_id=input_port_id)) - return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name) + sq_multiply = SQMultiply(scale_value.shape) + sq_multiply.scale = scale_value + return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name) @staticmethod def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: diff --git a/nncf/torch/dynamic_graph/io_handling.py b/nncf/torch/dynamic_graph/io_handling.py index c0f3aeb92e5..2884033e6fb 100644 --- a/nncf/torch/dynamic_graph/io_handling.py +++ b/nncf/torch/dynamic_graph/io_handling.py @@ -122,6 +122,7 @@ def __init__(self, shape: List[int], type_str: str = "float", keyword: str = Non """ self.shape = shape self.type = self._string_to_torch_type(type_str) + self._type_str = type_str self.keyword = keyword if filler is None: self.filler = self.FILLER_TYPE_ONES @@ -157,6 +158,15 @@ def get_tensor_for_input(self) -> torch.Tensor: return torch.rand(size=self.shape, dtype=self.type) raise NotImplementedError + def get_state(self) -> Dict[str, Any]: + return {"shape": self.shape, "type_str": self._type_str, "keyword": self.keyword, "filler": self.filler} + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "FillerInputElement": + return FillerInputElement( + shape=state["shape"], type_str=state["type_str"], keyword=state["keyword"], filler=state["filler"] + ) + class FillerInputInfo(ModelInputInfo): """ @@ -220,6 +230,16 @@ def get_forward_inputs( kwargs[fe.keyword] = tensor return tuple(args_list), kwargs + def get_state(self) -> Dict[str, Any]: + return {"elements": [elem.get_state() for elem in self.elements]} + + @classmethod + def from_state(cls, state) -> "FillerInputInfo": + return FillerInputInfo([FillerInputElement.from_state(s) for s in state["elements"]]) + + def __eq__(self, other: "FillerInputInfo") -> bool: + return self.elements == other.elements + class ExactInputsInfo(ModelInputInfo): """ diff --git a/nncf/torch/graph/transformations/serialization.py b/nncf/torch/graph/transformations/serialization.py new file mode 100644 index 00000000000..616a4c25e52 --- /dev/null +++ b/nncf/torch/graph/transformations/serialization.py @@ -0,0 +1,104 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Any, Dict, Union + +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTTransformationCommand +from nncf.torch.layer_utils import COMPRESSION_MODULES + +COMPRESSION_STATE_ATTR = "compression_state" + + +class CompressionKeys(Enum): + SHARED_INSERTION_COMMAND = "SHARED_INSERTION_COMMAND" + INSERTION_COMMAND = "INSERTION_COMMAND" + + +def serialize_transformations(transformations_layout: TransformationLayout) -> Dict[str, Any]: + transformation_commands = [] + for command in transformations_layout.transformations: + serialized_command = serialize_command(command) + if serialized_command: + transformation_commands.append(serialized_command) + + return {COMPRESSION_STATE_ATTR: transformation_commands} + + +def load_transformations(transformations_state: Dict[str, Any]) -> TransformationLayout: + transformation_layout = TransformationLayout() + for serialized_command in transformations_state[COMPRESSION_STATE_ATTR]: + command = load_command(serialized_command) + transformation_layout.register(command) + + return transformation_layout + + +def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]: + if not isinstance(command, (PTSharedFnInsertionCommand, PTInsertionCommand)): + return {} + + serialized_transformation = dict() + if isinstance(command, PTSharedFnInsertionCommand): + serialized_transformation["type"] = CompressionKeys.SHARED_INSERTION_COMMAND.value + serialized_transformation["target_points"] = [point.get_state() for point in command.target_points] + serialized_transformation["op_name"] = command.op_name + serialized_transformation["compression_module_type"] = command.compression_module_type.value + + elif isinstance(command, PTInsertionCommand): + serialized_transformation["type"] = CompressionKeys.INSERTION_COMMAND.value + serialized_transformation["target_point"] = command.target_point.get_state() + + # Check compression module is registered + compression_module_name = command.fn.__class__.__name__ + if compression_module_name not in COMPRESSION_MODULES.registry_dict: + raise RuntimeError( + f"Could not serialize compression module with name {compression_module_name}." + " Please register your module in the COMPRESSION_MODULES registry." + ) + serialized_transformation["compression_module_name"] = compression_module_name + serialized_transformation["fn_state"] = command.fn.get_state() + serialized_transformation["hooks_group_name"] = command.hooks_group_name + priority = command.priority + serialized_transformation["priority"] = priority.value if isinstance(priority, Enum) else priority + return serialized_transformation + + +def load_command(serialized_command: Dict[str, Any]) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]: + module_cls = COMPRESSION_MODULES.get(serialized_command["compression_module_name"]) + fn = module_cls.from_state(serialized_command["fn_state"]) + priority = serialized_command["priority"] + if priority in iter(TransformationPriority): + priority = TransformationPriority(priority) + + if serialized_command["type"] == CompressionKeys.INSERTION_COMMAND.value: + target_point = PTTargetPoint.from_state(serialized_command["target_point"]) + return PTInsertionCommand( + point=target_point, fn=fn, priority=priority, hooks_group_name=serialized_command["hooks_group_name"] + ) + + if serialized_command["type"] == CompressionKeys.SHARED_INSERTION_COMMAND.value: + target_points = [PTTargetPoint.from_state(state) for state in serialized_command["target_points"]] + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + op_unique_name=serialized_command["op_name"], + compression_module_type=ExtraCompressionModuleType(serialized_command["compression_module_type"]), + priority=priority, + hooks_group_name=serialized_command["hooks_group_name"], + ) + raise RuntimeError(f"Command type {serialized_command['type']} is not supported.") diff --git a/nncf/torch/layer_utils.py b/nncf/torch/layer_utils.py index 85756f13dcc..7a4c7975813 100644 --- a/nncf/torch/layer_utils.py +++ b/nncf/torch/layer_utils.py @@ -9,6 +9,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC +from abc import abstractclassmethod +from abc import abstractmethod +from typing import Any, Dict + import torch from torch import nn @@ -19,6 +24,30 @@ COMPRESSION_MODULES = Registry("compression modules") +class StatefullTorchModuleInterface(ABC): + """ + Interface that should be implemented for every registered compression module to make it possible + to save an compression modules state and create an compression module from the saved state. + State of the module should be json serializable, no python objects except + standart (str, list and etc.) should be present in a compression module state. + The state for attributes with type torch.nn.Parameter + is recovered from the model `state_dict`, so there is no need to keep them in the module state. + + """ + + @abstractmethod + def get_state(self) -> Dict[str, Any]: + """ + Returns the compression module state. + """ + + @abstractclassmethod + def from_state(cls, state: Dict[str, Any]) -> object: + """ + Creates a compression module instance from the given state. + """ + + class ProxyModule: def __init__(self, module): self._module = module @@ -117,7 +146,12 @@ def __init__(self, data: torch.Tensor = None, requires_grad: bool = True, compre """ super().__init__() + self._compression_lr_multiplier = compression_lr_multiplier if compression_lr_multiplier is not None and self.dtype.is_floating_point: self.requires_grad = True self.register_hook(lambda grad: compression_lr_multiplier * grad) self.requires_grad = requires_grad + + @property + def compression_lr_multiplier(self): + return self._compression_lr_multiplier diff --git a/nncf/torch/model_transformer.py b/nncf/torch/model_transformer.py index 6bdb95fbe38..b8df3510f55 100644 --- a/nncf/torch/model_transformer.py +++ b/nncf/torch/model_transformer.py @@ -61,6 +61,7 @@ def __init__(self, model: NNCFNetwork): ] def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwork: + # self._model.nncf.record_commands(transformation_layout.transformations) transformations = transformation_layout.transformations aggregated_transformations = defaultdict(list) requires_graph_rebuild = False diff --git a/nncf/torch/pruning/filter_pruning/layers.py b/nncf/torch/pruning/filter_pruning/layers.py index 2d95a9d982f..872ce31ca24 100644 --- a/nncf/torch/pruning/filter_pruning/layers.py +++ b/nncf/torch/pruning/filter_pruning/layers.py @@ -8,6 +8,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import Any, Dict + import numpy as np import torch from torch import nn @@ -15,15 +18,20 @@ import nncf from nncf.common.graph import NNCFNodeName from nncf.torch.layer_utils import COMPRESSION_MODULES +from nncf.torch.layer_utils import StatefullTorchModuleInterface @COMPRESSION_MODULES.register() -class FilterPruningMask(nn.Module): +class FilterPruningMask(nn.Module, StatefullTorchModuleInterface): """ A module contains the mask for pruning. On forward pass applying the mask to weight and bias of the module. """ + MASK_APPLYING_DIM_KEY = "dim" + NODE_NAME_KEY = "node_name" + SIZE_KEY = "size_key" + def __init__(self, size, node_name, dim=0): super().__init__() self.register_buffer("_binary_filter_pruning_mask", torch.ones(size)) @@ -31,11 +39,11 @@ def __init__(self, size, node_name, dim=0): self.node_name = node_name @property - def binary_filter_pruning_mask(self): + def binary_filter_pruning_mask(self) -> torch.Tensor: return self._binary_filter_pruning_mask @binary_filter_pruning_mask.setter - def binary_filter_pruning_mask(self, mask): + def binary_filter_pruning_mask(self, mask: torch.Tensor): with torch.no_grad(): self._binary_filter_pruning_mask.set_(mask) @@ -56,6 +64,19 @@ def forward(self, **params): ) return new_params + def get_state(self) -> Dict[str, Any]: + return { + self.MASK_APPLYING_DIM_KEY: self.mask_applying_dim, + self.NODE_NAME_KEY: self.node_name, + self.SIZE_KEY: list(self.binary_filter_pruning_mask.size()), + } + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "FilterPruningMask": + return FilterPruningMask( + size=state[cls.SIZE_KEY], node_name=state[cls.NODE_NAME_KEY], dim=state[cls.MASK_APPLYING_DIM_KEY] + ) + def broadcast_filter_mask(filter_mask, shape, dim=0): broadcasted_shape = np.ones(len(shape), dtype=np.int64) diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index cb8906b1ff0..92ccf89d010 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -41,6 +41,7 @@ from nncf.torch.graph.transformations.commands import TargetType from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import CompressionParameter +from nncf.torch.layer_utils import StatefullTorchModuleInterface from nncf.torch.quantization.quantize_functions import ExportQuantizeToFakeQuantize from nncf.torch.quantization.quantize_functions import ExportQuantizeToONNXQuantDequant from nncf.torch.quantization.quantize_functions import TuneRange @@ -283,9 +284,10 @@ def add_quantization_point(self, qp_id: QuantizationPointId, qp: PTQuantizationP self.quantization_points[qp_id] = qp -class BaseQuantizer(nn.Module, ABC): +class BaseQuantizer(nn.Module, StatefullTorchModuleInterface, ABC): def __init__(self, qspec: PTQuantizerSpec): super().__init__() + self._qspec = qspec self._narrow_range = qspec.narrow_range self._signedness_to_force = qspec.signedness_to_force self._is_using_log_scale_storage = qspec.logarithm_scale @@ -563,6 +565,14 @@ def get_parameters_for_torch_fq(self) -> Tuple[int, int, torch.Tensor, torch.Ten zero_point - Quantizer zero point. """ + def get_state(self): + return self._qspec.get_state() + + @classmethod + def from_state(cls, state) -> "BaseQuantizer": + qsetup = PTQuantizerSpec.from_state(state) + return cls(qsetup) + class QuantizersSwitcher: """Enables/disables quantizers with saving and restoring original state""" diff --git a/nncf/torch/sparsity/layers.py b/nncf/torch/sparsity/layers.py index 5a506b87a10..f6c695895dc 100644 --- a/nncf/torch/sparsity/layers.py +++ b/nncf/torch/sparsity/layers.py @@ -8,29 +8,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Any, Dict, List import torch from torch import nn from nncf.torch.layer_utils import COMPRESSION_MODULES +from nncf.torch.layer_utils import StatefullTorchModuleInterface from nncf.torch.sparsity.functions import apply_binary_mask as apply_binary_mask_impl from nncf.torch.utils import is_tracing_state @COMPRESSION_MODULES.register() -class BinaryMask(nn.Module): +class BinaryMask(nn.Module, StatefullTorchModuleInterface): + SHAPE_KEY = "shape" + def __init__(self, shape: List[int]): super().__init__() self.register_buffer("_binary_mask", torch.ones(shape)) self.frozen = False @property - def binary_mask(self): + def binary_mask(self) -> torch.Tensor: return self._binary_mask @binary_mask.setter - def binary_mask(self, tensor): + def binary_mask(self, tensor: torch.Tensor): with torch.no_grad(): self._binary_mask.set_(tensor) @@ -45,3 +48,10 @@ def _calc_training_binary_mask(self, weight): def apply_binary_mask(self, weight): return apply_binary_mask_impl(self.binary_mask, weight) + + def get_state(self) -> Dict[str, Any]: + return {self.SHAPE_KEY: list(self.binary_mask.shape)} + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "BinaryMask": + return BinaryMask(state[cls.SHAPE_KEY]) diff --git a/nncf/torch/sparsity/rb/layers.py b/nncf/torch/sparsity/rb/layers.py index c1df48ad563..39b1013e5ea 100644 --- a/nncf/torch/sparsity/rb/layers.py +++ b/nncf/torch/sparsity/rb/layers.py @@ -8,20 +8,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Any, Dict, List import torch from nncf.torch.functions import logit from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import CompressionParameter +from nncf.torch.layer_utils import StatefullTorchModuleInterface from nncf.torch.sparsity.layers import BinaryMask from nncf.torch.sparsity.rb.functions import binary_mask from nncf.torch.sparsity.rb.functions import calc_rb_binary_mask @COMPRESSION_MODULES.register() -class RBSparsifyingWeight(BinaryMask): +class RBSparsifyingWeight(BinaryMask, StatefullTorchModuleInterface): + WEIGHTS_SHAPE_KEY = "weight_shape" + FROZEN_KEY = "frozen" + COMPRESSION_LR_MULTIPLIER_KEY = "compression_lr_multiplier" + EPS_KEY = "eps" + def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multiplier=None, eps=1e-6): super().__init__(weight_shape) self.frozen = frozen @@ -36,11 +42,11 @@ def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multipli self.mask_calculation_hook = MaskCalculationHook(self) @property - def mask(self): + def mask(self) -> torch.nn.Parameter: return self._mask @mask.setter - def mask(self, tensor): + def mask(self, tensor: torch.Tensor): self._mask.data = tensor self.binary_mask = binary_mask(self._mask) @@ -51,6 +57,23 @@ def _calc_training_binary_mask(self, weight): def loss(self): return binary_mask(self._mask) + def get_state(self) -> Dict[str, Any]: + return { + self.WEIGHTS_SHAPE_KEY: list(self.mask.shape), + self.FROZEN_KEY: self.frozen, + self.COMPRESSION_LR_MULTIPLIER_KEY: self.mask.compression_lr_multiplier, + self.EPS_KEY: self.eps, + } + + @classmethod + def from_state(cls, state: Dict[str, Any]) -> "RBSparsifyingWeight": + return RBSparsifyingWeight( + weight_shape=state[cls.WEIGHTS_SHAPE_KEY], + frozen=state[cls.FROZEN_KEY], + compression_lr_multiplier=state[cls.COMPRESSION_LR_MULTIPLIER_KEY], + eps=state[cls.EPS_KEY], + ) + class MaskCalculationHook: def __init__(self, module): diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index bf7e33cae13..d23e76c1ac6 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -215,24 +215,24 @@ def nz_bias_num(self): class TwoSharedConvTestModel(nn.Module): INPUT_SHAPE = [1, 1, 4, 4] NNCF_CONV_NODES_NAMES = [ - "TwoSharedConvTestModel/NNCFConv2d[conv1]/conv2d_0", - "TwoSharedConvTestModel/NNCFConv2d[conv2]/conv2d_0", + "TwoSharedConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0", + "TwoSharedConvTestModel/Sequential[features]/Sequential[1]/NNCFConv2d[0]/conv2d_0", ] CONV_NODES_NAMES = [ - "TwoSharedConvTestModel/Conv2d[conv1]/conv2d_0", - "TwoSharedConvTestModel/Conv2d[conv2]/conv2d_0", + "TwoSharedConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0", + "TwoSharedConvTestModel/Sequential[features]/Sequential[1]/Conv2d[0]/conv2d_0", ] def __init__(self): super().__init__() self.features = [] - self.conv1 = create_conv(1, 1, 1, -1, -2) - self.conv2 = create_conv(1, 1, 1, 0, 0) + self.features.append(nn.Sequential(create_conv(1, 1, 1, -1, -2))) + self.features.append(nn.Sequential(create_conv(1, 1, 1, 0, 0))) + self.features = nn.Sequential(*self.features) def forward(self, x): for _ in range(2): - x = self.conv1(x) - x = self.conv2(x) + x = self.features(x) return x diff --git a/tests/torch/nncf_network/helpers.py b/tests/torch/nncf_network/helpers.py index 06805cd59b7..719166ffa9f 100644 --- a/tests/torch/nncf_network/helpers.py +++ b/tests/torch/nncf_network/helpers.py @@ -54,7 +54,9 @@ class InsertionCommandBuilder: Contains methods which allows to build all possible commands for the given torch.nn.Module. Target module should have NNCF_CONV_NODES_NAMES and CONV_NODES_NAMES with names of - target model convolutions and names of nncf-wrapped target model convolutions + target model convolutions and names of nncf-wrapped target model convolutions. + Convolutions should be placed inside nn.sequential in .features attribute + for test compatibility. """ AVAILABLE_MODELS = (TwoConvTestModel, TwoSharedConvTestModel) @@ -162,7 +164,7 @@ def get_all_available_commands( command_type, target_type ): continue - command = self._create_command( + command = self.create_one_command( command_builder, target_type, priority, @@ -185,7 +187,7 @@ def is_unsupported_by_transformer_command(command_type: PTTransformationCommand, ] @staticmethod - def _create_command( + def create_one_command( command_builder, target_type, priority, diff --git a/tests/torch/test_serialization.py b/tests/torch/test_serialization.py new file mode 100644 index 00000000000..577b5194bb6 --- /dev/null +++ b/tests/torch/test_serialization.py @@ -0,0 +1,308 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from copy import deepcopy + +import pytest +import torch + +from nncf.common.factory import ModelTransformerFactory +from nncf.common.quantization.structs import QuantizationScheme +from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply +from nncf.torch import wrap_model +from nncf.torch.graph.transformations.serialization import load_command +from nncf.torch.graph.transformations.serialization import load_transformations +from nncf.torch.graph.transformations.serialization import serialize_command +from nncf.torch.graph.transformations.serialization import serialize_transformations +from nncf.torch.module_operations import UpdateWeight +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.pruning.filter_pruning.layers import FilterPruningMask +from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import SymmetricQuantizer +from nncf.torch.sparsity.layers import BinaryMask +from nncf.torch.sparsity.rb.layers import RBSparsifyingWeight +from tests.torch.helpers import DummyOpWithState +from tests.torch.helpers import TwoConvTestModel +from tests.torch.helpers import commands_are_equal +from tests.torch.nncf_network.helpers import AVAILABLE_TARGET_TYPES +from tests.torch.nncf_network.helpers import InsertionCommandBuilder + + +@pytest.mark.parametrize("target_type", AVAILABLE_TARGET_TYPES) +@pytest.mark.parametrize("command_builder", InsertionCommandBuilder(TwoConvTestModel).get_command_builders()) +@pytest.mark.parametrize("priority", InsertionCommandBuilder.PRIORITIES) +def test_serialize_load_command(target_type, command_builder, priority): + dummy_op_state = "DUMMY_OP_STATE" + op_unique_name = "UNIQUE_NAME" + # The only difference for trace_parameters param in this test is taget nodes names + command = InsertionCommandBuilder(TwoConvTestModel).create_one_command( + command_builder[0], target_type, priority, dummy_op_state, trace_parameters=False, op_unique_name=op_unique_name + ) + + serialized_command = serialize_command(command) + + # Check serialized transformation are json compatible + j_str = json.dumps(serialized_command) + serialized_command = json.loads(j_str) + + recovered_command = load_command(serialized_command) + _check_commands_after_serialization(command, recovered_command, dummy_op_state) + + +def test_serialize_transformations(): + dummy_op_state = "DUMMY_OP_STATE" + # The only difference for trace_parameters param in this test is taget nodes names + layout = InsertionCommandBuilder(TwoConvTestModel).get_all_available_commands( + dummy_op_state=dummy_op_state, trace_parameters=False + ) + + serialized_transformations = serialize_transformations(layout) + + # Check serialized transformation are json compatible + j_str = json.dumps(serialized_transformations) + serialized_transformations = json.loads(j_str) + + recovered_layout = load_transformations(serialized_transformations) + assert len(layout.transformations) == len(recovered_layout.transformations) + # Can zip layouts because the order should not be altered + for command, recovered_command in zip(layout.transformations, recovered_layout.transformations): + _check_commands_after_serialization(command, recovered_command, dummy_op_state) + + +def load_from_config_impl(model: torch.nn.Module, serialized_transformations, example_input, trace_parameters): + transformations_layout = load_transformations(serialized_transformations) + + nncf_network = wrap_model(deepcopy(model), example_input=example_input, trace_parameters=trace_parameters) + transformed_model = ModelTransformerFactory.create(nncf_network).transform(transformations_layout) + + transformed_model.nncf.disable_dynamic_graph_building() + return transformed_model + + +def nncf_get_config_impl( + model: NNCFNetwork, +): + layout = model.nncf.transformation_layout() + return serialize_transformations(layout) + + +@pytest.mark.parametrize("model_cls", InsertionCommandBuilder.AVAILABLE_MODELS) +@pytest.mark.parametrize("trace_parameters", (False, True)) +def test_get_apply_serialization_from_a_model(model_cls, trace_parameters): + dummy_op_state = "DUMMY_OP_STATE" + layout = InsertionCommandBuilder(model_cls).get_all_available_commands( + dummy_op_state, trace_parameters, skip_model_transformer_unsupported=True + ) + model = model_cls() + example_input = torch.ones((1, 1, 4, 4)) + nncf_model = wrap_model(deepcopy(model), example_input=example_input, trace_parameters=trace_parameters) + modified_model = ModelTransformerFactory.create(nncf_model).transform(layout) + + serialized_transformations = nncf_get_config_impl(modified_model) + + # Check serialized transformation are json compatible + j_str = json.dumps(serialized_transformations) + serialized_transformations = json.loads(j_str) + + recovered_model = load_from_config_impl(model, serialized_transformations, example_input, trace_parameters) + + if not trace_parameters: + _check_pre_post_ops(modified_model, recovered_model) + + context = modified_model.nncf._compressed_context + recovered_context = recovered_model.nncf._compressed_context + for hooks_attr in ["_pre_hooks", "_post_hooks"]: + container = getattr(context, hooks_attr) + recovered_container = getattr(recovered_context, hooks_attr) + assert len(container) == len(recovered_container) + for op_address, hooks in container.items(): + recovered_hooks = recovered_container[op_address] + for k, hook in hooks.items(): + recovered_hook = recovered_hooks[k] + _check_hook_are_equal(hook, recovered_hook) + + for attr_name in ["external_quantizers", "external_op"]: + container = getattr(modified_model.nncf, attr_name) + recovered_container = getattr(recovered_model.nncf, attr_name) + assert len(container) == len(recovered_container) + for k, module in container.items(): + recovered_module = recovered_container[k] + _check_hook_are_equal(module, recovered_module) + + +def _check_pre_post_ops(modified_model, recovered_model): + for conv, recovered_conv in zip(modified_model.features, recovered_model.features): + for hooks_attr in ["pre_ops", "post_ops"]: + hooks = getattr(conv[0], hooks_attr) + recovered_hooks = getattr(recovered_conv[0], hooks_attr) + assert len(hooks) == len(recovered_hooks) + for k, hook in hooks.items(): + recovered_hook = recovered_hooks[k] + if isinstance(hook, UpdateWeight): + assert isinstance(recovered_hook, UpdateWeight) + hook = hook.op + recovered_hook = recovered_hook.op + _check_hook_are_equal(hook, recovered_hook) + + +def _check_hook_are_equal(hook, recovered_hook): + assert type(hook) == type(recovered_hook) + if isinstance(hook, DummyOpWithState): + assert hook.get_state() == recovered_hook.get_state() + return + # Hook is external op call hook then + assert hook._storage_name == recovered_hook._storage_name + assert hook._storage_key == recovered_hook._storage_key + + +def _check_commands_after_serialization(command, recovered_command, dummy_op_state=None): + commands_are_equal(recovered_command, command, check_fn_ref=False) + assert isinstance(command.fn, DummyOpWithState) + assert command.fn.get_state() == recovered_command.fn.get_state() + if dummy_op_state is not None: + assert command.fn.get_state() == dummy_op_state + + +@pytest.mark.parametrize("size", (4, [3, 4])) +def test_pruning_mask_serialization(size): + node_name = "dummy_node_name" + dim = 2 + mask = FilterPruningMask(size=size, node_name=node_name, dim=dim) + mask.binary_filter_pruning_mask = torch.fill(torch.empty(size), 5) + state_dict = mask.state_dict() + + state = mask.get_state() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_mask = FilterPruningMask.from_state(state) + recovered_mask.load_state_dict(state_dict) + + ref_size = size if isinstance(size, list) else [size] + assert list(recovered_mask.binary_filter_pruning_mask.size()) == ref_size + assert recovered_mask.node_name == node_name + assert recovered_mask.mask_applying_dim == dim + + assert torch.all(mask.binary_filter_pruning_mask == recovered_mask.binary_filter_pruning_mask) + + +@pytest.mark.parametrize("quantizer_class", (SymmetricQuantizer, AsymmetricQuantizer)) +def test_quantizer_serialization(quantizer_class: BaseQuantizer): + scale_shape = [1, 3, 1, 1] + ref_qspec = PTQuantizerSpec( + num_bits=4, + mode=QuantizationScheme.ASYMMETRIC, + signedness_to_force=False, + narrow_range=True, + half_range=False, + scale_shape=scale_shape, + logarithm_scale=False, + is_quantized_on_export=False, + compression_lr_multiplier=2.0, + ) + quantizer = quantizer_class(ref_qspec) + if isinstance(quantizer, SymmetricQuantizer): + quantizer.scale = torch.nn.Parameter(torch.fill(torch.empty(scale_shape), 5)) + elif isinstance(quantizer, AsymmetricQuantizer): + quantizer.input_low = torch.nn.Parameter(torch.fill(torch.empty(scale_shape), 6)) + quantizer.input_range = torch.nn.Parameter(torch.fill(torch.empty(scale_shape), 7)) + + state_dict = quantizer.state_dict() + + state = quantizer.get_state() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_quantizer = quantizer_class.from_state(state) + recovered_quantizer.load_state_dict(state_dict) + + assert recovered_quantizer._qspec == ref_qspec + + assert torch.all(quantizer._num_bits == recovered_quantizer._num_bits) + assert torch.all(quantizer.enabled == recovered_quantizer.enabled) + if isinstance(quantizer, SymmetricQuantizer): + assert torch.all(quantizer.signed_tensor == recovered_quantizer.signed_tensor) + assert torch.all(quantizer.scale == recovered_quantizer.scale) + elif isinstance(quantizer, AsymmetricQuantizer): + assert torch.all(quantizer.input_low == recovered_quantizer.input_low) + assert torch.all(quantizer.input_range == recovered_quantizer.input_range) + else: + raise RuntimeError() + + +def test_sparsity_binary_mask_serialization(): + ref_shape = [4, 2, 1, 3] + mask = BinaryMask(ref_shape) + mask.binary_mask = torch.zeros(ref_shape) + state_dict = mask.state_dict() + + state = mask.get_state() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_mask = BinaryMask.from_state(state) + recovered_mask.load_state_dict(state_dict) + + assert list(recovered_mask.binary_mask.shape) == ref_shape + assert torch.all(mask.binary_mask == recovered_mask.binary_mask) + + +def test_rb_sparsity_mask_serialization(): + ref_weights_shape = [3, 2, 4, 1] + ref_frozen = False + ref_compression_lr_multiplier = 2.0 + ref_eps = 0.3 + mask = RBSparsifyingWeight( + weight_shape=ref_weights_shape, + frozen=ref_frozen, + compression_lr_multiplier=ref_compression_lr_multiplier, + eps=ref_eps, + ) + mask.binary_mask = torch.zeros(ref_weights_shape) + mask.mask = torch.fill(torch.empty(ref_weights_shape), 5) + state_dict = mask.state_dict() + + state = mask.get_state() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_mask = RBSparsifyingWeight.from_state(state) + recovered_mask.load_state_dict(state_dict) + + assert list(recovered_mask.mask.shape) == ref_weights_shape + assert recovered_mask.frozen == ref_frozen + assert recovered_mask.mask.compression_lr_multiplier == ref_compression_lr_multiplier + assert recovered_mask.eps == ref_eps + + assert torch.all(mask.mask == recovered_mask.mask) + assert torch.all(mask.binary_mask == recovered_mask.binary_mask) + assert torch.all(mask.uniform == recovered_mask.uniform) + + +def test_sq_multiply_serialization(): + tensor_shape = [1, 3, 5] + tensor_value = torch.fill(torch.empty(tensor_shape, dtype=torch.float16), 5) + sq_multiply = SQMultiply(tensor_shape) + sq_multiply.scale = tensor_value + state_dict = sq_multiply.state_dict() + + state = sq_multiply.get_state() + json_state = json.dumps(state) + state = json.loads(json_state) + + recovered_sq_multiply = SQMultiply.from_state(state) + recovered_sq_multiply.load_state_dict(state_dict) + + assert torch.all(sq_multiply.scale == recovered_sq_multiply.scale) + assert sq_multiply.scale.shape == recovered_sq_multiply.scale.shape From 9e5197d3f20854e6f32fc297d205890f10d408a0 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 2 Apr 2024 11:37:17 +0200 Subject: [PATCH 2/6] Tests adjusted --- ...TwoConvTestModel_overflow_fix_disable.json | 30 ++++++++++++++----- .../TwoConvTestModel_overflow_fix_enable.json | 30 ++++++++++++++----- ...stModel_overflow_fix_first_layer_only.json | 30 ++++++++++++++----- tests/torch/ptq/helpers.py | 13 -------- tests/torch/ptq/test_graphs.py | 10 +++++-- 5 files changed, 76 insertions(+), 37 deletions(-) diff --git a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json index c1dc5e6731a..6be76d49e47 100644 --- a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json +++ b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json @@ -12,14 +12,14 @@ [ [ [ - -1.0 + -1 ] ] ], [ [ [ - -1.0 + -1 ] ] ] @@ -28,14 +28,14 @@ [ [ [ - 1.0 + 1 ] ] ], [ [ [ - 1.0 + 1 ] ] ] @@ -46,7 +46,7 @@ [ [ [ - -1.0 + -1 ] ] ] @@ -55,10 +55,26 @@ [ [ [ - 1.0 + 1 ] ] ] ] + }, + "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { + "input_low": [ + 0 + ], + "input_high": [ + 0.9800970554351807 + ] + }, + "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0|OUTPUT]": { + "input_low": [ + -3.8243322372436523 + ], + "input_high": [ + 3.794454574584961 + ] } -} \ No newline at end of file +} diff --git a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json index c679824f317..e008c1adb9d 100644 --- a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json +++ b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json @@ -12,14 +12,14 @@ [ [ [ - -2.0 + -2 ] ] ], [ [ [ - -2.0 + -2 ] ] ] @@ -28,14 +28,14 @@ [ [ [ - 2.0 + 2 ] ] ], [ [ [ - 2.0 + 2 ] ] ] @@ -46,7 +46,7 @@ [ [ [ - -2.0 + -2 ] ] ] @@ -55,10 +55,26 @@ [ [ [ - 2.0 + 2 ] ] ] ] + }, + "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { + "input_low": [ + 0 + ], + "input_high": [ + 0.9800970554351807 + ] + }, + "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0|OUTPUT]": { + "input_low": [ + -3.8243322372436523 + ], + "input_high": [ + 3.794454574584961 + ] } -} \ No newline at end of file +} diff --git a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json index 708715926b7..4c20675fe0b 100644 --- a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json +++ b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json @@ -12,14 +12,14 @@ [ [ [ - -2.0 + -2 ] ] ], [ [ [ - -2.0 + -2 ] ] ] @@ -28,14 +28,14 @@ [ [ [ - 2.0 + 2 ] ] ], [ [ [ - 2.0 + 2 ] ] ] @@ -46,7 +46,7 @@ [ [ [ - -1.0 + -1 ] ] ] @@ -55,10 +55,26 @@ [ [ [ - 1.0 + 1 ] ] ] ] + }, + "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { + "input_low": [ + 0 + ], + "input_high": [ + 0.9800970554351807 + ] + }, + "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0|OUTPUT]": { + "input_low": [ + -3.8243322372436523 + ], + "input_high": [ + 3.794454574584961 + ] } -} \ No newline at end of file +} diff --git a/tests/torch/ptq/helpers.py b/tests/torch/ptq/helpers.py index 7dd88540104..4047f892e21 100644 --- a/tests/torch/ptq/helpers.py +++ b/tests/torch/ptq/helpers.py @@ -20,7 +20,6 @@ from nncf.torch.graph.operator_metatypes import PTModuleDepthwiseConv2dSubtype from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype from nncf.torch.graph.operator_metatypes import PTSumMetatype -from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic from tests.post_training.test_templates.models import NNCFGraphToTest from tests.post_training.test_templates.models import NNCFGraphToTestDepthwiseConv from tests.post_training.test_templates.models import NNCFGraphToTestSumAggregation @@ -81,15 +80,3 @@ def get_nncf_network(model: torch.nn.Module, input_shape: Optional[List[int]] = model = model.eval() device = next(model.named_parameters())[1].device return wrap_model(model, torch.ones(input_shape).to(device=device), trace_parameters=True) - - -def mock_collect_statistics(mocker): - _ = mocker.patch( - "nncf.common.tensor_statistics.aggregator.StatisticsAggregator.collect_statistics", return_value=None - ) - min_, max_ = 0.0, 1.0 - min_, max_ = torch.tensor(min_), torch.tensor(max_) - _ = mocker.patch( - "nncf.experimental.common.tensor_statistics.collectors.TensorCollector.get_statistics", - return_value=PTMinMaxTensorStatistic(min_values=min_, max_values=max_), - ) diff --git a/tests/torch/ptq/test_graphs.py b/tests/torch/ptq/test_graphs.py index ee59703d3ed..93281435104 100644 --- a/tests/torch/ptq/test_graphs.py +++ b/tests/torch/ptq/test_graphs.py @@ -24,7 +24,6 @@ from tests.post_training.test_templates.helpers import EmbeddingModel from tests.post_training.test_templates.helpers import get_static_dataset from tests.torch import test_models -from tests.torch.ptq.helpers import mock_collect_statistics from tests.torch.quantization.test_algo_quantization import SharedLayersModel from tests.torch.test_compressed_graph import ModelDesc from tests.torch.test_compressed_graph import check_graph @@ -95,15 +94,20 @@ def get_model_name(description): ("desc", "quantization_parameters"), TEST_MODELS_DESC, ids=[get_model_name(m) for m in TEST_MODELS_DESC] ) def test_min_max_classification_quantized_graphs(desc: ModelDesc, quantization_parameters, graph_dir, mocker): - mock_collect_statistics(mocker) model = desc.model_builder() nncf_network = wrap_model(model, torch.ones(desc.input_sample_sizes), trace_parameters=True) quantization_parameters["advanced_parameters"] = AdvancedQuantizationParameters(disable_bias_correction=True) + quantization_parameters["subset_size"] = 1 quantization_algorithm = PostTrainingQuantization(**quantization_parameters) + def transform_fn(input_) -> torch.Tensor: + return torch.tensor(input_[0]) + quantized_model = quantization_algorithm.apply( - nncf_network, nncf_network.nncf.get_graph(), dataset=get_static_dataset(desc.input_sample_sizes, None, None) + nncf_network, + nncf_network.nncf.get_graph(), + dataset=get_static_dataset(desc.input_sample_sizes, transform_fn, None), ) check_graph(quantized_model.nncf.get_graph(), desc.dot_filename(), graph_dir) From fa077fa717c8e96915e751c12dad3387e63ab2ff Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 25 Apr 2024 13:26:29 +0200 Subject: [PATCH 3/6] Cleanup --- nncf/torch/dynamic_graph/io_handling.py | 20 ------ .../graph/transformations/serialization.py | 70 ++++++++++++------- nncf/torch/model_transformer.py | 1 - ...TwoConvTestModel_overflow_fix_disable.json | 46 +++++------- .../TwoConvTestModel_overflow_fix_enable.json | 46 +++++------- ...stModel_overflow_fix_first_layer_only.json | 46 +++++------- ...tions.py => test_transformation_layout.py} | 0 tests/torch/test_serialization.py | 65 +++++++++++------ 8 files changed, 146 insertions(+), 148 deletions(-) rename tests/torch/nncf_network/{test_get_applied_modifications.py => test_transformation_layout.py} (100%) diff --git a/nncf/torch/dynamic_graph/io_handling.py b/nncf/torch/dynamic_graph/io_handling.py index 2884033e6fb..c0f3aeb92e5 100644 --- a/nncf/torch/dynamic_graph/io_handling.py +++ b/nncf/torch/dynamic_graph/io_handling.py @@ -122,7 +122,6 @@ def __init__(self, shape: List[int], type_str: str = "float", keyword: str = Non """ self.shape = shape self.type = self._string_to_torch_type(type_str) - self._type_str = type_str self.keyword = keyword if filler is None: self.filler = self.FILLER_TYPE_ONES @@ -158,15 +157,6 @@ def get_tensor_for_input(self) -> torch.Tensor: return torch.rand(size=self.shape, dtype=self.type) raise NotImplementedError - def get_state(self) -> Dict[str, Any]: - return {"shape": self.shape, "type_str": self._type_str, "keyword": self.keyword, "filler": self.filler} - - @classmethod - def from_state(cls, state: Dict[str, Any]) -> "FillerInputElement": - return FillerInputElement( - shape=state["shape"], type_str=state["type_str"], keyword=state["keyword"], filler=state["filler"] - ) - class FillerInputInfo(ModelInputInfo): """ @@ -230,16 +220,6 @@ def get_forward_inputs( kwargs[fe.keyword] = tensor return tuple(args_list), kwargs - def get_state(self) -> Dict[str, Any]: - return {"elements": [elem.get_state() for elem in self.elements]} - - @classmethod - def from_state(cls, state) -> "FillerInputInfo": - return FillerInputInfo([FillerInputElement.from_state(s) for s in state["elements"]]) - - def __eq__(self, other: "FillerInputInfo") -> bool: - return self.elements == other.elements - class ExactInputsInfo(ModelInputInfo): """ diff --git a/nncf/torch/graph/transformations/serialization.py b/nncf/torch/graph/transformations/serialization.py index 616a4c25e52..27f1ac7286c 100644 --- a/nncf/torch/graph/transformations/serialization.py +++ b/nncf/torch/graph/transformations/serialization.py @@ -22,14 +22,16 @@ from nncf.torch.layer_utils import COMPRESSION_MODULES COMPRESSION_STATE_ATTR = "compression_state" - - -class CompressionKeys(Enum): - SHARED_INSERTION_COMMAND = "SHARED_INSERTION_COMMAND" - INSERTION_COMMAND = "INSERTION_COMMAND" +SUPPORTED_COMMANDS = (PTSharedFnInsertionCommand, PTInsertionCommand) def serialize_transformations(transformations_layout: TransformationLayout) -> Dict[str, Any]: + """ + Serializes given transformation layout to a dict. + + :param tranformation_layout: Given transformation layout. + :return: Serialized representation of given transformation layout as a dict. + """ transformation_commands = [] for command in transformations_layout.transformations: serialized_command = serialize_command(command) @@ -39,28 +41,39 @@ def serialize_transformations(transformations_layout: TransformationLayout) -> D return {COMPRESSION_STATE_ATTR: transformation_commands} -def load_transformations(transformations_state: Dict[str, Any]) -> TransformationLayout: +def deserialize_transformations(serialized_transformation_layout: Dict[str, Any]) -> TransformationLayout: + """ + Deserializes given serialized transformation layout. + + :param serialized_transformation_layout: Given serialized transformation layout. + :return: The deserialized transformation layout. + """ transformation_layout = TransformationLayout() - for serialized_command in transformations_state[COMPRESSION_STATE_ATTR]: - command = load_command(serialized_command) + for serialized_command in serialized_transformation_layout[COMPRESSION_STATE_ATTR]: + command = deserialize_command(serialized_command) transformation_layout.register(command) return transformation_layout def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]: - if not isinstance(command, (PTSharedFnInsertionCommand, PTInsertionCommand)): - return {} + """ + Serializes given command layout to a dict. + + :param command: Given command. + :return: Serialized representation of given command as a dict. + """ + if not isinstance(command, SUPPORTED_COMMANDS): + raise RuntimeError(f"Command type {command.__class__} is not supported.") serialized_transformation = dict() + serialized_transformation["type"] = command.__class__.__name__ if isinstance(command, PTSharedFnInsertionCommand): - serialized_transformation["type"] = CompressionKeys.SHARED_INSERTION_COMMAND.value serialized_transformation["target_points"] = [point.get_state() for point in command.target_points] serialized_transformation["op_name"] = command.op_name serialized_transformation["compression_module_type"] = command.compression_module_type.value elif isinstance(command, PTInsertionCommand): - serialized_transformation["type"] = CompressionKeys.INSERTION_COMMAND.value serialized_transformation["target_point"] = command.target_point.get_state() # Check compression module is registered @@ -78,27 +91,34 @@ def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]: return serialized_transformation -def load_command(serialized_command: Dict[str, Any]) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]: +def deserialize_command(serialized_command: Dict[str, Any]) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]: + """ + Deserializes given serialized command. + + :param serialized_command: Given serialized command. + :return: The deserialized command. + """ + if serialized_command["type"] not in (command_cls.__name__ for command_cls in SUPPORTED_COMMANDS): + raise RuntimeError(f"Command type {serialized_command['type']} is not supported.") + module_cls = COMPRESSION_MODULES.get(serialized_command["compression_module_name"]) fn = module_cls.from_state(serialized_command["fn_state"]) priority = serialized_command["priority"] if priority in iter(TransformationPriority): priority = TransformationPriority(priority) - if serialized_command["type"] == CompressionKeys.INSERTION_COMMAND.value: + if serialized_command["type"] == PTInsertionCommand.__name__: target_point = PTTargetPoint.from_state(serialized_command["target_point"]) return PTInsertionCommand( point=target_point, fn=fn, priority=priority, hooks_group_name=serialized_command["hooks_group_name"] ) - if serialized_command["type"] == CompressionKeys.SHARED_INSERTION_COMMAND.value: - target_points = [PTTargetPoint.from_state(state) for state in serialized_command["target_points"]] - return PTSharedFnInsertionCommand( - target_points=target_points, - fn=fn, - op_unique_name=serialized_command["op_name"], - compression_module_type=ExtraCompressionModuleType(serialized_command["compression_module_type"]), - priority=priority, - hooks_group_name=serialized_command["hooks_group_name"], - ) - raise RuntimeError(f"Command type {serialized_command['type']} is not supported.") + target_points = [PTTargetPoint.from_state(state) for state in serialized_command["target_points"]] + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=fn, + op_unique_name=serialized_command["op_name"], + compression_module_type=ExtraCompressionModuleType(serialized_command["compression_module_type"]), + priority=priority, + hooks_group_name=serialized_command["hooks_group_name"], + ) diff --git a/nncf/torch/model_transformer.py b/nncf/torch/model_transformer.py index b8df3510f55..6bdb95fbe38 100644 --- a/nncf/torch/model_transformer.py +++ b/nncf/torch/model_transformer.py @@ -61,7 +61,6 @@ def __init__(self, model: NNCFNetwork): ] def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwork: - # self._model.nncf.record_commands(transformation_layout.transformations) transformations = transformation_layout.transformations aggregated_transformations = defaultdict(list) requires_graph_rebuild = False diff --git a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json index 6be76d49e47..d462913193c 100644 --- a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json +++ b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_disable.json @@ -1,25 +1,33 @@ { "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { - "input_low": 0.0, - "input_high": 0.9970665574073792 + "input_low": [ + 0.0 + ], + "input_high": [ + 0.9970665574073792 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|OUTPUT]": { - "input_low": -3.8243322372436523, - "input_high": 3.794454574584961 + "input_low": [ + -3.8243322372436523 + ], + "input_high": [ + 3.794454574584961 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|INPUT1]": { "input_low": [ [ [ [ - -1 + -1.0 ] ] ], [ [ [ - -1 + -1.0 ] ] ] @@ -28,14 +36,14 @@ [ [ [ - 1 + 1.0 ] ] ], [ [ [ - 1 + 1.0 ] ] ] @@ -46,7 +54,7 @@ [ [ [ - -1 + -1.0 ] ] ] @@ -55,26 +63,10 @@ [ [ [ - 1 + 1.0 ] ] ] ] - }, - "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { - "input_low": [ - 0 - ], - "input_high": [ - 0.9800970554351807 - ] - }, - "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0|OUTPUT]": { - "input_low": [ - -3.8243322372436523 - ], - "input_high": [ - 3.794454574584961 - ] } -} +} \ No newline at end of file diff --git a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json index e008c1adb9d..6f60ba19e5b 100644 --- a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json +++ b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_enable.json @@ -1,25 +1,33 @@ { "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { - "input_low": 0.0, - "input_high": 0.9970665574073792 + "input_low": [ + 0.0 + ], + "input_high": [ + 0.9970665574073792 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|OUTPUT]": { - "input_low": -3.8243322372436523, - "input_high": 3.794454574584961 + "input_low": [ + -3.8243322372436523 + ], + "input_high": [ + 3.794454574584961 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|INPUT1]": { "input_low": [ [ [ [ - -2 + -2.0 ] ] ], [ [ [ - -2 + -2.0 ] ] ] @@ -28,14 +36,14 @@ [ [ [ - 2 + 2.0 ] ] ], [ [ [ - 2 + 2.0 ] ] ] @@ -46,7 +54,7 @@ [ [ [ - -2 + -2.0 ] ] ] @@ -55,26 +63,10 @@ [ [ [ - 2 + 2.0 ] ] ] ] - }, - "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { - "input_low": [ - 0 - ], - "input_high": [ - 0.9800970554351807 - ] - }, - "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0|OUTPUT]": { - "input_low": [ - -3.8243322372436523 - ], - "input_high": [ - 3.794454574584961 - ] } -} +} \ No newline at end of file diff --git a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json index 4c20675fe0b..89d8d054e81 100644 --- a/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json +++ b/tests/torch/data/reference_scales/TwoConvTestModel_overflow_fix_first_layer_only.json @@ -1,25 +1,33 @@ { "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { - "input_low": 0.0, - "input_high": 0.9970665574073792 + "input_low": [ + 0.0 + ], + "input_high": [ + 0.9970665574073792 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|OUTPUT]": { - "input_low": -3.8243322372436523, - "input_high": 3.794454574584961 + "input_low": [ + -3.8243322372436523 + ], + "input_high": [ + 3.794454574584961 + ] }, "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/Conv2d[0]/conv2d_0|INPUT1]": { "input_low": [ [ [ [ - -2 + -2.0 ] ] ], [ [ [ - -2 + -2.0 ] ] ] @@ -28,14 +36,14 @@ [ [ [ - 2 + 2.0 ] ] ], [ [ [ - 2 + 2.0 ] ] ] @@ -46,7 +54,7 @@ [ [ [ - -1 + -1.0 ] ] ] @@ -55,26 +63,10 @@ [ [ [ - 1 + 1.0 ] ] ] ] - }, - "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[/nncf_model_input_0|OUTPUT]": { - "input_low": [ - 0 - ], - "input_high": [ - 0.9800970554351807 - ] - }, - "TwoConvTestModel/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[TwoConvTestModel/Sequential[features]/Sequential[0]/NNCFConv2d[0]/conv2d_0|OUTPUT]": { - "input_low": [ - -3.8243322372436523 - ], - "input_high": [ - 3.794454574584961 - ] } -} +} \ No newline at end of file diff --git a/tests/torch/nncf_network/test_get_applied_modifications.py b/tests/torch/nncf_network/test_transformation_layout.py similarity index 100% rename from tests/torch/nncf_network/test_get_applied_modifications.py rename to tests/torch/nncf_network/test_transformation_layout.py diff --git a/tests/torch/test_serialization.py b/tests/torch/test_serialization.py index 577b5194bb6..3b16fb3ed3d 100644 --- a/tests/torch/test_serialization.py +++ b/tests/torch/test_serialization.py @@ -19,8 +19,10 @@ from nncf.common.quantization.structs import QuantizationScheme from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply from nncf.torch import wrap_model -from nncf.torch.graph.transformations.serialization import load_command -from nncf.torch.graph.transformations.serialization import load_transformations +from nncf.torch.graph.transformations.commands import PTTransformationCommand +from nncf.torch.graph.transformations.commands import TransformationType +from nncf.torch.graph.transformations.serialization import deserialize_command +from nncf.torch.graph.transformations.serialization import deserialize_transformations from nncf.torch.graph.transformations.serialization import serialize_command from nncf.torch.graph.transformations.serialization import serialize_transformations from nncf.torch.module_operations import UpdateWeight @@ -39,6 +41,29 @@ from tests.torch.nncf_network.helpers import InsertionCommandBuilder +def load_from_config_impl(model: torch.nn.Module, serialized_transformations, example_input, trace_parameters): + """ + Test implementation of nncf.torch.load_from_config(). Should be replaced by the implementation + """ + transformations_layout = deserialize_transformations(serialized_transformations) + + nncf_network = wrap_model(deepcopy(model), example_input=example_input, trace_parameters=trace_parameters) + transformed_model = ModelTransformerFactory.create(nncf_network).transform(transformations_layout) + + transformed_model.nncf.disable_dynamic_graph_building() + return transformed_model + + +def nncf_get_config_impl( + model: NNCFNetwork, +): + """ + Test implementation of model.nncf.get_config(). Should be replaced by the implementation + """ + layout = model.nncf.transformation_layout() + return serialize_transformations(layout) + + @pytest.mark.parametrize("target_type", AVAILABLE_TARGET_TYPES) @pytest.mark.parametrize("command_builder", InsertionCommandBuilder(TwoConvTestModel).get_command_builders()) @pytest.mark.parametrize("priority", InsertionCommandBuilder.PRIORITIES) @@ -56,10 +81,25 @@ def test_serialize_load_command(target_type, command_builder, priority): j_str = json.dumps(serialized_command) serialized_command = json.loads(j_str) - recovered_command = load_command(serialized_command) + recovered_command = deserialize_command(serialized_command) _check_commands_after_serialization(command, recovered_command, dummy_op_state) +def test_non_supported_command_serialization(): + class NonSupportedCommand(PTTransformationCommand): + def __init__(self): + super().__init__(TransformationType.INSERT, None) + + command = NonSupportedCommand() + + with pytest.raises(RuntimeError): + serialize_command(command) + + serialized_command = {"type": NonSupportedCommand.__name__} + with pytest.raises(RuntimeError): + deserialize_command(serialized_command) + + def test_serialize_transformations(): dummy_op_state = "DUMMY_OP_STATE" # The only difference for trace_parameters param in this test is taget nodes names @@ -73,30 +113,13 @@ def test_serialize_transformations(): j_str = json.dumps(serialized_transformations) serialized_transformations = json.loads(j_str) - recovered_layout = load_transformations(serialized_transformations) + recovered_layout = deserialize_transformations(serialized_transformations) assert len(layout.transformations) == len(recovered_layout.transformations) # Can zip layouts because the order should not be altered for command, recovered_command in zip(layout.transformations, recovered_layout.transformations): _check_commands_after_serialization(command, recovered_command, dummy_op_state) -def load_from_config_impl(model: torch.nn.Module, serialized_transformations, example_input, trace_parameters): - transformations_layout = load_transformations(serialized_transformations) - - nncf_network = wrap_model(deepcopy(model), example_input=example_input, trace_parameters=trace_parameters) - transformed_model = ModelTransformerFactory.create(nncf_network).transform(transformations_layout) - - transformed_model.nncf.disable_dynamic_graph_building() - return transformed_model - - -def nncf_get_config_impl( - model: NNCFNetwork, -): - layout = model.nncf.transformation_layout() - return serialize_transformations(layout) - - @pytest.mark.parametrize("model_cls", InsertionCommandBuilder.AVAILABLE_MODELS) @pytest.mark.parametrize("trace_parameters", (False, True)) def test_get_apply_serialization_from_a_model(model_cls, trace_parameters): From d0cc6d2a6af7ad768eb36d6da9f398720c588cd3 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 26 Apr 2024 16:58:29 +0200 Subject: [PATCH 4/6] Specify additional restriction on serializable modules --- nncf/torch/layer_utils.py | 5 ++++- tests/torch/helpers.py | 11 +++++++++-- tests/torch/test_serialization.py | 1 + 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/nncf/torch/layer_utils.py b/nncf/torch/layer_utils.py index 7a4c7975813..a6442b03325 100644 --- a/nncf/torch/layer_utils.py +++ b/nncf/torch/layer_utils.py @@ -32,7 +32,10 @@ class StatefullTorchModuleInterface(ABC): standart (str, list and etc.) should be present in a compression module state. The state for attributes with type torch.nn.Parameter is recovered from the model `state_dict`, so there is no need to keep them in the module state. - + Modules should avoid implementation of `__call__` method and use `forward` method instead, + as torch functions called inside the `__call__` method could not be unambiguously + separated from the wrapped parent nncf module functions calls, thus nncf is unable to + identify target point for that call during transformations recovery process. """ @abstractmethod diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index d23e76c1ac6..3547d41d76a 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -269,13 +269,20 @@ class DummyOpWithState(torch.nn.Module): def __init__(self, state: str): super().__init__() self._state = state + # Keep dummy param to check state dict + self._dummy_param = torch.nn.Parameter( + torch.tensor( + 0.0, + ) + ) - def __call__(self, *args): + def forward(self, *args): if len(args) == 1: - return args[0] + return args[0] + self._dummy_param # To work correctly with # TargetType.PRE_LAYER_OPERATION # TargetType.POST_LAYER_OPERATION + args[0].weight + self._dummy_param return None def get_state(self): diff --git a/tests/torch/test_serialization.py b/tests/torch/test_serialization.py index 3b16fb3ed3d..1cfa0f093ad 100644 --- a/tests/torch/test_serialization.py +++ b/tests/torch/test_serialization.py @@ -140,6 +140,7 @@ def test_get_apply_serialization_from_a_model(model_cls, trace_parameters): recovered_model = load_from_config_impl(model, serialized_transformations, example_input, trace_parameters) + assert modified_model.state_dict().keys() == recovered_model.state_dict().keys() if not trace_parameters: _check_pre_post_ops(modified_model, recovered_model) From bef363ea92ad10c099e03342aaa2ad8ea91b76ae Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 2 May 2024 19:02:05 +0200 Subject: [PATCH 5/6] Comments --- .../algorithms/smooth_quant/torch_backend.py | 4 +-- .../graph/transformations/serialization.py | 9 +++---- nncf/torch/layer_utils.py | 16 ++++++------ nncf/torch/pruning/filter_pruning/layers.py | 4 +-- nncf/torch/quantization/layers.py | 4 +-- nncf/torch/sparsity/layers.py | 4 +-- nncf/torch/sparsity/rb/layers.py | 4 +-- tests/torch/helpers.py | 7 ++--- tests/torch/test_serialization.py | 26 +++++++++---------- 9 files changed, 38 insertions(+), 40 deletions(-) diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index 775eadeffe1..f6c9ccce171 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -60,11 +60,11 @@ def scale(self, value: torch.tensor): def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.mul(x, self._scale_value) - def get_state(self) -> Dict[str, Any]: + def get_config(self) -> Dict[str, Any]: return {self.SCALE_SHAPE_KEY: list(self._scale_value.shape)} @classmethod - def from_state(cls, state) -> "SQMultiply": + def from_config(cls, state) -> "SQMultiply": return SQMultiply(state[cls.SCALE_SHAPE_KEY]) diff --git a/nncf/torch/graph/transformations/serialization.py b/nncf/torch/graph/transformations/serialization.py index 27f1ac7286c..282c59453eb 100644 --- a/nncf/torch/graph/transformations/serialization.py +++ b/nncf/torch/graph/transformations/serialization.py @@ -34,9 +34,7 @@ def serialize_transformations(transformations_layout: TransformationLayout) -> D """ transformation_commands = [] for command in transformations_layout.transformations: - serialized_command = serialize_command(command) - if serialized_command: - transformation_commands.append(serialized_command) + transformation_commands.append(serialize_command(command)) return {COMPRESSION_STATE_ATTR: transformation_commands} @@ -72,7 +70,6 @@ def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]: serialized_transformation["target_points"] = [point.get_state() for point in command.target_points] serialized_transformation["op_name"] = command.op_name serialized_transformation["compression_module_type"] = command.compression_module_type.value - elif isinstance(command, PTInsertionCommand): serialized_transformation["target_point"] = command.target_point.get_state() @@ -84,7 +81,7 @@ def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]: " Please register your module in the COMPRESSION_MODULES registry." ) serialized_transformation["compression_module_name"] = compression_module_name - serialized_transformation["fn_state"] = command.fn.get_state() + serialized_transformation["fn_config"] = command.fn.get_config() serialized_transformation["hooks_group_name"] = command.hooks_group_name priority = command.priority serialized_transformation["priority"] = priority.value if isinstance(priority, Enum) else priority @@ -102,7 +99,7 @@ def deserialize_command(serialized_command: Dict[str, Any]) -> Union[PTInsertion raise RuntimeError(f"Command type {serialized_command['type']} is not supported.") module_cls = COMPRESSION_MODULES.get(serialized_command["compression_module_name"]) - fn = module_cls.from_state(serialized_command["fn_state"]) + fn = module_cls.from_config(serialized_command["fn_config"]) priority = serialized_command["priority"] if priority in iter(TransformationPriority): priority = TransformationPriority(priority) diff --git a/nncf/torch/layer_utils.py b/nncf/torch/layer_utils.py index a6442b03325..56814a9b6de 100644 --- a/nncf/torch/layer_utils.py +++ b/nncf/torch/layer_utils.py @@ -28,10 +28,10 @@ class StatefullTorchModuleInterface(ABC): """ Interface that should be implemented for every registered compression module to make it possible to save an compression modules state and create an compression module from the saved state. - State of the module should be json serializable, no python objects except - standart (str, list and etc.) should be present in a compression module state. - The state for attributes with type torch.nn.Parameter - is recovered from the model `state_dict`, so there is no need to keep them in the module state. + Config of the module should be json serializable, no python objects except + standart (str, list and etc.) should be present in a compression module config. + Values for attributes with type torch.nn.Parameter + is recovered from the model `state_dict`, so there is no need to keep them in the module config. Modules should avoid implementation of `__call__` method and use `forward` method instead, as torch functions called inside the `__call__` method could not be unambiguously separated from the wrapped parent nncf module functions calls, thus nncf is unable to @@ -39,15 +39,15 @@ class StatefullTorchModuleInterface(ABC): """ @abstractmethod - def get_state(self) -> Dict[str, Any]: + def get_config(self) -> Dict[str, Any]: """ - Returns the compression module state. + Returns the compression module config. """ @abstractclassmethod - def from_state(cls, state: Dict[str, Any]) -> object: + def from_config(cls, state: Dict[str, Any]) -> object: """ - Creates a compression module instance from the given state. + Creates a compression module instance from the given config. """ diff --git a/nncf/torch/pruning/filter_pruning/layers.py b/nncf/torch/pruning/filter_pruning/layers.py index 872ce31ca24..3e912cf5400 100644 --- a/nncf/torch/pruning/filter_pruning/layers.py +++ b/nncf/torch/pruning/filter_pruning/layers.py @@ -64,7 +64,7 @@ def forward(self, **params): ) return new_params - def get_state(self) -> Dict[str, Any]: + def get_config(self) -> Dict[str, Any]: return { self.MASK_APPLYING_DIM_KEY: self.mask_applying_dim, self.NODE_NAME_KEY: self.node_name, @@ -72,7 +72,7 @@ def get_state(self) -> Dict[str, Any]: } @classmethod - def from_state(cls, state: Dict[str, Any]) -> "FilterPruningMask": + def from_config(cls, state: Dict[str, Any]) -> "FilterPruningMask": return FilterPruningMask( size=state[cls.SIZE_KEY], node_name=state[cls.NODE_NAME_KEY], dim=state[cls.MASK_APPLYING_DIM_KEY] ) diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 92ccf89d010..5d97374970e 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -565,11 +565,11 @@ def get_parameters_for_torch_fq(self) -> Tuple[int, int, torch.Tensor, torch.Ten zero_point - Quantizer zero point. """ - def get_state(self): + def get_config(self): return self._qspec.get_state() @classmethod - def from_state(cls, state) -> "BaseQuantizer": + def from_config(cls, state) -> "BaseQuantizer": qsetup = PTQuantizerSpec.from_state(state) return cls(qsetup) diff --git a/nncf/torch/sparsity/layers.py b/nncf/torch/sparsity/layers.py index f6c695895dc..6edc0257929 100644 --- a/nncf/torch/sparsity/layers.py +++ b/nncf/torch/sparsity/layers.py @@ -49,9 +49,9 @@ def _calc_training_binary_mask(self, weight): def apply_binary_mask(self, weight): return apply_binary_mask_impl(self.binary_mask, weight) - def get_state(self) -> Dict[str, Any]: + def get_config(self) -> Dict[str, Any]: return {self.SHAPE_KEY: list(self.binary_mask.shape)} @classmethod - def from_state(cls, state: Dict[str, Any]) -> "BinaryMask": + def from_config(cls, state: Dict[str, Any]) -> "BinaryMask": return BinaryMask(state[cls.SHAPE_KEY]) diff --git a/nncf/torch/sparsity/rb/layers.py b/nncf/torch/sparsity/rb/layers.py index 39b1013e5ea..9b80cde6524 100644 --- a/nncf/torch/sparsity/rb/layers.py +++ b/nncf/torch/sparsity/rb/layers.py @@ -57,7 +57,7 @@ def _calc_training_binary_mask(self, weight): def loss(self): return binary_mask(self._mask) - def get_state(self) -> Dict[str, Any]: + def get_config(self) -> Dict[str, Any]: return { self.WEIGHTS_SHAPE_KEY: list(self.mask.shape), self.FROZEN_KEY: self.frozen, @@ -66,7 +66,7 @@ def get_state(self) -> Dict[str, Any]: } @classmethod - def from_state(cls, state: Dict[str, Any]) -> "RBSparsifyingWeight": + def from_config(cls, state: Dict[str, Any]) -> "RBSparsifyingWeight": return RBSparsifyingWeight( weight_shape=state[cls.WEIGHTS_SHAPE_KEY], frozen=state[cls.FROZEN_KEY], diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index 3547d41d76a..4c7bf9a6fc7 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -43,6 +43,7 @@ from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args +from nncf.torch.layer_utils import StatefullTorchModuleInterface from nncf.torch.layers import NNCF_MODULES_MAP from nncf.torch.model_creation import create_compressed_model from nncf.torch.module_operations import UpdateWeight @@ -265,7 +266,7 @@ def num_flat_features(self, x): return num_features -class DummyOpWithState(torch.nn.Module): +class DummyOpWithState(torch.nn.Module, StatefullTorchModuleInterface): def __init__(self, state: str): super().__init__() self._state = state @@ -285,11 +286,11 @@ def forward(self, *args): args[0].weight + self._dummy_param return None - def get_state(self): + def get_config(self): return self._state @classmethod - def from_state(cls, state: str): + def from_config(cls, state: str): return cls(state) diff --git a/tests/torch/test_serialization.py b/tests/torch/test_serialization.py index 1cfa0f093ad..b6830383e9b 100644 --- a/tests/torch/test_serialization.py +++ b/tests/torch/test_serialization.py @@ -183,7 +183,7 @@ def _check_pre_post_ops(modified_model, recovered_model): def _check_hook_are_equal(hook, recovered_hook): assert type(hook) == type(recovered_hook) if isinstance(hook, DummyOpWithState): - assert hook.get_state() == recovered_hook.get_state() + assert hook.get_config() == recovered_hook.get_config() return # Hook is external op call hook then assert hook._storage_name == recovered_hook._storage_name @@ -193,9 +193,9 @@ def _check_hook_are_equal(hook, recovered_hook): def _check_commands_after_serialization(command, recovered_command, dummy_op_state=None): commands_are_equal(recovered_command, command, check_fn_ref=False) assert isinstance(command.fn, DummyOpWithState) - assert command.fn.get_state() == recovered_command.fn.get_state() + assert command.fn.get_config() == recovered_command.fn.get_config() if dummy_op_state is not None: - assert command.fn.get_state() == dummy_op_state + assert command.fn.get_config() == dummy_op_state @pytest.mark.parametrize("size", (4, [3, 4])) @@ -206,11 +206,11 @@ def test_pruning_mask_serialization(size): mask.binary_filter_pruning_mask = torch.fill(torch.empty(size), 5) state_dict = mask.state_dict() - state = mask.get_state() + state = mask.get_config() json_state = json.dumps(state) state = json.loads(json_state) - recovered_mask = FilterPruningMask.from_state(state) + recovered_mask = FilterPruningMask.from_config(state) recovered_mask.load_state_dict(state_dict) ref_size = size if isinstance(size, list) else [size] @@ -244,11 +244,11 @@ def test_quantizer_serialization(quantizer_class: BaseQuantizer): state_dict = quantizer.state_dict() - state = quantizer.get_state() + state = quantizer.get_config() json_state = json.dumps(state) state = json.loads(json_state) - recovered_quantizer = quantizer_class.from_state(state) + recovered_quantizer = quantizer_class.from_config(state) recovered_quantizer.load_state_dict(state_dict) assert recovered_quantizer._qspec == ref_qspec @@ -271,11 +271,11 @@ def test_sparsity_binary_mask_serialization(): mask.binary_mask = torch.zeros(ref_shape) state_dict = mask.state_dict() - state = mask.get_state() + state = mask.get_config() json_state = json.dumps(state) state = json.loads(json_state) - recovered_mask = BinaryMask.from_state(state) + recovered_mask = BinaryMask.from_config(state) recovered_mask.load_state_dict(state_dict) assert list(recovered_mask.binary_mask.shape) == ref_shape @@ -297,11 +297,11 @@ def test_rb_sparsity_mask_serialization(): mask.mask = torch.fill(torch.empty(ref_weights_shape), 5) state_dict = mask.state_dict() - state = mask.get_state() + state = mask.get_config() json_state = json.dumps(state) state = json.loads(json_state) - recovered_mask = RBSparsifyingWeight.from_state(state) + recovered_mask = RBSparsifyingWeight.from_config(state) recovered_mask.load_state_dict(state_dict) assert list(recovered_mask.mask.shape) == ref_weights_shape @@ -321,11 +321,11 @@ def test_sq_multiply_serialization(): sq_multiply.scale = tensor_value state_dict = sq_multiply.state_dict() - state = sq_multiply.get_state() + state = sq_multiply.get_config() json_state = json.dumps(state) state = json.loads(json_state) - recovered_sq_multiply = SQMultiply.from_state(state) + recovered_sq_multiply = SQMultiply.from_config(state) recovered_sq_multiply.load_state_dict(state_dict) assert torch.all(sq_multiply.scale == recovered_sq_multiply.scale) From c2242b1a8b25131a5cfcb36d167ac7d9fa7aad8e Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 6 May 2024 11:10:18 +0200 Subject: [PATCH 6/6] StatefullTorchModuleInterface -> StatefullModuleInterface --- nncf/quantization/algorithms/smooth_quant/torch_backend.py | 4 ++-- nncf/torch/layer_utils.py | 2 +- nncf/torch/pruning/filter_pruning/layers.py | 4 ++-- nncf/torch/quantization/layers.py | 4 ++-- nncf/torch/sparsity/layers.py | 4 ++-- nncf/torch/sparsity/rb/layers.py | 4 ++-- tests/torch/helpers.py | 4 ++-- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index f6c9ccce171..5db48a24f85 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -32,7 +32,7 @@ from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import CompressionParameter -from nncf.torch.layer_utils import StatefullTorchModuleInterface +from nncf.torch.layer_utils import StatefullModuleInterface from nncf.torch.model_graph_manager import get_const_data from nncf.torch.model_graph_manager import get_const_node from nncf.torch.nncf_network import NNCFNetwork @@ -42,7 +42,7 @@ @COMPRESSION_MODULES.register() -class SQMultiply(torch.nn.Module, StatefullTorchModuleInterface): +class SQMultiply(torch.nn.Module, StatefullModuleInterface): SCALE_SHAPE_KEY = "scale_shape" def __init__(self, scale_shape: Tuple[int, ...]): diff --git a/nncf/torch/layer_utils.py b/nncf/torch/layer_utils.py index 56814a9b6de..0614d5fd2ea 100644 --- a/nncf/torch/layer_utils.py +++ b/nncf/torch/layer_utils.py @@ -24,7 +24,7 @@ COMPRESSION_MODULES = Registry("compression modules") -class StatefullTorchModuleInterface(ABC): +class StatefullModuleInterface(ABC): """ Interface that should be implemented for every registered compression module to make it possible to save an compression modules state and create an compression module from the saved state. diff --git a/nncf/torch/pruning/filter_pruning/layers.py b/nncf/torch/pruning/filter_pruning/layers.py index 3e912cf5400..644a5eec7f5 100644 --- a/nncf/torch/pruning/filter_pruning/layers.py +++ b/nncf/torch/pruning/filter_pruning/layers.py @@ -18,11 +18,11 @@ import nncf from nncf.common.graph import NNCFNodeName from nncf.torch.layer_utils import COMPRESSION_MODULES -from nncf.torch.layer_utils import StatefullTorchModuleInterface +from nncf.torch.layer_utils import StatefullModuleInterface @COMPRESSION_MODULES.register() -class FilterPruningMask(nn.Module, StatefullTorchModuleInterface): +class FilterPruningMask(nn.Module, StatefullModuleInterface): """ A module contains the mask for pruning. On forward pass applying the mask to weight and bias of the module. diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 5d97374970e..4b463600bc5 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -41,7 +41,7 @@ from nncf.torch.graph.transformations.commands import TargetType from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import CompressionParameter -from nncf.torch.layer_utils import StatefullTorchModuleInterface +from nncf.torch.layer_utils import StatefullModuleInterface from nncf.torch.quantization.quantize_functions import ExportQuantizeToFakeQuantize from nncf.torch.quantization.quantize_functions import ExportQuantizeToONNXQuantDequant from nncf.torch.quantization.quantize_functions import TuneRange @@ -284,7 +284,7 @@ def add_quantization_point(self, qp_id: QuantizationPointId, qp: PTQuantizationP self.quantization_points[qp_id] = qp -class BaseQuantizer(nn.Module, StatefullTorchModuleInterface, ABC): +class BaseQuantizer(nn.Module, StatefullModuleInterface, ABC): def __init__(self, qspec: PTQuantizerSpec): super().__init__() self._qspec = qspec diff --git a/nncf/torch/sparsity/layers.py b/nncf/torch/sparsity/layers.py index 6edc0257929..bf3794cd716 100644 --- a/nncf/torch/sparsity/layers.py +++ b/nncf/torch/sparsity/layers.py @@ -14,13 +14,13 @@ from torch import nn from nncf.torch.layer_utils import COMPRESSION_MODULES -from nncf.torch.layer_utils import StatefullTorchModuleInterface +from nncf.torch.layer_utils import StatefullModuleInterface from nncf.torch.sparsity.functions import apply_binary_mask as apply_binary_mask_impl from nncf.torch.utils import is_tracing_state @COMPRESSION_MODULES.register() -class BinaryMask(nn.Module, StatefullTorchModuleInterface): +class BinaryMask(nn.Module, StatefullModuleInterface): SHAPE_KEY = "shape" def __init__(self, shape: List[int]): diff --git a/nncf/torch/sparsity/rb/layers.py b/nncf/torch/sparsity/rb/layers.py index 9b80cde6524..8d0199046d9 100644 --- a/nncf/torch/sparsity/rb/layers.py +++ b/nncf/torch/sparsity/rb/layers.py @@ -15,14 +15,14 @@ from nncf.torch.functions import logit from nncf.torch.layer_utils import COMPRESSION_MODULES from nncf.torch.layer_utils import CompressionParameter -from nncf.torch.layer_utils import StatefullTorchModuleInterface +from nncf.torch.layer_utils import StatefullModuleInterface from nncf.torch.sparsity.layers import BinaryMask from nncf.torch.sparsity.rb.functions import binary_mask from nncf.torch.sparsity.rb.functions import calc_rb_binary_mask @COMPRESSION_MODULES.register() -class RBSparsifyingWeight(BinaryMask, StatefullTorchModuleInterface): +class RBSparsifyingWeight(BinaryMask, StatefullModuleInterface): WEIGHTS_SHAPE_KEY = "weight_shape" FROZEN_KEY = "frozen" COMPRESSION_LR_MULTIPLIER_KEY = "compression_lr_multiplier" diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index 4c7bf9a6fc7..aa8268ef1d1 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -43,7 +43,7 @@ from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args -from nncf.torch.layer_utils import StatefullTorchModuleInterface +from nncf.torch.layer_utils import StatefullModuleInterface from nncf.torch.layers import NNCF_MODULES_MAP from nncf.torch.model_creation import create_compressed_model from nncf.torch.module_operations import UpdateWeight @@ -266,7 +266,7 @@ def num_flat_features(self, x): return num_features -class DummyOpWithState(torch.nn.Module, StatefullTorchModuleInterface): +class DummyOpWithState(torch.nn.Module, StatefullModuleInterface): def __init__(self, state: str): super().__init__() self._state = state