Skip to content

Commit

Permalink
Restore total of status bar (#2606)
Browse files Browse the repository at this point in the history
### Changes

Restore total number of iterations for StatisticsAggregator after
#2197

### Reason for changes

N/A

### Related tickets

136892

### Tests

Manually tested on yolov8 sample for OV.
Before:

![image](https://github.com/openvinotoolkit/nncf/assets/32935044/6a1292ef-cf00-4571-bbe7-92acc36b44f4)

After:

![image](https://github.com/openvinotoolkit/nncf/assets/32935044/aa5c1f6d-c4f8-4d6c-86d0-15a34bc9f86e)
  • Loading branch information
kshpv authored Mar 27, 2024
1 parent 0696166 commit 06f8a1e
Showing 1 changed file with 3 additions and 10 deletions.
13 changes: 3 additions & 10 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from nncf.common import factory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging.logger import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
Expand All @@ -29,9 +28,6 @@
EMPTY_DATASET_ERROR = (
"Calibration dataset must not be empty. Please provide calibration dataset with at least one sample."
)
ITERATIONS_NUMBER_WARNING = (
"The number of iterations for statistics collection is bigger than the length of the dataset."
)


class StatisticsAggregator(ABC):
Expand All @@ -46,16 +42,13 @@ def __init__(self, dataset: Dataset):

def _get_iterations_number(self) -> Optional[int]:
"""
Returns number of iterations, output number is less than min(self.stat_subset_size, dataset_length).
Returns number of iterations.
:return: Number of iterations for statistics collection.
"""
dataset_length = self.dataset.get_length()
if dataset_length and self.stat_subset_size:
if self.stat_subset_size > dataset_length:
nncf_logger.warning(ITERATIONS_NUMBER_WARNING)
return dataset_length
return self.stat_subset_size
return min(dataset_length, self.stat_subset_size)
return dataset_length or self.stat_subset_size

def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
Expand All @@ -78,7 +71,7 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
empty_statistics = True
for input_data in track(
islice(self.dataset.get_inference_data(), iterations_number),
total=self.stat_subset_size,
total=iterations_number,
description="Statistics collection",
):
outputs = engine.infer(input_data)
Expand Down

0 comments on commit 06f8a1e

Please sign in to comment.