Skip to content

Commit

Permalink
[Torch] Save/load NNCFNetwork state
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Apr 2, 2024
1 parent f7a5660 commit f89019b
Show file tree
Hide file tree
Showing 20 changed files with 1,104 additions and 25 deletions.
2 changes: 2 additions & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
from nncf.parameters import SensitivityMetric as SensitivityMetric
from nncf.parameters import TargetDevice as TargetDevice
from nncf.quantization import QuantizationPreset as QuantizationPreset
from nncf.quantization import apply_serialized_transformations as apply_serialized_transformations
from nncf.quantization import compress_weights as compress_weights
from nncf.quantization import quantize as quantize
from nncf.quantization import quantize_with_accuracy_control as quantize_with_accuracy_control
from nncf.quantization import serialize_transformations as serialize_transformations
from nncf.quantization.advanced_parameters import (
AdvancedAccuracyRestorerParameters as AdvancedAccuracyRestorerParameters,
)
Expand Down
2 changes: 2 additions & 0 deletions nncf/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# limitations under the License.
"""Post-training quantization APIs."""
from nncf.common.quantization.structs import QuantizationPreset as QuantizationPreset
from nncf.quantization.quantize_model import apply_serialized_transformations as apply_serialized_transformations
from nncf.quantization.quantize_model import compress_weights as compress_weights
from nncf.quantization.quantize_model import quantize as quantize
from nncf.quantization.quantize_model import quantize_with_accuracy_control as quantize_with_accuracy_control
from nncf.quantization.quantize_model import serialize_transformations as serialize_transformations
10 changes: 5 additions & 5 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,22 +263,22 @@ def _create_quantizer(
quantizer = quantizer_cls(quantizer_spec)

# Fill it with minmax
PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters)
PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape)
return quantizer

@staticmethod
def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters) -> None:
def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None:
if isinstance(quantizer, AsymmetricQuantizer):
quantizer.input_low = torch.nn.Parameter(parameters.input_low.data)
quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape))
input_range = parameters.input_high - parameters.input_low
# Subtract eps from the input_range to make quantizer parameters equal to
# original parameters on the forward call.
quantizer.input_range = torch.nn.Parameter(input_range.data - quantizer.eps)
quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape))
else:
quantizer.signed = bool(torch.any(parameters.input_low.data < 0))
# Subtract eps from the scale to make quantizer parameters equal to
# original parameters on the forward call.
quantizer.scale = torch.nn.Parameter(parameters.input_high.data - quantizer.eps)
quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape))

@staticmethod
def create_quantizer_insertion_command(
Expand Down
37 changes: 30 additions & 7 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, List, Tuple
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
import torch
Expand All @@ -30,20 +30,41 @@
from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.layer_utils import COMPRESSION_MODULES
from nncf.torch.layer_utils import CompressionParameter
from nncf.torch.layer_utils import StatefullTorchModuleInterface
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer
from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor


class SQMultiply(torch.nn.Module):
def __init__(self, scale_value):
@COMPRESSION_MODULES.register()
class SQMultiply(torch.nn.Module, StatefullTorchModuleInterface):
SCALE_SHAPE_KEY = "scale_shape"

def __init__(self, scale_shape: Tuple[int, ...]):
super().__init__()
self._scale_value = scale_value
self._scale_value = CompressionParameter(torch.empty(scale_shape))

@property
def scale(self) -> torch.nn.Parameter:
return self._scale_value

def forward(self, x):
@scale.setter
def scale(self, value: torch.tensor):
self._scale_value.data = value

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.mul(x, self._scale_value)

def get_state(self) -> Dict[str, Any]:
return {self.SCALE_SHAPE_KEY: list(self._scale_value.shape)}

@classmethod
def from_state(cls, state) -> "SQMultiply":
return SQMultiply(state[cls.SCALE_SHAPE_KEY])


PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK

Expand Down Expand Up @@ -117,7 +138,7 @@ def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray)
@staticmethod
def scale_insertion_command(
source_node: NNCFNode,
scale_value: np.ndarray,
scale_value: torch.Tensor,
source_output_port_id: int,
nodes: List[NNCFNode],
scale_node_name: str,
Expand All @@ -127,7 +148,9 @@ def scale_insertion_command(
for node in nodes:
target_points.append(PTTargetPoint(PT_PRE_LAYER_TARGET_TYPE, node.node_name, input_port_id=input_port_id))

return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name)
sq_multiply = SQMultiply(scale_value.shape)
sq_multiply.scale = scale_value
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
Expand Down
33 changes: 32 additions & 1 deletion nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union

import nncf
from nncf.api.compression import TModel
Expand Down Expand Up @@ -540,3 +540,34 @@ def quantize_with_tune_hyperparams(
quantized_model = hyperparameter_tuner.apply(model, validation_dataset)

return quantized_model


@api(canonical_alias="nncf.apply_serialized_transformations")
def apply_serialized_transformations(
model: TModel,
serialized_transformations,
) -> TModel:
"""
Applies transformation layout to the model.
"""
backend = get_backend(model)
if backend == BackendType.TORCH:
from nncf.torch.quantization.quantize_model import apply_serialized_transformations_impl

return apply_serialized_transformations_impl(model, serialized_transformations)
raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}")


@api(canonical_alias="nncf.serialize_transformations")
def serialize_transformations(
model: TModel,
) -> Dict[str, Any]:
"""
Applies transformation layout to the model.
"""
backend = get_backend(model)
if backend == BackendType.TORCH:
from nncf.torch.quantization.quantize_model import serialize_transformations_impl

return serialize_transformations_impl(model)
raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}")
20 changes: 20 additions & 0 deletions nncf/torch/dynamic_graph/io_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(self, shape: List[int], type_str: str = "float", keyword: str = Non
"""
self.shape = shape
self.type = self._string_to_torch_type(type_str)
self._type_str = type_str
self.keyword = keyword
if filler is None:
self.filler = self.FILLER_TYPE_ONES
Expand Down Expand Up @@ -157,6 +158,15 @@ def get_tensor_for_input(self) -> torch.Tensor:
return torch.rand(size=self.shape, dtype=self.type)
raise NotImplementedError

def get_state(self) -> Dict[str, Any]:
return {"shape": self.shape, "type_str": self._type_str, "keyword": self.keyword, "filler": self.filler}

@classmethod
def from_state(cls, state: Dict[str, Any]) -> "FillerInputElement":
return FillerInputElement(
shape=state["shape"], type_str=state["type_str"], keyword=state["keyword"], filler=state["filler"]
)


class FillerInputInfo(ModelInputInfo):
"""
Expand Down Expand Up @@ -220,6 +230,16 @@ def get_forward_inputs(
kwargs[fe.keyword] = tensor
return tuple(args_list), kwargs

def get_state(self) -> Dict[str, Any]:
return {"elements": [elem.get_state() for elem in self.elements]}

@classmethod
def from_state(cls, state) -> "FillerInputInfo":
return FillerInputInfo([FillerInputElement.from_state(s) for s in state["elements"]])

def __eq__(self, other: "FillerInputInfo") -> bool:
return self.elements == other.elements


class ExactInputsInfo(ModelInputInfo):
"""
Expand Down
7 changes: 7 additions & 0 deletions nncf/torch/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def get_op_nodes_in_scope(self, scope: Scope) -> List[NNCFNode]:
matching_graph_op_nodes.extend(nodes_in_module)
return matching_graph_op_nodes

def get_op_node_in_scope(self, scope: Scope) -> List[NNCFNode]:
for scope_str, nodes_in_module in self._layer_name_vs_shared_nodes.items():
module_scope = Scope.from_str(scope_str)
if module_scope == scope:
return nodes_in_module
return []

def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope:
matches = []
for node_id, scope_str in self._node_ids_vs_layer_names.items():
Expand Down
113 changes: 113 additions & 0 deletions nncf/torch/graph/transformations/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Any, Dict, Tuple, Union

import torch

from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTTransformationCommand
from nncf.torch.layer_utils import COMPRESSION_MODULES

COMPRESSION_STATE_ATTR = "compression_state"
INPUT_INFO_ATTR = "example_input"


class CompressionKeys(Enum):
SHARED_INSERTION_COMMAND = "SHARED_INSERTION_COMMAND"
INSERTION_COMMAND = "INSERTION_COMMAND"


def serialize_transformations(model: torch.nn.Module, transformations_layout: TransformationLayout) -> Dict[str, Any]:
input_info = model.nncf._input_info
if not isinstance(input_info, FillerInputInfo):
raise RuntimeError("Could not serialize model inputs input: {input_info}")

transformation_commands = []
for command in transformations_layout.transformations:
serialized_command = serialize_command(command)
if serialized_command:
transformation_commands.append(serialized_command)

return {COMPRESSION_STATE_ATTR: transformation_commands, INPUT_INFO_ATTR: input_info.get_state()}


def load_transformations(transformations_state: Dict[str, Any]) -> Tuple[TransformationLayout, FillerInputInfo]:
transformation_layout = TransformationLayout()
for serialized_command in transformations_state[COMPRESSION_STATE_ATTR]:
command = load_command(serialized_command)
transformation_layout.register(command)

input_info = FillerInputInfo.from_state(transformations_state[INPUT_INFO_ATTR])
return transformation_layout, input_info


def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]:
if not isinstance(command, (PTSharedFnInsertionCommand, PTInsertionCommand)):
return {}

serialized_transformation = dict()
if isinstance(command, PTSharedFnInsertionCommand):
serialized_transformation["type"] = CompressionKeys.SHARED_INSERTION_COMMAND.value
serialized_transformation["target_points"] = [point.get_state() for point in command.target_points]
serialized_transformation["op_name"] = command.op_name
serialized_transformation["compression_module_type"] = command.compression_module_type.value

elif isinstance(command, PTInsertionCommand):
serialized_transformation["type"] = CompressionKeys.INSERTION_COMMAND.value
serialized_transformation["target_point"] = command.target_point.get_state()

# Check compression module is registered
compression_module_name = command.fn.__class__.__name__
if compression_module_name not in COMPRESSION_MODULES.registry_dict:
raise RuntimeError(
f"Could not serialize compression module with name {compression_module_name}."
" Please register your module in the COMPRESSION_MODULES registry."
)
serialized_transformation["compression_module_name"] = compression_module_name
serialized_transformation["fn_state"] = command.fn.get_state()
serialized_transformation["hooks_group_name"] = command.hooks_group_name
priority = command.priority
serialized_transformation["priority"] = priority.value if isinstance(priority, Enum) else priority
return serialized_transformation


def load_command(serialized_command: Dict[str, Any]) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]:
module_cls = COMPRESSION_MODULES.get(serialized_command["compression_module_name"])
fn = module_cls.from_state(serialized_command["fn_state"])
priority = serialized_command["priority"]
if priority in iter(TransformationPriority):
priority = TransformationPriority(priority)

if serialized_command["type"] == CompressionKeys.INSERTION_COMMAND.value:
target_point = PTTargetPoint.from_state(serialized_command["target_point"])
return PTInsertionCommand(
point=target_point, fn=fn, priority=priority, hooks_group_name=serialized_command["hooks_group_name"]
)

if serialized_command["type"] == CompressionKeys.SHARED_INSERTION_COMMAND.value:
target_points = [PTTargetPoint.from_state(state) for state in serialized_command["target_points"]]
return PTSharedFnInsertionCommand(
target_points=target_points,
fn=fn,
op_unique_name=serialized_command["op_name"],
compression_module_type=ExtraCompressionModuleType(serialized_command["compression_module_type"]),
priority=priority,
hooks_group_name=serialized_command["hooks_group_name"],
)
raise RuntimeError(f"Command type {serialized_command['type']} is not supported.")
34 changes: 34 additions & 0 deletions nncf/torch/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractclassmethod
from abc import abstractmethod
from typing import Any, Dict

import torch
from torch import nn

Expand All @@ -19,6 +24,30 @@
COMPRESSION_MODULES = Registry("compression modules")


class StatefullTorchModuleInterface(ABC):
"""
Interface that should be implemented for every registered compression module to make it possible
to save an compression modules state and create an compression module from the saved state.
State of the module should be json serializable, no python objects except
standart (str, list and etc.) should be present in a compression module state.
The state for attributes with type torch.nn.Parameter
is recovered from the model `state_dict`, so there is no need to keep them in the module state.
"""

@abstractmethod
def get_state(self) -> Dict[str, Any]:
"""
Returns the compression module state.
"""

@abstractclassmethod
def from_state(cls, state: Dict[str, Any]) -> object:
"""
Creates a compression module instance from the given state.
"""


class ProxyModule:
def __init__(self, module):
self._module = module
Expand Down Expand Up @@ -117,7 +146,12 @@ def __init__(self, data: torch.Tensor = None, requires_grad: bool = True, compre
"""
super().__init__()

self._compression_lr_multiplier = compression_lr_multiplier
if compression_lr_multiplier is not None and self.dtype.is_floating_point:
self.requires_grad = True
self.register_hook(lambda grad: compression_lr_multiplier * grad)
self.requires_grad = requires_grad

@property
def compression_lr_multiplier(self):
return self._compression_lr_multiplier
Loading

0 comments on commit f89019b

Please sign in to comment.