From 54c59242d5584f0f9df441854a063a07766a4583 Mon Sep 17 00:00:00 2001 From: zina-cs <109593976+zina-cs@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:04:15 +0400 Subject: [PATCH] Update aggregator.py --- nncf/common/tensor_statistics/aggregator.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/nncf/common/tensor_statistics/aggregator.py b/nncf/common/tensor_statistics/aggregator.py index 8c007d94a62..56691e5abde 100644 --- a/nncf/common/tensor_statistics/aggregator.py +++ b/nncf/common/tensor_statistics/aggregator.py @@ -17,13 +17,13 @@ from nncf.common import factory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.data.dataset import DataItem from nncf.data.dataset import Dataset from nncf.data.dataset import ModelInput -from nncf.common.logging import nncf_logger TensorType = TypeVar("TensorType") TModel = TypeVar("TModel") @@ -69,11 +69,8 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics) model_with_outputs: TModel = model_transformer.transform(transformation_layout) engine = factory.EngineFactory.create(model_with_outputs) - iterations_number = self._get_iterations_number() - processed_samples = 0 - for input_data in track( # type: ignore islice(self.dataset.get_inference_data(), iterations_number), total=iterations_number, @@ -83,10 +80,8 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: processed_outputs = self._process_outputs(outputs) self._register_statistics(processed_outputs, merged_statistics) processed_samples += 1 - if processed_samples == 0: raise nncf.ValidationError(EMPTY_DATASET_ERROR) - if self.stat_subset_size is not None and self.stat_subset_size > processed_samples: nncf_logger.warning( f"Dataset contains only {processed_samples} samples, "