diff --git a/nncf/common/tensor_statistics/collectors.py b/nncf/common/tensor_statistics/collectors.py index ae9c2e536b2..894fe9a2a39 100644 --- a/nncf/common/tensor_statistics/collectors.py +++ b/nncf/common/tensor_statistics/collectors.py @@ -326,6 +326,42 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor: :return: Reduced NNCFTensor. """ + @staticmethod + @abstractmethod + def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor: + """ + Returns an array with axes transposed. + + :param x: The input tensor. + :param axes: Tuple which contains a permutation of [0,1,…,N-1] where N is the number of axes of a. + The ith axis of the returned array will correspond to the axis numbered axes[i] of the input. + :return: x with its axes permuted. + """ + + @staticmethod + @abstractmethod + def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor: + """ + Gives a new shape to an array without changing its data. + + :param x: The input tensor. + :param shape: New shape for the input tensor. The new shape should be compatible with the original shape. + One shape dimension can be -1. In this case, the value is inferred + from the length of the array and remaining dimensions. + :return: Reshaped x. + """ + + @staticmethod + @abstractmethod + def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor: + """ + Join a sequence of arrays along an existing axis. + + :param x: The input tensor. + :param axis: The axis along which the arrays will be joined. + :return: The concatenated array. + """ + @staticmethod def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor: """ diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 68cd4442c0a..5b9d0ae1282 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -739,20 +739,23 @@ def _register_reduced_input_impl(self, x: TensorType) -> None: return self._container.append(x) def _aggregate_impl(self) -> Dict[str, NNCFTensor]: - stacked_val = self._tensor_processor.stack(self._container) + stacked_val, shape_after_aggregation = _moveaxes_flatten_cat( + self._container, [x - 1 for x in self._aggregation_axes if x > 0], self._tensor_processor + ) mask = self._tensor_processor.zero_elements(stacked_val) - median_per_ch = self._tensor_processor.masked_median( - stacked_val, mask=mask, axis=self._aggregation_axes, keepdims=True - ) + median_per_ch = self._tensor_processor.masked_median(stacked_val, mask=mask, axis=0, keepdims=True) mad_values = self._tensor_processor.median( self._tensor_processor.abs(self._tensor_processor.sub(stacked_val, median_per_ch)), - axis=self._aggregation_axes, - keepdims=self._keepdims, + axis=0, + keepdims=False, ) - if not self._keepdims: - median_per_ch = self._tensor_processor.squeeze(median_per_ch, self._aggregation_axes) + if self._keepdims: + median_per_ch = self._tensor_processor.reshape(median_per_ch, shape_after_aggregation) + mad_values = self._tensor_processor.reshape(mad_values, shape_after_aggregation) + else: + median_per_ch = self._tensor_processor.squeeze(median_per_ch, 0) return { MedianMADTensorStatistic.MEDIAN_VALUES_STAT: median_per_ch.tensor, MedianMADTensorStatistic.MAD_VALUES_STAT: mad_values.tensor, @@ -777,17 +780,60 @@ def _register_reduced_input_impl(self, x: TensorType) -> None: return self._container.append(x) def _aggregate_impl(self) -> Dict[float, NNCFTensor]: - stacked_val = self._tensor_processor.stack(self._container) + stacked_val, shape_after_aggregation = _moveaxes_flatten_cat( + self._container, [x - 1 for x in self._aggregation_axes if x > 0], self._tensor_processor + ) percentiles = self._tensor_processor.percentile( - stacked_val, self._percentiles_to_collect, axis=self._aggregation_axes, keepdims=self._keepdims + stacked_val, self._percentiles_to_collect, axis=0, keepdims=False ) retval = {} for idx, percentile in enumerate(self._percentiles_to_collect): - retval[percentile] = percentiles[idx].tensor + value = percentiles[idx] + if self._keepdims: + value = self._tensor_processor.reshape(value, shape_after_aggregation) + retval[percentile] = value.tensor return retval +def _moveaxes_flatten_cat( + tensor_list: List[NNCFTensor], aggregation_axes: Tuple[int, ...], tensor_processor: NNCFCollectorTensorProcessor +) -> Tuple[NNCFTensor, Tuple[int, ...]]: + """ + Moves aggregation axes to the begining of the tensor shape for each tensor from the list, flattens + and concatenates them in 0 dimension. Computes target shape for the processed tensor + after an aggregation function is applied to it. Target shape preserves original order + of dimensions and replaces aggregated dimensions by 1. + + :param tensor_list: NNCFTensor list to process. + :param aggregation_axes: Aggregation axes to move, flatten and concatinate. + :param tensor_processor: Backed-specific tensor processor instance. + :return: Tuple of the processed tensor and + target shape for the processed tensor after an aggregation function is applied to it. + """ + tensor_shape = list(tensor_list[0].shape) + + # Transpose dims to move aggregation axes forward + transpose_dims = list(range(len(tensor_shape))) + for idx, axis in enumerate(aggregation_axes): + transpose_dims[axis], transpose_dims[idx] = transpose_dims[idx], transpose_dims[axis] + + # Shape to flatten aggregation axes + reshape_shape = [ + -1, + ] + [ + tensor_shape[dim] for dim in transpose_dims + ][len(aggregation_axes) :] + + reshaped_tensors = [] + for tensor in tensor_list: + transposed_t = tensor_processor.transpose(tensor, transpose_dims) + reshaped_tensors.append(tensor_processor.reshape(transposed_t, reshape_shape)) + + shape_after_aggregation = tuple(1 if idx in aggregation_axes else dim for idx, dim in enumerate(tensor_shape)) + return tensor_processor.cat(reshaped_tensors, axis=0), shape_after_aggregation + + AGGREGATORS_MAP = { AggregatorType.MIN: MinAggregator, AggregatorType.MAX: MaxAggregator, diff --git a/nncf/onnx/statistics/collectors.py b/nncf/onnx/statistics/collectors.py index 2ce98915edc..9804bd1d3e5 100644 --- a/nncf/onnx/statistics/collectors.py +++ b/nncf/onnx/statistics/collectors.py @@ -139,6 +139,19 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor: t = x.reshape(x.shape[0], x.shape[1], -1) return ONNXNNCFTensor(np.mean(t, axis=(0, 2))) + @staticmethod + def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor: + return ONNXNNCFTensor(np.transpose(x.tensor, axes)) + + @staticmethod + def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor: + return ONNXNNCFTensor(np.reshape(x.tensor, shape)) + + @staticmethod + def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor: + x = [t.tensor for t in x] + return ONNXNNCFTensor(np.concatenate(x, axis)) + @staticmethod def batch_mean(x: NNCFTensor) -> NNCFTensor: return ONNXNNCFTensor(np.mean(x.tensor, axis=0, keepdims=True)) diff --git a/nncf/openvino/statistics/collectors.py b/nncf/openvino/statistics/collectors.py index 4672541d86b..005744fa577 100644 --- a/nncf/openvino/statistics/collectors.py +++ b/nncf/openvino/statistics/collectors.py @@ -117,6 +117,19 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor: t = x.reshape(x.shape[0], x.shape[1], -1) return OVNNCFTensor(np.mean(t, axis=(0, 2))) + @staticmethod + def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor: + return OVNNCFTensor(np.transpose(x.tensor, axes)) + + @staticmethod + def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor: + return OVNNCFTensor(np.reshape(x.tensor, shape)) + + @staticmethod + def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor: + x = [t.tensor for t in x] + return OVNNCFTensor(np.concatenate(x, axis)) + @staticmethod def batch_mean(x: NNCFTensor) -> NNCFTensor: return OVNNCFTensor(np.mean(x.tensor, axis=0, keepdims=True)) diff --git a/nncf/tensorflow/tensor_statistics/collectors.py b/nncf/tensorflow/tensor_statistics/collectors.py index a8c3da70a8b..812d1073fb3 100644 --- a/nncf/tensorflow/tensor_statistics/collectors.py +++ b/nncf/tensorflow/tensor_statistics/collectors.py @@ -128,6 +128,18 @@ def percentile( def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor: raise NotImplementedError() + @staticmethod + def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor: + raise NotImplementedError() + + @staticmethod + def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor: + raise NotImplementedError() + + @staticmethod + def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor: + raise NotImplementedError() + @staticmethod def sub(a: NNCFTensor, b: NNCFTensor) -> NNCFTensor: raise NotImplementedError() diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index 4089fb77f2e..72ecbf54d9f 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -82,7 +82,7 @@ def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims device = x.tensor.device result = torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims)) return PTNNCFTensor(result.type(x.tensor.dtype).to(device)) - return PTNNCFTensor(torch.quantile(x.tensor, q=0.5, dim=axis, keepdim=keepdims).values) + return PTNNCFTensor(torch.quantile(x.tensor, q=0.5, dim=axis, keepdim=keepdims)) @classmethod def masked_mean( @@ -123,6 +123,19 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor: def batch_mean(x: NNCFTensor) -> NNCFTensor: return PTNNCFTensor(torch.mean(x.tensor, axis=0, keepdims=True)) + @staticmethod + def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor: + return PTNNCFTensor(torch.permute(x.tensor, axes)) + + @staticmethod + def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor: + return PTNNCFTensor(torch.reshape(x.tensor, shape)) + + @staticmethod + def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor: + x = [t.tensor for t in x] + return PTNNCFTensor(torch.cat(x, axis)) + @staticmethod def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor: return PTNNCFTensor(torch.logical_or(input_.tensor, other.tensor)) diff --git a/tests/common/experimental/test_reducers_and_aggregators.py b/tests/common/experimental/test_reducers_and_aggregators.py index e307fb73f84..3ad32aa3f3e 100644 --- a/tests/common/experimental/test_reducers_and_aggregators.py +++ b/tests/common/experimental/test_reducers_and_aggregators.py @@ -326,6 +326,19 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor, assert self.all_close(ret_val, self.cast_tensor(refs, Dtype.FLOAT)) + @pytest.fixture( + name="aggregator_cls", + params=[ + MedianAbsoluteDeviationAggregator, + partial( + PercentileAggregator, + percentiles_to_collect=[5, 10, 90, 95], + ), + ], + ) + def aggregator_cls_fixture(self, request): + return request.param + REF_MAD_PERCENTILE_REF_VALUES = { MedianAbsoluteDeviationAggregator: { None: { @@ -363,20 +376,10 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor, }, } - @pytest.mark.parametrize( - "aggregator_cls", - [ - MedianAbsoluteDeviationAggregator, - partial( - PercentileAggregator, - percentiles_to_collect=[5, 10, 90, 95], - ), - ], - ) @pytest.mark.parametrize("aggregation_axes", [None, (0,), (0, 1)]) def test_mad_percentile_aggregators(self, aggregator_cls, tensor_processor, aggregation_axes): aggregator = aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=aggregation_axes) - input_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.float32) + input_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) for i in range(9): aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i, Dtype.FLOAT)) @@ -386,6 +389,32 @@ def test_mad_percentile_aggregators(self, aggregator_cls, tensor_processor, aggr for k, v in ref_values.items(): assert self.all_close(ret_val[k], self.cast_tensor(v, Dtype.FLOAT)) + REF_MAD_PERCENTILE_REF_VALUES_DYNAMIC_TENSORS = { + MedianAbsoluteDeviationAggregator: { + "median_values": np.array([28.5, 35.5, 43.5]), + "mad_values": np.array([24.0, 24.0, 24.0]), + }, + PercentileAggregator: { + 5: np.array([0.95, 5.95, 9.95]), + 10: np.array([1.9, 7.9, 15.5]), + 90: np.array([75.1, 83.1, 91.1]), + 95: np.array([77.05, 85.05, 93.05]), + }, + } + + def test_mad_percentile_aggregators_different_sizes(self, aggregator_cls, tensor_processor): + aggregator = aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=(0, 1, 3)) + for shape in ((2, 3, 4), (4, 3, 8)): + aggregator.register_reduced_input( + self.get_nncf_tensor(np.arange(np.prod(shape)).reshape(shape), Dtype.FLOAT) + ) + ret_val = aggregator.aggregate() + + ref_values = self.REF_MAD_PERCENTILE_REF_VALUES_DYNAMIC_TENSORS[aggregator.__class__] + assert len(ret_val) == len(ref_values) + for k, v in ref_values.items(): + assert self.all_close(ret_val[k], self.cast_tensor(v, Dtype.FLOAT)) + @pytest.mark.parametrize( "reducer_name", ["noop", "min", "max", "abs_max", "mean", "quantile", "abs_quantile", "batch_mean", "mean_per_ch"], diff --git a/tests/openvino/native/quantization/test_reducers_and_aggregators.py b/tests/openvino/native/quantization/test_reducers_and_aggregators.py index 7b13174e961..d5645d2de3a 100644 --- a/tests/openvino/native/quantization/test_reducers_and_aggregators.py +++ b/tests/openvino/native/quantization/test_reducers_and_aggregators.py @@ -35,6 +35,10 @@ def tensor_processor(self): return OVNNCFCollectorTensorProcessor def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None): + if dtype is Dtype.INTEGER: + x = x.astype(np.int64) + if dtype is Dtype.FLOAT: + x = x.astype(np.float32) return OVNNCFTensor(x) @pytest.fixture(scope="module")