Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Serialize and load NNCF transformations #2531

Merged
merged 6 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,22 +268,22 @@ def _create_quantizer(
quantizer = quantizer_cls(quantizer_spec)

# Fill it with minmax
PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters)
PTMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape)
return quantizer

@staticmethod
def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters) -> None:
def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None:
if isinstance(quantizer, AsymmetricQuantizer):
quantizer.input_low = torch.nn.Parameter(parameters.input_low.data)
quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape))
input_range = parameters.input_high - parameters.input_low
# Subtract eps from the input_range to make quantizer parameters equal to
# original parameters on the forward call.
quantizer.input_range = torch.nn.Parameter(input_range.data - quantizer.eps)
quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape))
else:
quantizer.signed = bool(torch.any(parameters.input_low.data < 0))
# Subtract eps from the scale to make quantizer parameters equal to
# original parameters on the forward call.
quantizer.scale = torch.nn.Parameter(parameters.input_high.data - quantizer.eps)
quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape))

@staticmethod
def create_quantizer_insertion_command(
Expand Down
37 changes: 30 additions & 7 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, List, Tuple
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
import torch
Expand All @@ -30,6 +30,9 @@
from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.layer_utils import COMPRESSION_MODULES
from nncf.torch.layer_utils import CompressionParameter
from nncf.torch.layer_utils import StatefullTorchModuleInterface
from nncf.torch.model_graph_manager import get_const_data
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.nncf_network import NNCFNetwork
Expand All @@ -38,14 +41,32 @@
from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor


class SQMultiply(torch.nn.Module):
def __init__(self, scale_value):
@COMPRESSION_MODULES.register()
class SQMultiply(torch.nn.Module, StatefullTorchModuleInterface):
SCALE_SHAPE_KEY = "scale_shape"

def __init__(self, scale_shape: Tuple[int, ...]):
super().__init__()
self._scale_value = scale_value
self._scale_value = CompressionParameter(torch.empty(scale_shape))

@property
def scale(self) -> torch.nn.Parameter:
return self._scale_value

def forward(self, x):
@scale.setter
def scale(self, value: torch.tensor):
self._scale_value.data = value

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.mul(x, self._scale_value)

def get_config(self) -> Dict[str, Any]:
return {self.SCALE_SHAPE_KEY: list(self._scale_value.shape)}

@classmethod
def from_config(cls, state) -> "SQMultiply":
return SQMultiply(state[cls.SCALE_SHAPE_KEY])


PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK

Expand Down Expand Up @@ -122,7 +143,7 @@ def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray)
@staticmethod
def scale_insertion_command(
source_node: NNCFNode,
scale_value: np.ndarray,
scale_value: torch.Tensor,
source_output_port_id: int,
nodes: List[NNCFNode],
scale_node_name: str,
Expand All @@ -132,7 +153,9 @@ def scale_insertion_command(
for node in nodes:
target_points.append(PTTargetPoint(PT_PRE_LAYER_TARGET_TYPE, node.node_name, input_port_id=input_port_id))

return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name)
sq_multiply = SQMultiply(scale_value.shape)
sq_multiply.scale = scale_value
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
Expand Down
121 changes: 121 additions & 0 deletions nncf/torch/graph/transformations/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Any, Dict, Union

from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTTransformationCommand
from nncf.torch.layer_utils import COMPRESSION_MODULES

COMPRESSION_STATE_ATTR = "compression_state"
SUPPORTED_COMMANDS = (PTSharedFnInsertionCommand, PTInsertionCommand)


def serialize_transformations(transformations_layout: TransformationLayout) -> Dict[str, Any]:
"""
Serializes given transformation layout to a dict.

:param tranformation_layout: Given transformation layout.
:return: Serialized representation of given transformation layout as a dict.
"""
transformation_commands = []
for command in transformations_layout.transformations:
transformation_commands.append(serialize_command(command))

return {COMPRESSION_STATE_ATTR: transformation_commands}


def deserialize_transformations(serialized_transformation_layout: Dict[str, Any]) -> TransformationLayout:
"""
Deserializes given serialized transformation layout.

:param serialized_transformation_layout: Given serialized transformation layout.
:return: The deserialized transformation layout.
"""
transformation_layout = TransformationLayout()
for serialized_command in serialized_transformation_layout[COMPRESSION_STATE_ATTR]:
command = deserialize_command(serialized_command)
transformation_layout.register(command)

return transformation_layout


def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]:
"""
Serializes given command layout to a dict.

:param command: Given command.
:return: Serialized representation of given command as a dict.
"""
if not isinstance(command, SUPPORTED_COMMANDS):
raise RuntimeError(f"Command type {command.__class__} is not supported.")

serialized_transformation = dict()
serialized_transformation["type"] = command.__class__.__name__
if isinstance(command, PTSharedFnInsertionCommand):
serialized_transformation["target_points"] = [point.get_state() for point in command.target_points]
serialized_transformation["op_name"] = command.op_name
serialized_transformation["compression_module_type"] = command.compression_module_type.value
elif isinstance(command, PTInsertionCommand):
serialized_transformation["target_point"] = command.target_point.get_state()

# Check compression module is registered
compression_module_name = command.fn.__class__.__name__
if compression_module_name not in COMPRESSION_MODULES.registry_dict:
raise RuntimeError(
f"Could not serialize compression module with name {compression_module_name}."
" Please register your module in the COMPRESSION_MODULES registry."
)
serialized_transformation["compression_module_name"] = compression_module_name
serialized_transformation["fn_config"] = command.fn.get_config()
serialized_transformation["hooks_group_name"] = command.hooks_group_name
priority = command.priority
serialized_transformation["priority"] = priority.value if isinstance(priority, Enum) else priority
return serialized_transformation


def deserialize_command(serialized_command: Dict[str, Any]) -> Union[PTInsertionCommand, PTSharedFnInsertionCommand]:
"""
Deserializes given serialized command.

:param serialized_command: Given serialized command.
:return: The deserialized command.
"""
if serialized_command["type"] not in (command_cls.__name__ for command_cls in SUPPORTED_COMMANDS):
raise RuntimeError(f"Command type {serialized_command['type']} is not supported.")

module_cls = COMPRESSION_MODULES.get(serialized_command["compression_module_name"])
fn = module_cls.from_config(serialized_command["fn_config"])
priority = serialized_command["priority"]
if priority in iter(TransformationPriority):
priority = TransformationPriority(priority)

if serialized_command["type"] == PTInsertionCommand.__name__:
target_point = PTTargetPoint.from_state(serialized_command["target_point"])
return PTInsertionCommand(
point=target_point, fn=fn, priority=priority, hooks_group_name=serialized_command["hooks_group_name"]
)

target_points = [PTTargetPoint.from_state(state) for state in serialized_command["target_points"]]
return PTSharedFnInsertionCommand(
target_points=target_points,
fn=fn,
op_unique_name=serialized_command["op_name"],
compression_module_type=ExtraCompressionModuleType(serialized_command["compression_module_type"]),
priority=priority,
hooks_group_name=serialized_command["hooks_group_name"],
)
37 changes: 37 additions & 0 deletions nncf/torch/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractclassmethod
from abc import abstractmethod
from typing import Any, Dict

import torch
from torch import nn

Expand All @@ -19,6 +24,33 @@
COMPRESSION_MODULES = Registry("compression modules")


class StatefullTorchModuleInterface(ABC):
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
"""
Interface that should be implemented for every registered compression module to make it possible
to save an compression modules state and create an compression module from the saved state.
Config of the module should be json serializable, no python objects except
standart (str, list and etc.) should be present in a compression module config.
Values for attributes with type torch.nn.Parameter
is recovered from the model `state_dict`, so there is no need to keep them in the module config.
Modules should avoid implementation of `__call__` method and use `forward` method instead,
as torch functions called inside the `__call__` method could not be unambiguously
separated from the wrapped parent nncf module functions calls, thus nncf is unable to
identify target point for that call during transformations recovery process.
"""

@abstractmethod
def get_config(self) -> Dict[str, Any]:
"""
Returns the compression module config.
"""

@abstractclassmethod
def from_config(cls, state: Dict[str, Any]) -> object:
"""
Creates a compression module instance from the given config.
"""


class ProxyModule:
def __init__(self, module):
self._module = module
Expand Down Expand Up @@ -117,7 +149,12 @@ def __init__(self, data: torch.Tensor = None, requires_grad: bool = True, compre
"""
super().__init__()

self._compression_lr_multiplier = compression_lr_multiplier
if compression_lr_multiplier is not None and self.dtype.is_floating_point:
self.requires_grad = True
self.register_hook(lambda grad: compression_lr_multiplier * grad)
self.requires_grad = requires_grad

@property
def compression_lr_multiplier(self):
return self._compression_lr_multiplier
27 changes: 24 additions & 3 deletions nncf/torch/pruning/filter_pruning/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict

import numpy as np
import torch
from torch import nn

import nncf
from nncf.common.graph import NNCFNodeName
from nncf.torch.layer_utils import COMPRESSION_MODULES
from nncf.torch.layer_utils import StatefullTorchModuleInterface


@COMPRESSION_MODULES.register()
class FilterPruningMask(nn.Module):
class FilterPruningMask(nn.Module, StatefullTorchModuleInterface):
"""
A module contains the mask for pruning.
On forward pass applying the mask to weight and bias of the module.
"""

MASK_APPLYING_DIM_KEY = "dim"
NODE_NAME_KEY = "node_name"
SIZE_KEY = "size_key"

def __init__(self, size, node_name, dim=0):
super().__init__()
self.register_buffer("_binary_filter_pruning_mask", torch.ones(size))
self.mask_applying_dim = dim
self.node_name = node_name

@property
def binary_filter_pruning_mask(self):
def binary_filter_pruning_mask(self) -> torch.Tensor:
return self._binary_filter_pruning_mask

@binary_filter_pruning_mask.setter
def binary_filter_pruning_mask(self, mask):
def binary_filter_pruning_mask(self, mask: torch.Tensor):
with torch.no_grad():
self._binary_filter_pruning_mask.set_(mask)

Expand All @@ -56,6 +64,19 @@ def forward(self, **params):
)
return new_params

def get_config(self) -> Dict[str, Any]:
return {
self.MASK_APPLYING_DIM_KEY: self.mask_applying_dim,
self.NODE_NAME_KEY: self.node_name,
self.SIZE_KEY: list(self.binary_filter_pruning_mask.size()),
}

@classmethod
def from_config(cls, state: Dict[str, Any]) -> "FilterPruningMask":
return FilterPruningMask(
size=state[cls.SIZE_KEY], node_name=state[cls.NODE_NAME_KEY], dim=state[cls.MASK_APPLYING_DIM_KEY]
)


def broadcast_filter_mask(filter_mask, shape, dim=0):
broadcasted_shape = np.ones(len(shape), dtype=np.int64)
Expand Down
12 changes: 11 additions & 1 deletion nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from nncf.torch.graph.transformations.commands import TargetType
from nncf.torch.layer_utils import COMPRESSION_MODULES
from nncf.torch.layer_utils import CompressionParameter
from nncf.torch.layer_utils import StatefullTorchModuleInterface
from nncf.torch.quantization.quantize_functions import ExportQuantizeToFakeQuantize
from nncf.torch.quantization.quantize_functions import ExportQuantizeToONNXQuantDequant
from nncf.torch.quantization.quantize_functions import TuneRange
Expand Down Expand Up @@ -283,9 +284,10 @@ def add_quantization_point(self, qp_id: QuantizationPointId, qp: PTQuantizationP
self.quantization_points[qp_id] = qp


class BaseQuantizer(nn.Module, ABC):
class BaseQuantizer(nn.Module, StatefullTorchModuleInterface, ABC):
def __init__(self, qspec: PTQuantizerSpec):
super().__init__()
self._qspec = qspec
self._narrow_range = qspec.narrow_range
self._signedness_to_force = qspec.signedness_to_force
self._is_using_log_scale_storage = qspec.logarithm_scale
Expand Down Expand Up @@ -563,6 +565,14 @@ def get_parameters_for_torch_fq(self) -> Tuple[int, int, torch.Tensor, torch.Ten
zero_point - Quantizer zero point.
"""

def get_config(self):
return self._qspec.get_state()

@classmethod
def from_config(cls, state) -> "BaseQuantizer":
qsetup = PTQuantizerSpec.from_state(state)
return cls(qsetup)


class QuantizersSwitcher:
"""Enables/disables quantizers with saving and restoring original state"""
Expand Down
Loading
Loading