diff --git a/nncf/common/tensor_statistics/aggregator.py b/nncf/common/tensor_statistics/aggregator.py index 8c007d94a62..56691e5abde 100644 --- a/nncf/common/tensor_statistics/aggregator.py +++ b/nncf/common/tensor_statistics/aggregator.py @@ -17,13 +17,13 @@ from nncf.common import factory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.data.dataset import DataItem from nncf.data.dataset import Dataset from nncf.data.dataset import ModelInput -from nncf.common.logging import nncf_logger TensorType = TypeVar("TensorType") TModel = TypeVar("TModel") @@ -69,11 +69,8 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics) model_with_outputs: TModel = model_transformer.transform(transformation_layout) engine = factory.EngineFactory.create(model_with_outputs) - iterations_number = self._get_iterations_number() - processed_samples = 0 - for input_data in track( # type: ignore islice(self.dataset.get_inference_data(), iterations_number), total=iterations_number, @@ -83,10 +80,8 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: processed_outputs = self._process_outputs(outputs) self._register_statistics(processed_outputs, merged_statistics) processed_samples += 1 - if processed_samples == 0: raise nncf.ValidationError(EMPTY_DATASET_ERROR) - if self.stat_subset_size is not None and self.stat_subset_size > processed_samples: nncf_logger.warning( f"Dataset contains only {processed_samples} samples, "