From 35f1215b7e8d265060f889736ed7da5f11af1cd9 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Fri, 12 Apr 2024 17:34:10 +0200 Subject: [PATCH] [PTQ][MinMax][Torch] One shared quantizer is used for all unified scale quantization points (#2622) ### Changes * MinMax: new backend method `create_unified_scales_quantizers_insertion_commands` is introduced: it receives several target points and one quantization parameter. Depending on implementation, one or several insertion commands are generated and returned back to the common algorithm. ### Reason for changes * Torch backend requires one `PTSharedFNInsertionCommand` to make quantizers aligned during QAT in comparison with OV/ONNX backend, which can use separate commands/quantizers for each insertion point without any restrictions ### Related tickets 104304 ### Tests [Template test] test_ptq_params: test_unified_scales_command_creation test_create_shared_quantizer_insertion_command ### Jobs manual/job/post_training_quantization/350/: passed --------- Co-authored-by: Alexander Dokuchaev --- .../algorithms/min_max/algorithm.py | 38 ++++--- .../algorithms/min_max/backend.py | 21 +++- .../algorithms/min_max/onnx_backend.py | 16 ++- .../algorithms/min_max/openvino_backend.py | 9 ++ .../algorithms/min_max/torch_backend.py | 17 +++ .../graph/transformations/command_creation.py | 19 +++- tests/onnx/quantization/test_ptq_params.py | 23 ++++ .../native/quantization/test_ptq_params.py | 23 ++++ .../test_templates/test_ptq_params.py | 103 ++++++++++++++++-- .../quantized/ptq/symmetric/inception_v3.dot | 18 +-- tests/torch/ptq/test_ptq_params.py | 23 ++++ tests/torch/test_model_transformer.py | 71 ++++++------ 12 files changed, 305 insertions(+), 76 deletions(-) diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 97f6cc53ca2..92b2d5067f5 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -863,25 +863,29 @@ def filter_func(point: StatisticPoint) -> bool: group_statistics.append(statistics) unified_values = self._backend_entity.unify_statistics(group_statistics) - for quantization_target_point in unified_scale_group: - qconfig = quantization_target_points[quantization_target_point] - q_group = QuantizerGroup.ACTIVATIONS - narrow_range = get_quantizer_narrow_range(qconfig, q_group) - if self._mode is not None: - destination_type = self._quantization_params[q_group].destination_type - parameters = calculate_convert_parameters( - unified_values, is_per_channel=qconfig.per_channel, destination_type=destination_type - ) - command = self._backend_entity.create_convert_insertion_command( - quantization_target_point, parameters - ) - else: - parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range) - command = self._backend_entity.create_quantizer_insertion_command( - graph, quantization_target_point, qconfig, parameters + qconfigs = [quantization_target_points[qtp] for qtp in unified_scale_group] + if any(qconfigs[0] != qconfig for qconfig in qconfigs[1:]): + raise nncf.InternalError(f"QConfigs for unified scale group {unified_scale_group} are not equal") + qconfig = qconfigs[0] + q_group = QuantizerGroup.ACTIVATIONS + narrow_range = get_quantizer_narrow_range(qconfig, q_group) + if self._mode is not None: + destination_type = self._quantization_params[q_group].destination_type + parameters = calculate_convert_parameters( + unified_values, is_per_channel=qconfig.per_channel, destination_type=destination_type + ) + for quantization_target_point in unified_scale_group: + transformation_layout.register( + self._backend_entity.create_convert_insertion_command(quantization_target_point, parameters) ) + continue + parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range) + commands = self._backend_entity.create_unified_scales_quantizers_insertion_commands( + graph, unified_scale_group, qconfig, parameters + ) + for command in commands: transformation_layout.register(command) - unified_ops_list.add(quantization_target_point) + unified_ops_list.update(unified_scale_group) for quantization_target_point, qconfig in quantization_target_points.items(): if quantization_target_point in unified_ops_list: diff --git a/nncf/quantization/algorithms/min_max/backend.py b/nncf/quantization/algorithms/min_max/backend.py index d521f233243..2f2e7b7361d 100644 --- a/nncf/quantization/algorithms/min_max/backend.py +++ b/nncf/quantization/algorithms/min_max/backend.py @@ -141,12 +141,31 @@ def create_quantizer_insertion_command( Returns backend-specific quantizer insertion command. :param nncf_graph: NNCFGraph to get input/output shapes for the target point. - :param target_point: Target location for the correction. + :param target_point: Target location for the quantizer insertion. :param quantizer_config: QuantizerConfig instance for the current layer. :param parameters: FakeQuantizeParameters to calculate activation quantization parameters. :return: Backend-specific TransformationCommand for the quantizer insertion operation. """ + @staticmethod + @abstractmethod + def create_unified_scales_quantizers_insertion_commands( + nncf_graph: NNCFGraph, + target_points: List[TargetPoint], + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> List[TransformationCommand]: + """ + Returns backend-specific unified scales quantizers insertion commands. + + :param nncf_graph: NNCFGraph to get input/output shapes for the target point. + :param target_points: List of target locations for the quantizers insertion. + :param quantizer_config: QuantizerConfig instance for the current layer. + :param parameters: FakeQuantizeParameters to calculate activation quantization parameters. + :return: List of backend-specific TransformationCommands + for the quantizers with unified scales insertion operations. + """ + @staticmethod @abstractmethod def create_convert_insertion_command( diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index 16295518de1..f58299a5d10 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -118,7 +118,7 @@ def create_quantizer_insertion_command( target_point: ONNXTargetPoint, quantizer_config: QuantizerConfig, parameters: FakeQuantizeParameters, - ): + ) -> ONNXQuantizerInsertionCommand: tensor_type = np.int8 if np.any(parameters.input_low.data < 0) else np.uint8 is_weight = target_point.is_weight_target_point() if is_weight: @@ -131,6 +131,20 @@ def create_quantizer_insertion_command( onnx_parameters = convert_fq_params_to_onnx_params(parameters, quantizer_config.num_bits, tensor_type, axis) return ONNXQuantizerInsertionCommand(target_point, nncf_input_node_next_nodes, onnx_parameters) + @staticmethod + def create_unified_scales_quantizers_insertion_commands( + nncf_graph: NNCFGraph, + target_points: List[ONNXTargetPoint], + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> List[ONNXQuantizerInsertionCommand]: + return [ + ONNXMinMaxAlgoBackend.create_quantizer_insertion_command( + nncf_graph, target_point, quantizer_config, parameters + ) + for target_point in target_points + ] + @staticmethod def create_convert_insertion_command( target_point: ONNXTargetPoint, diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index d621993c3ae..417f9c7cbec 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -120,6 +120,15 @@ def create_quantizer_insertion_command( ) -> OVQuantizerInsertionCommand: return OVQuantizerInsertionCommand(target_point, parameters) + @staticmethod + def create_unified_scales_quantizers_insertion_commands( + nncf_graph: NNCFGraph, + target_points: List[OVTargetPoint], + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> List[OVQuantizerInsertionCommand]: + return [OVQuantizerInsertionCommand(target_point, parameters) for target_point in target_points] + @staticmethod def create_convert_insertion_command( target_point: OVTargetPoint, diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index d5a37ddcbe5..541792eca78 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -37,6 +37,7 @@ from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph import PTTargetPoint from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command +from nncf.torch.graph.transformations.command_creation import create_shared_quantizer_insertion_command from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.hardware.config import PTHWConfig @@ -296,6 +297,22 @@ def create_quantizer_insertion_command( ) return create_quantizer_insertion_command(target_point, quantizer) + @staticmethod + def create_unified_scales_quantizers_insertion_commands( + nncf_graph: NNCFGraph, + target_points: List[PTTargetPoint], + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> List[PTSharedFnInsertionCommand]: + _, scale_shape, _ = PTMinMaxAlgoBackend._get_input_scale_shape( + nncf_graph, target_points[0], quantizer_config.per_channel + ) + + quantizer = PTMinMaxAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_points[0].target_type + ) + return [create_shared_quantizer_insertion_command(target_points, quantizer)] + @staticmethod def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]: types = [] diff --git a/nncf/torch/graph/transformations/command_creation.py b/nncf/torch/graph/transformations/command_creation.py index 16b14c3e172..6146803ae19 100644 --- a/nncf/torch/graph/transformations/command_creation.py +++ b/nncf/torch/graph/transformations/command_creation.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import List, Union from torch import Tensor @@ -65,3 +65,20 @@ def create_quantizer_insertion_command( compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, priority=TransformationPriority.QUANTIZATION_PRIORITY, ) + + +def create_shared_quantizer_insertion_command( + target_points: List[PTTargetPoint], quantizer: BaseQuantizer +) -> PTSharedFnInsertionCommand: + quantizers_ids = [] + for target_point in target_points: + quantizers_ids.append(NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id)) + + storage_key = ";".join(str(quantizer_id) for quantizer_id in sorted(quantizers_ids, key=str)) + return PTSharedFnInsertionCommand( + target_points=target_points, + fn=quantizer, + op_unique_name=storage_key, + compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER, + priority=TransformationPriority.QUANTIZATION_PRIORITY, + ) diff --git a/tests/onnx/quantization/test_ptq_params.py b/tests/onnx/quantization/test_ptq_params.py index daadc8a8337..f6c6c041459 100644 --- a/tests/onnx/quantization/test_ptq_params.py +++ b/tests/onnx/quantization/test_ptq_params.py @@ -9,22 +9,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pytest +from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.patterns import GraphPattern from nncf.common.graph.patterns.manager import PatternsManager from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationType from nncf.common.utils.backend import BackendType +from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConcatMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConvolutionMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGemmMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXSoftmaxMetatype from nncf.onnx.graph.nncf_graph_builder import GraphConverter from nncf.onnx.graph.nncf_graph_builder import ONNXLayerAttributes +from nncf.onnx.graph.transformations.commands import ONNXQuantizerInsertionCommand from nncf.onnx.graph.transformations.commands import ONNXTargetPoint from nncf.parameters import TargetDevice from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.min_max.onnx_backend import ONNXMinMaxAlgoBackend from nncf.scopes import IgnoredScope +from tests.common.quantization.metatypes import CatTestMetatype from tests.common.quantization.metatypes import Conv2dTestMetatype from tests.common.quantization.metatypes import LinearTestMetatype from tests.common.quantization.metatypes import SoftmaxTestMetatype @@ -61,17 +67,34 @@ def check_quantize_outputs_fq_num(self, quantize_outputs, act_num_q, weight_num_ assert act_num_q == 1 assert weight_num_q == 1 + def check_unified_scale_layout(self, layout, unified_scale_group): + assert len(layout.transformations) == len(unified_scale_group) + for t, ref_tp in zip(layout.transformations, unified_scale_group): + assert isinstance(t, ONNXQuantizerInsertionCommand) + assert t.target_point == ref_tp + assert t.type == TransformationType.INSERT + assert t.quantizer_parameters.zero_point == 0 + assert np.isclose(t.quantizer_parameters.scale, 0.03149606) + def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> ONNXTargetPoint: return ONNXTargetPoint(target_type, target_node_name, port_id) + def get_backend_tensor(self, value): + return np.array(value) + @property def metatypes_mapping(self): return { Conv2dTestMetatype: ONNXConvolutionMetatype, LinearTestMetatype: ONNXGemmMetatype, SoftmaxTestMetatype: ONNXSoftmaxMetatype, + CatTestMetatype: ONNXConcatMetatype, } + @property + def nncf_graph_cls(self): + return NNCFGraph + @pytest.fixture(scope="session") def test_params(self): linear_model = LinearModel().onnx_model diff --git a/tests/openvino/native/quantization/test_ptq_params.py b/tests/openvino/native/quantization/test_ptq_params.py index 71fda6d1a7c..73b73e511ea 100644 --- a/tests/openvino/native/quantization/test_ptq_params.py +++ b/tests/openvino/native/quantization/test_ptq_params.py @@ -9,22 +9,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import pytest +from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.patterns import GraphPattern from nncf.common.graph.patterns.manager import PatternsManager from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationType from nncf.common.hardware.config import HW_CONFIG_TYPE_TARGET_DEVICE_MAP from nncf.common.utils.backend import BackendType +from nncf.openvino.graph.metatypes.openvino_metatypes import OVConcatMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVSoftmaxMetatype from nncf.openvino.graph.nncf_graph_builder import GraphConverter +from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand from nncf.openvino.graph.transformations.commands import OVTargetPoint from nncf.parameters import TargetDevice from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.min_max.openvino_backend import OVMinMaxAlgoBackend from nncf.scopes import IgnoredScope +from tests.common.quantization.metatypes import CatTestMetatype from tests.common.quantization.metatypes import Conv2dTestMetatype from tests.common.quantization.metatypes import LinearTestMetatype from tests.common.quantization.metatypes import SoftmaxTestMetatype @@ -60,17 +66,34 @@ def check_quantize_outputs_fq_num(self, quantize_outputs, act_num_q, weight_num_ assert act_num_q == 1 assert weight_num_q == 1 + def check_unified_scale_layout(self, layout, unified_scale_group): + assert len(layout.transformations) == len(unified_scale_group) + for t, ref_tp in zip(layout.transformations, unified_scale_group): + assert isinstance(t, OVQuantizerInsertionCommand) + assert t.target_point == ref_tp + assert t.type == TransformationType.INSERT + assert np.isclose(t.quantizer_parameters.input_low.data, -4.031496) + assert np.isclose(t.quantizer_parameters.input_high.data, 4) + def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint: return OVTargetPoint(target_type, target_node_name, port_id) + def get_backend_tensor(self, value): + return np.array(value) + @property def metatypes_mapping(self): return { Conv2dTestMetatype: OVConvolutionMetatype, LinearTestMetatype: OVMatMulMetatype, SoftmaxTestMetatype: OVSoftmaxMetatype, + CatTestMetatype: OVConcatMetatype, } + @property + def nncf_graph_cls(self): + return NNCFGraph + @pytest.fixture(scope="session") def test_params(self): linear_model = LinearModel().ov_model diff --git a/tests/post_training/test_templates/test_ptq_params.py b/tests/post_training/test_templates/test_ptq_params.py index c7a0e45799c..f56db68f938 100644 --- a/tests/post_training/test_templates/test_ptq_params.py +++ b/tests/post_training/test_templates/test_ptq_params.py @@ -16,6 +16,7 @@ import pytest import nncf +from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.operator_metatypes import InputNoopMetatype from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.operator_metatypes import OutputNoopMetatype @@ -37,6 +38,7 @@ from nncf.quantization.passes import transform_to_inference_graph from nncf.quantization.range_estimator import RangeEstimatorParametersSet from nncf.scopes import IgnoredScope +from tests.common.quantization.metatypes import CatTestMetatype from tests.common.quantization.metatypes import Conv2dTestMetatype from tests.common.quantization.metatypes import IdentityTestMetatype from tests.common.quantization.metatypes import LinearTestMetatype @@ -91,6 +93,50 @@ def __init__(self, metatypes: Dict[TestMetatype, OperatorMetatype]): self.weight_quantization_target_point_names.append(node.node_name) +class ModelWithUnifiedScales: + # Input_1 + # / | \ + # Conv_1 Conv_2 Conv_3 + # \ | / + # Cat_1 + # | + # Output_1 + + def __init__(self, metatypes: Dict[TestMetatype, OperatorMetatype], nncf_graph_cls=NNCFGraph): + nodes = [ + NodeWithType("Input_1", InputNoopMetatype), + NodeWithType("Conv_1", metatypes[Conv2dTestMetatype]), + NodeWithType("Conv_2", metatypes[Conv2dTestMetatype]), + NodeWithType("Conv_3", metatypes[Conv2dTestMetatype]), + NodeWithType("Cat_1", metatypes[CatTestMetatype]), + NodeWithType("Output_1", OutputNoopMetatype), + ] + node_edges = [ + ("Input_1", "Conv_1"), + ("Input_1", "Conv_2"), + ("Input_1", "Conv_3"), + ("Conv_1", "Cat_1"), + ("Conv_2", "Cat_1"), + ("Conv_3", "Cat_1"), + ("Cat_1", "Output_1"), + ] + original_mock_graph = create_mock_graph(nodes, node_edges) + self.nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph, nncf_graph_cls=nncf_graph_cls) + + +class DummyMinMaxTensorStatistic(MinMaxTensorStatistic): + def tensor_eq(self): + return True + + +class DummyMinMaxTensorCollector: + def __init__(self, min_val, max_val): + self._stat = DummyMinMaxTensorStatistic(min_values=min_val, max_values=max_val) + + def get_statistics(self): + return self._stat + + class TemplateTestPTQParams: @abstractmethod def get_algo_backend(self): @@ -112,6 +158,13 @@ def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorColl def check_quantize_outputs_fq_num(self, quantize_outputs, act_num_q, weight_num_q): pass + @abstractmethod + def check_unified_scale_layout(self, layout, unified_scales_group): + """ + Checks that given transfromation layout and unified_scales_group target points + are correspond to each other and to the test params + """ + @abstractmethod @pytest.fixture(scope="session") def test_params(self): @@ -131,6 +184,15 @@ def target_point(self, target_type: TargetType, target_node_name: str, port_id: def metatypes_mapping(self): pass + @property + @abstractmethod + def nncf_graph_cls(self): + pass + + @abstractmethod + def get_backend_tensor(self, value): + pass + @pytest.mark.parametrize( "range_estimator_params", [RangeEstimatorParametersSet.MINMAX, RangeEstimatorParametersSet.MEAN_MINMAX, None] ) @@ -282,6 +344,35 @@ def test_quantization_points_overflow_fix(self, overflow_fix, affected_target_po ) assert Counter([t_p.target_node_name for t_p in target_points_overflow_fix]) == Counter(affected_target_points) + def test_unified_scales_command_creation(self, mocker): + model = ModelWithUnifiedScales(self.metatypes_mapping, self.nncf_graph_cls) + algo = MinMaxQuantization() + algo._backend_entity = self.get_algo_backend() + # Imitating solver quantization setup building + q_tp_vs_qcf = {} + unified_scales_group = [] + for i in range(1, 4): + tp = self.target_point(TargetType.POST_LAYER_OPERATION, f"/Conv_{i}_0", port_id=0) + q_tp_vs_qcf[tp] = QuantizerConfig() + unified_scales_group.append(tp) + + algo._quantization_target_points_to_qconfig = q_tp_vs_qcf + algo._unified_scale_groups = [unified_scales_group] + + mock_transformer = mocker.MagicMock() + mocker.patch( + "nncf.quantization.algorithms.min_max.algorithm.ModelTransformerFactory.create", + return_value=mock_transformer, + ) + stats = StatisticPointsContainer() + for idx, tp in enumerate(unified_scales_group): + tc = DummyMinMaxTensorCollector(self.get_backend_tensor(idx - 1), self.get_backend_tensor(idx + 2)) + stats.add_statistic_point(StatisticPoint(tp, tc, algo._algorithm_key)) + algo.apply(model, model.nncf_graph, stats) + mock_transformer.transform.assert_called_once() + layout = mock_transformer.transform.call_args.args[0] + self.check_unified_scale_layout(layout, unified_scales_group) + @pytest.mark.parametrize("validate_scopes", (True, False)) def test_validate_scope(self, test_params, validate_scopes): nncf_graph = test_params["test_model_type_pass"]["nncf_graph"] @@ -308,20 +399,14 @@ def test_empty_statistics(self, mode, mocker): target_point = self.target_point(TargetType.PRE_LAYER_OPERATION, "A", 0) stat_points = StatisticPointsContainer() - class DummyMinMaxTensorStatistic(MinMaxTensorStatistic): - def tensor_eq(self): - return True - - class EmptyTensorCollector: - def get_statistics(self): - return DummyMinMaxTensorStatistic(None, None) - dummy_tp = {target_point: QuantizerConfig()} if mode == "target_point": dummy_tps = (dummy_tp, {}) else: dummy_tps = ({}, ((target_point,),)) - stat_points.add_statistic_point(StatisticPoint(target_point, EmptyTensorCollector(), algo._algorithm_key)) + stat_points.add_statistic_point( + StatisticPoint(target_point, DummyMinMaxTensorCollector(None, None), algo._algorithm_key) + ) mocker.patch("nncf.common.factory.ModelTransformerFactory.create", return_value=mocker.MagicMock()) mocker.patch( "nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization._get_quantization_target_points", diff --git a/tests/torch/data/reference_graphs/quantized/ptq/symmetric/inception_v3.dot b/tests/torch/data/reference_graphs/quantized/ptq/symmetric/inception_v3.dot index 3eca81edfce..1082c9c5847 100644 --- a/tests/torch/data/reference_graphs/quantized/ptq/symmetric/inception_v3.dot +++ b/tests/torch/data/reference_graphs/quantized/ptq/symmetric/inception_v3.dot @@ -6,19 +6,19 @@ strict digraph { "4 Inception3/__mul___0" [id=4, type=__mul__]; "5 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__mul___0|OUTPUT]/symmetric_quantize_0" [id=5, type=symmetric_quantize]; "6 Inception3/__add___0" [id=6, type=__add__]; -"7 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT]/symmetric_quantize_0" [id=7, type=symmetric_quantize]; +"7 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT;Inception3/__add___1|OUTPUT;Inception3/__add___2|OUTPUT]/symmetric_quantize_0" [id=7, type=symmetric_quantize]; "8 Inception3/__getitem___1" [id=8, type=__getitem__]; "9 Inception3/unsqueeze_1" [id=9, type=unsqueeze]; "10 Inception3/__mul___1" [id=10, type=__mul__]; "11 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__mul___1|OUTPUT]/symmetric_quantize_0" [id=11, type=symmetric_quantize]; "12 Inception3/__add___1" [id=12, type=__add__]; -"13 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___1|OUTPUT]/symmetric_quantize_0" [id=13, type=symmetric_quantize]; +"13 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT;Inception3/__add___1|OUTPUT;Inception3/__add___2|OUTPUT]/symmetric_quantize_1" [id=13, type=symmetric_quantize]; "14 Inception3/__getitem___2" [id=14, type=__getitem__]; "15 Inception3/unsqueeze_2" [id=15, type=unsqueeze]; "16 Inception3/__mul___2" [id=16, type=__mul__]; "17 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__mul___2|OUTPUT]/symmetric_quantize_0" [id=17, type=symmetric_quantize]; "18 Inception3/__add___2" [id=18, type=__add__]; -"19 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___2|OUTPUT]/symmetric_quantize_0" [id=19, type=symmetric_quantize]; +"19 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT;Inception3/__add___1|OUTPUT;Inception3/__add___2|OUTPUT]/symmetric_quantize_2" [id=19, type=symmetric_quantize]; "20 Inception3/cat_0" [id=20, type=cat]; "21 Inception3/BasicConv2d[Conv2d_1a_3x3]/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" [id=21, type=symmetric_quantize]; "22 Inception3/BasicConv2d[Conv2d_1a_3x3]/NNCFConv2d[conv]/conv2d_0" [id=22, type=conv2d]; @@ -542,20 +542,20 @@ strict digraph { "3 Inception3/unsqueeze_0" -> "4 Inception3/__mul___0"; "4 Inception3/__mul___0" -> "5 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__mul___0|OUTPUT]/symmetric_quantize_0"; "5 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__mul___0|OUTPUT]/symmetric_quantize_0" -> "6 Inception3/__add___0"; -"6 Inception3/__add___0" -> "7 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT]/symmetric_quantize_0"; -"7 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT]/symmetric_quantize_0" -> "20 Inception3/cat_0"; +"6 Inception3/__add___0" -> "7 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT;Inception3/__add___1|OUTPUT;Inception3/__add___2|OUTPUT]/symmetric_quantize_0"; +"7 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT;Inception3/__add___1|OUTPUT;Inception3/__add___2|OUTPUT]/symmetric_quantize_0" -> "20 Inception3/cat_0"; "8 Inception3/__getitem___1" -> "9 Inception3/unsqueeze_1"; "9 Inception3/unsqueeze_1" -> "10 Inception3/__mul___1"; "10 Inception3/__mul___1" -> "11 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__mul___1|OUTPUT]/symmetric_quantize_0"; "11 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__mul___1|OUTPUT]/symmetric_quantize_0" -> "12 Inception3/__add___1"; -"12 Inception3/__add___1" -> "13 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___1|OUTPUT]/symmetric_quantize_0"; -"13 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___1|OUTPUT]/symmetric_quantize_0" -> "20 Inception3/cat_0"; +"12 Inception3/__add___1" -> "13 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT;Inception3/__add___1|OUTPUT;Inception3/__add___2|OUTPUT]/symmetric_quantize_1"; +"13 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT;Inception3/__add___1|OUTPUT;Inception3/__add___2|OUTPUT]/symmetric_quantize_1" -> "20 Inception3/cat_0"; "14 Inception3/__getitem___2" -> "15 Inception3/unsqueeze_2"; "15 Inception3/unsqueeze_2" -> "16 Inception3/__mul___2"; "16 Inception3/__mul___2" -> "17 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__mul___2|OUTPUT]/symmetric_quantize_0"; "17 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__mul___2|OUTPUT]/symmetric_quantize_0" -> "18 Inception3/__add___2"; -"18 Inception3/__add___2" -> "19 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___2|OUTPUT]/symmetric_quantize_0"; -"19 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___2|OUTPUT]/symmetric_quantize_0" -> "20 Inception3/cat_0"; +"18 Inception3/__add___2" -> "19 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT;Inception3/__add___1|OUTPUT;Inception3/__add___2|OUTPUT]/symmetric_quantize_2"; +"19 Inception3/NNCFNetworkInterface[_nncf]/ModuleDict[external_quantizers]/SymmetricQuantizer[Inception3/__add___0|OUTPUT;Inception3/__add___1|OUTPUT;Inception3/__add___2|OUTPUT]/symmetric_quantize_2" -> "20 Inception3/cat_0"; "20 Inception3/cat_0" -> "22 Inception3/BasicConv2d[Conv2d_1a_3x3]/NNCFConv2d[conv]/conv2d_0"; "21 Inception3/BasicConv2d[Conv2d_1a_3x3]/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/SymmetricQuantizer[op]/symmetric_quantize_0" -> "22 Inception3/BasicConv2d[Conv2d_1a_3x3]/NNCFConv2d[conv]/conv2d_0"; "22 Inception3/BasicConv2d[Conv2d_1a_3x3]/NNCFConv2d[conv]/conv2d_0" -> "23 Inception3/BasicConv2d[Conv2d_1a_3x3]/NNCFBatchNorm2d[bn]/batch_norm_0"; diff --git a/tests/torch/ptq/test_ptq_params.py b/tests/torch/ptq/test_ptq_params.py index 8ebdb1f1b2d..88a691c47fc 100644 --- a/tests/torch/ptq/test_ptq_params.py +++ b/tests/torch/ptq/test_ptq_params.py @@ -10,20 +10,26 @@ # limitations under the License. import pytest +import torch from torch import nn from nncf.common.graph.patterns import GraphPattern from nncf.common.graph.patterns.manager import PatternsManager from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationType from nncf.common.utils.backend import BackendType from nncf.parameters import TargetDevice from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend from nncf.scopes import IgnoredScope +from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph import PTTargetPoint +from nncf.torch.graph.operator_metatypes import PTCatMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype from nncf.torch.graph.operator_metatypes import PTSoftmaxMetatype +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from tests.common.quantization.metatypes import CatTestMetatype from tests.common.quantization.metatypes import Conv2dTestMetatype from tests.common.quantization.metatypes import LinearTestMetatype from tests.common.quantization.metatypes import SoftmaxTestMetatype @@ -97,17 +103,34 @@ def check_quantize_outputs_fq_num(self, quantize_outputs, act_num_q, weight_num_ assert act_num_q == 1 assert weight_num_q == 1 + def check_unified_scale_layout(self, layout, unified_scale_group): + assert len(layout.transformations) == 1 + command = layout.transformations[0] + assert isinstance(command, PTSharedFnInsertionCommand) + assert command.op_name == "/Conv_1_0|INPUT0;/Conv_2_0|INPUT0;/Conv_3_0|INPUT0" + assert command.target_points == unified_scale_group + assert torch.allclose(command.fn.scale, torch.tensor(4.0)) + assert command.type == TransformationType.INSERT + def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) + def get_backend_tensor(self, value): + return torch.tensor(value) + @property def metatypes_mapping(self): return { Conv2dTestMetatype: PTModuleConv2dMetatype, LinearTestMetatype: PTModuleLinearMetatype, SoftmaxTestMetatype: PTSoftmaxMetatype, + CatTestMetatype: PTCatMetatype, } + @property + def nncf_graph_cls(self): + return PTNNCFGraph + @pytest.fixture(scope="session") def test_params(self): linear_model = LinearTestModel().get_nncf_network() diff --git a/tests/torch/test_model_transformer.py b/tests/torch/test_model_transformer.py index c6348644f68..c554a39ccb3 100644 --- a/tests/torch/test_model_transformer.py +++ b/tests/torch/test_model_transformer.py @@ -50,6 +50,7 @@ from nncf.torch.graph.operator_metatypes import PTOutputNoopMetatype from nncf.torch.graph.operator_metatypes import PTReshapeMetatype from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command +from nncf.torch.graph.transformations.command_creation import create_shared_quantizer_insertion_command from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTInsertionCommand @@ -638,32 +639,26 @@ def to(self, device): self.to_device = device -@pytest.mark.parametrize( - "target_type, node_name, input_port_id, ref_name, compression_module_registered", - ( - ( - TargetType.OPERATOR_POST_HOOK, - "/nncf_model_input_0", - None, - "/nncf_model_input_0|OUTPUT", - True, - ), - ( - TargetType.OPERATOR_PRE_HOOK, - "InsertionPointTestModel/linear_0", - 0, - "InsertionPointTestModel/linear_0|INPUT0", - True, - ), - (TargetType.OPERATION_WITH_WEIGHTS, "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0", None, None, False), +SHARED_FN_TARGET_POINTS = ( + PTTargetPoint( + TargetType.OPERATOR_POST_HOOK, + "/nncf_model_input_0", + ), + PTTargetPoint( + TargetType.OPERATOR_PRE_HOOK, + "InsertionPointTestModel/linear_0", + input_port_id=0, + ), + PTTargetPoint( + TargetType.OPERATION_WITH_WEIGHTS, + "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0", ), ) -def test_quantizer_insertion_transformations( - target_type, node_name, input_port_id, ref_name, compression_module_registered -): - hook = Hook() - target_point = PTTargetPoint(target_type, node_name, input_port_id=input_port_id) + +@pytest.mark.parametrize("target_point", SHARED_FN_TARGET_POINTS) +def test_create_quantizer_insertion_command(target_point): + hook = Hook() command = create_quantizer_insertion_command(target_point, hook) assert command.fn is hook @@ -679,21 +674,21 @@ def test_quantizer_insertion_transformations( assert command.compression_module_type is ExtraCompressionModuleType.EXTERNAL_QUANTIZER -SHARED_FN_TARGET_POINTS = [ - PTTargetPoint( - TargetType.OPERATOR_POST_HOOK, - "/nncf_model_input_0", - ), - PTTargetPoint( - TargetType.OPERATOR_PRE_HOOK, - "InsertionPointTestModel/linear_0", - input_port_id=0, - ), - PTTargetPoint( - TargetType.OPERATION_WITH_WEIGHTS, - "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0", - ), -] +def test_create_shared_quantizer_insertion_command(): + ref_storage_key = ( + "/nncf_model_input_0|OUTPUT;" + "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0|OUTPUT;" + "InsertionPointTestModel/linear_0|INPUT0" + ) + hook = Hook() + + command = create_shared_quantizer_insertion_command(list(SHARED_FN_TARGET_POINTS), hook) + assert command.fn is hook + assert isinstance(command, PTSharedFnInsertionCommand) + assert command.target_points == list(SHARED_FN_TARGET_POINTS) + assert command.fn is hook + assert command.op_name == ref_storage_key + assert command.compression_module_type is ExtraCompressionModuleType.EXTERNAL_QUANTIZER @pytest.mark.parametrize("compression_module_type", ExtraCompressionModuleType)