Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTQ][Torch][KQV self attention] Align FQ placement between OV and Torch backend #2166

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4260614
Add softmax -> dropout -> mm <- non pattern pattern / add new names t…
daniil-lyakhov Sep 29, 2023
9698e6c
LayerNorm metatype was added to ignored metatypes for MinMax
daniil-lyakhov Oct 4, 2023
39d785b
Add GroupNorm to ignored MinMax metatypes
daniil-lyakhov Oct 5, 2023
4aa0811
Duplicates are removed from ignored torch patterns
daniil-lyakhov Oct 5, 2023
c6dcc08
Dropout removing pass is added to function
daniil-lyakhov Oct 6, 2023
48da94e
metatypes_to_ignore quantization propagation solver test
daniil-lyakhov Oct 9, 2023
42fa63b
Fix test_passes
daniil-lyakhov Oct 9, 2023
8abc6ea
get_inference_graph method is refactored
daniil-lyakhov Oct 11, 2023
bb90e81
Metrics update
daniil-lyakhov Oct 11, 2023
b4bb243
get_inference_graph fix
daniil-lyakhov Oct 11, 2023
c4225e6
Merge remote-tracking branch 'origin/develop' into dl/torch/patterns_…
daniil-lyakhov Oct 11, 2023
9837d2b
ptq_params test microfix
daniil-lyakhov Oct 11, 2023
d478e2b
Revert transform_to_inferece_graph function
daniil-lyakhov Oct 17, 2023
237ba50
Fix tests
daniil-lyakhov Oct 17, 2023
0638a35
Merge remote-tracking branch 'origin/develop' into dl/torch/patterns_…
daniil-lyakhov Oct 18, 2023
2ecb439
Metrics update
daniil-lyakhov Oct 18, 2023
a662180
Dropout removal original graph reference
daniil-lyakhov Oct 18, 2023
a9d99d0
Clean
daniil-lyakhov Oct 18, 2023
03c2b5e
function remove_dropout_node is refactored
daniil-lyakhov Oct 19, 2023
ce14175
Make passes return values
daniil-lyakhov Oct 19, 2023
5b962ab
Merge remote-tracking branch 'origin/develop' into dl/torch/patterns_…
daniil-lyakhov Oct 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,12 +594,22 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph
attrs_edge = {}
u = u.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
v = v.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
label = {}
if edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]:
label["parallel_input_port_ids"] = edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]

if extended:
if edge[NNCFGraph.DTYPE_EDGE_ATTR] is Dtype.INTEGER:
attrs_edge["style"] = "dashed"
else:
attrs_edge["style"] = "solid"
attrs_edge["label"] = edge[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]
label["shape"] = edge[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]

if label:
if len(label) == 1 and extended:
attrs_edge["label"] = label.popitem()[1]
else:
attrs_edge["label"] = ", ".join((f"{k}:{v}" for k, v in label.items()))
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved
out_graph.add_edge(u, v, **attrs_edge)
return out_graph

Expand Down
7 changes: 4 additions & 3 deletions nncf/quantization/algorithms/accuracy_control/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nncf.quantization.algorithms.accuracy_control.evaluator import Evaluator
from nncf.quantization.algorithms.accuracy_control.rank_functions import create_normalized_mse_func
from nncf.quantization.algorithms.accuracy_control.subset_selection import select_subset
from nncf.quantization.passes import remove_shapeof_subgraphs
from nncf.quantization.passes import remove_shapeof_subgraphs_inplace

TModel = TypeVar("TModel")
TPModel = TypeVar("TPModel")
Expand Down Expand Up @@ -98,8 +98,9 @@ def find_groups_of_quantizers_to_rank(self, quantized_model_graph: NNCFGraph) ->
if x.metatype in self._algo_backend.get_quantizer_metatypes()
]

quantized_model_graph_without_shapeof = remove_shapeof_subgraphs(
deepcopy(quantized_model_graph), self._algo_backend.get_shapeof_metatypes()
quantized_model_graph_without_shapeof = deepcopy(quantized_model_graph)
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved
remove_shapeof_subgraphs_inplace(
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
quantized_model_graph_without_shapeof, self._algo_backend.get_shapeof_metatypes()
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved
)

for quantizer_node in reversed(quantizers):
Expand Down
5 changes: 1 addition & 4 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
from nncf.quantization.algorithms.min_max.backend import ALGO_BACKENDS
from nncf.quantization.fake_quantize import calculate_quantizer_parameters
from nncf.quantization.fake_quantize import get_quantizer_narrow_range
from nncf.quantization.passes import transform_to_inference_graph
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.quantization.range_estimator import RangeEstimatorParametersSet
from nncf.scopes import IgnoredScope
Expand Down Expand Up @@ -504,9 +503,7 @@ def _get_quantization_target_points(
)
hw_patterns = PatternsManager.get_full_hw_pattern_graph(backend=backend, device=device, model_type=model_type)

inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph), self._backend_entity.shapeof_metatypes, self._backend_entity.read_variable_metatypes
)
inference_nncf_graph = self._backend_entity.transform_to_inference_graph(nncf_graph)

quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
self._apply_model_type_pass(self._model_type, quantizer_setup, nncf_graph)
Expand Down
21 changes: 7 additions & 14 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,6 @@ def post_processing_metatypes(self) -> List[OperatorMetatype]:
Property for the backend-specific post-processing metatypes (NonMaximumSupression, TopK, etc.).
"""

@property
@abstractmethod
def shapeof_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific ShapeOf metatypes.
"""

@property
@abstractmethod
def conv_metatypes(self) -> List[OperatorMetatype]:
Expand All @@ -70,13 +63,6 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
Property for the backend-specific metatypes for which overflow_fix is applicable.
"""

@property
@abstractmethod
def read_variable_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific metatypes that also can be interpreted as inputs (ReadValue).
"""

@property
@abstractmethod
def add_metatypes(self) -> List[OperatorMetatype]:
Expand Down Expand Up @@ -174,6 +160,13 @@ def get_statistic_collector(
:return: Backend-specific TensorStatisticCollectorBase for the statistics calculation.
"""

@staticmethod
@abstractmethod
def transform_to_inference_graph(graph: NNCFGraph) -> NNCFGraph:
"""
Returns inference NNCFGraph without constant flows and training time operations.
"""

@staticmethod
@abstractmethod
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
Expand Down
20 changes: 12 additions & 8 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Dict, List, Optional, Set, Union

import numpy as np
Expand Down Expand Up @@ -42,6 +43,8 @@
from nncf.quantization.algorithms.min_max.backend import ALGO_BACKENDS
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.quantization.passes import filter_constant_nodes_inplace
from nncf.quantization.passes import remove_shapeof_subgraphs_inplace
from nncf.quantization.range_estimator import RangeEstimatorParameters


Expand All @@ -56,10 +59,6 @@ def mat_mul_metatypes(self) -> List[OperatorMetatype]:
def post_processing_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXTopKMetatype, om.ONNXNonMaxSuppressionMetatype]

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXShapeMetatype]

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXConvolutionMetatype]
Expand All @@ -68,10 +67,6 @@ def conv_metatypes(self) -> List[OperatorMetatype]:
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXConvolutionMetatype, om.ONNXConvolutionTransposeMetatype, *MATMUL_METATYPES]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXAddLayerMetatype]
Expand Down Expand Up @@ -170,6 +165,15 @@ def get_statistic_collector(
f"{str(range_estimator_params)}"
)

@staticmethod
def transform_to_inference_graph(graph: NNCFGraph) -> NNCFGraph:
inference_graph = deepcopy(graph)
remove_shapeof_subgraphs_inplace(
nncf_graph=inference_graph, shapeof_metatypes=[om.ONNXShapeMetatype], read_variable_metatypes=[]
)
filter_constant_nodes_inplace(nncf_graph=inference_graph)
return inference_graph

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
return list(node.layer_attributes.weight_attrs.keys())
Expand Down
23 changes: 15 additions & 8 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
Expand Down Expand Up @@ -43,6 +44,8 @@
from nncf.quantization.algorithms.min_max.backend import ALGO_BACKENDS
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.quantization.passes import filter_constant_nodes_inplace
from nncf.quantization.passes import remove_shapeof_subgraphs_inplace


# pylint:disable=too-many-public-methods
Expand All @@ -56,10 +59,6 @@ def mat_mul_metatypes(self) -> List[OperatorMetatype]:
def post_processing_metatypes(self) -> List[OperatorMetatype]:
return [om.OVTopKMetatype, om.OVNonMaxSuppressionMetatype]

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.OVShapeOfMetatype]

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.OVConvolutionMetatype]
Expand All @@ -74,10 +73,6 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
om.OVMatMulMetatype,
]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return [om.OVReadValueMetatype]

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.OVAddMetatype]
Expand Down Expand Up @@ -200,6 +195,18 @@ def get_statistic_collector(
collector.register_statistic_branch(container_key, reducer, aggregator)
return collector

@staticmethod
def transform_to_inference_graph(graph: NNCFGraph) -> NNCFGraph:
inference_graph = deepcopy(graph)
read_variable_metatypes = [om.OVReadValueMetatype]
remove_shapeof_subgraphs_inplace(
nncf_graph=inference_graph,
shapeof_metatypes=[om.OVShapeOfMetatype],
read_variable_metatypes=read_variable_metatypes,
)
filter_constant_nodes_inplace(nncf_graph=inference_graph, read_variable_metatypes=read_variable_metatypes)
return inference_graph

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
return node.layer_attributes.get_const_port_ids()
Expand Down
22 changes: 14 additions & 8 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Dict, List, Optional, Set, Tuple

import torch
Expand All @@ -32,6 +33,8 @@
from nncf.quantization.algorithms.min_max.backend import ALGO_BACKENDS
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.quantization.passes import filter_constant_nodes_inplace
from nncf.quantization.passes import remove_dropout_nodes_inplace
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.torch.graph.graph import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand
Expand Down Expand Up @@ -65,10 +68,6 @@ def mat_mul_metatypes(self) -> List[OperatorMetatype]:
def post_processing_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
Expand All @@ -85,10 +84,6 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
om.PTModuleConvTranspose3dMetatype,
]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.PTAddMetatype]
Expand Down Expand Up @@ -188,6 +183,13 @@ def get_statistic_collector(
collector.register_statistic_branch(container_key, reducer, aggregator)
return collector

@staticmethod
def transform_to_inference_graph(graph: NNCFGraph) -> NNCFGraph:
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
inference_graph = deepcopy(graph)
remove_dropout_nodes_inplace(nncf_graph=inference_graph, dropout_metatypes=[om.PTDropoutMetatype])
filter_constant_nodes_inplace(nncf_graph=inference_graph)
return inference_graph

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]:
return [None]
Expand Down Expand Up @@ -307,6 +309,10 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
om.PTDivMetatype,
om.PTMaxMetatype,
om.PTSqueezeMetatype,
om.PTLayerNormMetatype,
om.PTModuleLayerNormMetatype,
om.PTGroupNormMetatype,
om.PTModuleGroupNormMetatype,
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
]
if device != TargetDevice.CPU_SPR:
types.append(om.PTMulMetatype)
Expand Down
Loading