Skip to content

Commit

Permalink
[PTQ][Experimental stats] MAD and Percentile aggregators accept dynam…
Browse files Browse the repository at this point in the history
…ic tensors (#2222)

Lucky PR 🍀🍀🍀

### Changes

MAD and Percentile aggregators accepts tensors that have different axes along aggregation axes

### Reason for changes

To align behaviour of the experimental tensor collectors with common tensor collectors

### Related tickets

123519

### Tests
* test_mad_percentile_aggregators_different_sizes
  • Loading branch information
daniil-lyakhov authored Nov 2, 2023
1 parent e047404 commit cb781eb
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 26 deletions.
36 changes: 36 additions & 0 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,42 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
:return: Reduced NNCFTensor.
"""

@staticmethod
@abstractmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
"""
Returns an array with axes transposed.
:param x: The input tensor.
:param axes: Tuple which contains a permutation of [0,1,…,N-1] where N is the number of axes of a.
The ith axis of the returned array will correspond to the axis numbered axes[i] of the input.
:return: x with its axes permuted.
"""

@staticmethod
@abstractmethod
def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor:
"""
Gives a new shape to an array without changing its data.
:param x: The input tensor.
:param shape: New shape for the input tensor. The new shape should be compatible with the original shape.
One shape dimension can be -1. In this case, the value is inferred
from the length of the array and remaining dimensions.
:return: Reshaped x.
"""

@staticmethod
@abstractmethod
def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
"""
Join a sequence of arrays along an existing axis.
:param x: The input tensor.
:param axis: The axis along which the arrays will be joined.
:return: The concatenated array.
"""

@staticmethod
def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
"""
Expand Down
90 changes: 79 additions & 11 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,24 +735,45 @@ def _aggregation_fn(


class MedianAbsoluteDeviationAggregator(TensorAggregatorBase):
def __init__(
self,
tensor_processor: NNCFCollectorTensorProcessor,
aggregation_axes: Optional[AggregationAxes] = None,
num_samples: Optional[int] = None,
window_size=None,
):
super().__init__(
tensor_processor=tensor_processor,
aggregation_axes=aggregation_axes,
num_samples=num_samples,
window_size=window_size,
)
if 0 not in self._aggregation_axes:
raise NotImplementedError(
"Aggregation without 0 dim is not supported yet for MedianAbsoluteDeviationAggregator"
)

def _register_reduced_input_impl(self, x: TensorType) -> None:
return self._container.append(x)

def _aggregate_impl(self) -> Dict[str, NNCFTensor]:
stacked_val = self._tensor_processor.stack(self._container)
stacked_val, shape_after_aggregation = _moveaxes_flatten_cat(
self._container, [x - 1 for x in self._aggregation_axes if x > 0], self._tensor_processor
)

mask = self._tensor_processor.zero_elements(stacked_val)
median_per_ch = self._tensor_processor.masked_median(
stacked_val, mask=mask, axis=self._aggregation_axes, keepdims=True
)
median_per_ch = self._tensor_processor.masked_median(stacked_val, mask=mask, axis=0, keepdims=True)

mad_values = self._tensor_processor.median(
self._tensor_processor.abs(self._tensor_processor.sub(stacked_val, median_per_ch)),
axis=self._aggregation_axes,
keepdims=self._keepdims,
axis=0,
keepdims=False,
)
if not self._keepdims:
median_per_ch = self._tensor_processor.squeeze(median_per_ch, self._aggregation_axes)
if self._keepdims:
median_per_ch = self._tensor_processor.reshape(median_per_ch, shape_after_aggregation)
mad_values = self._tensor_processor.reshape(mad_values, shape_after_aggregation)
else:
median_per_ch = self._tensor_processor.squeeze(median_per_ch, 0)
return {
MedianMADTensorStatistic.MEDIAN_VALUES_STAT: median_per_ch.tensor,
MedianMADTensorStatistic.MAD_VALUES_STAT: mad_values.tensor,
Expand All @@ -769,6 +790,8 @@ def __init__(
window_size=None,
):
super().__init__(tensor_processor, aggregation_axes=aggregation_axes, num_samples=num_samples)
if 0 not in self._aggregation_axes:
raise NotImplementedError("Aggregation without 0 dim is not supported yet for PercentileAggregator")
self._percentiles_to_collect = percentiles_to_collect
self._window_size = window_size
self._container = deque(maxlen=window_size)
Expand All @@ -777,17 +800,62 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:
return self._container.append(x)

def _aggregate_impl(self) -> Dict[float, NNCFTensor]:
stacked_val = self._tensor_processor.stack(self._container)
stacked_val, shape_after_aggregation = _moveaxes_flatten_cat(
self._container, [x - 1 for x in self._aggregation_axes if x > 0], self._tensor_processor
)

percentiles = self._tensor_processor.percentile(
stacked_val, self._percentiles_to_collect, axis=self._aggregation_axes, keepdims=self._keepdims
stacked_val, self._percentiles_to_collect, axis=0, keepdims=False
)
retval = {}
for idx, percentile in enumerate(self._percentiles_to_collect):
retval[percentile] = percentiles[idx].tensor
value = percentiles[idx]
if self._keepdims:
value = self._tensor_processor.reshape(value, shape_after_aggregation)
retval[percentile] = value.tensor
return retval


def _moveaxes_flatten_cat(
tensor_list: List[NNCFTensor], aggregation_axes: Tuple[int, ...], tensor_processor: NNCFCollectorTensorProcessor
) -> Tuple[NNCFTensor, Tuple[int, ...]]:
"""
Moves aggregation axes to the begining of the tensor shape for each tensor from the list, flattens
and concatenates them in 0 dimension. Computes target shape for the processed tensor
after an aggregation function is applied to it. Target shape preserves original order
of dimensions and replaces aggregated dimensions by 1.
:param tensor_list: NNCFTensor list to process.
:param aggregation_axes: Aggregation axes to move, flatten and concatinate.
:param tensor_processor: Backed-specific tensor processor instance.
:return: Tuple of the processed tensor and
target shape for the processed tensor after an aggregation function is applied to it.
"""
tensor_shape = list(tensor_list[0].shape)

# Transpose dims to move aggregation axes forward
transpose_dims = list(range(len(tensor_shape)))
for idx, axis in enumerate(aggregation_axes):
transpose_dims[axis], transpose_dims[idx] = transpose_dims[idx], transpose_dims[axis]

# Shape to flatten aggregation axes
reshape_shape = [
-1,
] + [
tensor_shape[dim] for dim in transpose_dims
][len(aggregation_axes) :]

reshaped_tensors = []
for tensor in tensor_list:
transposed_t = tensor_processor.transpose(tensor, transpose_dims)
reshaped_tensors.append(tensor_processor.reshape(transposed_t, reshape_shape))

shape_after_aggregation = (1,) + tuple(
1 if idx in aggregation_axes else dim for idx, dim in enumerate(tensor_shape)
)
return tensor_processor.cat(reshaped_tensors, axis=0), shape_after_aggregation


AGGREGATORS_MAP = {
AggregatorType.MIN: MinAggregator,
AggregatorType.MAX: MaxAggregator,
Expand Down
13 changes: 13 additions & 0 deletions nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
t = x.reshape(x.shape[0], x.shape[1], -1)
return ONNXNNCFTensor(np.mean(t, axis=(0, 2)))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
return ONNXNNCFTensor(np.transpose(x.tensor, axes))

@staticmethod
def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor:
return ONNXNNCFTensor(np.reshape(x.tensor, shape))

@staticmethod
def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
x = [t.tensor for t in x]
return ONNXNNCFTensor(np.concatenate(x, axis))

@staticmethod
def batch_mean(x: NNCFTensor) -> NNCFTensor:
return ONNXNNCFTensor(np.mean(x.tensor, axis=0, keepdims=True))
Expand Down
13 changes: 13 additions & 0 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
t = x.reshape(x.shape[0], x.shape[1], -1)
return OVNNCFTensor(np.mean(t, axis=(0, 2)))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
return OVNNCFTensor(np.transpose(x.tensor, axes))

@staticmethod
def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor:
return OVNNCFTensor(np.reshape(x.tensor, shape))

@staticmethod
def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
x = [t.tensor for t in x]
return OVNNCFTensor(np.concatenate(x, axis))

@staticmethod
def batch_mean(x: NNCFTensor) -> NNCFTensor:
return OVNNCFTensor(np.mean(x.tensor, axis=0, keepdims=True))
Expand Down
12 changes: 12 additions & 0 deletions nncf/tensorflow/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ def percentile(
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
raise NotImplementedError()

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
raise NotImplementedError()

@staticmethod
def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor:
raise NotImplementedError()

@staticmethod
def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
raise NotImplementedError()

@staticmethod
def sub(a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:
raise NotImplementedError()
Expand Down
22 changes: 20 additions & 2 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def median(x: NNCFTensor, axis: Union[int, Tuple[int, ...], List[int]], keepdims
device = x.tensor.device
result = torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims))
return PTNNCFTensor(result.type(x.tensor.dtype).to(device))
return PTNNCFTensor(torch.quantile(x.tensor, q=0.5, dim=axis, keepdim=keepdims).values)
return PTNNCFTensor(torch.quantile(x.tensor, q=0.5, dim=axis, keepdim=keepdims))

@classmethod
def masked_mean(
Expand Down Expand Up @@ -122,6 +122,19 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
def batch_mean(x: NNCFTensor) -> NNCFTensor:
return PTNNCFTensor(torch.mean(x.tensor, axis=0, keepdims=True))

@staticmethod
def transpose(x: NNCFTensor, axes: Tuple[int, ...]) -> NNCFTensor:
return PTNNCFTensor(torch.permute(x.tensor, axes))

@staticmethod
def reshape(x: NNCFTensor, shape: Tuple[int, ...]) -> NNCFTensor:
return PTNNCFTensor(torch.reshape(x.tensor, shape))

@staticmethod
def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor:
x = [t.tensor for t in x]
return PTNNCFTensor(torch.cat(x, axis))

@staticmethod
def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
return PTNNCFTensor(torch.logical_or(input_.tensor, other.tensor))
Expand Down Expand Up @@ -165,7 +178,12 @@ def quantile(
np.quantile(tensor.tensor.detach().cpu().numpy(), q=quantile, axis=axis, keepdims=keepdims)
)
else:
result = torch.quantile(tensor.tensor, torch.tensor(quantile).type(tensor.tensor.dtype), axis, keepdims)
result = torch.quantile(
tensor.tensor,
torch.tensor(quantile, dtype=tensor.tensor.dtype, device=tensor.tensor.device),
axis,
keepdims,
)
result = result.type(tensor.tensor.dtype).to(device)
return [PTNNCFTensor(x) for x in result]

Expand Down
61 changes: 48 additions & 13 deletions tests/common/experimental/test_reducers_and_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,19 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor,

assert self.all_close(ret_val, self.cast_tensor(refs, Dtype.FLOAT))

@pytest.fixture(
name="MAD_precentile_aggregator_cls",
params=[
MedianAbsoluteDeviationAggregator,
partial(
PercentileAggregator,
percentiles_to_collect=[5, 10, 90, 95],
),
],
)
def aggregator_cls_fixture(self, request):
return request.param

REF_MAD_PERCENTILE_REF_VALUES = {
MedianAbsoluteDeviationAggregator: {
None: {
Expand Down Expand Up @@ -361,20 +374,10 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor,
},
}

@pytest.mark.parametrize(
"aggregator_cls",
[
MedianAbsoluteDeviationAggregator,
partial(
PercentileAggregator,
percentiles_to_collect=[5, 10, 90, 95],
),
],
)
@pytest.mark.parametrize("aggregation_axes", [None, (0,), (0, 1)])
def test_mad_percentile_aggregators(self, aggregator_cls, tensor_processor, aggregation_axes):
aggregator = aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=aggregation_axes)
input_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.float32)
def test_mad_percentile_aggregators(self, MAD_precentile_aggregator_cls, tensor_processor, aggregation_axes):
aggregator = MAD_precentile_aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=aggregation_axes)
input_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
for i in range(9):
aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i, Dtype.FLOAT))

Expand All @@ -384,6 +387,38 @@ def test_mad_percentile_aggregators(self, aggregator_cls, tensor_processor, aggr
for k, v in ref_values.items():
assert self.all_close(ret_val[k], self.cast_tensor(v, Dtype.FLOAT))

REF_MAD_PERCENTILE_REF_VALUES_DYNAMIC_TENSORS = {
MedianAbsoluteDeviationAggregator: {
"median_values": np.array([28.5, 35.5, 43.5]),
"mad_values": np.array([24.0, 24.0, 24.0]),
},
PercentileAggregator: {
5: np.array([0.95, 5.95, 9.95]),
10: np.array([1.9, 7.9, 15.5]),
90: np.array([75.1, 83.1, 91.1]),
95: np.array([77.05, 85.05, 93.05]),
},
}

def test_mad_percentile_aggregators_different_sizes(self, MAD_precentile_aggregator_cls, tensor_processor):
aggregator = MAD_precentile_aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=(0, 1, 3))
for shape in ((2, 3, 4), (4, 3, 8)):
aggregator.register_reduced_input(
self.get_nncf_tensor(np.arange(np.prod(shape)).reshape(shape), Dtype.FLOAT)
)
ret_val = aggregator.aggregate()

ref_values = self.REF_MAD_PERCENTILE_REF_VALUES_DYNAMIC_TENSORS[aggregator.__class__]
assert len(ret_val) == len(ref_values)
for k, v in ref_values.items():
assert self.all_close(ret_val[k], self.cast_tensor(v, Dtype.FLOAT))

def test_mad_percentile_aggregators_not_implemented_aggregation_axes(
self, MAD_precentile_aggregator_cls, tensor_processor
):
with pytest.raises(NotImplementedError):
MAD_precentile_aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=(1, 2, 3))

@pytest.mark.parametrize(
"reducer_name",
["noop", "min", "max", "abs_max", "mean", "quantile", "abs_quantile", "batch_mean", "mean_per_ch"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def tensor_processor(self):
return OVNNCFCollectorTensorProcessor

def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None):
if dtype is Dtype.INTEGER:
x = x.astype(np.int64)
if dtype is Dtype.FLOAT:
x = x.astype(np.float32)
return OVNNCFTensor(x)

@pytest.fixture(scope="module")
Expand Down

0 comments on commit cb781eb

Please sign in to comment.