Skip to content

Commit

Permalink
[PTQ][MinMax][Torch] One shared quantizer is used for all unified sca…
Browse files Browse the repository at this point in the history
…le quantization points (#2622)

### Changes

* MinMax: new backend method
`create_unified_scales_quantizers_insertion_commands` is introduced: it
receives several target points and one quantization parameter. Depending
on implementation, one or several insertion commands are generated and
returned back to the common algorithm.

### Reason for changes

* Torch backend requires one `PTSharedFNInsertionCommand` to make
quantizers aligned during QAT in comparison with OV/ONNX backend, which
can use separate commands/quantizers for each insertion point without
any restrictions

### Related tickets

104304

### Tests

[Template test] test_ptq_params: test_unified_scales_command_creation
test_create_shared_quantizer_insertion_command

### Jobs
manual/job/post_training_quantization/350/: passed

---------

Co-authored-by: Alexander Dokuchaev <[email protected]>
  • Loading branch information
daniil-lyakhov and AlexanderDokuchaev authored Apr 12, 2024
1 parent f878143 commit 35f1215
Show file tree
Hide file tree
Showing 12 changed files with 305 additions and 76 deletions.
38 changes: 21 additions & 17 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,25 +863,29 @@ def filter_func(point: StatisticPoint) -> bool:
group_statistics.append(statistics)

unified_values = self._backend_entity.unify_statistics(group_statistics)
for quantization_target_point in unified_scale_group:
qconfig = quantization_target_points[quantization_target_point]
q_group = QuantizerGroup.ACTIVATIONS
narrow_range = get_quantizer_narrow_range(qconfig, q_group)
if self._mode is not None:
destination_type = self._quantization_params[q_group].destination_type
parameters = calculate_convert_parameters(
unified_values, is_per_channel=qconfig.per_channel, destination_type=destination_type
)
command = self._backend_entity.create_convert_insertion_command(
quantization_target_point, parameters
)
else:
parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range)
command = self._backend_entity.create_quantizer_insertion_command(
graph, quantization_target_point, qconfig, parameters
qconfigs = [quantization_target_points[qtp] for qtp in unified_scale_group]
if any(qconfigs[0] != qconfig for qconfig in qconfigs[1:]):
raise nncf.InternalError(f"QConfigs for unified scale group {unified_scale_group} are not equal")
qconfig = qconfigs[0]
q_group = QuantizerGroup.ACTIVATIONS
narrow_range = get_quantizer_narrow_range(qconfig, q_group)
if self._mode is not None:
destination_type = self._quantization_params[q_group].destination_type
parameters = calculate_convert_parameters(
unified_values, is_per_channel=qconfig.per_channel, destination_type=destination_type
)
for quantization_target_point in unified_scale_group:
transformation_layout.register(
self._backend_entity.create_convert_insertion_command(quantization_target_point, parameters)
)
continue
parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range)
commands = self._backend_entity.create_unified_scales_quantizers_insertion_commands(
graph, unified_scale_group, qconfig, parameters
)
for command in commands:
transformation_layout.register(command)
unified_ops_list.add(quantization_target_point)
unified_ops_list.update(unified_scale_group)

for quantization_target_point, qconfig in quantization_target_points.items():
if quantization_target_point in unified_ops_list:
Expand Down
21 changes: 20 additions & 1 deletion nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,31 @@ def create_quantizer_insertion_command(
Returns backend-specific quantizer insertion command.
:param nncf_graph: NNCFGraph to get input/output shapes for the target point.
:param target_point: Target location for the correction.
:param target_point: Target location for the quantizer insertion.
:param quantizer_config: QuantizerConfig instance for the current layer.
:param parameters: FakeQuantizeParameters to calculate activation quantization parameters.
:return: Backend-specific TransformationCommand for the quantizer insertion operation.
"""

@staticmethod
@abstractmethod
def create_unified_scales_quantizers_insertion_commands(
nncf_graph: NNCFGraph,
target_points: List[TargetPoint],
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> List[TransformationCommand]:
"""
Returns backend-specific unified scales quantizers insertion commands.
:param nncf_graph: NNCFGraph to get input/output shapes for the target point.
:param target_points: List of target locations for the quantizers insertion.
:param quantizer_config: QuantizerConfig instance for the current layer.
:param parameters: FakeQuantizeParameters to calculate activation quantization parameters.
:return: List of backend-specific TransformationCommands
for the quantizers with unified scales insertion operations.
"""

@staticmethod
@abstractmethod
def create_convert_insertion_command(
Expand Down
16 changes: 15 additions & 1 deletion nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def create_quantizer_insertion_command(
target_point: ONNXTargetPoint,
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
):
) -> ONNXQuantizerInsertionCommand:
tensor_type = np.int8 if np.any(parameters.input_low.data < 0) else np.uint8
is_weight = target_point.is_weight_target_point()
if is_weight:
Expand All @@ -131,6 +131,20 @@ def create_quantizer_insertion_command(
onnx_parameters = convert_fq_params_to_onnx_params(parameters, quantizer_config.num_bits, tensor_type, axis)
return ONNXQuantizerInsertionCommand(target_point, nncf_input_node_next_nodes, onnx_parameters)

@staticmethod
def create_unified_scales_quantizers_insertion_commands(
nncf_graph: NNCFGraph,
target_points: List[ONNXTargetPoint],
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> List[ONNXQuantizerInsertionCommand]:
return [
ONNXMinMaxAlgoBackend.create_quantizer_insertion_command(
nncf_graph, target_point, quantizer_config, parameters
)
for target_point in target_points
]

@staticmethod
def create_convert_insertion_command(
target_point: ONNXTargetPoint,
Expand Down
9 changes: 9 additions & 0 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def create_quantizer_insertion_command(
) -> OVQuantizerInsertionCommand:
return OVQuantizerInsertionCommand(target_point, parameters)

@staticmethod
def create_unified_scales_quantizers_insertion_commands(
nncf_graph: NNCFGraph,
target_points: List[OVTargetPoint],
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> List[OVQuantizerInsertionCommand]:
return [OVQuantizerInsertionCommand(target_point, parameters) for target_point in target_points]

@staticmethod
def create_convert_insertion_command(
target_point: OVTargetPoint,
Expand Down
17 changes: 17 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command
from nncf.torch.graph.transformations.command_creation import create_shared_quantizer_insertion_command
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
Expand Down Expand Up @@ -296,6 +297,22 @@ def create_quantizer_insertion_command(
)
return create_quantizer_insertion_command(target_point, quantizer)

@staticmethod
def create_unified_scales_quantizers_insertion_commands(
nncf_graph: NNCFGraph,
target_points: List[PTTargetPoint],
quantizer_config: QuantizerConfig,
parameters: FakeQuantizeParameters,
) -> List[PTSharedFnInsertionCommand]:
_, scale_shape, _ = PTMinMaxAlgoBackend._get_input_scale_shape(
nncf_graph, target_points[0], quantizer_config.per_channel
)

quantizer = PTMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_points[0].target_type
)
return [create_shared_quantizer_insertion_command(target_points, quantizer)]

@staticmethod
def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]:
types = []
Expand Down
19 changes: 18 additions & 1 deletion nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from typing import List, Union

from torch import Tensor

Expand Down Expand Up @@ -65,3 +65,20 @@ def create_quantizer_insertion_command(
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
priority=TransformationPriority.QUANTIZATION_PRIORITY,
)


def create_shared_quantizer_insertion_command(
target_points: List[PTTargetPoint], quantizer: BaseQuantizer
) -> PTSharedFnInsertionCommand:
quantizers_ids = []
for target_point in target_points:
quantizers_ids.append(NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id))

storage_key = ";".join(str(quantizer_id) for quantizer_id in sorted(quantizers_ids, key=str))
return PTSharedFnInsertionCommand(
target_points=target_points,
fn=quantizer,
op_unique_name=storage_key,
compression_module_type=ExtraCompressionModuleType.EXTERNAL_QUANTIZER,
priority=TransformationPriority.QUANTIZATION_PRIORITY,
)
23 changes: 23 additions & 0 deletions tests/onnx/quantization/test_ptq_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.patterns.manager import PatternsManager
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationType
from nncf.common.utils.backend import BackendType
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConcatMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConvolutionMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGemmMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXSoftmaxMetatype
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from nncf.onnx.graph.nncf_graph_builder import ONNXLayerAttributes
from nncf.onnx.graph.transformations.commands import ONNXQuantizerInsertionCommand
from nncf.onnx.graph.transformations.commands import ONNXTargetPoint
from nncf.parameters import TargetDevice
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.quantization.algorithms.min_max.onnx_backend import ONNXMinMaxAlgoBackend
from nncf.scopes import IgnoredScope
from tests.common.quantization.metatypes import CatTestMetatype
from tests.common.quantization.metatypes import Conv2dTestMetatype
from tests.common.quantization.metatypes import LinearTestMetatype
from tests.common.quantization.metatypes import SoftmaxTestMetatype
Expand Down Expand Up @@ -61,17 +67,34 @@ def check_quantize_outputs_fq_num(self, quantize_outputs, act_num_q, weight_num_
assert act_num_q == 1
assert weight_num_q == 1

def check_unified_scale_layout(self, layout, unified_scale_group):
assert len(layout.transformations) == len(unified_scale_group)
for t, ref_tp in zip(layout.transformations, unified_scale_group):
assert isinstance(t, ONNXQuantizerInsertionCommand)
assert t.target_point == ref_tp
assert t.type == TransformationType.INSERT
assert t.quantizer_parameters.zero_point == 0
assert np.isclose(t.quantizer_parameters.scale, 0.03149606)

def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> ONNXTargetPoint:
return ONNXTargetPoint(target_type, target_node_name, port_id)

def get_backend_tensor(self, value):
return np.array(value)

@property
def metatypes_mapping(self):
return {
Conv2dTestMetatype: ONNXConvolutionMetatype,
LinearTestMetatype: ONNXGemmMetatype,
SoftmaxTestMetatype: ONNXSoftmaxMetatype,
CatTestMetatype: ONNXConcatMetatype,
}

@property
def nncf_graph_cls(self):
return NNCFGraph

@pytest.fixture(scope="session")
def test_params(self):
linear_model = LinearModel().onnx_model
Expand Down
23 changes: 23 additions & 0 deletions tests/openvino/native/quantization/test_ptq_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pytest

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.patterns.manager import PatternsManager
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationType
from nncf.common.hardware.config import HW_CONFIG_TYPE_TARGET_DEVICE_MAP
from nncf.common.utils.backend import BackendType
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConcatMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVSoftmaxMetatype
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.parameters import TargetDevice
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.quantization.algorithms.min_max.openvino_backend import OVMinMaxAlgoBackend
from nncf.scopes import IgnoredScope
from tests.common.quantization.metatypes import CatTestMetatype
from tests.common.quantization.metatypes import Conv2dTestMetatype
from tests.common.quantization.metatypes import LinearTestMetatype
from tests.common.quantization.metatypes import SoftmaxTestMetatype
Expand Down Expand Up @@ -60,17 +66,34 @@ def check_quantize_outputs_fq_num(self, quantize_outputs, act_num_q, weight_num_
assert act_num_q == 1
assert weight_num_q == 1

def check_unified_scale_layout(self, layout, unified_scale_group):
assert len(layout.transformations) == len(unified_scale_group)
for t, ref_tp in zip(layout.transformations, unified_scale_group):
assert isinstance(t, OVQuantizerInsertionCommand)
assert t.target_point == ref_tp
assert t.type == TransformationType.INSERT
assert np.isclose(t.quantizer_parameters.input_low.data, -4.031496)
assert np.isclose(t.quantizer_parameters.input_high.data, 4)

def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint:
return OVTargetPoint(target_type, target_node_name, port_id)

def get_backend_tensor(self, value):
return np.array(value)

@property
def metatypes_mapping(self):
return {
Conv2dTestMetatype: OVConvolutionMetatype,
LinearTestMetatype: OVMatMulMetatype,
SoftmaxTestMetatype: OVSoftmaxMetatype,
CatTestMetatype: OVConcatMetatype,
}

@property
def nncf_graph_cls(self):
return NNCFGraph

@pytest.fixture(scope="session")
def test_params(self):
linear_model = LinearModel().ov_model
Expand Down
Loading

0 comments on commit 35f1215

Please sign in to comment.