Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reenable scale unification #2199

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion nncf/common/hardware/configs/cpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@
"weights": ["q8_w_sym", "q8_w_asym"]
}
},
{"type": "EmbeddingBag"}
{"type": "EmbeddingBag"},
{
"type": "BatchNorm",
"quantization": {
"activations": ["q8_a_ch"]
}
}
]
}
1 change: 1 addition & 0 deletions nncf/common/hardware/opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ class HWConfigOpName:
GRUSEQUENCE = "GRUSequence"
GROUPNORMALIZATION = "GroupNormalization"
SCALED_DOT_PRODUCT_ATTENTION = "ScaledDotProductAttention"
BATCH_NORM = "BatchNorm"
38 changes: 7 additions & 31 deletions nncf/common/quantization/quantizer_propagation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,11 @@ def unify_pq_scales(
secondary_pq: PropagatingQuantizer,
unified_scale_type: Optional[UnifiedScaleType] = None,
):
if unified_scale_type is None:
primary_pq.unified_scale_type = UnifiedScaleType.UNIFY_ALWAYS
else:
primary_pq.unified_scale_type = unified_scale_type
if primary_pq.unified_scale_type is None:
if unified_scale_type is None:
primary_pq.unified_scale_type = UnifiedScaleType.UNIFY_ALWAYS
else:
primary_pq.unified_scale_type = unified_scale_type
secondary_pq.unified_scale_type = primary_pq.unified_scale_type
primary_gid = self._unified_scale_group_manager.get_group_id_by_propagating_quantizer_id(primary_pq.id)
if primary_gid is None:
Expand Down Expand Up @@ -774,40 +775,17 @@ def get_paths_to_immediately_dominating_insertion_points(
self, insertion_point_node_key: str
) -> List[PropagationPath]:
group_dict = self.get_paths_to_immediately_dominating_insertion_points_grouped_by_unified_scales(
insertion_point_node_key, set(), {}
insertion_point_node_key, set()
)
return group_dict[None]

def get_paths_to_immediately_dominating_insertion_points_grouped_by_unified_scales(
self,
insertion_point_node_key: str,
unified_scale_op_metatypes: Set[Type[OperatorMetatype]],
scales_unification_map: Dict[OperatorMetatype, OperatorMetatype],
self, insertion_point_node_key: str, unified_scale_op_metatypes: Set[Type[OperatorMetatype]]
) -> Dict[Optional[int], List[PropagationPath]]:
"""Paths are lists of edges."""
next_group_idx = 0
paths = {}

def followed_by_weighted_types(curr_node_key, curr_node_metatype) -> bool:
nodes_queue = deque(self.successors(curr_node_key))
while nodes_queue:
next_node_key = nodes_queue.popleft()
next_node = self.nodes[next_node_key]
next_node_type = next_node[QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
if next_node_type != QuantizerPropagationStateGraphNodeType.OPERATOR:
nodes_queue.extend(self.successors(next_node_key))
else:
next_node_metatype = next_node[QuantizerPropagationStateGraph.OPERATOR_METATYPE_NODE_ATTR]
next_node_trait = next_node[QuantizerPropagationStateGraph.QUANTIZATION_TRAIT_NODE_ATTR]
if (
next_node_trait == QuantizationTrait.QUANTIZATION_AGNOSTIC
or next_node_metatype in unified_scale_op_metatypes
):
nodes_queue.extend(self.successors(next_node_key))
if next_node_metatype in scales_unification_map[curr_node_metatype]:
return True
return False

def recursive_helper(curr_edge, curr_path, all_paths, curr_group):
nonlocal next_group_idx
curr_path.append(curr_edge)
Expand All @@ -828,8 +806,6 @@ def recursive_helper(curr_edge, curr_path, all_paths, curr_group):
curr_group is None,
len(self.in_edges(curr_node_key)) > 1,
]
if scales_unification_map is not None and metatype in scales_unification_map:
unify_conditions.append(followed_by_weighted_types(curr_node_key, metatype))
if all(unify_conditions):
curr_group = next_group_idx
next_group_idx += 1
Expand Down
6 changes: 1 addition & 5 deletions nncf/common/quantization/quantizer_propagation/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ def __init__(
quantize_outputs: bool = False,
post_processing_marker_metatypes: List[OperatorMetatype] = None,
metatypes_to_ignore: List[OperatorMetatype] = None,
scales_unification_map: Dict[OperatorMetatype, OperatorMetatype] = None,
):
"""
Initializes the solver with parameters affecting the resulting quantizer setup.
Expand Down Expand Up @@ -389,8 +388,6 @@ def __init__(
If None automatic ignoring will be skipped.
:param metatypes_to_ignore: The framework specific NNCF metatypes,
which should be automatically ignored.
:param scales_unification_map: The framework-specific map with NNCF metatypes, which generating a quantizer
that can be unified if it so requires based on metatype.
"""
if default_trait_to_metatype_map is None:
self._default_trait_to_metatype_map = {}
Expand Down Expand Up @@ -446,7 +443,6 @@ def __init__(
self._quantizable_layer_nodes = quantizable_layer_nodes
self._post_processing_marker_metatypes = post_processing_marker_metatypes
self._metatypes_to_ignore = metatypes_to_ignore
self._scales_unification_map = scales_unification_map

def _filter_by_weight_ignored_target_scopes(
self,
Expand Down Expand Up @@ -729,7 +725,7 @@ def propagation_step(
# only concat unified scale groups appear here
unified_scale_grouped_paths = (
quant_prop_graph.get_paths_to_immediately_dominating_insertion_points_grouped_by_unified_scales(
curr_node_key, self._unified_scales_operation_set, self._scales_unification_map
curr_node_key, self._unified_scales_operation_set
)
)

Expand Down
34 changes: 26 additions & 8 deletions nncf/common/quantization/quantizer_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,19 +251,19 @@ def add_independent_quantization_point(self, qp: QuantizationPointBase):

def register_unified_scale_group(self, qp_group: List[QuantizationPointId]) -> int:
for qp_id in qp_group:
gid = self.get_unified_scale_group_id(qp_id) is not None
if gid:
raise nncf.InternalError("QP id {} is already in unified scale group {}".format(qp_id, gid))
gid = self.get_unified_scale_group_id(qp_id)
if gid is not None:
raise RuntimeError(f"QP id {qp_id} is already in unified scale group {gid}")
gid = self._next_unified_scale_gid
self.unified_scale_groups[self._next_unified_scale_gid] = set(qp_group)
self._next_unified_scale_gid += 1
return gid

def register_shared_inputs_group(self, qp_group: List[QuantizationPointId]) -> int:
for qp_id in qp_group:
gid = self.get_shared_inputs_group_id(qp_id) is not None
if gid:
raise nncf.InternalError("QP id {} is already in shared input group {}".format(qp_id, gid))
gid = self.get_shared_inputs_group_id(qp_id)
if gid is not None:
raise RuntimeError(f"QP id {qp_id} is already in shared input group {gid}")
gid = self._next_shared_inputs_gid
self.shared_input_operation_set_groups[self._next_shared_inputs_gid] = set(qp_group)
self._next_shared_inputs_gid += 1
Expand Down Expand Up @@ -495,17 +495,35 @@ def select_qconfigs(
for qid in per_tensor_qids:
retval.remove_unified_scale_from_point(qid)

retval.register_unified_scale_group(list(per_tensor_qids))
if len(per_tensor_qids) > 1:
retval.register_unified_scale_group(list(per_tensor_qids))
else:
nncf_logger.debug(
"Not making a unified scale group out of single per-tensor quantizer remaining in "
"the group after segregating per-tensor and per-channel quantizers within same original "
"unified scale group"
)

remaining_per_channel_qids = []
for per_channel_qid in per_channel_qids:
retval.remove_unified_scale_from_point(per_channel_qid)
us_type = self._unified_scale_qpid_vs_type[per_channel_qid]
if us_type is UnifiedScaleType.UNIFY_ONLY_PER_TENSOR:
nncf_logger.debug(
"Per-channel quantizer config selected in a MultiConfigQuantizerSetup for a "
"unified scale point that only supports per-tensor scale unification, disabling "
"unified scales for this point."
)
retval.remove_unified_scale_from_point(per_channel_qid)
else:
remaining_per_channel_qids.append(per_channel_qid)

if len(remaining_per_channel_qids) > 1:
retval.register_unified_scale_group(list(remaining_per_channel_qids))
elif len(remaining_per_channel_qids) == 1:
nncf_logger.debug(
"Not making a unified scale group out of single per-channel quantizer remaining in "
"the group after removing quantizers with UnifiedScaleType.UNIFY_ONLY_PER_TENSOR"
)

return retval

Expand Down
1 change: 1 addition & 0 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class ONNXConcatMetatype(ONNXOpMetatype):
class ONNXBatchNormMetatype(ONNXOpMetatype):
name = "BatchNormalizationOp"
op_names = ["BatchNormalization"]
hw_config_names = [HWConfigOpName.BATCH_NORM]


@ONNX_OPERATION_METATYPES.register()
Expand Down
1 change: 1 addition & 0 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class OVConcatMetatype(OVOpMetatype):
class OVBatchNormMetatype(OVOpMetatype):
name = "BatchNormalizationOp"
op_names = ["BatchNormInference"]
hw_config_names = [HWConfigOpName.BATCH_NORM]


@OV_OPERATOR_METATYPES.register()
Expand Down
1 change: 0 additions & 1 deletion nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ def _get_quantizer_setup(
global_constraints=self._global_quantizer_constraints,
post_processing_marker_metatypes=post_processing_types,
metatypes_to_ignore=metatypes_to_ignore,
scales_unification_map=self._backend_entity.scales_unification_map,
scope_overrides=scope_overrides,
)

Expand Down
7 changes: 0 additions & 7 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,6 @@ def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:
Property for the backend-specific Scaled Dot Product Attention metatypes.
"""

@property
@abstractmethod
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
"""
Property for the backend-specific metatypes that produces quantizers that might be unified.
"""

@property
@abstractmethod
def hw_config(self) -> HWConfig:
Expand Down
4 changes: 0 additions & 4 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,6 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:
def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes}

@property
def hw_config(self) -> HWConfig:
return ONNXHWConfig
Expand Down
4 changes: 0 additions & 4 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:
def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:
return [om.OVScaledDotProductAttentionMetatype]

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.OVConcatMetatype: self.overflow_fix_metatypes}

@property
def hw_config(self) -> HWConfig:
return OVHWConfig
Expand Down
7 changes: 0 additions & 7 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def group_conv_metatypes(self) -> List[OperatorMetatype]:
def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.PTCatMetatype: self.overflow_fix_metatypes}

@property
def hw_config(self) -> HWConfig:
return PTHWConfig
Expand Down Expand Up @@ -333,9 +329,6 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
om.PTModuleLayerNormMetatype,
om.PTGroupNormMetatype,
om.PTModuleGroupNormMetatype,
# Batchnorm
om.PTBatchNormMetatype,
om.PTModuleBatchNormMetatype,
]
if device != TargetDevice.CPU_SPR:
types.append(om.PTMulMetatype)
Expand Down
1 change: 1 addition & 0 deletions nncf/tensorflow/graph/metatypes/keras_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class TFBatchNormalizationLayerMetatype(TFLayerWithWeightsMetatype):

weight_definitions = [WeightDef(weight_attr_name="gamma", channel_axes=0)]
bias_attr_name = "beta"
hw_config_names = [HWConfigOpName.BATCH_NORM]


@KERAS_LAYER_METATYPES.register()
Expand Down
7 changes: 0 additions & 7 deletions nncf/tensorflow/quantization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@
from nncf.tensorflow.graph.metatypes.common import ELEMENTWISE_LAYER_METATYPES
from nncf.tensorflow.graph.metatypes.common import GENERAL_CONV_LAYER_METATYPES
from nncf.tensorflow.graph.metatypes.common import LINEAR_LAYER_METATYPES
from nncf.tensorflow.graph.metatypes.keras_layers import TFConcatenateLayerMetatype
from nncf.tensorflow.graph.metatypes.keras_layers import TFLambdaLayerMetatype
from nncf.tensorflow.graph.metatypes.keras_layers import TFLayerWithWeightsMetatype
from nncf.tensorflow.graph.metatypes.tf_ops import TFConcatOpMetatype
from nncf.tensorflow.graph.metatypes.tf_ops import TFIdentityOpMetatype
from nncf.tensorflow.graph.metatypes.tf_ops import TFOpWithWeightsMetatype
from nncf.tensorflow.graph.transformations.commands import TFAfterLayer
Expand Down Expand Up @@ -620,10 +618,6 @@ def _get_quantizer_propagation_solution(
**{name: IgnoreReason.AUTOGENERATED for name in input_preprocessing_node_names},
**{name: IgnoreReason.AUTOGENERATED for name in custom_layer_node_names},
}
scales_unification_map = {
TFConcatenateLayerMetatype: GENERAL_CONV_LAYER_METATYPES + LINEAR_LAYER_METATYPES,
TFConcatOpMetatype: GENERAL_CONV_LAYER_METATYPES + LINEAR_LAYER_METATYPES,
}
solver = QuantizerPropagationSolver(
activation_ignored_scopes=ignored_scopes_for_solver,
weight_ignored_scopes=self.ignored_scopes_per_group[QuantizerGroup.WEIGHTS],
Expand All @@ -637,7 +631,6 @@ def _get_quantizer_propagation_solution(
quantizable_layer_nodes=quantizable_weighted_layer_nodes,
global_constraints=self.global_quantizer_constraints,
quantize_outputs=self.quantize_outputs,
scales_unification_map=scales_unification_map,
)

quantization_proposal = solver.run_on_ip_graph(ip_graph)
Expand Down
1 change: 1 addition & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ class PTBatchNormMetatype(PTOperatorMetatype):
name = "BatchNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]}
subtypes = [PTModuleBatchNormMetatype]
hw_config_names = [HWConfigOpName.BATCH_NORM]


@PT_OPERATOR_METATYPES.register()
Expand Down
4 changes: 0 additions & 4 deletions nncf/torch/quantization/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@
from nncf.torch.compression_method_api import PTCompressionAlgorithmBuilder
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.operator_metatypes import UNIFICATION_PRODUCING_METATYPES
from nncf.torch.graph.operator_metatypes import PTCatMetatype
from nncf.torch.graph.operator_metatypes import PTDepthwiseConv2dSubtype
from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype
from nncf.torch.graph.transformations.commands import PTInsertionCommand
Expand Down Expand Up @@ -350,7 +348,6 @@ def generate_setup(self) -> SingleConfigQuantizerSetup:
self._debug_interface.visualize_insertion_point_graph(insertion_point_graph)
from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationSolver

scales_unification_map = {PTCatMetatype: UNIFICATION_PRODUCING_METATYPES}
ignored_scopes_for_solver = {
name: IgnoreReason.USER_REQUESTED for name in self._ignored_scopes_per_group[QuantizerGroup.ACTIVATIONS]
}
Expand All @@ -369,7 +366,6 @@ def generate_setup(self) -> SingleConfigQuantizerSetup:
global_constraints=self.global_quantizer_constraints,
additional_unified_scale_op_scopes=self._unified_scale_ops,
quantize_outputs=self._quantize_outputs,
scales_unification_map=scales_unification_map,
)

merged_ip_graph = insertion_point_graph.get_ip_graph_with_merged_hw_optimized_operations(
Expand Down
10 changes: 10 additions & 0 deletions tests/common/quantization/metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ class DivideTestMetatype(TestMetatype):
name = "divide"


@METATYPES_FOR_TEST.register()
class GenericBinaryUnifiedScaleOpMetatype(TestMetatype):
name = "binary_unified_scale_op"


@METATYPES_FOR_TEST.register()
class GenericBinaryOpMetatype(TestMetatype):
name = "binary_op"


@METATYPES_FOR_TEST.register()
@INPUT_NOOP_METATYPES.register()
class ParameterTestMetatype(TestMetatype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,23 @@
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.quantization.structs import UnifiedScaleType
from tests.common.quantization.metatypes import WEIGHT_LAYER_METATYPES
from tests.common.quantization.metatypes import CatTestMetatype
from tests.common.quantization.metatypes import Conv2dTestMetatype
from tests.common.quantization.mock_graphs import get_ip_graph_for_test
from tests.common.quantization.mock_graphs import get_mock_nncf_node_attrs
from tests.common.quantization.mock_graphs import get_nncf_graph_from_mock_nx_graph
from tests.common.quantization.mock_graphs import get_two_branch_mock_model_graph
from tests.common.quantization.mock_graphs import mark_input_ports_lexicographically_based_on_input_node_key


def get_edge_paths(graph, start_node_key, finish_node_key) -> List[List[Tuple]]:
def get_edge_paths(graph: QPSG, start_node_key: str, finish_node_key: str) -> List[List[Tuple]]:
node_paths = list(nx.all_simple_paths(graph, start_node_key, finish_node_key))
edge_paths = []
for path in node_paths:
edge_paths.append([(path[i], path[i + 1]) for i in range(0, len(path) - 1)])
return edge_paths


def get_edge_paths_for_propagation(graph, start_node_key, finish_node_key) -> List[List[Tuple]]:
def get_edge_paths_for_propagation(graph: QPSG, start_node_key: str, finish_node_key: str) -> List[List[Tuple]]:
paths = get_edge_paths(graph, start_node_key, finish_node_key)
return [list(reversed(path)) for path in paths]

Expand All @@ -68,7 +66,6 @@ def mock_qp_graph():
qpsg = QPSG(ip_graph)

qpsg.nodes["5 /F_0"][QPSG.OPERATOR_METATYPE_NODE_ATTR] = CatTestMetatype
qpsg.nodes["6 /G_0"][QPSG.OPERATOR_METATYPE_NODE_ATTR] = Conv2dTestMetatype
qpsg.skip_check = False
yield qpsg
if not qpsg.skip_check:
Expand Down Expand Up @@ -287,7 +284,7 @@ def test_get_paths_to_immediately_dominating_insertion_points_grouped_by_unified
ref_groups_vs_paths = start_ip_node_and_dom_node_grouped_paths.ref_groups_vs_paths
test_groups_vs_paths = (
mock_qp_graph.get_paths_to_immediately_dominating_insertion_points_grouped_by_unified_scales(
start_node_key, {CatTestMetatype}, {CatTestMetatype: WEIGHT_LAYER_METATYPES}
start_node_key, {CatTestMetatype}
)
)

Expand Down
Loading
Loading