diff --git a/nncf/common/graph/layer_attributes.py b/nncf/common/graph/layer_attributes.py index aa6358d8c0b..06da97313b8 100644 --- a/nncf/common/graph/layer_attributes.py +++ b/nncf/common/graph/layer_attributes.py @@ -150,21 +150,23 @@ def __init__( dilations: Tuple[int, ...], groups: int, transpose: bool, - padding_values: Tuple[int, ...], + padding_values: Union[Tuple[int, ...], int], with_bias: bool = False, + output_padding_values: Optional[Union[Tuple[int, ...], int]] = None, ): """ :param weight_requires_grad: Is True if gradients need to be computed for the corresponding Tensor, False otherwise. - :param in_channels: number of input channels in the layer's input. - :param out_channels: number of channels produced by the layer. - :param kernel_size: size of the convolving kernel. - :param stride: stride of the convolution. - :param groups: number of blocked connections from input channels to output channels. + :param in_channels: Number of input channels in the layer's input. + :param out_channels: Number of channels produced by the layer. + :param kernel_size: Size of the convolving kernel. + :param stride: Stride of the convolution. + :param groups: Number of blocked connections from input channels to output channels. :param transpose: If set to `True`, the layer is an ordinary convolution, otherwise - transpose one. - :param padding_values: defines the amount of padding applied to the layer's input. + :param padding_values: Defines the amount of padding applied to the layer's input. :param with_bias: Operation include bias. + :param output_padding_values: Defines the amount of output padding applied to the layer's output, for transpose. """ super().__init__(weight_requires_grad=weight_requires_grad, with_bias=with_bias) self.in_channels = in_channels @@ -175,6 +177,7 @@ def __init__( self.groups = groups self.transpose = transpose self.padding_values = padding_values + self.output_padding_values = output_padding_values def get_weight_shape(self) -> List[int]: if not self.transpose: diff --git a/nncf/common/graph/operator_metatypes.py b/nncf/common/graph/operator_metatypes.py index ac3a3c0e9fb..e65f659471f 100644 --- a/nncf/common/graph/operator_metatypes.py +++ b/nncf/common/graph/operator_metatypes.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Type +from typing import List, Optional, Set, Type import nncf from nncf.common.graph.definitions import NNCFGraphNodeType @@ -187,3 +187,13 @@ class ConstNoopMetatype(OperatorMetatype): @classmethod def get_all_aliases(cls) -> List[str]: return [NNCFGraphNodeType.CONST_NODE] + + +def get_all_aliases(*metatypes: OperatorMetatype) -> Set[str]: + """ + Returns a set of all unique aliases from the provided metatypes. + + :param *metatypes: A list of operator metatypes. + :return: A set containing all unique aliases for metatypes. + """ + return set(a for m in metatypes for a in m.get_all_aliases()) diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index e0c291e81b0..3fc90a5e9c5 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple import torch @@ -32,56 +32,16 @@ from nncf.torch.graph import operator_metatypes as om from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.model_graph_manager import find_const_node_in_constant_subgraph +from nncf.torch.model_graph_manager import get_const_node +from nncf.torch.model_graph_manager import get_module_by_name +from nncf.torch.model_graph_manager import split_const_name from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.layers import WeightsDecompressor from nncf.torch.tensor_statistics.collectors import get_raw_stat_collector -def split_weight_name(weight_name: str) -> Tuple[str, str]: - index = weight_name.rfind(".") - if index == -1: - return str(), weight_name - module_name = weight_name[:index] - weight_attr_name = weight_name[index + 1 :] - return module_name, weight_attr_name - - -def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Module: - if not module_name: - return model - curr_module = model - for name in module_name.split("."): - for child_name, child_module in curr_module.named_children(): - if child_name == name: - curr_module = child_module - break - else: - raise nncf.ModuleNotFoundError(f"Could not find the {module_name} module in the model.") - return curr_module - - -def find_weight_node_in_constant_subgraph(node: NNCFNode, graph: NNCFGraph) -> Union[NNCFNode, None]: - if node.metatype == om.PTNoopMetatype: - prev_nodes = graph.get_previous_nodes(node) - if len(prev_nodes) != 1: - return None - return find_weight_node_in_constant_subgraph(prev_nodes[0], graph) - if node.metatype in CONST_NOOP_METATYPES: - return node - return None - - -def get_weight_node(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> NNCFNode: - for prev_node in graph.get_previous_nodes(node_with_weight): - edge = graph.get_edge(prev_node, node_with_weight) - if edge.input_port_id == weight_port_id: - weight_node = find_weight_node_in_constant_subgraph(prev_node, graph) - if weight_node is None: - raise nncf.InternalError("Could not find a constant node in the model graph.") - return weight_node - - class PTWeightCompressionAlgoBackend(WeightCompressionAlgoBackend): TARGET_TYPE_TO_PT_INS_TYPE_MAP = { TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, @@ -125,7 +85,7 @@ def is_node_with_weights(node: NNCFNode, graph: NNCFGraph) -> bool: edge = graph.get_edge(prev_node, node) if edge.input_port_id not in node.metatype.weight_port_ids: continue - weight_node = find_weight_node_in_constant_subgraph(prev_node, graph) + weight_node = find_const_node_in_constant_subgraph(prev_node, graph) if weight_node is not None: return True return False @@ -134,7 +94,7 @@ def is_node_with_weights(node: NNCFNode, graph: NNCFGraph) -> bool: def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Tuple[str, int]]: weight_port_ids = [] for prev_node in graph.get_previous_nodes(node): - weight_node = find_weight_node_in_constant_subgraph(prev_node, graph) + weight_node = find_const_node_in_constant_subgraph(prev_node, graph) if weight_node is None: continue edge = graph.get_edge(prev_node, node) @@ -146,7 +106,7 @@ def get_weight_names_and_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Tupl def get_channel_agnostic_reduction_axes( node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph ) -> Optional[Tuple[int]]: - weight_node = get_weight_node(node_with_weight, weight_port_id, graph) + weight_node = get_const_node(node_with_weight, weight_port_id, graph) ndims = len(weight_node.layer_attributes.shape) reduction_axes = None @@ -200,9 +160,9 @@ def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int: def get_weight( self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph ) -> Tensor: - weight_node = get_weight_node(node_with_weight, weight_port_id, graph) + weight_node = get_const_node(node_with_weight, weight_port_id, graph) weight_name = weight_node.layer_attributes.name - module_name, weight_attr_name = split_weight_name(weight_name) + module_name, weight_attr_name = split_const_name(weight_name) module = get_module_by_name(module_name, model) weight = getattr(module, weight_attr_name) if weight is None or not isinstance(weight, torch.nn.Parameter): @@ -229,9 +189,9 @@ def transform_model( ]: raise ValueError(f"{compression_config.mode.value} is not supported.") - weight_node = get_weight_node(wc_params.node_with_weight, wc_params.weight_port_id, graph) + weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph) weight_name = weight_node.layer_attributes.name - module_name, weight_attr_name = split_weight_name(weight_name) + module_name, weight_attr_name = split_const_name(weight_name) module = get_module_by_name(module_name, model) weight = getattr(module, weight_attr_name) if weight is None or not isinstance(weight, torch.nn.Parameter): diff --git a/nncf/torch/dynamic_graph/layer_attributes_handlers.py b/nncf/torch/dynamic_graph/layer_attributes_handlers.py index 8338e760c17..bc3c809ecc2 100644 --- a/nncf/torch/dynamic_graph/layer_attributes_handlers.py +++ b/nncf/torch/dynamic_graph/layer_attributes_handlers.py @@ -9,15 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from torch.nn import Conv1d -from torch.nn import Conv2d -from torch.nn import Conv3d -from torch.nn import ConvTranspose1d -from torch.nn import ConvTranspose2d -from torch.nn import ConvTranspose3d -from torch.nn import Linear -from torch.nn import Module as TorchModule +from typing import Any, Dict, List, Tuple, Union +import nncf.torch.graph.operator_metatypes as om from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.layer_attributes import BaseLayerAttributes from nncf.common.graph.layer_attributes import ConstantLayerAttributes @@ -33,89 +27,52 @@ from nncf.common.graph.layer_attributes import ReshapeLayerAttributes from nncf.common.graph.layer_attributes import TransposeLayerAttributes from nncf.common.graph.operator_metatypes import ConstNoopMetatype +from nncf.common.graph.operator_metatypes import get_all_aliases from nncf.common.graph.utils import get_split_axis from nncf.torch.dynamic_graph.trace_tensor import TracedParameter -from nncf.torch.graph.operator_metatypes import PTCatMetatype -from nncf.torch.graph.operator_metatypes import PTGroupNormMetatype -from nncf.torch.graph.operator_metatypes import PTPadMetatype -from nncf.torch.graph.operator_metatypes import PTReshapeMetatype -from nncf.torch.graph.operator_metatypes import PTSplitMetatype -from nncf.torch.graph.operator_metatypes import PTSqueezeMetatype from nncf.torch.layers import NNCF_MODULES_DICT OP_NAMES_REQUIRING_MODULE_ATTRS = [v.op_func_name for v in NNCF_MODULES_DICT] + list( - PTGroupNormMetatype.get_all_aliases() + om.PTGroupNormMetatype.get_all_aliases() ) TRANSPOSE_OP_NAMES = ["transpose", "transpose_"] PERMUTE_OP_NAMES = ["permute"] GETITEM_OP_NAMES = ["__getitem__"] -PAD_OP_NAMES = PTPadMetatype.get_all_aliases() -CONCAT_OP_NAMES = PTCatMetatype.get_all_aliases() +CONV_OP_NAMES = get_all_aliases(om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype) +CONV_TRANSPOSE_OP_NAMES = get_all_aliases( + om.PTConvTranspose1dMetatype, om.PTConvTranspose2dMetatype, om.PTConvTranspose3dMetatype +) +LINEAR_OP_NAMES = get_all_aliases(om.PTLinearMetatype) +BATCHNORM_OP_NAMES = get_all_aliases(om.PTBatchNormMetatype) +EMBEDDING_OP_NAMES = get_all_aliases(om.PTEmbeddingMetatype, om.PTEmbeddingBagMetatype) +GROUP_NORM_OP_NAMES = get_all_aliases(om.PTGroupNormMetatype) +LAYER_NORM_OP_NAMES = get_all_aliases(om.PTLayerNormMetatype) +PAD_OP_NAMES = om.PTPadMetatype.get_all_aliases() +CONCAT_OP_NAMES = om.PTCatMetatype.get_all_aliases() CONST_OP_NAMES = ConstNoopMetatype.get_all_aliases() OP_NAMES_REQUIRING_ATTRS_FROM_ARGS_KWARGS = list( TRANSPOSE_OP_NAMES + PERMUTE_OP_NAMES + GETITEM_OP_NAMES + PAD_OP_NAMES + CONCAT_OP_NAMES + CONST_OP_NAMES ) -def get_layer_attributes_from_module(module: TorchModule, operator_name: str) -> BaseLayerAttributes: - if operator_name == "group_norm": - return GroupNormLayerAttributes( - weight_requires_grad=module.weight.requires_grad, - num_channels=module.num_channels, - num_groups=module.num_groups, - ) - # torch.nn.utils.weight_norm replaces weight with weight_g and weight_v - is_weight_norm_applied = hasattr(module, "weight_g") and hasattr(module, "weight_v") - weight_attr = "weight_g" if is_weight_norm_applied else "weight" - with_bias = hasattr(module, "bias") and module.bias is not None - if isinstance(module, (Conv1d, Conv2d, Conv3d)): - return ConvolutionLayerAttributes( - weight_requires_grad=getattr(module, weight_attr).requires_grad, - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, - dilations=module.dilation, - groups=module.groups, - transpose=False, - padding_values=module.padding, - with_bias=with_bias, - ) - if isinstance(module, (ConvTranspose1d, ConvTranspose2d, ConvTranspose3d)): - return ConvolutionLayerAttributes( - weight_requires_grad=getattr(module, weight_attr).requires_grad, - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size, - stride=module.stride, - dilations=module.dilation, - groups=module.groups, - transpose=True, - padding_values=module.padding, - with_bias=with_bias, - ) - if isinstance(module, Linear): - return LinearLayerAttributes( - weight_requires_grad=getattr(module, weight_attr).requires_grad, - in_features=module.in_features, - out_features=module.out_features, - with_bias=with_bias, - ) - - if hasattr(module, "weight"): - return GenericWeightedLayerAttributes( - weight_requires_grad=getattr(module, weight_attr).requires_grad, - weight_shape=module.weight.shape, - with_bias=with_bias, - ) - - return GenericWeightedLayerAttributes(weight_requires_grad=False, weight_shape=[1, 1]) - - def get_layer_attributes_from_args_and_kwargs(op_name: str, args, kwargs) -> BaseLayerAttributes: layer_attrs = None - if op_name in TRANSPOSE_OP_NAMES: + if op_name in CONV_OP_NAMES: + layer_attrs = _get_conv_attrs_from_args_kwargs(args, kwargs) + elif op_name in CONV_TRANSPOSE_OP_NAMES: + layer_attrs = _get_conv_transpose_attrs_from_args_kwargs(args, kwargs) + elif op_name in LINEAR_OP_NAMES: + layer_attrs = _get_linear_attrs_from_args_kwargs(args, kwargs) + elif op_name in GROUP_NORM_OP_NAMES: + layer_attrs = _get_group_norm_attrs_from_args_kwargs(args, kwargs) + elif op_name in BATCHNORM_OP_NAMES: + layer_attrs = _get_batchnorm_attrs_from_args_kwargs(args, kwargs) + elif op_name in LAYER_NORM_OP_NAMES: + layer_attrs = _get_layer_norm_attrs_from_args_kwargs(args, kwargs) + elif op_name in EMBEDDING_OP_NAMES: + layer_attrs = _get_embedding_attrs_from_args_kwargs(args, kwargs) + elif op_name in TRANSPOSE_OP_NAMES: layer_attrs = _get_transpose_attrs_from_args_kwargs(args, kwargs) elif op_name in PERMUTE_OP_NAMES: layer_attrs = _get_permute_attrs_from_args_kwargs(args, kwargs) @@ -132,7 +89,7 @@ def get_layer_attributes_from_args_and_kwargs(op_name: str, args, kwargs) -> Bas def set_nodes_attributes_in_nncf_graph(graph: NNCFGraph) -> None: for node in graph.get_all_nodes(): - if node.metatype in [PTReshapeMetatype, PTSqueezeMetatype]: + if node.metatype in [om.PTReshapeMetatype, om.PTSqueezeMetatype]: input_nodes = graph.get_input_edges(node) output_nodes = graph.get_output_edges(node) # In case ReshapeMetatype op is intermediate node @@ -140,7 +97,7 @@ def set_nodes_attributes_in_nncf_graph(graph: NNCFGraph) -> None: layer_attributes = ReshapeLayerAttributes(input_nodes[0].tensor_shape, output_nodes[0].tensor_shape) node.layer_attributes = layer_attributes - if node.metatype is PTSplitMetatype: + if node.metatype is om.PTSplitMetatype: input_edges = graph.get_input_edges(node) output_edges = graph.get_output_edges(node) if input_edges and output_edges: @@ -195,3 +152,171 @@ def _get_const_attrs_from_args_kwargs(args, _) -> ConstantLayerAttributes: name = args[0].name shape = args[0].shape return ConstantLayerAttributes(name, shape) + + +def apply_args_defaults( + args: List[Any], kwargs: Dict[str, Any], args_signature=List[Union[str, Tuple[str, Any]]] +) -> Dict[str, Any]: + """ + Combines positional arguments (`args`) and keyword arguments (`kwargs`) + according to the provided `args_signature`. + + The `args_signature` is a list that defines the expected arguments. + Each element in the list can be either: + + - string: This represents the name of an argument expected to be a positional argument. + - tuple: This represents the name and default value of an argument. + - The first element in the tuple is the argument name. + - The second element in the tuple is the default value. + + :param args: List of positional arguments. + :param kwargs: Dictionary of keyword arguments. + :param args_signature: List defining the expected arguments as described above. + + :return: A dictionary combining arguments from `args` and `kwargs` according to the `args_signature`. + """ + # Manual defines function signature neccecery because inspection of torch function is not available + # https://github.com/pytorch/pytorch/issues/74539 + + args_dict: Dict[str, Any] = dict() + for idx, arg_desc in enumerate(args_signature): + if isinstance(arg_desc, str): + args_dict[arg_desc] = kwargs.get(arg_desc, args[idx]) + elif isinstance(arg_desc, Tuple): + arg_name, default = arg_desc + args_dict[arg_name] = kwargs.get(arg_name, args[idx] if idx < len(args) else default) + else: + raise ValueError("Incorrect args_signature, element of list should be str or tuple.") + return args_dict + + +GENERIC_WEIGHT_FUNC_SIGNATURE = ["input", "weight"] +LINEAR_FUNC_SIGNATURE = ["input", "weight", ("bias", None)] +CONV_FUNC_SIGNATURE = ["input", "weight", ("bias", None), ("stride", 1), ("padding", 0), ("dilation", 1), ("groups", 1)] +CONV_TRANSPOSE_FUNC_SIGNATURE = [ + "input", + "weight", + ("bias", None), + ("stride", 1), + ("padding", 0), + ("output_padding", 0), + ("groups", 1), + ("dilation", 1), +] +BATCH_NORM_FUNC_SIGNATURE = [ + "input", + "running_mean", + "running_var", + ("weight", None), + ("bias", None), + ("training", False), + ("momentum", 0.1), + ("eps", 1e-5), +] +GROUP_NORM_FUNC_SIGNATURE = [ + "input", + "num_groups", + ("weight", None), + ("bias", None), + ("eps", None), +] +LAYER_NORM_FUNC_SIGNATURE = [ + "input", + "normalized_shape", + ("weight", None), + ("bias", None), + ("eps", 1e-05), + ("cudnn_enable", True), +] + + +def _get_conv_attrs_from_args_kwargs(args: List[Any], kwargs: Dict[str, Any]) -> ConvolutionLayerAttributes: + args_dict = apply_args_defaults(args, kwargs, CONV_FUNC_SIGNATURE) + + kernel_size = tuple(args_dict["weight"].shape[2:]) + in_channels = args_dict["weight"].shape[1] * args_dict["groups"] + out_channels = args_dict["weight"].shape[0] + + return ConvolutionLayerAttributes( + weight_requires_grad=args_dict["weight"].requires_grad, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=args_dict["stride"], + dilations=args_dict["dilation"], + groups=args_dict["groups"], + transpose=False, + padding_values=args_dict["padding"], + with_bias=args_dict["bias"] is not None, + ) + + +def _get_conv_transpose_attrs_from_args_kwargs(args: List[Any], kwargs: Dict[str, Any]) -> ConvolutionLayerAttributes: + args_dict = apply_args_defaults(args, kwargs, CONV_TRANSPOSE_FUNC_SIGNATURE) + + kernel_size = tuple(args_dict["weight"].shape[2:]) + in_channels = args_dict["weight"].shape[0] + out_channels = args_dict["weight"].shape[1] * args_dict["groups"] + + return ConvolutionLayerAttributes( + weight_requires_grad=args_dict["weight"].requires_grad, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=args_dict["stride"], + dilations=args_dict["dilation"], + groups=args_dict["groups"], + transpose=True, + padding_values=args_dict["padding"], + with_bias=args_dict["bias"] is not None, + output_padding_values=args_dict["output_padding"], + ) + + +def _get_linear_attrs_from_args_kwargs(args, kwargs) -> LinearLayerAttributes: + args_dict = apply_args_defaults(args, kwargs, LINEAR_FUNC_SIGNATURE) + return LinearLayerAttributes( + weight_requires_grad=args_dict["weight"].requires_grad, + in_features=args_dict["weight"].shape[1], + out_features=args_dict["weight"].shape[0], + with_bias=args_dict["bias"] is not None, + ) + + +def _get_batchnorm_attrs_from_args_kwargs(args, kwargs): + args_dict = apply_args_defaults(args, kwargs, BATCH_NORM_FUNC_SIGNATURE) + return GenericWeightedLayerAttributes( + weight_requires_grad=False if args_dict["weight"] is None else args_dict["weight"].requires_grad, + weight_shape=[] if args_dict["weight"] is None else args_dict["weight"].shape, + filter_dimension_idx=0, + with_bias=args_dict["bias"] is not None, + ) + + +def _get_embedding_attrs_from_args_kwargs(args, kwargs): + args_dict = apply_args_defaults(args, kwargs, GENERIC_WEIGHT_FUNC_SIGNATURE) + return GenericWeightedLayerAttributes( + weight_requires_grad=args_dict["weight"].requires_grad, + weight_shape=args_dict["weight"].shape, + filter_dimension_idx=0, + with_bias=False, + ) + + +def _get_group_norm_attrs_from_args_kwargs(args, kwargs): + args_dict = apply_args_defaults(args, kwargs, GROUP_NORM_FUNC_SIGNATURE) + return GroupNormLayerAttributes( + weight_requires_grad=args_dict["weight"].requires_grad, + num_channels=args_dict["weight"].shape[0], + num_groups=args_dict["num_groups"], + ) + + +def _get_layer_norm_attrs_from_args_kwargs(args, kwargs): + args_dict = apply_args_defaults(args, kwargs, LAYER_NORM_FUNC_SIGNATURE) + return GenericWeightedLayerAttributes( + weight_requires_grad=args_dict["weight"].requires_grad, + weight_shape=args_dict["weight"].shape, + filter_dimension_idx=0, + with_bias=args_dict["bias"] is not None, + ) diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index 01fcff9ee46..6728d5cd904 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -15,16 +15,14 @@ import torch from torch.nn import DataParallel -import nncf from nncf.common.graph.definitions import MODEL_CONST_OP_NAME from nncf.common.graph.layer_attributes import BaseLayerAttributes +from nncf.common.graph.layer_attributes import WeightedLayerAttributes from nncf.common.logging import nncf_logger from nncf.common.utils.debug import is_debug from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.dynamic_graph.context import get_current_context -from nncf.torch.dynamic_graph.layer_attributes_handlers import OP_NAMES_REQUIRING_ATTRS_FROM_ARGS_KWARGS from nncf.torch.dynamic_graph.layer_attributes_handlers import get_layer_attributes_from_args_and_kwargs -from nncf.torch.dynamic_graph.layer_attributes_handlers import get_layer_attributes_from_module from nncf.torch.dynamic_graph.op_input_processing import OperatorInput from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.dynamic_graph.structs import NamespaceTarget @@ -202,21 +200,22 @@ def _execute_op( def _collect_module_attrs_and_ignored_algorithms( ctx: TracingContext, op_name: str, args, kwargs ) -> Tuple[BaseLayerAttributes, List[str]]: - layer_attrs = None ignored_algos = [] - from nncf.torch.graph.operator_metatypes import OP_NAMES_WITH_WEIGHTS + layer_attrs = get_layer_attributes_from_args_and_kwargs(op_name, args, kwargs) - if op_name in OP_NAMES_WITH_WEIGHTS: - curr_module = ctx.get_current_module() - if curr_module is None: - raise nncf.ValidationError( - f"Operation {op_name} requires module attributes, but it was executed outside any module" - ) - layer_attrs = get_layer_attributes_from_module(curr_module, op_name) + curr_module = ctx.get_current_module() + if curr_module is not None: if isinstance(curr_module, _NNCFModuleMixin): ignored_algos = deepcopy(curr_module.ignored_algorithms) - elif op_name in OP_NAMES_REQUIRING_ATTRS_FROM_ARGS_KWARGS: - layer_attrs = get_layer_attributes_from_args_and_kwargs(op_name, args, kwargs) + + if ( + isinstance(layer_attrs, WeightedLayerAttributes) + and hasattr(curr_module, "weight_g") + and hasattr(curr_module, "weight_v") + ): + # torch.nn.utils.weight_norm replaces weight with weight_g and weight_v + layer_attrs.weight_requires_grad = curr_module.weight_g.requires_grad + return layer_attrs, ignored_algos diff --git a/nncf/torch/extractor.py b/nncf/torch/extractor.py new file mode 100644 index 00000000000..a09b0e5280f --- /dev/null +++ b/nncf/torch/extractor.py @@ -0,0 +1,212 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from itertools import chain +from typing import Any, Dict, Iterable, List, Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm + +import nncf +from nncf import nncf_logger +from nncf.common.graph.graph import NNCFNode +from nncf.torch.graph import operator_metatypes as om +from nncf.torch.model_graph_manager import get_const_data +from nncf.torch.model_graph_manager import get_const_node +from nncf.torch.model_graph_manager import get_fake_quantizer +from nncf.torch.nncf_network import NNCFNetwork + +BATCH_NORM_CLASSES = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) + + +CONV_METATYPES = ( + om.PTConv1dMetatype, + om.PTConv2dMetatype, + om.PTConv3dMetatype, + om.PTDepthwiseConv1dSubtype, + om.PTDepthwiseConv2dSubtype, + om.PTDepthwiseConv3dSubtype, +) + +CONV_TRANSPOSE_METATYPES = ( + om.PTConvTranspose1dMetatype, + om.PTConvTranspose2dMetatype, + om.PTConvTranspose3dMetatype, +) + + +class ExtractedFunc(nn.Module): + def __init__( + self, + fn_name: str, + kwargs: Dict[str, Any], + ) -> None: + super().__init__() + self.fn_name = fn_name + self.kwargs = kwargs + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return getattr(F, self.fn_name)(input=x, **self.kwargs) + + +def extract_conv( + input_node: NNCFNode, + output_node: NNCFNode, + model: NNCFNetwork, +) -> ExtractedFunc: + """ + Extracts a convolutional layer from an NNCF graph and constructs an ExtractedFunc module. + + :param input_nodes: The name of input node. + :param output_nodes: The name of output node. + :param model: The NNCF network containing the layer. + :return: The extracted convolutional layer as an ExtractedFunc module. + """ + graph = model.nncf.get_graph() + weight_node = get_const_node(input_node, input_node.metatype.weight_port_ids[0], graph) + weight = get_const_data(weight_node, model) + w_fq = get_fake_quantizer(input_node, input_node.metatype.weight_port_ids[0], model) + bias_node = get_const_node(input_node, input_node.metatype.bias_port_id, graph) + bias = get_const_data(bias_node, model) if bias_node is not None else None + + with torch.no_grad(): + e_weight = w_fq(weight) if w_fq else weight + + if input_node.metatype in CONV_METATYPES: + kwargs = { + "weight": e_weight.clone(), + "bias": bias.clone() if bias is not None else bias, + "stride": input_node.layer_attributes.stride, + "padding": input_node.layer_attributes.padding_values, + "dilation": input_node.layer_attributes.dilations, + "groups": input_node.layer_attributes.groups, + } + elif input_node.metatype in CONV_TRANSPOSE_METATYPES: + kwargs = { + "weight": e_weight.clone(), + "bias": bias.clone() if bias is not None else bias, + "stride": input_node.layer_attributes.stride, + "padding": input_node.layer_attributes.padding_values, + "output_padding": input_node.layer_attributes.output_padding_values, + "dilation": input_node.layer_attributes.dilations, + } + + extracted_module = ExtractedFunc(input_node.node_type, kwargs) + + if input_node != output_node: + extracted_module = try_to_fuse_conv(input_node, output_node, model, extracted_module) + + return extracted_module + + +def _find_parent_class(cls: type, parent_classes: Iterable[type]) -> Optional[type]: + """ + Finds the first parent class of the given class that is present in the list of possible parent classes. + + :param cls: The class whose parent to find. + :param parent_classes: A list of potential parent classes. + :return: The first matching parent class, or None if no match is found. + """ + for exp_cls in parent_classes: + if issubclass(cls, exp_cls): + return exp_cls + return None + + +def extract_bn(node: NNCFNode, model: NNCFNetwork) -> Optional[Union[nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]]: + """ + Extract batch_norm operation. + If source modules inhered from nn.BatchNorm1d, nn.BatchNorm2d, or nn.BatchNorm3d return torch BatchNorm module. + + :param node: Target batch_norm node. + :param model: Source model. + :return: BatchNorm module with same attributes and parameters from source module or None. + """ + bn_module: _BatchNorm = model.nncf.get_containing_module(node.node_name) + bn_class = _find_parent_class(bn_module.__class__, BATCH_NORM_CLASSES) + if bn_class is None: + nncf_logger.debug(f"Module associated with {node} should be inhered from one of {BATCH_NORM_CLASSES}") + return None + + extracted_bn: _BatchNorm = bn_class( + num_features=bn_module.num_features, + eps=bn_module.eps, + momentum=bn_module.momentum, + affine=bn_module.affine, + track_running_stats=bn_module.track_running_stats, + device=bn_module.weight.device, + dtype=bn_module.weight.dtype, + ) + + # Copy named parameters and buffer that exists in native BatchNorm module from module in the module. + for name, _ in chain(extracted_bn.named_parameters(), extracted_bn.named_buffers()): + setattr(extracted_bn, name, deepcopy(getattr(bn_module, name))) + extracted_bn.eval() + return extracted_bn + + +def try_to_fuse_conv( + input_node: NNCFNode, output_node: NNCFNode, model: NNCFNetwork, extracted_module: nn.Module +) -> nn.Module: + """ + Fused convolution operation with the next batch norm node if possible, + + :param input_node: Input subgraph node. + :param output_node: Output subgraph node (fused with input node). + :param model: Source model. + :param extracted_module: Extracted module. + """ + next_nodes = model.nncf.get_graph().get_next_nodes(input_node) + + if len(next_nodes) != 1: + return extracted_module + + if output_node != next_nodes[0]: + raise nncf.InternalError(f"Output node {output_node} not found after {input_node}") + + if next_nodes[0].metatype != om.PTBatchNormMetatype: + raise nncf.InternalError("Supported only BatchNorm layers") + + extracted_bn = extract_bn(next_nodes[0], model) + if extracted_bn is None: + nncf_logger.debug( + f"Can`t extract fused batchnorm module for {input_node.node_name}," + " module that contain batchnorm operator should be inhered from one of {BATCH_NORM_CLASSES}." + ) + return None + return nn.Sequential(extracted_module, extracted_bn) + + +def extract_model(model: NNCFNetwork, input_nodes: List[str], output_nodes: List[str]) -> Optional[nn.Module]: + """ + Extracts a submodule from a given NNCF network containing only the nodes from the input to the output node. + + :param model: The NNCF network to extract the submodule from. + :param input_nodes: List containing names of the input nodes for the submodule. + :param output_nodes: List containing names of the output nodes for the submodule. + :return: An nn.Module containing the extracted submodel, or None if extraction is not supported. + """ + + if len(input_nodes) != 1 or len(output_nodes) != 1: + raise nncf.InternalError("input_nodes and output_nodes should contain only one node.") + + graph = model.nncf.get_graph() + input_node = graph.get_node_by_name(input_nodes[0]) + output_node = graph.get_node_by_name(output_nodes[0]) + + if input_node.metatype in CONV_METATYPES + CONV_TRANSPOSE_METATYPES: + return extract_conv(input_node, output_node, model) + + nncf_logger.debug(f"Can`t extract module for {input_node.node_name}") + return None diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index f4752d41c5e..28fcbea0c88 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -175,6 +175,7 @@ class PTDepthwiseConv1dSubtype(PTDepthwiseConvOperatorSubtype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -186,6 +187,7 @@ class PTModuleConv1dMetatype(PTModuleOperatorSubtype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -197,6 +199,7 @@ class PTConv1dMetatype(PTOperatorMetatype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -207,6 +210,7 @@ class PTDepthwiseConv2dSubtype(PTDepthwiseConvOperatorSubtype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -218,6 +222,7 @@ class PTModuleConv2dMetatype(PTModuleOperatorSubtype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -229,6 +234,7 @@ class PTConv2dMetatype(PTOperatorMetatype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -239,6 +245,7 @@ class PTDepthwiseConv3dSubtype(PTDepthwiseConvOperatorSubtype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -250,6 +257,7 @@ class PTModuleConv3dMetatype(PTModuleOperatorSubtype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -261,6 +269,7 @@ class PTConv3dMetatype(PTOperatorMetatype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -271,6 +280,7 @@ class PTModuleConvTranspose1dMetatype(PTModuleOperatorSubtype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -282,6 +292,7 @@ class PTConvTranspose1dMetatype(PTOperatorMetatype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -292,6 +303,7 @@ class PTModuleConvTranspose2dMetatype(PTModuleOperatorSubtype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -303,6 +315,7 @@ class PTConvTranspose2dMetatype(PTOperatorMetatype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -313,6 +326,7 @@ class PTModuleConvTranspose3dMetatype(PTModuleOperatorSubtype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -324,6 +338,7 @@ class PTConvTranspose3dMetatype(PTOperatorMetatype): output_channel_axis = 1 num_expected_input_edges = 2 weight_port_ids = [1] + bias_port_id = 2 @PT_OPERATOR_METATYPES.register() @@ -626,6 +641,8 @@ class PTBatchNormMetatype(PTOperatorMetatype): name = "BatchNormOp" module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]} subtypes = [PTModuleBatchNormMetatype] + weight_port_ids = [3] + bias_port_id = 4 @PT_OPERATOR_METATYPES.register() diff --git a/nncf/torch/model_graph_manager.py b/nncf/torch/model_graph_manager.py new file mode 100644 index 00000000000..83a0f02b6d4 --- /dev/null +++ b/nncf/torch/model_graph_manager.py @@ -0,0 +1,301 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch + +import nncf +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES +from nncf.torch.dynamic_graph.context import PreHookId +from nncf.torch.graph import operator_metatypes as om +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook +from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import SymmetricQuantizer + +CONV_META_TYPES = [ + om.PTConv1dMetatype, + om.PTConv2dMetatype, + om.PTConv3dMetatype, + om.PTDepthwiseConv1dSubtype, + om.PTDepthwiseConv2dSubtype, + om.PTDepthwiseConv3dSubtype, + om.PTConvTranspose1dMetatype, + om.PTConvTranspose2dMetatype, + om.PTConvTranspose3dMetatype, +] + +OPERATORS_WITH_BIAS_METATYPES = CONV_META_TYPES +CONV_FUSED_META_TYPES = [om.PTBatchNormMetatype] + + +def find_const_node_in_constant_subgraph(node: NNCFNode, graph: NNCFGraph) -> Optional[NNCFNode]: + """ + Finds a constant node within a constant subgraph, recursively traversing noop and quantize nodes. + + :param node: The starting node to search from. + :param graph: The NNCFGraph. + :return: The constant node found within the subgraph, or None if no constant node is found. + """ + if node.metatype == om.PTNoopMetatype or node.node_type in om.QUANTIZE_NODE_TYPES: + prev_nodes = graph.get_previous_nodes(node) + if len(prev_nodes) != 1: + return None + return find_const_node_in_constant_subgraph(prev_nodes[0], graph) + if node.metatype in CONST_NOOP_METATYPES: + return node + return None + + +def get_const_node(node: NNCFNode, port_id: int, graph: NNCFGraph) -> Optional[NNCFNode]: + """ + Retrieves the constant node providing the input to a specific port of a given node in the NNCF graph. + + :param node: The NNCF node for which to find the constant input node. + :param port_id: The ID of the input port to consider. + :param graph: The NNCF graph containing the nodes. + :return: The NNCF node providing the constant input to the specified port, or None if no such node is found. + """ + for prev_node in graph.get_previous_nodes(node): + edge = graph.get_edge(prev_node, node) + if edge.input_port_id == port_id: + weight_node = find_const_node_in_constant_subgraph(prev_node, graph) + if weight_node is None: + raise nncf.InternalError("Could not find a constant node in the model graph.") + return weight_node + + +def split_const_name(const_name: str) -> Tuple[str, str]: + """ + Splits the constant name into module and attribute names. + + :param const_name: The full name of the constant, including module and attribute. + :return: + - module_name: The name of the module containing the constant. + - weight_attr_name: The name of the constant attribute within the module. + """ + index = const_name.rfind(".") + if index == -1: + return str(), const_name + module_name = const_name[:index] + weight_attr_name = const_name[index + 1 :] + return module_name, weight_attr_name + + +def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Module: + """ + Retrieves a module from a PyTorch model by its hierarchical name. + + :param module_name: The name of the module to retrieve (e.g., "module1.submodule2"). + :param model: The model to search within. + :return: The retrieved module. + """ + if not module_name: + return model + curr_module = model + for name in module_name.split("."): + for child_name, child_module in curr_module.named_children(): + if child_name == name: + curr_module = child_module + break + else: + raise nncf.ModuleNotFoundError(f"Could not find the {module_name} module in the model.") + return curr_module + + +def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor: + """ + Retrieves a constant tensor associated with a given node. + + :param const_node: The node associated with const data. + :param model: The NNCFNetwork object. + :return: A torch.Tensor object containing the constant value. + """ + const_name = const_node.layer_attributes.name + module_name, const_attr_name = split_const_name(const_name) + module = get_module_by_name(module_name, model) + data = getattr(module, const_attr_name) + if isinstance(data, torch.nn.Parameter): + return data.data + return data + + +def get_const_data_on_port(node: NNCFNode, port_id: int, model: NNCFNetwork) -> torch.Tensor: + """ + Retrieves a constant tensor associated with a given node and input port in an NNCF graph. + + :param node: The node to retrieve the constant from. + :param port_id: The port id within the node that holds the constant. + :param model: The NNCFNetwork object. + :return: A torch.Tensor object containing the constant value, or None if the constant is not found. + """ + graph = model.nncf.get_graph() + const_node = get_const_node(node, port_id, graph) + if const_node is None: + return None + return get_const_data(const_node, model) + + +def get_potential_fused_node(node_name: str, nncf_graph: NNCFGraph) -> Optional[NNCFNode]: + """ + Retrieves the next node in the NNCF graph that could be fused with the provided node during runtime optimization. + + :param node_name: The node name. + :param nncf_graph: The NNCF graph. + :return: The node that can be fused or None if no suitable node is found. + """ + target_node = nncf_graph.get_node_by_name(node_name) + + if target_node.metatype in CONV_META_TYPES: + next_nodes = nncf_graph.get_next_nodes(target_node) + for node in next_nodes: + if node.metatype in CONV_FUSED_META_TYPES: + return node + return None + + +def is_node_with_fused_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + """ + Checks if the node has a fused bias. + + :param node: The node to check. + :param nncf_graph: The NNCF graph. + :return: Return `True` if `node` corresponds to the operation + with bias (bias is added to the output tensor of that operation), + `False` otherwise. + """ + if node.metatype not in OPERATORS_WITH_BIAS_METATYPES: + return False + fused_node = get_potential_fused_node(node.node_name, nncf_graph) + if fused_node is not None: + node = fused_node + bias_port = node.metatype.bias_port_id + bias = get_const_node(node, bias_port, nncf_graph) + return bias is not None + + +def get_fused_bias_value(node: NNCFNode, model: NNCFNetwork) -> Optional[torch.Tensor]: + """ + Returns the bias tensor for the node or for potential fused node. + + :param node: The node that corresponds to the operation with bias. + :param model: The model that contains this operation. + :return: The bias value that is applied to the output tensor of the node's operation. + """ + nncf_graph = model.nncf.get_graph() + fused_node = get_potential_fused_node(node.node_name, nncf_graph) + target_node_name = fused_node.node_name if fused_node else node.node_name + target_node = nncf_graph.get_node_by_name(target_node_name) + return get_const_data_on_port(target_node, target_node.metatype.bias_port_id, model) + + +def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[int]: + """ + Returns list of input port ids that contains traced constant tensor. + + :param node: Target node that contains weights. + :param graph: The NNCF graph. + :return: List of ports with weights. + """ + weight_port_ids = [] + for edge in graph.get_input_edges(node): + if edge.input_port_id in node.metatype.weight_port_ids: + weight_node = find_const_node_in_constant_subgraph(edge.from_node, graph) + if weight_node: + weight_port_ids.append(edge.input_port_id) + return weight_port_ids + + +def set_const_data(data: torch.Tensor, const_node: NNCFNode, model: NNCFNetwork) -> None: + """ + Sets the constant data associated with a specific constant node in an NNCF network model. + + :param data: The constant data tensor to be set. + :param const_node: The NNCF node representing the constant data. + :param model: The NNCF network model. + """ + const_name = const_node.layer_attributes.name + module_name, const_attr_name = split_const_name(const_name) + module = get_module_by_name(module_name, model) + const = getattr(module, const_attr_name) + if isinstance(const, torch.nn.Parameter): + const.data = data + else: + setattr(module, const_attr_name, data) + + +def set_const_data_to_port_id(data: torch.Tensor, node: NNCFNode, port_id: int, model: NNCFNetwork) -> None: + """ + Sets the value of a constant tensor within a specified node in an NNCFNetwork. + + :param data: The tensor containing the new value to be set for the constant. + :param node: The NNCF node representing the operation that uses the constant. + :param const_port_id: The input port id of the node that receives the constant. + :param model: The NNCF network containing the module to be modified. + """ + graph = model.nncf.get_graph() + const_node = get_const_node(node, port_id, graph) + if const_node is None: + raise nncf.InternalError(f"No found node with constant for {node.node_name} on {port_id} port") + const_name = const_node.layer_attributes.name + module_name, const_attr_name = split_const_name(const_name) + module = get_module_by_name(module_name, model) + const = getattr(module, const_attr_name) + if isinstance(const, torch.nn.Parameter): + const.data = data + else: + setattr(module, const_attr_name, data) + + +def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + """ + Check that module have fake_quantizer for its weights (supports only metatypes with only one weight port). + + :param node: The target node. + :param nncf_graph: The NNCF graph. + :return bool: return `True` if the node is quantized. + """ + assert len(node.metatype.weight_port_ids) == 1, "Support only metatype with only 1 weighted port" + for edge in nncf_graph.get_input_edges(node): + if edge.input_port_id in node.metatype.weight_port_ids and edge.from_node.node_type in om.QUANTIZE_NODE_TYPES: + return True + return False + + +def get_fake_quantizer( + node: NNCFNode, port_id: Optional[int], model: NNCFNetwork +) -> Union[SymmetricQuantizer, AsymmetricQuantizer]: + """ + Retrieves the fake quantizer associated with a specific node and input port id. + + :param node: The NNCFNode representing the node for which to retrieve the quantizer. + :param port_id: The port id number for which to retrieve the quantizer module, None means output port. + :param model: The NNCFNetwork instance. + :return: Fake Quantizer module if exists, overwise None. + """ + + address_map = model.nncf.get_node_to_op_address_mapping() + op_addr = address_map[node.node_name] + + if port_id is not None: + id = PreHookId(op_address=op_addr, input_port_id=port_id) + hook_container = model.nncf._compressed_context._pre_hooks.get(id, {}) + else: + hook_container = model.nncf._compressed_context._post_hooks.get(op_addr, {}) + + for call_hook in hook_container.values(): + if isinstance(call_hook, ExternalQuantizerCallHook): + storage = getattr(model.nncf, call_hook._storage_name) + return storage[call_hook._storage_key] + return None diff --git a/tests/post_training/test_templates/helpers.py b/tests/post_training/test_templates/helpers.py index 2c5f4975014..885eb6c4263 100644 --- a/tests/post_training/test_templates/helpers.py +++ b/tests/post_training/test_templates/helpers.py @@ -92,7 +92,7 @@ def forward(self, x): return x -class BiasConvBiasBNTestModel(torch.nn.Module): +class ConvBiasBNTestModel(torch.nn.Module): INPUT_SIZE = [1, 1, 4, 4] def __init__(self): @@ -110,6 +110,49 @@ def forward(self, x): return x +class CustomConv(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.Tensor([[[[0.1, -2.0], [1.0, 0.1]]], [[[0.1, 2.0], [-1.0, 0.1]]]])) + self.bias = nn.Parameter(torch.Tensor([0.1, 1.0])) + self.act = nn.Identity() + + def forward(self, x): + return self.act(F.conv2d(x, self.weight, self.bias)) + + +class CustomConvTestModel(nn.Module): + INPUT_SIZE = [1, 1, 4, 4] + + def __init__(self): + super().__init__() + self.conv = CustomConv() + self.drop = nn.Dropout(0) + + def forward(self, x): + return self.drop(self.conv(x)) + + +class CustomBN2d(nn.BatchNorm2d): + def __init__(self): + super().__init__(2) + self.bias.data = torch.Tensor([0.1, 1.0]) + self.weight.data = torch.Tensor([0.2, 2.0]) + self.act = nn.Identity() + + +class CustomConvBNTestModel(nn.Module): + INPUT_SIZE = [1, 1, 4, 4] + + def __init__(self): + super().__init__() + self.conv = CustomConv() + self.bn = CustomBN2d() + + def forward(self, x): + return self.bn(self.conv(x)) + + class FCTestModel(nn.Module): INPUT_SIZE = [1, 1, 4, 4] diff --git a/tests/torch/test_extractor.py b/tests/torch/test_extractor.py new file mode 100644 index 00000000000..b9ba7858d66 --- /dev/null +++ b/tests/torch/test_extractor.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from torch import nn + +import tests.post_training.test_templates.helpers as helpers +from nncf.common.graph.transformations.commands import TargetType +from nncf.torch import wrap_model +from nncf.torch.extractor import extract_model +from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.model_transformer import PTModelTransformer +from nncf.torch.model_transformer import PTTransformationLayout +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import QuantizationMode +from nncf.torch.quantization.layers import SymmetricQuantizer + + +@pytest.mark.parametrize( + "model_cls, input_node_name, output_node_name", + ( + ( + helpers.ConvBiasBNTestModel, + "ConvBiasBNTestModel/Conv2d[conv]/conv2d_0", + "ConvBiasBNTestModel/BatchNorm2d[bn]/batch_norm_0", + ), + ( + helpers.ConvBNTestModel, + "ConvBNTestModel/Conv2d[conv]/conv2d_0", + "ConvBNTestModel/BatchNorm2d[bn]/batch_norm_0", + ), + ( + helpers.ConvTestModel, + "ConvTestModel/Conv2d[conv]/conv2d_0", + "ConvTestModel/Conv2d[conv]/conv2d_0", + ), + ( + helpers.CustomConvBNTestModel, + "CustomConvBNTestModel/CustomConv[conv]/conv2d_0", + "CustomConvBNTestModel/CustomBN2d[bn]/batch_norm_0", + ), + ( + helpers.CustomConvTestModel, + "CustomConvTestModel/CustomConv[conv]/conv2d_0", + "CustomConvTestModel/CustomConv[conv]/conv2d_0", + ), + ), +) +def test_extract_model(model_cls, input_node_name, output_node_name): + example_input = torch.ones(model_cls.INPUT_SIZE) + + model = wrap_model(model_cls().eval(), example_input=example_input, trace_parameters=True) + extracted_module = extract_model(model, [input_node_name], [output_node_name]) + with torch.no_grad(): + ret1 = model(example_input) + ret2 = extracted_module(example_input) + assert torch.any(torch.isclose(ret1, ret2)) + + +@pytest.mark.parametrize( + "model_cls, input_node_name, output_node_name", + ( + ( + helpers.ConvBiasBNTestModel, + "ConvBiasBNTestModel/Conv2d[conv]/conv2d_0", + "ConvBiasBNTestModel/BatchNorm2d[bn]/batch_norm_0", + ), + ( + helpers.ConvBNTestModel, + "ConvBNTestModel/Conv2d[conv]/conv2d_0", + "ConvBNTestModel/BatchNorm2d[bn]/batch_norm_0", + ), + ( + helpers.ConvTestModel, + "ConvTestModel/Conv2d[conv]/conv2d_0", + "ConvTestModel/Conv2d[conv]/conv2d_0", + ), + ( + helpers.CustomConvBNTestModel, + "CustomConvBNTestModel/CustomConv[conv]/conv2d_0", + "CustomConvBNTestModel/CustomBN2d[bn]/batch_norm_0", + ), + ( + helpers.CustomConvTestModel, + "CustomConvTestModel/CustomConv[conv]/conv2d_0", + "CustomConvTestModel/CustomConv[conv]/conv2d_0", + ), + ), +) +def tes_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_name): + example_input = torch.ones(model_cls.INPUT_SIZE) + + model = wrap_model(model_cls().eval(), example_input=example_input, trace_parameters=True) + + transformer = PTModelTransformer(model) + qspec = PTQuantizerSpec( + num_bits=8, + mode=QuantizationMode.SYMMETRIC, + signedness_to_force=None, + scale_shape=(1,), + narrow_range=False, + half_range=False, + logarithm_scale=False, + ) + + fq = SymmetricQuantizer(qspec) + command = PTQuantizerInsertionCommand( + PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, input_node_name, input_port_id=1), fq + ) + layout = PTTransformationLayout() + layout.register(command) + q_model = transformer.transform(layout) + + extracted_module = extract_model(model, [input_node_name], [output_node_name]) + with torch.no_grad(): + ret1 = q_model(example_input) + ret2 = extracted_module(example_input) + assert torch.any(torch.isclose(ret1, ret2)) + + if isinstance(extracted_module, nn.Sequential): + assert extracted_module[0].w_fq is not None + else: + assert extracted_module.w_fq is not None diff --git a/tests/torch/test_layer_attributes.py b/tests/torch/test_layer_attributes.py index 018d9a6dbef..bfcae281ef3 100644 --- a/tests/torch/test_layer_attributes.py +++ b/tests/torch/test_layer_attributes.py @@ -26,6 +26,7 @@ from nncf.common.graph.layer_attributes import ReshapeLayerAttributes from nncf.common.graph.layer_attributes import TransposeLayerAttributes from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.torch import wrap_model from nncf.torch.dynamic_graph.graph_tracer import create_dummy_forward_fn from nncf.torch.dynamic_graph.io_handling import FillerInputElement from nncf.torch.dynamic_graph.io_handling import FillerInputInfo @@ -44,6 +45,7 @@ from nncf.torch.graph.operator_metatypes import PTGatherMetatype from nncf.torch.graph.operator_metatypes import PTGroupNormMetatype from nncf.torch.graph.operator_metatypes import PTInputNoopMetatype +from nncf.torch.graph.operator_metatypes import PTLayerNormMetatype from nncf.torch.graph.operator_metatypes import PTLinearMetatype from nncf.torch.graph.operator_metatypes import PTOutputNoopMetatype from nncf.torch.graph.operator_metatypes import PTReshapeMetatype @@ -103,13 +105,13 @@ def default_comparator(first_attr: BaseLayerAttributes, second_attr: BaseLayerAt class LayerAttributesTestDesc: def __init__( self, - module: nn.Module, + module_fn: nn.Module, model_input_info: ModelInputInfo, layer_attributes: BaseLayerAttributes, metatype_cls: Type[OperatorMetatype], layer_attributes_comparator: COMPARATOR_TYPE = default_comparator, ): - self.module = module + self.module_fn = module_fn self.layer_attributes = layer_attributes self.model_input_info = model_input_info self.metatype_cls = metatype_cls @@ -124,35 +126,35 @@ def __str__(self): ) LIST_TEST_DESCS = [ LayerAttributesTestDesc( - module=nn.GroupNorm(1, 2), + module_fn=lambda: nn.GroupNorm(1, 2), model_input_info=FillerInputInfo([FillerInputElement([1, 2, 1, 1])]), layer_attributes=GroupNormLayerAttributes(weight_requires_grad=True, num_channels=2, num_groups=1), metatype_cls=PTGroupNormMetatype, ), LayerAttributesTestDesc( - module=nn.BatchNorm2d(1), + module_fn=lambda: nn.BatchNorm2d(1), model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1, 1])]), layer_attributes=BATCH_NORM_REF_ATTR, metatype_cls=PTBatchNormMetatype, ), LayerAttributesTestDesc( - module=nn.BatchNorm3d(1), + module_fn=lambda: nn.BatchNorm3d(1), model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1, 1, 1])]), layer_attributes=BATCH_NORM_REF_ATTR, metatype_cls=PTBatchNormMetatype, ), LayerAttributesTestDesc( - module=nn.BatchNorm1d(1), + module_fn=lambda: nn.BatchNorm1d(1), model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1])]), layer_attributes=BATCH_NORM_REF_ATTR, metatype_cls=PTBatchNormMetatype, ), LayerAttributesTestDesc( - module=nn.Conv2d(1, 1, 1), - model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1, 1])]), + module_fn=lambda: nn.Conv2d(2, 1, 1), + model_input_info=FillerInputInfo([FillerInputElement([1, 2, 1, 1])]), layer_attributes=ConvolutionLayerAttributes( weight_requires_grad=True, - in_channels=1, + in_channels=2, out_channels=1, kernel_size=(1, 1), stride=(1, 1), @@ -165,12 +167,12 @@ def __str__(self): metatype_cls=PTConv2dMetatype, ), LayerAttributesTestDesc( - module=nn.Conv2d(2, 2, 1, groups=2), + module_fn=lambda: nn.Conv2d(2, 4, 1, groups=2), model_input_info=FillerInputInfo([FillerInputElement([1, 2, 1, 1])]), layer_attributes=ConvolutionLayerAttributes( weight_requires_grad=True, in_channels=2, - out_channels=2, + out_channels=4, kernel_size=(1, 1), stride=(1, 1), dilations=(1, 1), @@ -182,12 +184,12 @@ def __str__(self): metatype_cls=PTConv2dMetatype, ), LayerAttributesTestDesc( - module=nn.Conv1d(1, 1, 1), + module_fn=lambda: nn.Conv1d(1, 2, 1), model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1])]), layer_attributes=ConvolutionLayerAttributes( weight_requires_grad=True, in_channels=1, - out_channels=1, + out_channels=2, kernel_size=(1,), stride=(1,), dilations=(1,), @@ -199,12 +201,12 @@ def __str__(self): metatype_cls=PTConv1dMetatype, ), LayerAttributesTestDesc( - module=nn.Conv3d(1, 1, 1), + module_fn=lambda: nn.Conv3d(1, 2, 1), model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1, 1, 1])]), layer_attributes=ConvolutionLayerAttributes( weight_requires_grad=True, in_channels=1, - out_channels=1, + out_channels=2, kernel_size=(1, 1, 1), stride=(1, 1, 1), dilations=(1, 1, 1), @@ -216,12 +218,12 @@ def __str__(self): metatype_cls=PTConv3dMetatype, ), LayerAttributesTestDesc( - module=nn.ConvTranspose1d(1, 1, 1), + module_fn=lambda: nn.ConvTranspose1d(1, 2, 1), model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1])]), layer_attributes=ConvolutionLayerAttributes( weight_requires_grad=True, in_channels=1, - out_channels=1, + out_channels=2, kernel_size=(1,), stride=(1,), dilations=(1,), @@ -229,16 +231,17 @@ def __str__(self): transpose=True, padding_values=(0,), with_bias=True, + output_padding_values=(0,), ), metatype_cls=PTConvTranspose1dMetatype, ), LayerAttributesTestDesc( - module=nn.ConvTranspose2d(1, 1, 1), + module_fn=lambda: nn.ConvTranspose2d(1, 2, 1), model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1, 1])]), layer_attributes=ConvolutionLayerAttributes( weight_requires_grad=True, in_channels=1, - out_channels=1, + out_channels=2, kernel_size=(1, 1), stride=(1, 1), dilations=(1, 1), @@ -246,16 +249,35 @@ def __str__(self): transpose=True, padding_values=(0, 0), with_bias=True, + output_padding_values=(0, 0), + ), + metatype_cls=PTConvTranspose2dMetatype, + ), + LayerAttributesTestDesc( + module_fn=lambda: nn.ConvTranspose2d(2, 4, 1, groups=2), + model_input_info=FillerInputInfo([FillerInputElement([1, 2, 1, 1])]), + layer_attributes=ConvolutionLayerAttributes( + weight_requires_grad=True, + in_channels=2, + out_channels=4, + kernel_size=(1, 1), + stride=(1, 1), + dilations=(1, 1), + groups=2, + transpose=True, + padding_values=(0, 0), + with_bias=True, + output_padding_values=(0, 0), ), metatype_cls=PTConvTranspose2dMetatype, ), LayerAttributesTestDesc( - module=nn.ConvTranspose3d(1, 1, 1), + module_fn=lambda: nn.ConvTranspose3d(1, 2, 1), model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1, 1, 1])]), layer_attributes=ConvolutionLayerAttributes( weight_requires_grad=True, in_channels=1, - out_channels=1, + out_channels=2, kernel_size=(1, 1, 1), stride=(1, 1, 1), dilations=(1, 1, 1), @@ -263,17 +285,18 @@ def __str__(self): transpose=True, padding_values=(0, 0, 0), with_bias=True, + output_padding_values=(0, 0, 0), ), metatype_cls=PTConvTranspose3dMetatype, ), LayerAttributesTestDesc( - module=nn.Linear(1, 1), + module_fn=lambda: nn.Linear(1, 1), model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1, 1])]), layer_attributes=LinearLayerAttributes(weight_requires_grad=True, in_features=1, out_features=1), metatype_cls=PTLinearMetatype, ), LayerAttributesTestDesc( - module=nn.Linear(1, 1, bias=False), + module_fn=lambda: nn.Linear(1, 1, bias=False), model_input_info=FillerInputInfo([FillerInputElement([1, 1, 1, 1])]), layer_attributes=LinearLayerAttributes( weight_requires_grad=True, in_features=1, out_features=1, with_bias=False @@ -281,7 +304,7 @@ def __str__(self): metatype_cls=PTLinearMetatype, ), LayerAttributesTestDesc( - module=nn.Embedding(2, 1), + module_fn=lambda: nn.Embedding(2, 1), model_input_info=FillerInputInfo([FillerInputElement([1, 1], type_str="long")]), layer_attributes=GenericWeightedLayerAttributes( weight_requires_grad=True, weight_shape=Size([2, 1]), filter_dimension_idx=0 @@ -289,19 +312,27 @@ def __str__(self): metatype_cls=PTEmbeddingMetatype, ), LayerAttributesTestDesc( - module=nn.EmbeddingBag(1, 1), + module_fn=lambda: nn.EmbeddingBag(1, 1), model_input_info=FillerInputInfo([FillerInputElement([1, 1], type_str="long", filler="zeros")]), layer_attributes=GenericWeightedLayerAttributes( weight_requires_grad=True, weight_shape=Size([1, 1]), filter_dimension_idx=0 ), metatype_cls=PTEmbeddingBagMetatype, ), + LayerAttributesTestDesc( + module_fn=lambda: nn.LayerNorm(1, 1), + model_input_info=FillerInputInfo([FillerInputElement([1, 1])]), + layer_attributes=GenericWeightedLayerAttributes( + weight_requires_grad=True, weight_shape=Size([1]), filter_dimension_idx=0, with_bias=True + ), + metatype_cls=PTLayerNormMetatype, + ), ] @pytest.mark.parametrize("desc", LIST_TEST_DESCS, ids=map(str, LIST_TEST_DESCS)) def test_can_set_valid_layer_attributes(desc: LayerAttributesTestDesc): - single_layer_model = desc.module + single_layer_model = desc.module_fn() nncf_network = NNCFNetwork(single_layer_model, desc.model_input_info) @@ -507,3 +538,14 @@ def test_getitem_attributes(input_shape): else: assert node.layer_attributes is None assert getitem_nodes_with_attributes[node.node_name] is None + + +@pytest.mark.parametrize("desc", LIST_TEST_DESCS, ids=map(str, LIST_TEST_DESCS)) +def test_can_set_valid_layer_attributes_wrap_model(desc: LayerAttributesTestDesc): + nncf_network = wrap_model(desc.module_fn(), desc.model_input_info.get_forward_inputs()[0], trace_parameters=True) + graph = nncf_network.nncf.get_graph() + ref_values = [RefNodeDesc(desc.metatype_cls, desc.layer_attributes, desc.layer_attributes_comparator)] + actual_values = [ + RefNodeDesc(node.metatype, node.layer_attributes) for node in graph.get_nodes_by_metatypes([desc.metatype_cls]) + ] + assert ref_values == actual_values diff --git a/tests/torch/test_model_analyzer.py b/tests/torch/test_model_analyzer.py index 1c36292b9ee..6e4e9b14fe5 100644 --- a/tests/torch/test_model_analyzer.py +++ b/tests/torch/test_model_analyzer.py @@ -15,7 +15,7 @@ from nncf.torch import wrap_model from nncf.torch.model_analyzer import get_fused_bias_value from nncf.torch.model_transformer import update_fused_bias -from tests.post_training.test_templates.helpers import BiasConvBiasBNTestModel +from tests.post_training.test_templates.helpers import ConvBiasBNTestModel from tests.post_training.test_templates.helpers import ConvBNTestModel from tests.post_training.test_templates.helpers import ConvTestModel @@ -25,7 +25,7 @@ ( (ConvTestModel, [0.1000, 1.0000]), # conv.bias (ConvBNTestModel, [0.1000, 1.0000]), # bn.bias - (BiasConvBiasBNTestModel, [0.1600, 3.6000]), # conv.bias*bn.weight + bn.bias + (ConvBiasBNTestModel, [0.1600, 3.6000]), # conv.bias*bn.weight + bn.bias ), ) def test_get_fused_bias_value(model_cls, ref): @@ -43,7 +43,7 @@ def test_get_fused_bias_value(model_cls, ref): ( (ConvTestModel), # conv.bias (ConvBNTestModel), # bn.bias - (BiasConvBiasBNTestModel), # conv.bias*bn.weight + bn.bias + (ConvBiasBNTestModel), # conv.bias*bn.weight + bn.bias ), ) def test_update_fused_bias(model_cls): @@ -61,7 +61,7 @@ def test_update_fused_bias(model_cls): if model_cls == ConvBNTestModel: assert model.conv.bias is None assert torch.all(torch.isclose(model.bn.bias, ref_new_bias)) - if model_cls == BiasConvBiasBNTestModel: + if model_cls == ConvBiasBNTestModel: assert torch.all(torch.isclose(model.conv.bias, torch.tensor([0.3000, 1.3000]))) assert torch.all(torch.isclose(model.bn.bias, torch.tensor([-1.0600, -3.6000]))) assert torch.all(torch.isclose(model.conv.bias * model.bn.weight + model.bn.bias, ref_new_bias)) diff --git a/tests/torch/test_model_graph_manager.py b/tests/torch/test_model_graph_manager.py new file mode 100644 index 00000000000..89c21c4b883 --- /dev/null +++ b/tests/torch/test_model_graph_manager.py @@ -0,0 +1,313 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Tuple + +import pytest +import torch +from torch import nn + +import tests.post_training.test_templates.helpers as helpers +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.transformations.commands import TargetType +from nncf.torch import wrap_model +from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.model_graph_manager import get_const_data +from nncf.torch.model_graph_manager import get_const_data_on_port +from nncf.torch.model_graph_manager import get_const_node +from nncf.torch.model_graph_manager import get_fake_quantizer +from nncf.torch.model_graph_manager import get_module_by_name +from nncf.torch.model_graph_manager import get_potential_fused_node +from nncf.torch.model_graph_manager import get_weight_tensor_port_ids +from nncf.torch.model_graph_manager import is_node_with_fused_bias +from nncf.torch.model_graph_manager import is_quantized_weights +from nncf.torch.model_graph_manager import set_const_data +from nncf.torch.model_graph_manager import split_const_name +from nncf.torch.model_transformer import PTModelTransformer +from nncf.torch.model_transformer import PTTransformationLayout +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import QuantizationMode +from nncf.torch.quantization.layers import SymmetricQuantizer +from tests.torch.helpers import create_conv + + +@dataclass +class ModelDesc: + model: NNCFNetwork + graph: NNCFGraph + node: NNCFNode + node_name: str + model_name: str + + def __init__(self, model_cls: type, node_name: str): + self.model = wrap_model(model_cls(), example_input=torch.ones(model_cls.INPUT_SIZE), trace_parameters=True) + self.graph = self.model.nncf.get_graph() + self.node = self.graph.get_node_by_name(node_name) + self.node_name = node_name + + +MODELS_LIST = [ + "ConvBiasBNTestModel", + "ConvBNTestModel", + "ConvTestModel", + "FCTestModel", + "MultipleConvTestModel", + "CustomConvTestModel", + "CustomConvBNTestModel", +] + + +class TestManagerForOriginalModels: + @pytest.fixture(autouse=True, scope="function") + def init_models(self): + self.models: Dict[str, ModelDesc] = { + "ConvBiasBNTestModel": ModelDesc(helpers.ConvBiasBNTestModel, "ConvBiasBNTestModel/Conv2d[conv]/conv2d_0"), + "ConvBNTestModel": ModelDesc(helpers.ConvBNTestModel, "ConvBNTestModel/Conv2d[conv]/conv2d_0"), + "ConvTestModel": ModelDesc(helpers.ConvTestModel, "ConvTestModel/Conv2d[conv]/conv2d_0"), + "FCTestModel": ModelDesc(helpers.FCTestModel, "FCTestModel/Linear[fc]/linear_0"), + "MultipleConvTestModel": ModelDesc( + helpers.MultipleConvTestModel, "MultipleConvTestModel/Conv2d[conv_1]/conv2d_0" + ), + "CustomConvTestModel": ModelDesc( + helpers.CustomConvTestModel, "CustomConvTestModel/CustomConv[conv]/conv2d_0" + ), + "CustomConvBNTestModel": ModelDesc( + helpers.CustomConvBNTestModel, "CustomConvBNTestModel/CustomConv[conv]/conv2d_0" + ), + } + + REF_FUSED_NODE = { + "ConvBiasBNTestModel": "ConvBiasBNTestModel/BatchNorm2d[bn]/batch_norm_0", + "ConvBNTestModel": "ConvBNTestModel/BatchNorm2d[bn]/batch_norm_0", + "ConvTestModel": None, + "FCTestModel": None, + "MultipleConvTestModel": None, + "CustomConvTestModel": None, + "CustomConvBNTestModel": "CustomConvBNTestModel/CustomBN2d[bn]/batch_norm_0", + } + + @pytest.fixture(params=MODELS_LIST) + def model_desc(self, request) -> Tuple[str, ModelDesc]: + return request.param, self.models[request.param] + + def test_get_potential_fused_node(self, model_desc): + model_name, desc = model_desc + ref = self.REF_FUSED_NODE[model_name] + fused_node = get_potential_fused_node(desc.node_name, desc.graph) + result = fused_node.node_name if fused_node is not None else fused_node + assert result == ref + + REF_WITH_FUSED_BIAS = { + "ConvBiasBNTestModel": True, + "ConvBNTestModel": True, + "ConvTestModel": True, + "FCTestModel": False, + "MultipleConvTestModel": True, + "CustomConvTestModel": True, + "CustomConvBNTestModel": True, + } + + def test_is_node_with_fused_bias(self, model_desc): + model_name, desc = model_desc + ref = bool(self.REF_WITH_FUSED_BIAS[model_name]) + result = is_node_with_fused_bias(desc.node, desc.graph) + print(model_name, result) + assert result == ref + + REF_GET_CONST_NODE = { + "ConvBiasBNTestModel": ("conv.weight", "conv.bias"), + "ConvBNTestModel": ("conv.weight", None), + "ConvTestModel": ("conv.weight", "conv.bias"), + "FCTestModel": ("fc.weight", "fc.bias"), + "MultipleConvTestModel": ("conv_1.weight", "conv_1.bias"), + "CustomConvTestModel": ("conv.weight", "conv.bias"), + "CustomConvBNTestModel": ("conv.weight", "conv.bias"), + } + + @pytest.mark.parametrize("port_id", (1, 2)) + def test_get_const_node(self, model_desc, port_id): + model_name, desc = model_desc + const_node = get_const_node(desc.node, port_id, desc.graph) + ref = self.REF_GET_CONST_NODE[model_name][port_id - 1] + result = const_node.node_name if const_node is not None else const_node + assert result == ref + + REF_GET_CONST_DATA = { + "ConvBiasBNTestModel": ( + [[[[0.1000, -2.0000], [1.0000, 0.1000]]], [[[0.1000, 2.0000], [-1.0000, 0.1000]]]], + [0.3000, 1.3000], + ), + "ConvBNTestModel": ([[[[0.1000, -2.0000], [1.0000, 0.1000]]], [[[0.1000, 2.0000], [-1.0000, 0.1000]]]], None), + "ConvTestModel": ( + [[[[0.1000, -2.0000], [1.0000, 0.1000]]], [[[0.1000, 2.0000], [-1.0000, 0.1000]]]], + [0.1000, 1.0000], + ), + "FCTestModel": ([[0.1000, 0.2000, 0.3000, 0.2000], [0.3000, -0.1000, 0.2000, 0.4000]], [1.0000, 1.1000]), + "MultipleConvTestModel": ( + [[[[-2.4661, 0.3623], [0.3765, -0.1808]]], [[[0.3930, 0.4327], [-1.3627, 1.3564]]]], + [0.6688, -0.7077], + ), + "CustomConvTestModel": ( + [[[[0.1000, -2.0000], [1.0000, 0.1000]]], [[[0.1000, 2.0000], [-1.0000, 0.1000]]]], + [0.1000, 1.0000], + ), + "CustomConvBNTestModel": ( + [[[[0.1000, -2.0000], [1.0000, 0.1000]]], [[[0.1000, 2.0000], [-1.0000, 0.1000]]]], + [0.1000, 1.0000], + ), + } + + @pytest.mark.parametrize("port_id", (1, 2)) + def test_get_const_data_on_port(self, model_desc, port_id): + model_name, desc = model_desc + ref = self.REF_GET_CONST_DATA[model_name][port_id - 1] + + data = get_const_data_on_port(desc.node, port_id, desc.model) + if ref is None: + assert data is None + else: + assert torch.all(torch.isclose(data, torch.tensor(ref), atol=1e-4)) + + REF_WIGHT_PORT_ID = { + "ConvBiasBNTestModel": [1], + "ConvBNTestModel": [1], + "ConvTestModel": [1], + "FCTestModel": [1], + "MultipleConvTestModel": [1], + "CustomConvTestModel": [1], + "CustomConvBNTestModel": [1], + } + + def test_get_weight_tensor_port_ids(self, model_desc): + model_name, desc = model_desc + result = get_weight_tensor_port_ids(desc.node, desc.graph) + assert result == self.REF_WIGHT_PORT_ID[model_name] + + +@pytest.mark.parametrize( + "const_name, ref", + ( + ("conv.weight", ("conv", "weight")), + ("module.head.conv.bias", ("module.head.conv", "bias")), + ), + ids=["conv.weight", "module.head.conv.bias"], +) +def test_split_const_name(const_name, ref): + assert split_const_name(const_name) == ref + + +class ModelToGet(nn.Module): + def __init__(self): + super().__init__() + self.conv = create_conv(1, 1, 1) + self.customconv = helpers.CustomConv() + self.seq = nn.Sequential(nn.Identity(), helpers.CustomConv()) + + def forward(self, x: torch.Tensor): + return self.seq(self.conv(x)) + + +def test_get_module_by_name(): + model = ModelToGet() + assert get_module_by_name("", model) is model + assert get_module_by_name("conv", model) is model.conv + assert get_module_by_name("customconv.act", model) is model.customconv.act + assert get_module_by_name("seq.0", model) is model.seq[0] + assert get_module_by_name("seq.1", model) is model.seq[1] + assert get_module_by_name("seq.1.act", model) is model.seq[1].act + + +def test_get_set_const_data(): + model_cls = helpers.CustomConvBNTestModel + model = wrap_model(model_cls(), example_input=torch.ones(model_cls.INPUT_SIZE), trace_parameters=True) + graph = model.nncf.get_graph() + const_node = graph.get_node_by_name("conv.bias") + + data = get_const_data(const_node, model) + assert torch.all(model.conv.bias.data == data) + set_const_data(torch.ones_like(data), const_node, model) + assert torch.all(model.conv.bias.data == torch.ones_like(data)) + + +@pytest.mark.parametrize( + "target_type, port_id", + ( + (TargetType.OPERATOR_POST_HOOK, None), + (TargetType.OPERATOR_PRE_HOOK, 1), + ), + ids=["post_hook", "pre_hook"], +) +def test_get_fake_quantizer(target_type, port_id): + model = wrap_model( + helpers.CustomConvTestModel().eval(), + example_input=torch.ones(helpers.CustomConvTestModel.INPUT_SIZE), + trace_parameters=True, + ) + node_name = "CustomConvTestModel/CustomConv[conv]/conv2d_0" + transformer = PTModelTransformer(model) + qspec = PTQuantizerSpec( + num_bits=8, + mode=QuantizationMode.SYMMETRIC, + signedness_to_force=None, + scale_shape=(1,), + narrow_range=False, + half_range=False, + logarithm_scale=False, + ) + + fq = SymmetricQuantizer(qspec) + command = PTQuantizerInsertionCommand(PTTargetPoint(target_type, node_name, input_port_id=port_id), fq) + layout = PTTransformationLayout() + layout.register(command) + q_model = transformer.transform(layout) + + graph = q_model.nncf.get_graph() + q_node = graph.get_node_by_name("CustomConvTestModel/CustomConv[conv]/conv2d_0") + + found_fq = get_fake_quantizer(q_node, port_id, q_model) + assert fq is found_fq + + +def test_is_quantized_weights(): + model = wrap_model( + helpers.CustomConvTestModel().eval(), + example_input=torch.ones(helpers.CustomConvTestModel.INPUT_SIZE), + trace_parameters=True, + ) + node_name = "CustomConvTestModel/CustomConv[conv]/conv2d_0" + graph = model.nncf.get_graph() + node = graph.get_node_by_name(node_name) + assert not is_quantized_weights(node, graph) + + transformer = PTModelTransformer(model) + qspec = PTQuantizerSpec( + num_bits=8, + mode=QuantizationMode.SYMMETRIC, + signedness_to_force=None, + scale_shape=(1,), + narrow_range=False, + half_range=False, + logarithm_scale=False, + ) + + fq = SymmetricQuantizer(qspec) + command = PTQuantizerInsertionCommand(PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, node_name, input_port_id=1), fq) + layout = PTTransformationLayout() + layout.register(command) + q_model = transformer.transform(layout) + + q_graph = q_model.nncf.get_graph() + q_node = q_graph.get_node_by_name(node_name) + assert is_quantized_weights(q_node, q_graph)