Skip to content

Commit

Permalink
[Quantization] Add quantizer filter pass to Solver (#2903)
Browse files Browse the repository at this point in the history
### Changes

Add a filtering pass in QuantizerPropagationSolver to remove unnessecary
quantizers.


### Reason for changes

The motivation is to remove not propagated quantizers before elementwise
operations with constants or constant subgraphs. These quantizers do not
influence performance and could badly influence accuracy metrics.

### Related tickets

144218

### Tests

Many tests reference were updated for all backends. I checked every
changes and for all of them the new changes look correct.

LSTM synthetic model from OpenVINO backend:

![image](https://github.com/user-attachments/assets/2f197cb0-8284-4caa-9895-6ff6faf7fa0f)

![image](https://github.com/user-attachments/assets/c5dfbf59-9e5e-4ff3-80eb-077cda45c722)
  • Loading branch information
kshpv authored Sep 2, 2024
1 parent 3454626 commit 03f65a7
Show file tree
Hide file tree
Showing 58 changed files with 54,481 additions and 13,271 deletions.
70 changes: 69 additions & 1 deletion nncf/common/quantization/quantizer_propagation/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,9 @@ def _filter_by_weight_ignored_target_scopes(
nncf_logger.debug(f"Ignored adding weight quantizer for: {node_name}")
return weight_quantizable_node_names_vs_qconfigs

def run_on_ip_graph(self, ip_graph: InsertionPointGraph) -> QuantizationProposal:
def run_on_ip_graph(
self, ip_graph: InsertionPointGraph, metatypes_for_filter: Optional[List[OperatorMetatype]] = None
) -> QuantizationProposal:
"""
The main function to be used on an InsertionPointGraph to produce
the list of insertion commands and configs corresponding to the desired quantized
Expand All @@ -479,6 +481,7 @@ def run_on_ip_graph(self, ip_graph: InsertionPointGraph) -> QuantizationProposal
:param ip_graph: The InsertionPointGraph, potentially with fused operations w.r.t. the
original model graph. The propagating quantizers will travel along the pre- and post-
hook nodes registered in this graph.
:param metatypes_for_filter: Metatypes are used for the removal criterion.
:return: The intermediate propagation state in the form of QuantizationProposal, which
defines unambiguously the locations of the propagating quantizers, but not the final
configurations.
Expand Down Expand Up @@ -513,6 +516,8 @@ def run_on_ip_graph(self, ip_graph: InsertionPointGraph) -> QuantizationProposal
iteration_counter += 1

quant_prop_graph = self._filter_integer_input_quantizers(quant_prop_graph)
if metatypes_for_filter:
quant_prop_graph = self._filter_quantizers_by_metatypes(quant_prop_graph, metatypes_for_filter)

if self._visualizer is not None:
self._visualizer.visualize_quantizer_propagation(self, quant_prop_graph, "proposed")
Expand Down Expand Up @@ -1597,3 +1602,66 @@ def _filter_integer_input_quantizers(
quant_prop_graph.remove_propagating_quantizer(integer_input_pq)

return quant_prop_graph

def _filter_quantizers_by_metatypes(
self, quant_prop_graph: QuantizerPropagationStateGraph, metatypes: List[OperatorMetatype]
) -> QuantizerPropagationStateGraph:
"""
Removes quantizers for which _is_quantizer_to_remove returns True.
:param quant_prop_graph: The quantizer propagation state graph.
:param metatypes: Metatypes are used for the removal criterion.
:return: Filtered quantizer propagation state graph.
"""

def _is_quantizer_to_remove(
quant_prop_graph: QuantizerPropagationStateGraph,
quantizer: PropagatingQuantizer,
metatypes: List[OperatorMetatype],
) -> bool:
"""
Returns True if the quantizer meets the criteria for removal. The criteria are as follows:
1. The quantizer is generated from a node whose metatype is in the provided metatypes.
2. The quantizer is not propagated.
3. The quantizer has only one child.
4. The quantized node generates only one activation quantizer.
The function relies on the fact that considered metatypes should have two inputs.
In that case, if considered node at InsertionPointGraph has only one input,
it means that the another one is a constant.
:param quant_prop_graph: The quantizer propagation state graph holding the `quantizer`.
:param quantizer: The propagating quantizer to be currently considered.
:param metatypes: Metatypes are used for the criterion.
:return: True if quantizer satisfies the criteria, otherwise - False.
"""
quantizer_children = quantizer.quantized_input_sink_operator_nodes
quantized_node_metatype = quant_prop_graph.nodes[quantized_node_key][
QuantizerPropagationStateGraph.OPERATOR_METATYPE_NODE_ATTR
]
quantizers_generated_for_node = quant_prop_graph.nodes[quantized_node_key][
quant_prop_graph.AFFECTING_PROPAGATING_QUANTIZERS_ATTR
]

is_one_quantizer_generated_for_node = len(quantizers_generated_for_node) == 1
is_one_child = len(quantizer_children) == 1
is_metatype_to_filter = quantized_node_metatype in metatypes
is_quantizer_not_propagated = len(quantizer.propagation_path) <= 1

return (
is_one_child
and is_metatype_to_filter
and is_one_quantizer_generated_for_node
and is_quantizer_not_propagated
)

quantizers = self._finished_propagating_quantizers
to_remove_quantizers = []
for quantizer in quantizers:
quantized_node_key = next(iter(quantizer.quantized_input_sink_operator_nodes))
if _is_quantizer_to_remove(quant_prop_graph, quantizer, metatypes):
nncf_logger.debug(f"Quantizer generated for a node {quantized_node_key} will be removed.")
to_remove_quantizers.append(quantizer)
for quantizer in to_remove_quantizers:
quant_prop_graph.remove_propagating_quantizer(quantizer)
self._finished_propagating_quantizers.remove(quantizer)
return quant_prop_graph
19 changes: 19 additions & 0 deletions nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,25 @@
onnx_metatypes.ONNXDivLayerMetatype,
]

ELEMENTWISE_OPERATIONS = [
onnx_metatypes.ONNXAddLayerMetatype,
onnx_metatypes.ONNXMulLayerMetatype,
onnx_metatypes.ONNXSubMetatype,
onnx_metatypes.ONNXDivLayerMetatype,
onnx_metatypes.ONNXLessMetatype,
onnx_metatypes.ONNXLessOrEqualMetatype,
onnx_metatypes.ONNXGreaterMetatype,
onnx_metatypes.ONNXGreaterOrEqualMetatype,
onnx_metatypes.ONNXEqualMetatype,
onnx_metatypes.ONNXModMetatype,
onnx_metatypes.ONNXOrMetatype,
onnx_metatypes.ONNXNotMetatype,
onnx_metatypes.ONNXAndMetatype,
onnx_metatypes.ONNXXOrMetatype,
onnx_metatypes.ONNXMaximumMetatype,
onnx_metatypes.ONNXMinimumMetatype,
onnx_metatypes.ONNXMeanMetatype,
]

OPERATIONS_WITH_WEIGHTS = [
*CONSTANT_WEIGHT_LAYER_METATYPES,
Expand Down
35 changes: 34 additions & 1 deletion nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,27 @@ class ONNXLessMetatype(ONNXOpMetatype):
hw_config_names = [HWConfigOpName.LESS]


@ONNX_OPERATION_METATYPES.register()
class ONNXLessOrEqualMetatype(ONNXOpMetatype):
name = "LessOrEqualOp"
op_names = ["LessOrEqual"]
hw_config_names = [HWConfigOpName.LESSEQUAL]


@ONNX_OPERATION_METATYPES.register()
class ONNXGreaterMetatype(ONNXOpMetatype):
name = "GreaterOp"
op_names = ["Greater"]
hw_config_names = [HWConfigOpName.GREATER]


@ONNX_OPERATION_METATYPES.register()
class ONNXGreaterOrEqualMetatype(ONNXOpMetatype):
name = "GreaterOrEqualOp"
op_names = ["GreaterOrEqual"]
hw_config_names = [HWConfigOpName.GREATEREQUAL]


@ONNX_OPERATION_METATYPES.register()
class ONNXEqualMetatype(ONNXOpMetatype):
name = "EqualOp"
Expand Down Expand Up @@ -378,6 +392,20 @@ class ONNXOrMetatype(ONNXOpMetatype):
hw_config_names = [HWConfigOpName.LOGICALOR]


@ONNX_OPERATION_METATYPES.register()
class ONNXXOrMetatype(ONNXOpMetatype):
name = "XorOp"
op_names = ["Xor"]
hw_config_names = [HWConfigOpName.LOGICALXOR]


@ONNX_OPERATION_METATYPES.register()
class ONNXModMetatype(ONNXOpMetatype):
name = "ModOp"
op_names = ["Mod"]
hw_config_names = [HWConfigOpName.FLOORMOD]


@ONNX_OPERATION_METATYPES.register()
class ONNXMaximumMetatype(ONNXOpMetatype):
name = "MaxOp"
Expand All @@ -392,11 +420,16 @@ class ONNXMinimumMetatype(ONNXOpMetatype):
hw_config_names = [HWConfigOpName.MINIMUM]


@ONNX_OPERATION_METATYPES.register()
class ONNXMeanMetatype(ONNXOpMetatype):
name = "MeanOp"
op_names = ["Mean"]


@ONNX_OPERATION_METATYPES.register()
class ONNXFloorMetatype(ONNXOpMetatype):
name = "FloorOp"
op_names = ["Floor"]
hw_config_names = [HWConfigOpName.FLOORMOD]


@ONNX_OPERATION_METATYPES.register()
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def _get_quantizer_setup(
scope_overrides=scope_overrides,
)

quantization_proposal = solver.run_on_ip_graph(ip_graph)
quantization_proposal = solver.run_on_ip_graph(ip_graph, self._backend_entity.elementwise_metatypes)
multi_config_setup = quantization_proposal.quantizer_setup
single_config_setup = multi_config_setup.select_first_qconfig_for_each_point()
finalized_proposal = quantization_proposal.finalize(single_config_setup)
Expand Down
7 changes: 7 additions & 0 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
Property for the backend-specific Dropout metatypes.
"""

@property
@abstractmethod
def elementwise_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific Elementwises metatypes.
"""

@property
@abstractmethod
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
Expand Down
5 changes: 5 additions & 0 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic
from nncf.onnx.graph.metatypes import onnx_metatypes as om
from nncf.onnx.graph.metatypes.groups import ELEMENTWISE_OPERATIONS
from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES
from nncf.onnx.graph.node_utils import get_input_edges_mapping
from nncf.onnx.graph.node_utils import get_quantized_tensor_shape
Expand Down Expand Up @@ -62,6 +63,10 @@ def post_processing_metatypes(self) -> List[OperatorMetatype]:
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXConvolutionMetatype]

@property
def elementwise_metatypes(self) -> List[OperatorMetatype]:
return ELEMENTWISE_OPERATIONS

@property
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [
Expand Down
5 changes: 5 additions & 0 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.metatypes import openvino_metatypes as om
from nncf.openvino.graph.metatypes.groups import ELEMENTWISE_OPERATIONS
from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_WEIGHTS
from nncf.openvino.graph.model_utils import get_start_nodes_for_activation_path_tracing
from nncf.openvino.graph.node_utils import get_weight_channel_axes
Expand Down Expand Up @@ -59,6 +60,10 @@ def post_processing_metatypes(self) -> List[OperatorMetatype]:
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.OVConvolutionMetatype]

@property
def elementwise_metatypes(self) -> List[OperatorMetatype]:
return ELEMENTWISE_OPERATIONS

@property
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [
Expand Down
5 changes: 5 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.operator_metatypes import ELEMENTWISE_OPERATIONS
from nncf.torch.graph.transformations.command_creation import create_quantizer_insertion_command
from nncf.torch.graph.transformations.command_creation import create_shared_quantizer_insertion_command
from nncf.torch.graph.transformations.commands import PTInsertionCommand
Expand Down Expand Up @@ -89,6 +90,10 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype]

@property
def elementwise_metatypes(self) -> List[OperatorMetatype]:
return ELEMENTWISE_OPERATIONS

@property
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [
Expand Down
5 changes: 5 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.operator_metatypes import ELEMENTWISE_OPERATIONS
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
Expand Down Expand Up @@ -83,6 +84,10 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype]

@property
def elementwise_metatypes(self) -> List[OperatorMetatype]:
return ELEMENTWISE_OPERATIONS

@property
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [
Expand Down
4 changes: 3 additions & 1 deletion nncf/tensorflow/quantization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ def _build_insertion_commands_for_quantizer_setup(
non_unified_scales_quantization_point_ids = set(range(len(quantization_points)))

for unified_scales_group in quantizer_setup.get_unified_scale_groups():
if not unified_scales_group:
continue
us_qp_id = unified_scales_group[0]
qp = quantization_points[us_qp_id]
quantizer_spec = qp.quantizer_spec
Expand Down Expand Up @@ -640,7 +642,7 @@ def _get_quantizer_propagation_solution(
scales_unification_map=scales_unification_map,
)

quantization_proposal = solver.run_on_ip_graph(ip_graph)
quantization_proposal = solver.run_on_ip_graph(ip_graph, ELEMENTWISE_LAYER_METATYPES)
multi_config_setup = quantization_proposal.quantizer_setup
single_config_setup = multi_config_setup.select_first_qconfig_for_each_point()
finalized_proposal = quantization_proposal.finalize(single_config_setup)
Expand Down
19 changes: 19 additions & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,25 @@ def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
PTModuleLinearMetatype,
]

ELEMENTWISE_OPERATIONS = [
PTAddMetatype,
PTMulMetatype,
PTSubMetatype,
PTDivMetatype,
PTLessMetatype,
PTLessEqualMetatype,
PTGreaterMetatype,
PTGreaterEqualMetatype,
PTEqualsMetatype,
PTNotEqualMetatype,
PTModMetatype,
PTLogicalOrMetatype,
PTLogicalXorMetatype,
PTLogicalAndMetatype,
PTMaxMetatype,
PTMinMetatype,
]

OP_NAMES_WITH_WEIGHTS = [x for meta in OPERATORS_WITH_WEIGHTS_METATYPES for x in meta.get_all_aliases()]

QUANTIZE_NODE_TYPES = [
Expand Down
3 changes: 2 additions & 1 deletion nncf/torch/quantization/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from nncf.torch.compression_method_api import PTCompressionAlgorithmBuilder
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import ELEMENTWISE_OPERATIONS
from nncf.torch.graph.operator_metatypes import UNIFICATION_PRODUCING_METATYPES
from nncf.torch.graph.operator_metatypes import PTCatMetatype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
Expand Down Expand Up @@ -375,7 +376,7 @@ def generate_setup(self) -> SingleConfigQuantizerSetup:
merged_ip_graph = insertion_point_graph.get_ip_graph_with_merged_hw_optimized_operations(
self._pattern_fusing_graph
)
quantization_proposal = prop_graph_solver.run_on_ip_graph(merged_ip_graph)
quantization_proposal = prop_graph_solver.run_on_ip_graph(merged_ip_graph, ELEMENTWISE_OPERATIONS)
self._num_potential_quantized_activations = prop_graph_solver.get_num_potential_quantized_activations()

quantizer_setup = deepcopy(quantization_proposal.quantizer_setup)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
strict digraph {
"0 Reciprocal" [id=0, type=Reciprocal];
"1 Cast" [id=1, type=Cast];
"2 Mul" [id=2, type=Mul];
"2 Conv1" [id=2, type=Conv];
"3 nncf_model_input_0" [id=3, type=nncf_model_input];
"4 nncf_model_output_0" [id=4, type=nncf_model_output];
"0 Reciprocal" -> "1 Cast" [label="[1, 3, 10, 10]", style=dashed];
"1 Cast" -> "2 Mul" [label="[1, 3, 10, 10]", style=solid];
"2 Mul" -> "4 nncf_model_output_0" [label="[1, 3, 10, 10]", style=solid];
"1 Cast" -> "2 Conv1" [label="[1, 3, 10, 10]", style=solid];
"2 Conv1" -> "4 nncf_model_output_0" [label="[1, 3, 10, 10]", style=solid];
"3 nncf_model_input_0" -> "0 Reciprocal" [label="[1, 3, 10, 10]", style=dashed];
}
Loading

0 comments on commit 03f65a7

Please sign in to comment.