Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OV] Introduce support of quantization If operation #2101

Merged
merged 66 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
e7eddd0
Intorduce support of quantization If op for OV
kshpv Sep 1, 2023
846f0fd
add backend entities
kshpv Sep 1, 2023
9a6fcd4
code improvement
kshpv Sep 4, 2023
704a72d
typo
kshpv Sep 4, 2023
8314131
add reinitialization of cached variables for MinMax
kshpv Sep 5, 2023
3fe3b5e
Apply comments
kshpv Sep 6, 2023
9f7400c
update logic
kshpv Sep 6, 2023
e097cd4
Implement dfs approach
kshpv Sep 6, 2023
5139a18
typo
kshpv Sep 6, 2023
3b00f73
remove torch onnx backend impl
kshpv Sep 6, 2023
9ea3ff2
code improvements
kshpv Sep 6, 2023
b168892
make collect_dataitems_for_children_models common
kshpv Sep 7, 2023
c2473a9
generalize make_dataset_for_child_models
kshpv Sep 7, 2023
9201534
Merge remote-tracking branch 'remote/develop' into ov_if_op_support
kshpv Sep 7, 2023
c192331
unification
kshpv Sep 7, 2023
f20892d
generalize and remove set_child_model
kshpv Sep 7, 2023
6a7aeaf
update dataset calculation for submodels
kshpv Sep 8, 2023
9d0c27b
generalize logic
kshpv Sep 8, 2023
4970d9d
add extract if subgraph transform
kshpv Sep 8, 2023
589493b
Merge remote-tracking branch 'remote/develop' into ov_if_op_support
kshpv Sep 8, 2023
538c4b8
make private
kshpv Sep 8, 2023
10de119
code improvements
kshpv Sep 8, 2023
fb4559f
code improvements
kshpv Sep 8, 2023
08776a0
Merge remote-tracking branch 'remote/develop' into ov_if_op_support
kshpv Sep 11, 2023
24d88ca
doctrings; update method names
kshpv Sep 11, 2023
4ac706e
typo
kshpv Sep 11, 2023
09c8693
lint
kshpv Sep 11, 2023
9cd5dca
separate method for if condition input name and submodel input names
kshpv Sep 11, 2023
db79117
fix merge typos
kshpv Sep 11, 2023
68d04a8
lint
kshpv Sep 11, 2023
f374441
add minimum statistic sample
kshpv Sep 12, 2023
b7ec929
Merge remote-tracking branch 'remote/develop' into ov_if_op_support
kshpv Sep 12, 2023
0a9696d
apply comments
kshpv Sep 12, 2023
62c1954
make IF input quantizible for OV
kshpv Sep 12, 2023
25717a9
add hw config for IF to CPU
kshpv Sep 12, 2023
e35d1a3
Update placement of IF op quantization logic
kshpv Sep 15, 2023
af127ac
Remove model_cnt and dumping model method
kshpv Sep 15, 2023
dabd0d2
Rename module
kshpv Sep 15, 2023
bab6a71
remove intermediate_model_dir param for PTQ
kshpv Sep 15, 2023
1160ac5
revert MIN_SAMPLES_NUM
kshpv Sep 15, 2023
88873c6
typo
kshpv Sep 15, 2023
443332c
revert collectores changes
kshpv Sep 15, 2023
2958405
optimize dataset collection
kshpv Sep 15, 2023
b8e1dc6
update main method
kshpv Sep 15, 2023
67032d4
Update method name
kshpv Sep 15, 2023
f13ba41
revert formatting of config
kshpv Sep 18, 2023
9cd0f78
add error if turn on BiasCorrection
kshpv Sep 18, 2023
30d80a0
Merge remote-tracking branch 'remote/develop' into ov_if_op_support
kshpv Sep 18, 2023
3e831fb
Apply comments
kshpv Sep 18, 2023
587d5e7
update logs
kshpv Sep 19, 2023
8f34fc1
Improve logging
kshpv Sep 19, 2023
346951e
Apply comments
kshpv Sep 20, 2023
2c6f21b
typo
kshpv Sep 20, 2023
52adfd4
update logs
kshpv Sep 20, 2023
4135289
update track and revert FBC changes
kshpv Sep 20, 2023
7bd9e2e
update log msg
kshpv Sep 20, 2023
6e70357
Merge remote-tracking branch 'remote/develop' into ov_if_op_support
kshpv Sep 20, 2023
7a32199
Remove loggin from PTQ; put loggin into MinMax
kshpv Sep 21, 2023
6af3309
add WA for new OV
kshpv Sep 21, 2023
e9b4131
Merge remote-tracking branch 'remote/develop' into ov_if_op_support
kshpv Sep 21, 2023
b21cd59
lint
kshpv Sep 21, 2023
372141e
Merge remote-tracking branch 'remote/develop' into ov_if_op_support
kshpv Sep 22, 2023
df6cfcb
add graph test
kshpv Sep 22, 2023
1255f1e
minor
kshpv Sep 22, 2023
781e26c
typehints
kshpv Sep 22, 2023
36381a7
lint
kshpv Sep 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nncf/common/logging/track_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __init__(
TimeRemainingColumn(),
)
)

disable = disable or (hasattr(sequence, "__len__") and len(sequence) == 0)

self.progress = Progress(
*self.columns,
auto_refresh=auto_refresh,
Expand Down
6 changes: 6 additions & 0 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,12 @@ class OVAbsMetatype(OVOpMetatype):
op_names = ["Abs"]


@OV_OPERATOR_METATYPES.register()
class OVIfMetatype(OVOpMetatype):
name = "IfOp"
op_names = ["If"]
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved


@OV_OPERATOR_METATYPES.register()
class OVGroupNormalizationMetatype(OVOpMetatype):
name = "GroupNormalizationOp"
Expand Down
43 changes: 43 additions & 0 deletions nncf/openvino/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from nncf.openvino.graph.node_utils import get_result_node_name
from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand
from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVExtractIfBodyCommand
from nncf.openvino.graph.transformations.commands import OVFQNodeRemovingCommand
from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand
from nncf.openvino.graph.transformations.commands import OVModelExtractionCommand
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand
from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand
from nncf.openvino.graph.transformations.commands import OVUpdateIfBodyCommand
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
from nncf.quantization.fake_quantize import FakeQuantizeParameters

Expand All @@ -52,6 +54,8 @@ def __init__(self, model: TModel):
(OVOutputInsertionCommand, self._apply_output_insertion_transformations),
(OVBiasInsertionCommand, self._apply_bias_insertion_transformations),
(OVMultiplyInsertionCommand, self._apply_multiply_insertion_transformations),
(OVUpdateIfBodyCommand, self._apply_update_if_body_transformations),
(OVExtractIfBodyCommand, self._apply_extract_if_body_transformation),
]

@staticmethod
Expand Down Expand Up @@ -526,3 +530,42 @@ def _apply_multiply_insertion_transformations(
destination_port.replace_source_output(multiply_node.output(0))

return model

@staticmethod
def _apply_update_if_body_transformations(
model: ov.Model, transformations: List[OVUpdateIfBodyCommand]
) -> ov.Model:
"""
Update model body for IF node.

:param model: Model to update and insert a new subgraph.
:param transformations: Transformations with information of If node and an updated subgraph.
:return: Original model with an updated subgraph.
"""
name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model)
for transformation in transformations:
subgraph_model = transformation.subgraph_model
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved
port_id = transformation.target_point.port_id
node_name = transformation.target_point.target_node_name
node = name_to_node_mapping[node_name]
node.set_function(port_id, subgraph_model)
return model

@staticmethod
def _apply_extract_if_body_transformation(
model: ov.Model, transformations: List[OVExtractIfBodyCommand]
) -> ov.Model:
"""
Extract a model body from If node.

:param model: Model from which extracts a subgraph.
:param transformations: Transformations with information from which
If node and input port extract a model subgraph.
:return: Model subgraph.
"""
transformation = transformations[-1]
name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model)
ov_node = name_to_node_mapping[transformation.if_node_name]
if transformation.if_body_condition:
return ov.Model(ov_node.get_function(0)) # ticket: 121115
return ov.Model(ov_node.get_function(1)) # ticket: 121115
21 changes: 21 additions & 0 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from nncf.openvino.graph.metatypes.openvino_metatypes import OVAddMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConstantMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvertMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype

InplaceInsertionFnType = Callable[[ov.Node, int], ov.Node]

Expand All @@ -49,6 +51,25 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return bias_constant is not None


def get_number_if_op(model: ov.Model) -> int:
"""
Returns number of If operation in a model.

:param model: Model.
:return: True if Model has If operation, False - otherwise.
"""

def cnt_if_op(model: ov.Model, cnt: int) -> int:
for op in model.get_ops():
if get_node_metatype(op) == OVIfMetatype:
cnt += 1
cnt = cnt_if_op(op.get_function(0), cnt)
cnt = cnt_if_op(op.get_function(1), cnt)
return cnt

return cnt_if_op(model, 0)


def get_const_value(const_node: ov.Node) -> np.ndarray:
"""
Returns the constant tensor for the node.
Expand Down
38 changes: 38 additions & 0 deletions nncf/openvino/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import List

import numpy as np
import openvino.runtime as ov

from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TargetPoint
Expand Down Expand Up @@ -191,3 +192,40 @@ def __init__(
def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
raise NotImplementedError()


class OVUpdateIfBodyCommand(TransformationCommand):
"""
Updates If node body.
"""

def __init__(self, target_point: OVTargetPoint, body_model: ov.Model):
"""
:param target_point: The TargetPoint instance for the change that contains layer's information.
:param body_model: A new model to set.
"""
super().__init__(TransformationType.CHANGE, target_point)
self.subgraph_model = body_model

def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
raise NotImplementedError()


class OVExtractIfBodyCommand(Command):
"""
Extracts If node body.
"""

def __init__(self, if_node_name: str, if_body_condition: bool):
"""
:param target_point: The TargetPoint instance for the extraction that contains layer's information.
:param if_body_condition: If true extracts then body, else - else body.
"""
super().__init__(TransformationType.EXTRACT)
self.if_node_name = if_node_name
self.if_body_condition = if_body_condition

def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
raise NotImplementedError()
Loading