Skip to content

Commit

Permalink
[ONNX] Fix weight quantization for GroupConvolution (#3126)
Browse files Browse the repository at this point in the history
### Reason for changes

Not quantized weights for GroupConv

### Related tickets

158085

### Tests

ptq perf run 84
  • Loading branch information
kshpv authored Dec 5, 2024
1 parent 5189aab commit 6031ccc
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 73 deletions.
22 changes: 12 additions & 10 deletions nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# limitations under the License.

from nncf.onnx.graph.metatypes import onnx_metatypes
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import get_operator_metatypes

QUANTIZE_AGNOSTIC_OPERATIONS = [
onnx_metatypes.ONNXGlobalMaxPoolMetatype,
Expand Down Expand Up @@ -67,14 +69,19 @@
onnx_metatypes.ONNXMinimumMetatype,
]


CONSTANT_WEIGHT_LAYER_METATYPES = [
onnx_metatypes.ONNXConvolutionMetatype,
onnx_metatypes.ONNXDepthwiseConvolutionMetatype,
onnx_metatypes.ONNXConvolutionTransposeMetatype,
onnx_metatypes.ONNXEmbeddingMetatype,
metatype
for metatype in get_operator_metatypes()
if issubclass(metatype, ONNXOpWithWeightsMetatype) and metatype.weight_port_ids
]

POSSIBLE_WEIGHT_LAYER_METATYPES = [
metatype
for metatype in get_operator_metatypes()
if issubclass(metatype, ONNXOpWithWeightsMetatype) and metatype.possible_weight_ports
]

OPERATIONS_WITH_WEIGHTS = list(set().union(CONSTANT_WEIGHT_LAYER_METATYPES, POSSIBLE_WEIGHT_LAYER_METATYPES))

LINEAR_OPERATIONS = [
onnx_metatypes.ONNXConvolutionMetatype,
Expand Down Expand Up @@ -124,11 +131,6 @@
onnx_metatypes.ONNXMeanMetatype,
]

OPERATIONS_WITH_WEIGHTS = [
*CONSTANT_WEIGHT_LAYER_METATYPES,
*MATMUL_METATYPES,
]


BATCH_NORMALIZATION_OPERATIONS = [
onnx_metatypes.ONNXBatchNormMetatype,
Expand Down
20 changes: 9 additions & 11 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ def determine_subtype(cls, model: onnx.ModelProto, node: onnx.NodeProto) -> Opti
class ONNXOpWithWeightsMetatype(ONNXOpMetatype):
"""
Metatype which could have weights.
:param weight_channel_axis: Axis for weight per-channel quantization, meaning the number of output filters.
:param weight_port_ids: Input ports of the node's weight.
If the value is None the weight_port_id should be determined dynamically.
:param bias_port_id: Input port of the node's bias.
If the value is None it means that the Metatype does not have bias.
:param weight_channel_axis: Axis for weight per-channel quantization.
:param weight_port_ids: Constant input ports of the node's weight. Defaults to an empty list.
:param bias_port_id: Input port of the node's bias. If the value is None,
it means that the Metatype does not have bias. Defaults to None.
:param possible_weight_ports: Input ports on which weight could be laid. Defaults to an empty list.
"""

weight_channel_axis: int
weight_port_ids: Optional[List[int]] = None
weight_port_ids: List[int] = []
bias_port_id: Optional[int] = None
possible_weight_ports: List[int] = []


@ONNX_OPERATION_METATYPES.register(is_subtype=True)
Expand Down Expand Up @@ -131,19 +131,17 @@ class ONNXGemmMetatype(ONNXOpWithWeightsMetatype):
op_names = ["Gemm"]
hw_config_names = [HWConfigOpName.MATMUL]
weight_channel_axis = -1 # For port_id=1
weight_port_ids = None
bias_port_id = 2
possible_weight_ports = [0, 1]
output_channel_axis = -1


@ONNX_OPERATION_METATYPES.register()
class ONNXMatMulMetatype(ONNXOpMetatype):
class ONNXMatMulMetatype(ONNXOpWithWeightsMetatype):
name = "MatMulOp"
op_names = ["MatMul"]
hw_config_names = [HWConfigOpName.MATMUL]
weight_channel_axis = -1 # For port_id=1
weight_port_ids = None
bias_port_id = 2
possible_weight_ports = [0, 1]
output_channel_axis = -1
Expand Down Expand Up @@ -454,7 +452,7 @@ class ONNXReciprocalMetatype(ONNXOpMetatype):


@ONNX_OPERATION_METATYPES.register(is_subtype=True)
class ONNXEmbeddingMetatype(ONNXOpMetatype):
class ONNXEmbeddingMetatype(ONNXOpWithWeightsMetatype):
name = "EmbeddingOp"
hw_config_names = [HWConfigOpName.EMBEDDING]
weight_port_ids = [0]
Expand Down
4 changes: 2 additions & 2 deletions nncf/onnx/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from nncf.common.graph.operator_metatypes import InputNoopMetatype
from nncf.common.graph.operator_metatypes import OutputNoopMetatype
from nncf.onnx.graph.metatypes.groups import CONSTANT_WEIGHT_LAYER_METATYPES
from nncf.onnx.graph.metatypes.groups import MATMUL_METATYPES
from nncf.onnx.graph.metatypes.groups import OPERATIONS_WITH_BIAS
from nncf.onnx.graph.metatypes.groups import POSSIBLE_WEIGHT_LAYER_METATYPES
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXGemmMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXOpWithWeightsMetatype
Expand Down Expand Up @@ -95,7 +95,7 @@ def get_possible_weight_port_ids(metatype: ONNXOpMetatype) -> List[int]:
:param metatype: Metatype.
:return: Port ids.
"""
if metatype in MATMUL_METATYPES:
if metatype in POSSIBLE_WEIGHT_LAYER_METATYPES:
return metatype.possible_weight_ports
return []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: onnx.ModelProto
@staticmethod
def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]:
activation_port = 0

if hasattr(node.metatype, "possible_weight_ports"):
if node.metatype.possible_weight_ports:
activation_ports = deepcopy(node.metatype.possible_weight_ports)
for weight_port in node.layer_attributes.weight_attrs:
activation_ports.remove(weight_port)
Expand Down
10 changes: 5 additions & 5 deletions tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ torchvision/resnet18_backend_FX_TORCH:
torchvision/mobilenet_v3_small_BC_backend_FP32:
metric_value: 0.6766
torchvision/mobilenet_v3_small_BC_backend_OV:
metric_value: 0.6669
metric_value: 0.6681
torchvision/mobilenet_v3_small_BC_backend_ONNX:
metric_value: 0.6679
torchvision/mobilenet_v3_small_BC_backend_FX_TORCH:
Expand Down Expand Up @@ -103,7 +103,7 @@ timm/dpn68_backend_CUDA_TORCH:
timm/dpn68_backend_FP32:
metric_value: 0.76342
timm/dpn68_backend_ONNX:
metric_value: 0.75906
metric_value: 0.7592
timm/dpn68_backend_OV:
metric_value: 0.75972
timm/dpn68_backend_TORCH:
Expand Down Expand Up @@ -201,7 +201,7 @@ timm/regnetx_002_backend_CUDA_TORCH:
timm/regnetx_002_backend_FP32:
metric_value: 0.68756
timm/regnetx_002_backend_ONNX:
metric_value: 0.6848
metric_value: 0.6854
timm/regnetx_002_backend_OV:
metric_value: 0.6852
timm/regnetx_002_backend_TORCH:
Expand All @@ -211,7 +211,7 @@ timm/resnest14d_backend_CUDA_TORCH:
timm/resnest14d_backend_FP32:
metric_value: 0.75516
timm/resnest14d_backend_ONNX:
metric_value: 0.75428
metric_value: 0.7538
timm/resnest14d_backend_OV:
metric_value: 0.75
timm/resnest14d_backend_TORCH:
Expand Down Expand Up @@ -253,7 +253,7 @@ timm/visformer_small_backend_CUDA_TORCH:
timm/visformer_small_backend_FP32:
metric_value: 0.82098
timm/visformer_small_backend_ONNX:
metric_value: 0.81562
metric_value: 0.8160
timm/visformer_small_backend_OV:
metric_value: 0.81674
timm/visformer_small_backend_TORCH:
Expand Down
2 changes: 0 additions & 2 deletions tests/post_training/data/ptq_reference_data_2024.5.yaml

This file was deleted.

18 changes: 11 additions & 7 deletions tests/post_training/data/wc_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,44 @@ tinyllama_data_aware_backend_OV:
num_int4: 94
num_int8: 124
tinyllama_data_aware_awq_stateful_backend_OV:
metric_value: 0.85571
metric_value: 0.85616
num_int4: 94
num_int8: 124
tinyllama_data_aware_awq_scale_estimation_backend_OV:
metric_value: 0.86355
metric_value: 0.85502
num_int4: 94
num_int8: 124
tinyllama_data_aware_awq_scale_estimation_stateful_backend_OV:
metric_value: 0.86355
metric_value: 0.85502
num_int4: 94
num_int8: 124
tinyllama_int8_data_free_backend_TORCH:
metric_value: 0.95624
num_int4: 0
num_int8: 312
tinyllama_data_aware_gptq_scale_estimation_stateful_backend_OV:
metric_value: 0.86697
metric_value: 0.86503
num_int4: 94
num_int8: 124
metrics_xfail_reason: "Issue-148819"
tinyllama_scale_estimation_per_channel_backend_OV:
metric_value: 0.80798
metric_value: 0.81389
num_int4: 188
num_int8: 124
tinyllama_data_aware_lora_stateful_backend_OV:
metric_value: 0.83446
num_int4: 94
num_int8: 500
tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV:
metric_value: 0.87132
metric_value: 0.88663
num_int4: 11
num_int8: 290
metrics_xfail_reason: "Issue-148819"
tinyllama_awq_backup_mode_none_backend_OV:
metric_value: 0.85679
metric_value: 0.84783
num_int4: 208
num_int8: 0
tinyllama_int4_data_free_backend_TORCH:
metric_value: 0.73873
num_int4: 114
num_int8: 84
34 changes: 0 additions & 34 deletions tests/post_training/data/wc_reference_data_2024.5.yaml

This file was deleted.

0 comments on commit 6031ccc

Please sign in to comment.