Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTQ] Add support of arbitrary batch size for PTQ #2197

Merged
merged 131 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
1ce23c5
draft
kshpv Oct 13, 2023
8184df6
check on Nones
kshpv Oct 13, 2023
2e3f507
update aggregator with keep_dims=True
kshpv Oct 18, 2023
b5d15cd
typhints
kshpv Oct 18, 2023
8cb2391
Merge remote-tracking branch 'remote/develop' into torch_batch_size
kshpv Oct 18, 2023
1034acd
fix OV tests; update collectors
kshpv Oct 19, 2023
8b526c5
fix tests
kshpv Oct 20, 2023
e51bdb8
Merge remote-tracking branch 'remote/develop' into torch_batch_size
kshpv Nov 6, 2023
37684bd
add aggregation axes for OV; comment input check
kshpv Nov 7, 2023
18d931d
add test for OV and Torch
kshpv Nov 8, 2023
605a325
add batch_size param to conformance test
kshpv Nov 9, 2023
fb16b99
hardcode for CI run
kshpv Nov 9, 2023
cd60fa3
hardcode batch size = 10 for calibrate.py
kshpv Nov 10, 2023
f3bda28
Merge remote-tracking branch 'remote/develop' into torch_batch_size
kshpv Dec 18, 2023
cc621ab
merge
kshpv Dec 18, 2023
d2a9b00
update aggregator
kshpv Dec 20, 2023
5ffdf10
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Dec 20, 2023
d95be5d
revert unneseccary changes
kshpv Dec 20, 2023
cd68684
add logging; add torch data for OVEngine
kshpv Dec 20, 2023
4a009f3
refactor method get axes
kshpv Dec 21, 2023
c2659b3
fix OV tests
kshpv Dec 21, 2023
3a13f00
fix Torch tests
kshpv Jan 4, 2024
2347170
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 15, 2024
880073b
logic of warning message inside StatisticsAggregator
kshpv Jan 15, 2024
e9062a5
remove _check_input_data_format in OVEngine
kshpv Jan 15, 2024
8770ca4
get_channel_agnostic_reduction_axes to common
kshpv Jan 15, 2024
9556c49
use get_channel_agnostic_reduction_axes for Torch
kshpv Jan 15, 2024
cb90e77
use get_channel_agnostic_reduction_axes for ONNX
kshpv Jan 15, 2024
d9167e5
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 17, 2024
cd10c57
draft
kshpv Jan 17, 2024
16cc9db
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 17, 2024
11b538a
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 17, 2024
426ec04
fix test
kshpv Jan 18, 2024
21b0963
align reduction shape and aggregation shape
kshpv Jan 18, 2024
e90ca32
get_channel_agnostic_reduction_axes -> get_reduction_axes
kshpv Jan 18, 2024
f078a78
upd get_reduction_aggregation_axes
kshpv Jan 18, 2024
e4c57cd
upd aggregator
kshpv Jan 18, 2024
d226074
fix OV test
kshpv Jan 18, 2024
7d8ecd4
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 18, 2024
f502de5
fix ONNX test
kshpv Jan 18, 2024
83d03cb
tests
kshpv Jan 18, 2024
fbfe587
fix torch tests
kshpv Jan 18, 2024
0ae6ac4
fix tests
kshpv Jan 18, 2024
496339f
common tests
kshpv Jan 18, 2024
bcce584
add docs
kshpv Jan 18, 2024
e5950e0
comment
kshpv Jan 18, 2024
1d9ac7a
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 19, 2024
41f27b5
rollback changes for torch possible impact qat
kshpv Jan 19, 2024
51f3dd9
upd conformance
kshpv Jan 19, 2024
3a8de2f
upd calibrate.py
kshpv Jan 19, 2024
946523d
add get_reduction_aggregation_axes for PTRangeInitCollectorParams
kshpv Jan 19, 2024
1732d70
non returning None for get_reduction_aggregation_axes
kshpv Jan 19, 2024
1e96318
comments
kshpv Jan 19, 2024
03afe91
comments
kshpv Jan 19, 2024
bf792fb
describe comment
kshpv Jan 19, 2024
f98aea2
description x2
kshpv Jan 19, 2024
fbd05f9
description x3
kshpv Jan 19, 2024
e80bab1
apply suggestion
kshpv Jan 23, 2024
9c1648d
comments
kshpv Jan 24, 2024
df8ad03
add default scenario when batch_size=1 or None
kshpv Jan 25, 2024
f4db2bb
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 25, 2024
f4dfd1c
rollback scales changes
kshpv Jan 26, 2024
4a44a1c
fix tests
kshpv Jan 26, 2024
d4bfaca
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 26, 2024
f77f59b
fix OV test
kshpv Jan 26, 2024
43fd729
add warning for model_type=transformer
kshpv Jan 29, 2024
c20f7d3
fix torch test
kshpv Jan 29, 2024
52203f0
fix torch tests
kshpv Jan 29, 2024
9dd02b9
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 29, 2024
48c8426
final fix torch test
kshpv Jan 30, 2024
3fe8a37
comments
kshpv Jan 30, 2024
d228589
comments x2
kshpv Jan 30, 2024
b7de564
comments x3
kshpv Jan 30, 2024
67e4c7d
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 30, 2024
489d603
fix tests after merge
kshpv Jan 30, 2024
3b9fb6f
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 30, 2024
120ee1a
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 31, 2024
1f0cb94
improve test
kshpv Jan 31, 2024
532e8eb
fix test
kshpv Feb 6, 2024
38d71b8
upd fbs method calculations
kshpv Feb 6, 2024
c490362
revert changes with statistics collection
kshpv Feb 7, 2024
b778c0c
updates aggregators, reducers for BC and FBC
kshpv Feb 13, 2024
1a96012
upd torch mean_per_channel
kshpv Feb 14, 2024
2f89913
fix BC
kshpv Feb 14, 2024
f69acbd
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Feb 14, 2024
74594c7
fixes after merge
kshpv Feb 14, 2024
d760caf
Fix BC calculations
kshpv Feb 15, 2024
50ac6b4
revert FBC and BC changes
kshpv Feb 20, 2024
00f7979
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Feb 20, 2024
532ca55
fix merge
kshpv Feb 20, 2024
54f8ca3
fix revert typo
kshpv Feb 20, 2024
7637d4b
fix export of torch model
kshpv Feb 21, 2024
976255f
comments
kshpv Feb 23, 2024
8951e3c
more comments
kshpv Feb 26, 2024
0d72557
make bs=128 for Torch sample
kshpv Feb 26, 2024
0f8a438
fix channel alighnment + comments
kshpv Feb 27, 2024
78d4d6c
comments
kshpv Feb 28, 2024
34c9960
update typehints; revert changes in OV sample and apply to Torch
kshpv Feb 28, 2024
354505a
typo
kshpv Feb 28, 2024
97cb07f
some code improvements
kshpv Feb 28, 2024
2cc8b81
logging
kshpv Feb 28, 2024
e3a3291
remove iterations_number calculation in Aggregator
kshpv Mar 1, 2024
8cb7c60
update tests
kshpv Mar 1, 2024
ae772aa
reaname parameter
kshpv Mar 1, 2024
41c76fe
apply comments
kshpv Mar 1, 2024
321c65a
polishing
kshpv Mar 1, 2024
f19fd71
add test
kshpv Mar 4, 2024
4996333
small fixes
kshpv Mar 4, 2024
5e8bce7
polishing
kshpv Mar 5, 2024
9ba5700
conformance adoption for any batch_size; better logging
kshpv Mar 6, 2024
a0f5fe9
add dynamic_batch_shape option to conformance
kshpv Mar 6, 2024
7562d19
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Mar 6, 2024
ee6c14b
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Mar 7, 2024
deb0b51
polishing test
kshpv Mar 7, 2024
64cbd99
fix calibrate.py
kshpv Mar 7, 2024
6119511
new polishing
kshpv Mar 7, 2024
e010087
remove warnings about bathc_size>1 in aggregator
kshpv Mar 8, 2024
54319cb
add baatch_size logging in quantize_impl()
kshpv Mar 8, 2024
4e90c65
add IF op to batch_size warning metatypes list
kshpv Mar 8, 2024
d04ba75
put logs from minmax to quantize_impl
kshpv Mar 8, 2024
6e54d07
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Mar 8, 2024
6048155
rm typos
kshpv Mar 8, 2024
09783c4
typehints
kshpv Mar 8, 2024
da05b93
revert debug message minmax
kshpv Mar 8, 2024
291110e
typo
kshpv Mar 8, 2024
2676fbd
add model_param is_batch_size_supported to conformance; make all mode…
kshpv Mar 18, 2024
3ae9d28
add example in Readme
kshpv Mar 18, 2024
5efcdb5
comments
kshpv Mar 20, 2024
d8ea324
iterations_number -> stat_subset_size
kshpv Mar 20, 2024
7288924
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Mar 21, 2024
88653aa
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def get_model_size(ir_path: str, m_type: str = "Mb", verbose: bool = True) -> fl
]
),
)
val_data_loader = torch.utils.data.DataLoader(val_dataset)
batch_size = 128
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)

torch_model = models.mobilenet_v2(num_classes=DATASET_CLASSES)
torch_model = load_checkpoint(torch_model)
Expand Down Expand Up @@ -140,8 +141,10 @@ def transform_fn(data_item: Tuple[torch.Tensor, int], device: torch.device) -> t
# item and prepare model input data. The quantize method uses a small subset
# (default: 300 samples) of the calibration dataset.

# Recalculation default subset_size parameter based on batch_size.
subset_size = 300 // batch_size
calibration_dataset = nncf.Dataset(val_data_loader, partial(transform_fn, device=device))
torch_quantized_model = nncf.quantize(torch_model, calibration_dataset)
torch_quantized_model = nncf.quantize(torch_model, calibration_dataset, subset_size=subset_size)
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved

###############################################################################
# Benchmark performance, calculate compression rate and validate accuracy
Expand Down
18 changes: 17 additions & 1 deletion nncf/common/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from functools import partial
from typing import List, Set
from typing import List, Set, Tuple, Union

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
Expand Down Expand Up @@ -114,3 +114,19 @@ def get_number_of_quantized_ops(
else:
nodes_to_see.extend(graph.get_next_nodes(node))
return len(quantized_ops)


def get_reduction_axes(
channel_axes: Union[List[int], Tuple[int, ...]], shape: Union[List[int], Tuple[int, ...]]
) -> Tuple[int, ...]:
"""
Returns filtered reduction axes without axes that correspond to channels.

:param channel_axes: Channel axes.
:param shape: Shape that need to be filtered.
:return: Reduction axes.
"""
reduction_axes = list(range(len(shape)))
for channel_axis in sorted(channel_axes, reverse=True):
del reduction_axes[channel_axis]
return tuple(reduction_axes)
53 changes: 52 additions & 1 deletion nncf/common/quantization/initialization/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple, Union

from nncf.common.graph.utils import get_reduction_axes
from nncf.common.initialization.dataloader import NNCFDataLoader
from nncf.common.quantization.structs import QuantizationScheme
from nncf.common.quantization.structs import QuantizerGroup
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.config.schemata.defaults import NUM_INIT_SAMPLES
from nncf.experimental.common.tensor_statistics.collectors import AggregationAxes


class RangeInitConfig:
Expand Down Expand Up @@ -204,3 +207,51 @@ def use_means_of_mins(self) -> bool:
@property
def use_means_of_maxs(self) -> bool:
return not self._is_weights and not self._is_per_channel

def _get_reduction_axes(
self,
shape_to_reduce: Union[Tuple[int, ...], List[int]],
quantization_axes: Union[Tuple[int, ...], List[int]],
aggregation_axes: Union[Tuple[int, ...], List[int]],
):
"""
Returns axes for a reducer regarding aggregation axes. As aggregator takes axes counting from stacked tensors,
from these axes only tensor related axes should be used for reducer.

:param shape_to_reduce: Shape of a reduced tensor.
:param quantization_axes: Axes of quantization.
:param aggregation_axes: Axes of aggregator which is applied onto reduced tensor.
:return: Axes for reducer.
"""
axes_to_keep = set(el - 1 for el in aggregation_axes if el != 0)
axes_to_keep.update(quantization_axes)
return get_reduction_axes(axes_to_keep, shape_to_reduce)
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved

def _get_aggregation_axes(self, batchwise_statistics: bool) -> Tuple[int, ...]:
"""
Returns axes for aggregator.

:param batchwise_statistics: Determines whether quantizer statistics should be calculated
for each item of the batch or for the entire batch.
:return Tuple[int]: Aggregation axes.
"""
return (0, 1) if batchwise_statistics else (0,)

def get_reduction_aggregation_axes(
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
self,
shape_to_reduce: Union[Tuple[int, ...], List[int]],
quantization_axes: Union[Tuple[int, ...], List[int]],
batchwise_statistics: bool,
) -> Tuple[ReductionAxes, AggregationAxes]:
"""
Calculates the reduction axes, aggregation axes for the tensor.
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

:param shape_to_reduce: Shape of the tensor.
:param quantization_axes: Quantization axes if per-channel quantization.
:param batchwise_statistics: Determines whether quantizer statistics should be calculated
for each item of the batch or for the entire batch.
:return: Reduction axes and aggregation axes.
"""
aggregation_axes = self._get_aggregation_axes(batchwise_statistics)
reduction_axes = self._get_reduction_axes(shape_to_reduce, quantization_axes, aggregation_axes)
return reduction_axes, aggregation_axes
41 changes: 27 additions & 14 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from abc import ABC
from abc import abstractmethod
from itertools import islice
from typing import Any, Dict, TypeVar
from typing import Any, Dict, Optional, TypeVar

import nncf
from nncf.common import factory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging.logger import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
Expand All @@ -25,6 +26,13 @@
TensorType = TypeVar("TensorType")
TModel = TypeVar("TModel")

EMPTY_DATASET_ERROR = (
"Calibration dataset must not be empty. Please provide calibration dataset with at least one sample."
)
ITERATIONS_NUMBER_WARNING = (
"The number of iterations for statistics collection is bigger than the length of the dataset."
)


class StatisticsAggregator(ABC):
"""
Expand All @@ -36,6 +44,20 @@ def __init__(self, dataset: Dataset):
self.stat_subset_size = None
self.statistic_points = StatisticPointsContainer()

def _get_iterations_number(self) -> Optional[int]:
"""
Returns number of iterations, output number is less than min(self.stat_subset_size, dataset_length).

:return: Number of iterations for statistics collection.
"""
dataset_length = self.dataset.get_length()
if dataset_length and self.stat_subset_size:
if self.stat_subset_size > dataset_length:
nncf_logger.warning(ITERATIONS_NUMBER_WARNING)
return dataset_length
return self.stat_subset_size
return dataset_length or self.stat_subset_size

def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
"""
Collects statistics for registered StatisticPoints.
Expand All @@ -46,34 +68,25 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
"""
if not self.statistic_points:
return

model_transformer = factory.ModelTransformerFactory.create(model)

merged_statistics = self._get_merged_statistic_points(self.statistic_points, model, graph)
transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics)
model_with_outputs = model_transformer.transform(transformation_layout)
engine = factory.EngineFactory.create(model_with_outputs)

dataset_length = self.dataset.get_length()
total = (
min(dataset_length or self.stat_subset_size, self.stat_subset_size)
if self.stat_subset_size is not None
else None
)
iterations_number = self._get_iterations_number()
empty_statistics = True
for input_data in track(
islice(self.dataset.get_inference_data(), self.stat_subset_size),
total=total,
islice(self.dataset.get_inference_data(), iterations_number),
total=self.stat_subset_size,
description="Statistics collection",
):
outputs = engine.infer(input_data)
processed_outputs = self._process_outputs(outputs)
self._register_statistics(processed_outputs, merged_statistics)
empty_statistics = False
if empty_statistics:
raise nncf.ValidationError(
"Calibration dataset must not be empty. Please provide calibration dataset with at least one sample."
)
raise nncf.ValidationError(EMPTY_DATASET_ERROR)

def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None:
"""
Expand Down
11 changes: 11 additions & 0 deletions nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def get_length(self) -> Optional[int]:
return self._data_source.__len__()
return None

def get_batch_size(self) -> Optional[int]:
"""
Tries to fetch batch size of the underlying dataset.
:return: The value of batch_size or _batch_size attributes of the data_source if exist, and None otherwise.
"""
if hasattr(self._data_source, "batch_size"): # Torch dataloader
return self._data_source.batch_size
if hasattr(self._data_source, "_batch_size"): # TF dataloader
return self._data_source._batch_size
return None
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved


class DataProvider(Generic[DataItem, ModelInput]):
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from nncf.experimental.tensor import Tensor
from nncf.experimental.tensor import TensorDataType
from nncf.experimental.tensor.functions import numeric as fns


def mean_per_channel(x: Tensor, axis: int) -> Tensor:
def mean_per_channel(x: Tensor, axis: int, dtype: Optional[TensorDataType] = None) -> Tensor:
"""
Computes the mean of elements across given channel dimension of Tensor.

:param x: Tensor to reduce.
:param axis: The channel dimensions to reduce.
:param dtype: Type to use in computing the mean.
:return: Reduced Tensor.
"""
if len(x.shape) < 3:
return fns.mean(x, axis=0)
return fns.mean(x, axis=0, dtype=dtype)

pos_axis = axis + x.ndim if axis < 0 else axis
if pos_axis < 0 or pos_axis >= x.ndim:
raise ValueError(f"axis {axis} is out of bounds for array of dimension {x.ndim}")
axis = tuple(i for i in range(x.ndim) if i != pos_axis)
return fns.mean(x, axis=axis)
return fns.mean(x, axis=axis, dtype=dtype)
7 changes: 5 additions & 2 deletions nncf/experimental/tensor/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,16 +355,19 @@ def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[

@functools.singledispatch
@tensor_guard
def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor:
def mean(
a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, dtype: TensorDataType = None
) -> Tensor:
"""
Compute the arithmetic mean along the specified axis.

:param a: Array containing numbers whose mean is desired.
:param axis: Axis or axes along which the means are computed.
:param keepdims: Destination positions for each of the original axes. These must also be unique.
:param dtype: Type to use in computing the mean.
:return: Array with moved axes.
"""
return Tensor(mean(a.data, axis, keepdims))
return Tensor(mean(a.data, axis, keepdims, dtype))


@functools.singledispatch
Expand Down
10 changes: 8 additions & 2 deletions nncf/experimental/tensor/functions/numpy_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,14 @@ def _(a: np.ndarray, source: Union[int, Tuple[int, ...]], destination: Union[int


@register_numpy_types(numeric.mean)
def _(a: Union[np.ndarray, np.generic], axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> np.ndarray:
return np.array(np.mean(a, axis=axis, keepdims=keepdims))
def _(
a: Union[np.ndarray, np.generic],
axis: Union[int, Tuple[int, ...]] = None,
keepdims: bool = False,
dtype: Optional[TensorDataType] = None,
) -> np.ndarray:
dtype = DTYPE_MAP[dtype] if dtype else None
return np.array(np.mean(a, axis=axis, keepdims=keepdims, dtype=dtype))
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved


@register_numpy_types(numeric.round)
Expand Down
10 changes: 8 additions & 2 deletions nncf/experimental/tensor/functions/torch_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,14 @@ def _(a: torch.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[i


@numeric.mean.register(torch.Tensor)
def _(a: torch.Tensor, axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> torch.Tensor:
return torch.mean(a, dim=axis, keepdim=keepdims)
def _(
a: torch.Tensor,
axis: Union[int, Tuple[int, ...]] = None,
keepdims: bool = False,
dtype: Optional[TensorDataType] = None,
) -> torch.Tensor:
dtype = DTYPE_MAP[dtype] if dtype else None
return torch.mean(a, dim=axis, keepdim=keepdims, dtype=dtype)


@numeric.round.register(torch.Tensor)
Expand Down
9 changes: 9 additions & 0 deletions nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,12 @@
onnx_metatypes.ONNXQuantizeLinearMetatype,
onnx_metatypes.ONNXDequantizeLinearMetatype,
]

# These metatypes mix outputs for different samples into one axis.
# If reducers and aggregators collect statistics at the output of the following operations,
# assuming that 0-axis is batch axis, they get only 1 value instead of batch_size values.
# It could lead to inaccurate/incorrect statistics result.
OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS = [
onnx_metatypes.ONNXROIAlignMetatype,
onnx_metatypes.ONNXEmbeddingMetatype,
]
8 changes: 4 additions & 4 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class ONNXGemmMetatype(ONNXOpWithWeightsMetatype):
name = "GemmOp"
op_names = ["Gemm"]
hw_config_names = [HWConfigOpName.MATMUL]
weight_channel_axis = -1
weight_channel_axis = -1 # For port_id=1
kshpv marked this conversation as resolved.
Show resolved Hide resolved
weight_port_ids = None
bias_port_id = 2
possible_weight_ports = [0, 1]
Expand All @@ -125,7 +125,7 @@ class ONNXMatMulMetatype(ONNXOpMetatype):
name = "MatMulOp"
op_names = ["MatMul"]
hw_config_names = [HWConfigOpName.MATMUL]
weight_channel_axis = -1
weight_channel_axis = -1 # For port_id=1
weight_port_ids = None
bias_port_id = 2
possible_weight_ports = [0, 1]
Expand Down Expand Up @@ -446,8 +446,8 @@ class ONNXScatterNDMetatype(ONNXOpMetatype):


@ONNX_OPERATION_METATYPES.register()
class ONNXRoiAlignMetatype(ONNXOpMetatype):
name = "RoiAlignOp"
class ONNXROIAlignMetatype(ONNXOpMetatype):
name = "ROIAlignOp"
op_names = ["RoiAlign"]


Expand Down
Loading
Loading