Skip to content

Commit

Permalink
Collect statistics from subset in weight compression (#3061)
Browse files Browse the repository at this point in the history
### Changes

Set a subset size for collecting statistics

### Reason for changes

`subset size` option is ignored and the whole dataset is used

### Related tickets

n/a

### Tests

- [x] unit test for setting subset size
- [x] manual/job/post_training_weight_compression/245
  • Loading branch information
ljaljushkin authored Nov 7, 2024
1 parent defb624 commit 20ab35d
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 49 deletions.
1 change: 0 additions & 1 deletion nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def compress_weights_impl(
statistics_aggregator,
model,
graph,
subset_size,
compression_algorithm,
matmul_input_to_output_nodes_map,
)
Expand Down
8 changes: 3 additions & 5 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def __init__(

primary_config = WeightCompressionConfig(mode=self._mode, group_size=self._group_size)
criterion_cls = MIXED_PRECISION_CRITERIA.get(self._sensitivity_metric)
self._mixed_precision_algo = criterion_cls(primary_config, self._ratio)
self._mixed_precision_algo = criterion_cls(primary_config, self._ratio, self._subset_size)
self._statistics_path = self._advanced_parameters.statistics_path
if self._gptq:
gptq_params = self._advanced_parameters.gptq_params
Expand Down Expand Up @@ -759,15 +759,13 @@ def get_statistic_points(
model: TModel,
graph: NNCFGraph,
nodes_and_port_ids: Iterable[Tuple[NNCFNode, int]],
subset_size: Optional[int] = None,
) -> StatisticPointsContainer:
"""
Returns statistic points, for which StatisticsCollector should collect statistics.
:param model: Model for statistics collection.
:param graph: Model graph.
:param nodes_and_port_ids: Nodes and port ids for which statistics should be collected.
:param subset_size: Number of samples to collect.
:return: Statistic points, for which StatisticsCollector should collect statistics.
"""
statistic_container = StatisticPointsContainer()
Expand All @@ -781,7 +779,7 @@ def get_statistic_points(
# size dimension.
n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape)
stat_collector = self._backend_entity.mean_statistic_collector(
reduction_axes=tuple(range(n_dims - 1)), subset_size=subset_size
reduction_axes=tuple(range(n_dims - 1)), subset_size=self._subset_size
)
statistic_container.add_statistic_point(
StatisticPoint(
Expand All @@ -791,7 +789,7 @@ def get_statistic_points(
# Statistics for mixed precision algorithm
if self._data_aware_mixed_precision:
mixed_precision_statistics = self._mixed_precision_algo.get_statistic_points(
model, graph, nodes_and_port_ids, self._subset_size
model, graph, nodes_and_port_ids
)
for points in mixed_precision_statistics.values():
for point in points:
Expand Down
40 changes: 19 additions & 21 deletions nncf/quantization/algorithms/weight_compression/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,16 @@ class MixedPrecisionCriterion(Algorithm):
for weights based on some criteria.
"""

def __init__(
self,
primary_config: WeightCompressionConfig,
ratio: float,
):
def __init__(self, primary_config: WeightCompressionConfig, ratio: float, subset_size: Optional[int] = None):
"""
:param primary_config: Configuration on how to compress (quantize) weights to primary precision.
:param ratio: The ratio between primary and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4
and the rest to INT8_ASYM).
:param subset_size: Size of dataset subset for statistics.
"""
self._primary_config = primary_config
self._ratio = ratio
self._subset_size = subset_size
self._algorithm_key = f"MPC_{hash(self)}"
self._backend_entity = None

Expand Down Expand Up @@ -117,15 +115,13 @@ def get_statistic_points(
model: TModel,
graph: NNCFGraph,
nodes_and_port_ids: Iterable[Tuple[NNCFNode, int]],
subset_size: Optional[int] = None,
) -> StatisticPointsContainer:
"""
Returns statistic points, for which StatisticsCollector should collect statistics.
:param model: Model for statistics collection.
:param graph: Model graph.
:param nodes_and_port_ids: Nodes and port ids for which statistics should be collected.
:param subset_size: Number of samples to collect.
:return: Statistic points, for which StatisticsCollector should collect statistics.
"""

Expand Down Expand Up @@ -201,7 +197,6 @@ def get_statistic_points(
model: TModel,
graph: NNCFGraph,
nodes_and_port_ids: Iterable[Tuple[NNCFNode, int]],
subset_size: Optional[int] = None,
) -> StatisticPointsContainer:
raise RuntimeError("No statistics collection intended for data-free mixed precision criterion")

Expand Down Expand Up @@ -262,7 +257,6 @@ def get_statistic_points(
model: TModel,
graph: NNCFGraph,
nodes_and_port_ids: Iterable[Tuple[NNCFNode, int]],
subset_size: Optional[int] = None,
) -> StatisticPointsContainer:
self._set_backend_entity(model)

Expand All @@ -277,7 +271,7 @@ def get_statistic_points(
statistic_point = self._backend_entity.target_point(
TargetType.POST_LAYER_OPERATION, act_node.node_name, port_id=output_port_id
)
stat_collector = self._get_statistic_collector(subset_size=subset_size)
stat_collector = self._get_statistic_collector()
statistic_container.add_statistic_point(
StatisticPoint(
target_point=statistic_point, tensor_collector=stat_collector, algorithm=self._algorithm_key
Expand All @@ -287,11 +281,9 @@ def get_statistic_points(
return statistic_container

@abstractmethod
def _get_statistic_collector(self, subset_size=None):
def _get_statistic_collector():
"""
Get statistic collector
:param subset_size: Number of samples to collect
"""

def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[NNCFNode, int]:
Expand Down Expand Up @@ -367,8 +359,8 @@ def _calc_weight_sensitivity(
decompressed_weight = decompressed_weight.reshape(orig_shape)
return fns.linalg.norm(decompressed_weight - weight, ord="fro").item()

def _get_statistic_collector(self, subset_size=None):
return self._backend_entity.hawq_statistic_collector(subset_size)
def _get_statistic_collector(self):
return self._backend_entity.hawq_statistic_collector()


@MIXED_PRECISION_CRITERIA.register(SensitivityMetric.MEAN_ACTIVATION_VARIANCE)
Expand All @@ -379,9 +371,11 @@ class MeanVarianceCriterion(DataBasedCriterion):

STAT_KEY = SensitivityMetric.MEAN_ACTIVATION_VARIANCE.value

def _get_statistic_collector(self, subset_size=None):
def _get_statistic_collector(self):
# Reducing across the second-last dimension, assuming it is the sequence length dimension
return self._backend_entity.mean_variance_statistic_collector(reduction_axes=(-2,), subset_size=subset_size)
return self._backend_entity.mean_variance_statistic_collector(
reduction_axes=(-2,), subset_size=self._subset_size
)


@MIXED_PRECISION_CRITERIA.register(SensitivityMetric.MAX_ACTIVATION_VARIANCE)
Expand All @@ -392,9 +386,11 @@ class MaxVarianceCriterion(DataBasedCriterion):

STAT_KEY = SensitivityMetric.MAX_ACTIVATION_VARIANCE.value

def _get_statistic_collector(self, subset_size=None):
def _get_statistic_collector(self):
# Reducing across the second-last dimension, assuming it is the sequence length dimension
return self._backend_entity.max_variance_statistic_collector(reduction_axes=(-2,), subset_size=subset_size)
return self._backend_entity.max_variance_statistic_collector(
reduction_axes=(-2,), subset_size=self._subset_size
)


@MIXED_PRECISION_CRITERIA.register(SensitivityMetric.MEAN_ACTIVATION_MAGNITUDE)
Expand All @@ -405,6 +401,8 @@ class MeanMaxCriterion(DataBasedCriterion):

STAT_KEY = SensitivityMetric.MEAN_ACTIVATION_MAGNITUDE.value

def _get_statistic_collector(self, subset_size=None):
def _get_statistic_collector(self):
# Reducing across the second-last dimension, assuming it is the sequence length dimension
return self._backend_entity.mean_abs_max_statistic_collector(reduction_axes=(-2,), subset_size=subset_size)
return self._backend_entity.mean_abs_max_statistic_collector(
reduction_axes=(-2,), subset_size=self._subset_size
)
17 changes: 4 additions & 13 deletions nncf/quantization/statistics_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def register_statistics_for_algorithm(
aggregator: StatisticsAggregator,
model: TModel,
graph: NNCFGraph,
subset_size: int,
compression_algo: WeightCompression,
matmul_input_to_output_nodes_map: Dict[Tuple[NNCFNode, int], List[NNCFNode]],
) -> None:
Expand All @@ -36,14 +35,11 @@ def register_statistics_for_algorithm(
:param aggregator: Aggregator to register statistics.
:param model: Model being analyzed.
:param graph: Model's computational graph.
:param subset_size: Size of dataset subset for statistics.
:param compression_algo: WeightCompression algorithm instance.
:param matmul_input_to_output_nodes_map: A dictionary mapping from a tuple of (activation node, port ID)
to a list of MatMul nodes that accept the activation as input.
"""
statistic_points = compression_algo.get_statistic_points(
model, graph, matmul_input_to_output_nodes_map.keys(), subset_size
)
statistic_points = compression_algo.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys())
aggregator.register_statistic_points(statistic_points)


Expand Down Expand Up @@ -73,10 +69,8 @@ def _register_mixed_precision(

for sensitivity in sensitivities:
criterion_cls = MIXED_PRECISION_CRITERIA.get(sensitivity)
mixed_prec_algo = criterion_cls(None, None)
statistic_points = mixed_prec_algo.get_statistic_points(
model, graph, matmul_input_to_output_nodes_map.keys(), subset_size
)
mixed_prec_algo = criterion_cls(None, None, subset_size)
statistic_points = mixed_prec_algo.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys())
aggregator.register_statistic_points(statistic_points)


Expand All @@ -94,15 +88,12 @@ def register_all_statistics(
:param aggregator: Aggregator to register statistics.
:param model: Model being analyzed.
:param graph: Model's computational graph.
:param subset_size: Size of dataset subset for statistics.
:param compression_algo: WeightCompression algorithm instance.
:param enable_mixed_precision: Whether to enable mixed precision statistics.
"""
_, matmul_input_to_output_nodes_map = compression_algo.get_compression_nodes_info(graph)

register_statistics_for_algorithm(
aggregator, model, graph, subset_size, compression_algo, matmul_input_to_output_nodes_map
)
register_statistics_for_algorithm(aggregator, model, graph, compression_algo, matmul_input_to_output_nodes_map)

if enable_mixed_precision:
_register_mixed_precision(aggregator, model, graph, matmul_input_to_output_nodes_map, subset_size)
Expand Down
29 changes: 20 additions & 9 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,25 +886,36 @@ def test_compression_for_different_dtypes(activation_dtype, weight_dtype):
check_compressed_matmul_subgraph(scale_multiply_node, activation_dtype, weight_dtype)


DATASET_SIZE = 129
DATASET_SIZE = 5


@pytest.mark.parametrize(
("subset_size", "ref_size"),
("dataset_size", "subset_size", "ref_size"),
(
(1, 1),
(5, 5),
(130, DATASET_SIZE),
(DATASET_SIZE, 1, 1),
(DATASET_SIZE, DATASET_SIZE, DATASET_SIZE),
(DATASET_SIZE, DATASET_SIZE + 1, DATASET_SIZE),
),
)
def test_valid_subset_size(mocker, subset_size, ref_size):
@pytest.mark.parametrize(
("compression_args", "multiplier_of_calls"),
(
(dict(mode=CompressWeightsMode.INT4_ASYM, ratio=1), 0), # data-free, no reducers
(dict(mode=CompressWeightsMode.INT4_ASYM, ratio=0.5), 1), # 1 reducer for mixed precision
(dict(mode=CompressWeightsMode.INT4_ASYM, ratio=1, awq=True), 2), # mean & shape reducer for AWQ
(dict(mode=CompressWeightsMode.INT4_ASYM, ratio=0.5, awq=True), 3), # 2 - for AWQ + 1 - for Mixed Precision
),
)
def test_number_of_reduced_statistics_for_subset_size(
mocker, dataset_size, subset_size, ref_size, compression_args, multiplier_of_calls
):
model = IdentityMatmul().ov_model
dataset = Dataset([ACTIVATION] * DATASET_SIZE)
dataset = Dataset([ACTIVATION] * dataset_size)
stats_spy = mocker.spy(AggregatorBase, "register_reduced_input")

compress_weights(model, mode=CompressWeightsMode.INT4_ASYM, ratio=0.5, dataset=dataset, subset_size=subset_size)
compress_weights(model, dataset=dataset, subset_size=subset_size, **compression_args)

assert stats_spy.call_count == ref_size
assert stats_spy.call_count == ref_size * multiplier_of_calls


def test_default_subset_value():
Expand Down

0 comments on commit 20ab35d

Please sign in to comment.