Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTQ][Experimental stats] MAD and Percentile aggregators accept dynamic tensors #2221

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,42 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
:return: Reduced NNCFTensor.
"""

@staticmethod
@abstractmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
"""
Returns an array with axes transposed.

:param x: The input tensor.
:param axes: Tuple which contains a permutation of [0,1,…,N-1] where N is the number of axes of a.
The ith axis of the returned array will correspond to the axis numbered axes[i] of the input.
:return: x with its axes permuted.
"""

@staticmethod
@abstractmethod
def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor:
"""
Gives a new shape to an array without changing its data.

:param x: The input tensor.
:param shape: New shape for the input tensor. The new shape should be compatible with the original shape.
One shape dimension can be -1. In this case, the value is inferred
from the length of the array and remaining dimensions.
:return: Reshaped x.
"""

@staticmethod
@abstractmethod
def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
"""
Join a sequence of arrays along an existing axis.

:param x: The input tensor.
:param axis: The axis along which the arrays will be joined.
:return: The concatenated array.
"""

@staticmethod
def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
"""
Expand Down
68 changes: 57 additions & 11 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,20 +739,23 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:
return self._container.append(x)

def _aggregate_impl(self) -> Dict[str, NNCFTensor]:
stacked_val = self._tensor_processor.stack(self._container)
stacked_val, shape_after_aggregation = _moveaxes_flatten_cat(
self._container, [x - 1 for x in self._aggregation_axes if x > 0], self._tensor_processor
)

mask = self._tensor_processor.zero_elements(stacked_val)
median_per_ch = self._tensor_processor.masked_median(
stacked_val, mask=mask, axis=self._aggregation_axes, keepdims=True
)
median_per_ch = self._tensor_processor.masked_median(stacked_val, mask=mask, axis=0, keepdims=True)

mad_values = self._tensor_processor.median(
self._tensor_processor.abs(self._tensor_processor.sub(stacked_val, median_per_ch)),
axis=self._aggregation_axes,
keepdims=self._keepdims,
axis=0,
keepdims=False,
)
if not self._keepdims:
median_per_ch = self._tensor_processor.squeeze(median_per_ch, self._aggregation_axes)
if self._keepdims:
median_per_ch = self._tensor_processor.reshape(median_per_ch, shape_after_aggregation)
mad_values = self._tensor_processor.reshape(mad_values, shape_after_aggregation)
else:
median_per_ch = self._tensor_processor.squeeze(median_per_ch, 0)
return {
MedianMADTensorStatistic.MEDIAN_VALUES_STAT: median_per_ch.tensor,
MedianMADTensorStatistic.MAD_VALUES_STAT: mad_values.tensor,
Expand All @@ -777,17 +780,60 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:
return self._container.append(x)

def _aggregate_impl(self) -> Dict[float, NNCFTensor]:
stacked_val = self._tensor_processor.stack(self._container)
stacked_val, shape_after_aggregation = _moveaxes_flatten_cat(
self._container, [x - 1 for x in self._aggregation_axes if x > 0], self._tensor_processor
)

percentiles = self._tensor_processor.percentile(
stacked_val, self._percentiles_to_collect, axis=self._aggregation_axes, keepdims=self._keepdims
stacked_val, self._percentiles_to_collect, axis=0, keepdims=False
)
retval = {}
for idx, percentile in enumerate(self._percentiles_to_collect):
retval[percentile] = percentiles[idx].tensor
value = percentiles[idx]
if self._keepdims:
value = self._tensor_processor.reshape(value, shape_after_aggregation)
retval[percentile] = value.tensor
return retval


def _moveaxes_flatten_cat(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is working correctly only for cases when 0 is present in original aggregation axes. I didn't implement additional logic for other cases because it will be impossible after the #2213 anyway

tensor_list: List[NNCFTensor], aggregation_axes: Tuple[int, ...], tensor_processor: NNCFCollectorTensorProcessor
) -> Tuple[NNCFTensor, Tuple[int, ...]]:
"""
Moves aggregation axes to the begining of the tensor shape for each tensor from the list, flattens
and concatenates them in 0 dimension. Computes target shape for the processed tensor
after an aggregation function is applied to it. Target shape preserves original order
of dimensions and replaces aggregated dimensions by 1.

:param tensor_list: NNCFTensor list to process.
:param aggregation_axes: Aggregation axes to move, flatten and concatinate.
:param tensor_processor: Backed-specific tensor processor instance.
:return: Tuple of the processed tensor and
target shape for the processed tensor after an aggregation function is applied to it.
"""
tensor_shape = list(tensor_list[0].shape)

# Transpose dims to move aggregation axes forward
transpose_dims = list(range(len(tensor_shape)))
for idx, axis in enumerate(aggregation_axes):
transpose_dims[axis], transpose_dims[idx] = transpose_dims[idx], transpose_dims[axis]

# Shape to flatten aggregation axes
reshape_shape = [
-1,
] + [
tensor_shape[dim] for dim in transpose_dims
][len(aggregation_axes) :]

reshaped_tensors = []
for tensor in tensor_list:
transposed_t = tensor_processor.transpose(tensor, transpose_dims)
reshaped_tensors.append(tensor_processor.reshape(transposed_t, reshape_shape))

shape_after_aggregation = tuple(1 if idx in aggregation_axes else dim for idx, dim in enumerate(tensor_shape))
return tensor_processor.cat(reshaped_tensors, axis=0), shape_after_aggregation


AGGREGATORS_MAP = {
AggregatorType.MIN: MinAggregator,
AggregatorType.MAX: MaxAggregator,
Expand Down
13 changes: 13 additions & 0 deletions nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,19 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
t = x.reshape(x.shape[0], x.shape[1], -1)
return ONNXNNCFTensor(np.mean(t, axis=(0, 2)))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
return ONNXNNCFTensor(np.transpose(x.tensor, axes))

@staticmethod
def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor:
return ONNXNNCFTensor(np.reshape(x.tensor, shape))

@staticmethod
def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
x = [t.tensor for t in x]
return ONNXNNCFTensor(np.concatenate(x, axis))

@staticmethod
def batch_mean(x: NNCFTensor) -> NNCFTensor:
return ONNXNNCFTensor(np.mean(x.tensor, axis=0, keepdims=True))
Expand Down
13 changes: 13 additions & 0 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,19 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
t = x.reshape(x.shape[0], x.shape[1], -1)
return OVNNCFTensor(np.mean(t, axis=(0, 2)))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
return OVNNCFTensor(np.transpose(x.tensor, axes))

@staticmethod
def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor:
return OVNNCFTensor(np.reshape(x.tensor, shape))

@staticmethod
def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
x = [t.tensor for t in x]
return OVNNCFTensor(np.concatenate(x, axis))

@staticmethod
def batch_mean(x: NNCFTensor) -> NNCFTensor:
return OVNNCFTensor(np.mean(x.tensor, axis=0, keepdims=True))
Expand Down
12 changes: 12 additions & 0 deletions nncf/tensorflow/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ def percentile(
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
raise NotImplementedError()

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
raise NotImplementedError()

@staticmethod
def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor:
raise NotImplementedError()

@staticmethod
def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
raise NotImplementedError()

@staticmethod
def sub(a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:
raise NotImplementedError()
Expand Down
15 changes: 14 additions & 1 deletion nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims
device = x.tensor.device
result = torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims))
return PTNNCFTensor(result.type(x.tensor.dtype).to(device))
return PTNNCFTensor(torch.quantile(x.tensor, q=0.5, dim=axis, keepdim=keepdims).values)
return PTNNCFTensor(torch.quantile(x.tensor, q=0.5, dim=axis, keepdim=keepdims))

@classmethod
def masked_mean(
Expand Down Expand Up @@ -123,6 +123,19 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
def batch_mean(x: NNCFTensor) -> NNCFTensor:
return PTNNCFTensor(torch.mean(x.tensor, axis=0, keepdims=True))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
return PTNNCFTensor(torch.permute(x.tensor, axes))

@staticmethod
def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor:
return PTNNCFTensor(torch.reshape(x.tensor, shape))

@staticmethod
def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
x = [t.tensor for t in x]
return PTNNCFTensor(torch.cat(x, axis))

@staticmethod
def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
return PTNNCFTensor(torch.logical_or(input_.tensor, other.tensor))
Expand Down
51 changes: 40 additions & 11 deletions tests/common/experimental/test_reducers_and_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,19 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor,

assert self.all_close(ret_val, self.cast_tensor(refs, Dtype.FLOAT))

@pytest.fixture(
name="aggregator_cls",
params=[
MedianAbsoluteDeviationAggregator,
partial(
PercentileAggregator,
percentiles_to_collect=[5, 10, 90, 95],
),
],
)
def aggregator_cls_fixture(self, request):
return request.param

REF_MAD_PERCENTILE_REF_VALUES = {
MedianAbsoluteDeviationAggregator: {
None: {
Expand Down Expand Up @@ -363,20 +376,10 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor,
},
}

@pytest.mark.parametrize(
"aggregator_cls",
[
MedianAbsoluteDeviationAggregator,
partial(
PercentileAggregator,
percentiles_to_collect=[5, 10, 90, 95],
),
],
)
@pytest.mark.parametrize("aggregation_axes", [None, (0,), (0, 1)])
def test_mad_percentile_aggregators(self, aggregator_cls, tensor_processor, aggregation_axes):
aggregator = aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=aggregation_axes)
input_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.float32)
input_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
for i in range(9):
aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i, Dtype.FLOAT))

Expand All @@ -386,6 +389,32 @@ def test_mad_percentile_aggregators(self, aggregator_cls, tensor_processor, aggr
for k, v in ref_values.items():
assert self.all_close(ret_val[k], self.cast_tensor(v, Dtype.FLOAT))

REF_MAD_PERCENTILE_REF_VALUES_DYNAMIC_TENSORS = {
MedianAbsoluteDeviationAggregator: {
"median_values": np.array([28.5, 35.5, 43.5]),
"mad_values": np.array([24.0, 24.0, 24.0]),
},
PercentileAggregator: {
5: np.array([0.95, 5.95, 9.95]),
10: np.array([1.9, 7.9, 15.5]),
90: np.array([75.1, 83.1, 91.1]),
95: np.array([77.05, 85.05, 93.05]),
},
}

def test_mad_percentile_aggregators_different_sizes(self, aggregator_cls, tensor_processor):
aggregator = aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=(0, 1, 3))
for shape in ((2, 3, 4), (4, 3, 8)):
aggregator.register_reduced_input(
self.get_nncf_tensor(np.arange(np.prod(shape)).reshape(shape), Dtype.FLOAT)
)
ret_val = aggregator.aggregate()

ref_values = self.REF_MAD_PERCENTILE_REF_VALUES_DYNAMIC_TENSORS[aggregator.__class__]
assert len(ret_val) == len(ref_values)
for k, v in ref_values.items():
assert self.all_close(ret_val[k], self.cast_tensor(v, Dtype.FLOAT))

@pytest.mark.parametrize(
"reducer_name",
["noop", "min", "max", "abs_max", "mean", "quantile", "abs_quantile", "batch_mean", "mean_per_ch"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def tensor_processor(self):
return OVNNCFCollectorTensorProcessor

def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None):
if dtype is Dtype.INTEGER:
x = x.astype(np.int64)
if dtype is Dtype.FLOAT:
x = x.astype(np.float32)
return OVNNCFTensor(x)

@pytest.fixture(scope="module")
Expand Down
Loading