From df3b70f8310526b19603c3e1c45c27e13f101d45 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Mon, 18 Sep 2023 09:08:00 +0200 Subject: [PATCH] Added rich-based progress bar (#2132) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Changes Added `nncf.common.logging.track_progress.track()` method to replace `tqdm` for quantization algorithms bars. This method is an almost exact copy of `rich.progress.track()` method, but with an addition of an iteration counter. This is kind of a hack, but `rich` does not provide a way to add custom `Column` objects to the `track()` method. By default `rich.progress.track()` renders progress bar as ![image](https://github.com/openvinotoolkit/nncf/assets/23343961/7ca21a3c-f3dd-4a93-967c-6d1ac019017b) With an addition of customizations this becomes ![ac437dfc-de24-46fc-9538-3bdfee2e3ac9](https://github.com/openvinotoolkit/nncf/assets/23343961/ac5d1e1c-1454-48c2-a8d0-3cd6fe465c26) With this change the quantization pipeline looks like ![image](https://github.com/openvinotoolkit/nncf/assets/23343961/7a42d991-0df1-41a0-b177-621bed44b157) For iterables without length the progress bar displays as ![image](https://github.com/openvinotoolkit/nncf/assets/23343961/c4c59d9d-1b08-477b-bfec-4984ebb06189) How it looks in a notebook: ![image](https://github.com/openvinotoolkit/nncf/assets/23343961/12f8a786-dc59-4cd0-b4fa-cefefd92b0dd) ### Reason for changes - User experience improvement - Avoiding multi-line logs in CI produced by `tqdm`, for example: Statistics collection: 0%| | 0/300 [00:00 Text: + if task.total is None: + return Text("") + text = f"{int(task.completed)}/{int(task.total)}" + if task.finished: + return Text(text, style="progress.elapsed") + return Text(text, style="progress.remaining") + + +class SeparatorColumn(ProgressColumn): + def __init__(self, table_column: Optional[Column] = None, disable_if_no_total: bool = False) -> None: + super().__init__(table_column) + self.disable_if_no_total = disable_if_no_total + + def render(self, task: Task) -> Text: + if self.disable_if_no_total and task.total is None: + return Text("") + return Text("•") + + +class track: + def __init__( + self, + sequence: Optional[Union[Sequence[ProgressType], Iterable[ProgressType]]] = None, + description: str = "Working...", + total: Optional[float] = None, + auto_refresh: bool = True, + console: Optional[Console] = None, + transient: bool = False, + get_time: Optional[Callable[[], float]] = None, + refresh_per_second: float = 10, + style: StyleType = "bar.back", + complete_style: StyleType = "bar.complete", + finished_style: StyleType = "bar.finished", + pulse_style: StyleType = "bar.pulse", + update_period: float = 0.1, + disable: bool = False, + show_speed: bool = True, + ): + """ + Track progress by iterating over a sequence. + + This function is very similar to rich.progress.track(), but with some customizations. + + :param sequence: An iterable (must support "len") you wish to iterate over. + :param description: Description of the task to show next to the progress bar. Defaults to "Working". + :param total: Total number of steps. Default is len(sequence). + :param auto_refresh: Automatic refresh. Disable to force a refresh after each iteration. Default is True. + :param transient: Clear the progress on exit. Defaults to False. + :param get_time: A callable that gets the current time, or None to use Console.get_time. Defaults to None. + :param console: Console to write to. Default creates an internal Console instance. + :param refresh_per_second: Number of times per second to refresh the progress information. Defaults to 10. + :param style: Style for the bar background. Defaults to "bar.back". + :param complete_style: Style for the completed bar. Defaults to "bar.complete". + :param finished_style: Style for a finished bar. Defaults to "bar.finished". + :param pulse_style: Style for pulsing bars. Defaults to "bar.pulse". + :param update_period: Minimum time (in seconds) between calls to update(). Defaults to 0.1. + :param disable: Disable display of progress. + :param show_speed: Show speed if the total isn't known. Defaults to True. + :return: An iterable of the values in the sequence. + """ + + self.sequence = sequence + self.total = total + self.description = description + self.update_period = update_period + self.task = None + + self.columns: List[ProgressColumn] = ( + [TextColumn("[progress.description]{task.description}")] if description else [] + ) + self.columns.extend( + ( + BarColumn( + style=style, + complete_style=complete_style, + finished_style=finished_style, + pulse_style=pulse_style, + ), + TaskProgressColumn(show_speed=show_speed), + IterationsColumn(), + SeparatorColumn(), + TimeElapsedColumn(), + SeparatorColumn(disable_if_no_total=True), # disable because time remaining will be empty + TimeRemainingColumn(), + ) + ) + self.progress = Progress( + *self.columns, + auto_refresh=auto_refresh, + console=console, + transient=transient, + get_time=get_time, + refresh_per_second=refresh_per_second or 10, + disable=disable, + ) + + def __iter__(self) -> Iterable[ProgressType]: + with self.progress: + yield from self.progress.track( + self.sequence, total=self.total, description=self.description, update_period=self.update_period + ) + + def __enter__(self): + self.progress.start() + self.task = self.progress.add_task(self.description, total=self.total) + return self + + def __exit__(self, *args): + self.task = None + self.progress.stop() diff --git a/nncf/common/tensor_statistics/aggregator.py b/nncf/common/tensor_statistics/aggregator.py index 444c9581d55..6c4925ad458 100644 --- a/nncf/common/tensor_statistics/aggregator.py +++ b/nncf/common/tensor_statistics/aggregator.py @@ -13,11 +13,10 @@ from itertools import islice from typing import Any, Dict, TypeVar -from tqdm.auto import tqdm - from nncf.common import factory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.logging.track_progress import track from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.data.dataset import Dataset @@ -60,10 +59,10 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: if self.stat_subset_size is not None else None ) - for input_data in tqdm( + for input_data in track( islice(self.dataset.get_inference_data(), self.stat_subset_size), total=total, - desc="Statistics collection", + description="Statistics collection", ): outputs = engine.infer(input_data) processed_outputs = self._process_outputs(outputs) diff --git a/nncf/quantization/algorithms/bias_correction/algorithm.py b/nncf/quantization/algorithms/bias_correction/algorithm.py index 10bc844ec2c..c3e1e10a6f6 100644 --- a/nncf/quantization/algorithms/bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/bias_correction/algorithm.py @@ -13,7 +13,6 @@ from typing import Any, Dict, List, Optional, Tuple, TypeVar import numpy as np -from tqdm.auto import tqdm from nncf import Dataset from nncf import nncf_logger @@ -26,6 +25,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationCommand from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.logging.track_progress import track from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -159,8 +159,8 @@ def apply( # for which we will create a subgraph for inference and collection of statistics. subgraphs_data = [self._get_subgraph_data_for_node(node, nncf_graph) for node in nodes_with_bias] - for position, (node, subgraph_data) in tqdm( - list(enumerate(zip(nodes_with_bias, subgraphs_data))), desc="Applying Bias correction" + for position, (node, subgraph_data) in track( + list(enumerate(zip(nodes_with_bias, subgraphs_data))), description="Applying Bias correction" ): node_name = node.node_name diff --git a/nncf/quantization/algorithms/channel_alignment/algorithm.py b/nncf/quantization/algorithms/channel_alignment/algorithm.py index 08bf7731e18..c5cf65dfefd 100644 --- a/nncf/quantization/algorithms/channel_alignment/algorithm.py +++ b/nncf/quantization/algorithms/channel_alignment/algorithm.py @@ -9,10 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, TypeVar +from typing import Dict, List, Optional, Tuple, TypeVar import numpy as np -from tqdm.auto import tqdm from nncf import Dataset from nncf.common.factory import CommandCreatorFactory @@ -23,6 +22,7 @@ from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.logging.track_progress import track from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -104,7 +104,7 @@ def apply( def filter_func(point: StatisticPoint) -> bool: return self._algorithm_key in point.algorithm_to_tensor_collectors and point.target_point == target_point - for conv_in, add_in, conv_out in tqdm(self._get_node_pairs(graph), desc="Channel alignment"): + for conv_in, add_in, conv_out in track(self._get_node_pairs(graph), description="Channel alignment"): target_point, node_in = self._get_target_point_and_node_in(conv_in, add_in) tensor_collectors = list( statistic_points.get_algo_statistics_for_node(node_in.node_name, filter_func, self._algorithm_key) diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index 8192eebb460..c83fe3b8e26 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -11,8 +11,6 @@ from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union -from tqdm.auto import tqdm - from nncf import Dataset from nncf.common.factory import EngineFactory from nncf.common.factory import ModelTransformerFactory @@ -22,6 +20,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging import nncf_logger +from nncf.common.logging.track_progress import track from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -139,7 +138,7 @@ def apply( # for which we should update bias and new bias values. node_and_new_bias_value = [] - for node, bias_value in tqdm(node_and_bias_value, desc="Applying Fast Bias correction"): + for node, bias_value in track(node_and_bias_value, description="Applying Fast Bias correction"): node_name = node.node_name if not self._backend_entity.is_quantized_weights(node, graph): diff --git a/nncf/quantization/algorithms/smooth_quant/algorithm.py b/nncf/quantization/algorithms/smooth_quant/algorithm.py index 5be71625a3a..a9ccdef9e10 100644 --- a/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -24,8 +24,6 @@ from copy import deepcopy from typing import Dict, List, Optional, Tuple, TypeVar -from tqdm.auto import tqdm - from nncf import Dataset from nncf.common.factory import ModelTransformerFactory from nncf.common.graph.graph import NNCFGraph @@ -33,6 +31,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging import nncf_logger +from nncf.common.logging.track_progress import track from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -111,7 +110,7 @@ def apply( node_groups = self._group_nodes_by_source(nodes_to_smooth_data, graph) best_scale = None - for group_id, nodes in tqdm(node_groups.items(), desc="Applying Smooth Quant"): + for group_id, nodes in track(node_groups.items(), description="Applying Smooth Quant"): best_ratio = 0.0 empty_statistic = False for node_to_smooth in nodes: diff --git a/setup.py b/setup.py index 2058599c505..78d6a03c626 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ def find_version(*file_paths): # Using 2.x versions of pyparsing seems to fix the issue. # Ticket: 69520 "pyparsing<3.0", + "rich>=13.5.2", "scikit-learn>=0.24.0", "scipy>=1.3.2", "texttable>=1.6.3",