Skip to content

Commit

Permalink
Fix ONNX without shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
KodiaqQ committed Jul 1, 2024
1 parent 4fb1f2b commit 8056be4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 16 deletions.
9 changes: 3 additions & 6 deletions nncf/onnx/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,34 +131,31 @@ def get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int:
:return: Axis, along which quantizer parameters are calculated.
"""
weight_channel_axis = node.metatype.weight_channel_axis
weight_axes = list(range(len(node.layer_attributes.weight_attrs[port_id]["shape"])))

if node.metatype == om.ONNXGemmMetatype:
trans_attr = "transB" if port_id else "transA"
transpose = node.layer_attributes.node_attrs[trans_attr]
# 0 - (M, K), 1 - (K, N)
weight_channel_axis = -1 - port_id if transpose else -2 + port_id
return weight_axes[weight_channel_axis]
return weight_channel_axis


def get_act_quantization_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int:
def get_act_quantization_axis(node: NNCFNode, port_id: int) -> int:
"""
Returns activation tensor axis, along which quantizer parameters are calculated.
:param node: NNCFNode, with the activation on input port_id.
:param port_id: Input port id on which there is a activation of a node.
:param input_shape: Shape of the input tensor.
:return: Axis, along which quantizer parameters are calculated.
"""
act_channel_axis = node.metatype.output_channel_axis
act_axes = list(range(len(input_shape)))

if node.metatype == om.ONNXGemmMetatype:
trans_attr = "transB" if port_id else "transA"
transpose = node.layer_attributes.node_attrs[trans_attr]
# 0 - (M, K), 1 - (K, N)
act_channel_axis = -2 + port_id if transpose else -1 - port_id
return act_axes[act_channel_axis]
return act_channel_axis


def _get_activation_tensor_shape(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,4 @@ def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFG

@staticmethod
def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int:
return get_act_quantization_axis(node, port_id, input_shape)
return get_act_quantization_axis(node, port_id)
13 changes: 13 additions & 0 deletions tests/onnx/quantization/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.algorithms.min_max.onnx_backend import ONNXMinMaxAlgoBackend
from tests.post_training.test_templates.models import NNCFGraphToTest
from tests.post_training.test_templates.test_min_max import MATMUL_WEIGHT_SHAPE
from tests.post_training.test_templates.test_min_max import TemplateTestGetChannelAxes
from tests.post_training.test_templates.test_min_max import TemplateTestGetTargetPointShape
from tests.post_training.test_templates.test_min_max import TemplateTestMinMaxAlgorithm
Expand Down Expand Up @@ -79,3 +80,15 @@ def test_get_channel_axes_deptwiseconv_node_ov(self):

def test_get_channel_axes_matmul_torch(self):
pytest.skip("Test is not applied for ONNX backend.")

@pytest.mark.parametrize(
"weight_shape, weight_port_id, transpose_weight, ref_axes",
(
(MATMUL_WEIGHT_SHAPE, 1, False, (-1,)),
(MATMUL_WEIGHT_SHAPE, 1, True, (-2,)),
(MATMUL_WEIGHT_SHAPE, 0, True, (-1,)),
(MATMUL_WEIGHT_SHAPE, 0, False, (-2,)),
),
)
def test_get_channel_axes_matmul_node_ov_onnx(self, weight_shape, weight_port_id, transpose_weight, ref_axes):
super().test_get_channel_axes_matmul_node_ov_onnx(weight_shape, weight_port_id, transpose_weight, ref_axes)
18 changes: 9 additions & 9 deletions tests/onnx/test_node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ def test_get_bias_value(model):
@pytest.mark.parametrize(
"layer_attrs, port_id, ref_axis",
[
[{"node_attrs": {"transA": 0, "transB": 0}}, 0, 1],
[{"node_attrs": {"transA": 0, "transB": 1}}, 0, 1],
[{"node_attrs": {"transA": 1, "transB": 0}}, 0, 0],
[{"node_attrs": {"transA": 1, "transB": 1}}, 0, 0],
[{"node_attrs": {"transA": 0, "transB": 0}}, 1, 0],
[{"node_attrs": {"transA": 0, "transB": 1}}, 1, 1],
[{"node_attrs": {"transA": 1, "transB": 0}}, 1, 0],
[{"node_attrs": {"transA": 1, "transB": 1}}, 1, 1],
[{"node_attrs": {"transA": 0, "transB": 0}}, 0, -1],
[{"node_attrs": {"transA": 0, "transB": 1}}, 0, -1],
[{"node_attrs": {"transA": 1, "transB": 0}}, 0, -2],
[{"node_attrs": {"transA": 1, "transB": 1}}, 0, -2],
[{"node_attrs": {"transA": 0, "transB": 0}}, 1, -2],
[{"node_attrs": {"transA": 0, "transB": 1}}, 1, -1],
[{"node_attrs": {"transA": 1, "transB": 0}}, 1, -2],
[{"node_attrs": {"transA": 1, "transB": 1}}, 1, -1],
],
)
def test_get_act_quantization_axis(layer_attrs, port_id, ref_axis):
node = create_nncf_node(**layer_attrs)
channel_axis = get_act_quantization_axis(node, port_id, (2, 3))
channel_axis = get_act_quantization_axis(node, port_id)
assert channel_axis == ref_axis

0 comments on commit 8056be4

Please sign in to comment.