Skip to content

Commit

Permalink
[PTQ] Metatypes with bias list alignment (#2506)
Browse files Browse the repository at this point in the history
### Changes

- Reduced metatypes in `OPERATIONS_WITH_BIAS` list for ONNX;
- Reduced metatypes in `OPERATIONS_WITH_BIAS` list for PyTorch;
- Separated GroupConvolution from the Convolutions on ONNX;

### Reason for changes

- Alignment for algorithms between backends;

### Related tickets

- 133198
- 104166
- 106469

### Tests

- post_training_quantization/323 - failed with regressions for ONNX (as
expected)
- post_training_quantization/324 - passed (updated references)
  • Loading branch information
KodiaqQ authored Mar 21, 2024
1 parent e4f0970 commit 08cb49e
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 19 deletions.
3 changes: 2 additions & 1 deletion nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@
# Contains the operation metatypes for which bias can be applied.
OPERATIONS_WITH_BIAS = [
onnx_metatypes.ONNXConvolutionMetatype,
onnx_metatypes.ONNXDepthwiseConvolutionMetatype,
onnx_metatypes.ONNXGemmMetatype,
# TODO: Need to add MatMul with the separate bias support (CVS-135433)
]


Expand Down
49 changes: 40 additions & 9 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import deque
from typing import Dict, List, Optional, Type

import onnx

import nncf
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry
from nncf.common.hardware.opset import HWConfigOpName
Expand Down Expand Up @@ -44,14 +44,15 @@ def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
@classmethod
def determine_subtype(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> Optional[Type[OperatorMetatype]]:
matches = []
for subtype in cls.get_subtypes():
subtypes_list = deque(cls.get_subtypes())
while subtypes_list:
subtype = subtypes_list.popleft()
if subtype.matches(model, node):
subtypes_list.extend(subtype.get_subtypes())
matches.append(subtype)
if len(matches) > 1:
raise nncf.InternalError("Multiple subtypes match operator call - cannot determine single subtype.")
if not matches:
return None
return matches[0]
return matches[-1]


class ONNXOpWithWeightsMetatype(ONNXOpMetatype):
Expand Down Expand Up @@ -85,6 +86,22 @@ def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
return _is_depthwise_conv(model, node)


@ONNX_OPERATION_METATYPES.register()
class ONNXGroupConvolutionMetatype(ONNXOpWithWeightsMetatype):
name = "GroupConvOp"
op_names = ["Conv"]
hw_config_names = [HWConfigOpName.CONVOLUTION]
weight_channel_axis = 0
weight_port_ids = [1]
bias_port_id = 2
output_channel_axis = 1
subtypes = [ONNXDepthwiseConvolutionMetatype]

@classmethod
def matches(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
return _is_group_conv(node)


@ONNX_OPERATION_METATYPES.register()
class ONNXConvolutionMetatype(ONNXOpWithWeightsMetatype):
name = "ConvOp"
Expand All @@ -94,7 +111,7 @@ class ONNXConvolutionMetatype(ONNXOpWithWeightsMetatype):
weight_port_ids = [1]
bias_port_id = 2
output_channel_axis = 1
subtypes = [ONNXDepthwiseConvolutionMetatype]
subtypes = [ONNXGroupConvolutionMetatype]


@ONNX_OPERATION_METATYPES.register()
Expand Down Expand Up @@ -699,6 +716,23 @@ def get_tensor_edge_name(
return None


def _is_group_conv(node: onnx.NodeProto) -> bool:
"""
Returns True if the convolution is group, False - otherwise.
Group convolution is a convolution with the group attribute.
:param node: Convolution node to check whether it is depthwise.
:return: True if the convolution is group, False - otherwise.
"""
conv_group = None
for attribute in node.attribute:
if attribute.name == "group":
conv_group = onnx.helper.get_attribute_value(attribute)
if conv_group is None or conv_group == 1:
return False
return True


def _is_depthwise_conv(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
"""
Returns True if the convolution is depthwise, False - otherwise.
Expand All @@ -711,12 +745,9 @@ def _is_depthwise_conv(model: onnx.ModelProto, node: onnx.NodeProto) -> bool:
:param node: Convolution node to check whether it is depthwise.
:return: True if the convolution is depthwise, False - otherwise.
"""
conv_group = None
for attribute in node.attribute:
if attribute.name == "group":
conv_group = onnx.helper.get_attribute_value(attribute)
if conv_group is None:
return False
weight_tensor_value = None
initializer_name = node.input[1]
for init in model.graph.initializer:
Expand Down
7 changes: 1 addition & 6 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,12 +1055,7 @@ def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
PTModuleConv1dMetatype,
PTModuleConv2dMetatype,
PTModuleConv3dMetatype,
PTDepthwiseConv1dSubtype,
PTDepthwiseConv2dSubtype,
PTDepthwiseConv3dSubtype,
PTModuleConvTranspose1dMetatype,
PTModuleConvTranspose2dMetatype,
PTModuleConvTranspose3dMetatype,
# TODO: Need to add Linear support (CVS-111111)
]

OPERATORS_FUSED_METATYPES = [
Expand Down
6 changes: 3 additions & 3 deletions tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ timm/deit3_small_patch16_224_backend_CUDA_TORCH:
timm/deit3_small_patch16_224_backend_FP32:
metric_value: 0.81358
timm/deit3_small_patch16_224_backend_ONNX:
metric_value: 0.81268
metric_value: 0.81154
timm/deit3_small_patch16_224_backend_OV:
metric_value: 0.81276
timm/deit3_small_patch16_224_backend_TORCH:
Expand Down Expand Up @@ -131,7 +131,7 @@ timm/mobilenetv2_050_backend_CUDA_TORCH:
timm/mobilenetv2_050_backend_FP32:
metric_value: 0.6594
timm/mobilenetv2_050_backend_ONNX:
metric_value: 0.65462
metric_value: 0.65332
timm/mobilenetv2_050_backend_OV:
metric_value: 0.65282
timm/mobilenetv2_050_backend_TORCH:
Expand All @@ -141,7 +141,7 @@ timm/mobilenetv3_small_050_backend_CUDA_TORCH:
timm/mobilenetv3_small_050_backend_FP32:
metric_value: 0.57906
timm/mobilenetv3_small_050_backend_ONNX:
metric_value: 0.5617
metric_value: 0.42104
timm/mobilenetv3_small_050_backend_OV:
metric_value: 0.42184
timm/mobilenetv3_small_050_backend_TORCH:
Expand Down

0 comments on commit 08cb49e

Please sign in to comment.