Skip to content

Commit

Permalink
Update Fast-/BC algorithms (#2747)
Browse files Browse the repository at this point in the history
### Changes

- Updated FastBiasCorrection algorithm to correct work with transposed
MatMul layers.
- Updated BiasCorrection algorithm to avoid recursion depth error.

### Reason for changes

- FastBiasCorrection algorithm adaptation.
- BiasCorrection improvement for many correctable layers.

### Related tickets

- 144240

---------

Co-authored-by: Aleksei Kashapov <[email protected]>
  • Loading branch information
KodiaqQ and kshpv authored Jun 28, 2024
1 parent 2ba43bd commit beb9c04
Show file tree
Hide file tree
Showing 15 changed files with 327 additions and 125 deletions.
63 changes: 38 additions & 25 deletions nncf/onnx/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple

import numpy as np
import onnx
Expand Down Expand Up @@ -122,18 +122,6 @@ def is_port_quantized(node: NNCFNode, nncf_graph: NNCFGraph, port_id: int) -> bo
return False


def transpose_axis(shape: List[int], axis: int) -> int:
"""
Returns transpose axis.
:param shape: Tensor shape.
:param axis: Axis before transpose (only positive).
:return: Axis after transpose.
"""
assert axis >= 0
return range(len(shape) - 1, -1, -1)[axis] # Iterate backward throug axis


def get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int:
"""
Returns weight tensor axis, along which quantizer parameters are calculated.
Expand All @@ -143,21 +131,26 @@ def get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int:
:return: Axis, along which quantizer parameters are calculated.
"""
weight_channel_axis = node.metatype.weight_channel_axis
if node.layer_attributes.has_node_attrs() and node.metatype == om.ONNXGemmMetatype:
weight_shape = node.layer_attributes.weight_attrs[port_id]["shape"]
weight_channel_axis %= len(weight_shape) # Make axis positive
if port_id == 0:
weight_channel_axis -= 1
if (
port_id == 0
and node.layer_attributes.node_attrs["transA"] == 1
or port_id == 1
and node.layer_attributes.node_attrs["transB"] == 1
):
weight_channel_axis = transpose_axis(weight_shape, weight_channel_axis)
if node.metatype == om.ONNXGemmMetatype:
weight_channel_axis = calculate_gemm_channel_axis(node, port_id)
return weight_channel_axis


def get_act_quantization_axis(node: NNCFNode, port_id: int) -> int:
"""
Returns activation tensor axis, along which quantizer parameters are calculated.
:param node: NNCFNode, with the activation on input port_id.
:param port_id: Input port id on which there is a activation of a node.
:return: Axis, along which quantizer parameters are calculated.
"""
# In case of the ONNX, [N, C, ..] layout applicable for most quantizable layers.
act_channel_axis = 1
if node.metatype == om.ONNXGemmMetatype:
act_channel_axis = calculate_gemm_channel_axis(node, port_id)
return act_channel_axis


def _get_activation_tensor_shape(
nncf_graph: NNCFGraph, node: NNCFNode, target_point: ONNXTargetPoint
) -> Optional[Tuple[int, ...]]:
Expand Down Expand Up @@ -209,3 +202,23 @@ def get_quantized_tensor_shape(
if target_point.is_weight_target_point():
return node.layer_attributes.weight_attrs[target_point.port_id]["shape"]
return _get_activation_tensor_shape(nncf_graph, node, target_point)


def calculate_gemm_channel_axis(node: NNCFNode, port_id: int) -> int:
"""
Calculates Gemm channel axis based on the port and node attributes.
:param node: NNCFNode instance with layer attributes.
:param port_id: Port ID is used to choose correct transpose attribute.
transB in case of port_id is 1, transA for port_id is 0.
:return: Channel axis number.
"""
# Gemm metatype supports only 2D inputs according to the documentation -
# https://onnx.ai/onnx/operators/onnx__Gemm.html
# Usage of the tensor shape is not possible,
# because ONNX allows to contain empty shape even after the shape inference.
gemm_shape = [0, 1]
trans_attr = "transB" if port_id else "transA"
transpose = node.layer_attributes.node_attrs[trans_attr]
channel_axis = -1 - port_id if transpose else -2 + port_id
return gemm_shape.pop(channel_axis)
53 changes: 36 additions & 17 deletions nncf/openvino/graph/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from enum import Enum
from typing import List, Tuple
from typing import Tuple

from nncf.common.graph.graph import NNCFNode
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
Expand Down Expand Up @@ -46,7 +46,7 @@ class OVLayoutElem(Enum):
}


def get_conv_weights_layout_from_node(node: NNCFNode) -> List[OVLayoutElem]:
def get_conv_weights_layout_from_node(node: NNCFNode) -> Tuple[OVLayoutElem]:
"""
Calculates weights layout for a target convolution node.
Expand All @@ -60,7 +60,7 @@ def get_conv_weights_layout_from_node(node: NNCFNode) -> List[OVLayoutElem]:
)


def get_linear_weights_layout_from_node(node: NNCFNode) -> List[OVLayoutElem]:
def get_linear_weights_layout_from_node(node: NNCFNode) -> Tuple[OVLayoutElem]:
"""
Calculates weights layout for a target linear node.
Expand All @@ -70,14 +70,33 @@ def get_linear_weights_layout_from_node(node: NNCFNode) -> List[OVLayoutElem]:
layer_attributes = node.layer_attributes
port_id = _get_constant_port_id_from_layer_attributes(layer_attributes)
constant_layer_attrs = layer_attributes.constant_attributes[port_id]
return get_linear_weights_layout(
weights_shape=constant_layer_attrs["shape"],
return get_linear_input_layout(
input_shape=constant_layer_attrs["shape"],
transpose=constant_layer_attrs["transpose"],
port_id=port_id,
)


def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, ...]) -> List[OVLayoutElem]:
def get_linear_activations_layout_from_node(
node: NNCFNode, port_id: int, input_shape: Tuple[int]
) -> Tuple[OVLayoutElem]:
"""
Calculates activations layout for a target linear node.
:param node: Target linear node.
:param port_id: Target input port ID.
:param input_shape: Shape of the input.
:return: Target linear Node weight layout.
"""
act_layer_attrs = node.layer_attributes.input_attributes
return get_linear_input_layout(
input_shape=input_shape,
transpose=act_layer_attrs["transpose"],
port_id=port_id,
)


def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, ...]) -> Tuple[OVLayoutElem]:
"""
Calculates weights layout for a target convolution node.
Expand All @@ -91,23 +110,23 @@ def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int,
return tuple(weights_layout)


def get_linear_weights_layout(weights_shape: Tuple[int, ...], transpose: bool, port_id: int) -> List[OVLayoutElem]:
def get_linear_input_layout(input_shape: Tuple[int, ...], transpose: bool, port_id: int) -> Tuple[OVLayoutElem]:
"""
Calculates weights layout for a target linear node.
Calculates input layout for a target linear node.
:param weights_shape: Shape of the target linear node weight.
:param port_id: Port id of the target liner node weights.
:return: Target linear node weight layout.
:param input_shape: Shape of the target linear node input.
:param port_id: Port id of the target linear node input.
:return: Target linear node input layout.
"""
weights_layout = [OVLayoutElem.SPATIAL] * (len(weights_shape) - 2)
if len(weights_shape) > 1:
input_layout = [OVLayoutElem.SPATIAL] * (len(input_shape) - 2)
if len(input_shape) > 1:
if (transpose and port_id == 0) or (not transpose and port_id == 1):
weights_layout += [OVLayoutElem.C_IN, OVLayoutElem.C_OUT]
input_layout += [OVLayoutElem.C_IN, OVLayoutElem.C_OUT]
else:
weights_layout += [OVLayoutElem.C_OUT, OVLayoutElem.C_IN]
input_layout += [OVLayoutElem.C_OUT, OVLayoutElem.C_IN]
else:
weights_layout += [OVLayoutElem.C_IN]
return tuple(weights_layout)
input_layout += [OVLayoutElem.C_IN]
return tuple(input_layout)


def _get_constant_port_id_from_layer_attributes(layer_attributes: OVLayerAttributes) -> int:
Expand Down
27 changes: 24 additions & 3 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from nncf.openvino.graph.layout import OVLayoutElem
from nncf.openvino.graph.layout import get_conv_weights_layout
from nncf.openvino.graph.layout import get_conv_weights_layout_from_node
from nncf.openvino.graph.layout import get_linear_weights_layout
from nncf.openvino.graph.layout import get_linear_activations_layout_from_node
from nncf.openvino.graph.layout import get_linear_input_layout
from nncf.openvino.graph.layout import get_linear_weights_layout_from_node
from nncf.openvino.graph.metatypes.groups import CONV_OPERATIONS
from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_BIAS
Expand Down Expand Up @@ -452,8 +453,8 @@ def get_weighted_layer_attributes(
return ConvolutionLayerAttributes(**kwargs)
if ov_metatype == OVMatMulMetatype:
weights_shape = attrs["shape"]
weights_layout = get_linear_weights_layout(
weights_shape=weights_shape, transpose=attrs["transpose"], port_id=port_id
weights_layout = get_linear_input_layout(
input_shape=weights_shape, transpose=attrs["transpose"], port_id=port_id
)

kwargs = {
Expand All @@ -468,3 +469,23 @@ def get_weighted_layer_attributes(
}
return LinearLayerAttributes(**kwargs)
return GenericWeightedLayerAttributes(weight_requires_grad=False, weight_shape=attrs.get("shape", None))


def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int:
"""
Returns axis number of the activation tensor which correspond to it channel.
:param node: NNCFNode instance.
:param port_id: Port ID for input.
:param input_shape: Shape of the input.
:return: Channel axis number.
"""
# In case of the OpenVINO, [N, C, ..] layout applicable for most quantizable layers.
channel_axis = 1

# But the MatMul layers may transpose inputs internally.
if node.metatype == OVMatMulMetatype:
activations_layout = get_linear_activations_layout_from_node(node, port_id, input_shape)
channel_axis = activations_layout.index(OVLayoutElem.C_OUT)

return channel_axis
Loading

0 comments on commit beb9c04

Please sign in to comment.