diff --git a/dvc/command/experiments.py b/dvc/command/experiments.py index c72c96ea6e..6521d78d04 100644 --- a/dvc/command/experiments.py +++ b/dvc/command/experiments.py @@ -1,5 +1,4 @@ import argparse -import io import logging from collections import Counter, OrderedDict, defaultdict from collections.abc import Mapping @@ -18,6 +17,9 @@ logger = logging.getLogger(__name__) +SHOW_MAX_WIDTH = 1024 + + def _filter_name(names, label, filter_strs): ret = defaultdict(dict) path_filters = defaultdict(list) @@ -299,8 +301,8 @@ def _parse_filter_list(param_list): return ret -def _show_experiments(all_experiments, console, **kwargs): - from rich.table import Table +def _experiments_table(all_experiments, **kwargs): + from dvc.utils.table import Table include_metrics = _parse_filter_list(kwargs.pop("include_metrics", [])) exclude_metrics = _parse_filter_list(kwargs.pop("exclude_metrics", [])) @@ -316,11 +318,21 @@ def _show_experiments(all_experiments, console, **kwargs): ) table = Table() - table.add_column("Experiment", no_wrap=True) + table.add_column( + "Experiment", no_wrap=True, header_style="black on grey93" + ) if not kwargs.get("no_timestamp", False): - table.add_column("Created") - _add_data_col(table, metric_names, justify="right", no_wrap=True) - _add_data_col(table, param_names, justify="left") + table.add_column("Created", header_style="black on grey93") + _add_data_columns( + table, + metric_names, + justify="right", + no_wrap=True, + header_style="black on cornsilk1", + ) + _add_data_columns( + table, param_names, justify="left", header_style="black on light_cyan1" + ) for base_rev, experiments in all_experiments.items(): for row, _, in _collect_rows( @@ -328,17 +340,20 @@ def _show_experiments(all_experiments, console, **kwargs): ): table.add_row(*row) - console.print(table) + return table -def _add_data_col(table, names, **kwargs): +def _add_data_columns(table, names, **kwargs): count = Counter( name for path in names for name in names[path] for path in names ) + first = True for path in names: for name in names[path]: col_name = name if count[name] == 1 else f"{path}:{name}" + kwargs["collapse"] = False if first else True table.add_column(col_name, **kwargs) + first = False def _format_json(item): @@ -351,8 +366,6 @@ class CmdExperimentsShow(CmdBase): def run(self): from rich.console import Console - from dvc.utils.pager import pager - try: all_experiments = self.repo.experiments.show( all_branches=self.args.all_branches, @@ -368,23 +381,13 @@ def run(self): logger.info(json.dumps(all_experiments, default=_format_json)) return 0 - if self.args.no_pager: - console = Console() - else: - # Note: rich does not currently include a native way to force - # infinite width for use with a pager - console = Console( - file=io.StringIO(), force_terminal=True, width=9999 - ) - if self.args.precision is None: precision = DEFAULT_PRECISION else: precision = self.args.precision - _show_experiments( + table = _experiments_table( all_experiments, - console, include_metrics=self.args.include_metrics, exclude_metrics=self.args.exclude_metrics, include_params=self.args.include_params, @@ -395,8 +398,22 @@ def run(self): precision=precision, ) - if not self.args.no_pager: - pager(console.file.getvalue()) + console = Console() + if self.args.no_pager: + console.print(table) + else: + from dvc.utils.pager import DvcPager + + # NOTE: rich does not have native support for unlimited width + # via pager. we override rich table compression by setting + # console width to the full width of the table + measurement = table.__rich_measure__(console, SHOW_MAX_WIDTH) + console._width = ( # pylint: disable=protected-access + measurement.maximum + ) + with console.pager(pager=DvcPager(), styles=True): + console.print(table) + except DvcException: logger.exception("failed to show experiments") return 1 diff --git a/dvc/utils/pager.py b/dvc/utils/pager.py index 72344b1395..bf33031599 100644 --- a/dvc/utils/pager.py +++ b/dvc/utils/pager.py @@ -5,6 +5,8 @@ import pydoc import sys +from rich.pager import Pager + from dvc.env import DVC_PAGER from dvc.utils import format_link @@ -46,3 +48,8 @@ def find_pager(): def pager(text): find_pager()(text) + + +class DvcPager(Pager): + def show(self, content: str) -> None: + pager(content) diff --git a/dvc/utils/table.py b/dvc/utils/table.py new file mode 100644 index 0000000000..61363f2b65 --- /dev/null +++ b/dvc/utils/table.py @@ -0,0 +1,111 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, List + +from rich.style import StyleType +from rich.table import Column as RichColumn +from rich.table import Table as RichTable + +if TYPE_CHECKING: + from rich.console import ( + Console, + JustifyMethod, + OverflowMethod, + RenderableType, + ) + + +@dataclass +class Column(RichColumn): + collapse: bool = False + + +class Table(RichTable): + def add_column( # pylint: disable=arguments-differ + self, + header: "RenderableType" = "", + footer: "RenderableType" = "", + *, + header_style: StyleType = None, + footer_style: StyleType = None, + style: StyleType = None, + justify: "JustifyMethod" = "left", + overflow: "OverflowMethod" = "ellipsis", + width: int = None, + min_width: int = None, + max_width: int = None, + ratio: int = None, + no_wrap: bool = False, + collapse: bool = False, + ) -> None: + column = Column( # type: ignore[call-arg] + _index=len(self.columns), + header=header, + footer=footer, + header_style=header_style or "", + footer_style=footer_style or "", + style=style or "", + justify=justify, + overflow=overflow, + width=width, + min_width=min_width, + max_width=max_width, + ratio=ratio, + no_wrap=no_wrap, + collapse=collapse, + ) + self.columns.append(column) + + def _calculate_column_widths( + self, console: "Console", max_width: int + ) -> List[int]: + """Calculate the widths of each column, including padding, not + including borders. + + Adjacent collapsed columns will be removed until there is only a single + truncated column remaining. + """ + widths = super()._calculate_column_widths(console, max_width) + last_collapsed = -1 + for i in range(len(self.columns) - 1, -1, -1): + if widths[i] == 1 and self.columns[i].collapse: + if last_collapsed >= 0: + del widths[last_collapsed] + del self.columns[last_collapsed] + if self.box: + max_width += 1 + for column in self.columns[last_collapsed:]: + column._index -= 1 + last_collapsed = i + padding = self._get_padding_width(i) + if ( + self.columns[i].overflow == "ellipsis" + and (sum(widths) + padding) <= max_width + ): + # Set content width to 1 (plus padding) if we can fit a + # single unicode ellipsis in this column + widths[i] = 1 + padding + else: + last_collapsed = -1 + return widths + + def _collapse_widths( + self, widths: List[int], wrapable: List[bool], max_width: int, + ) -> List[int]: + """Collapse columns right-to-left if possible to fit table into + max_width. + + If table is still too wide after collapsing, rich's automatic overflow + handling will be used. + """ + collapsible = [column.collapse for column in self.columns] + total_width = sum(widths) + excess_width = total_width - max_width + if any(collapsible): + for i in range(len(widths) - 1, -1, -1): + if collapsible[i]: + total_width -= widths[i] + excess_width -= widths[i] + widths[i] = 0 + if excess_width <= 0: + break + return super()._collapse_widths(widths, wrapable, max_width) diff --git a/setup.py b/setup.py index b88f1fca84..896591a98c 100644 --- a/setup.py +++ b/setup.py @@ -82,7 +82,7 @@ def run(self): "pygtrie==2.3.2", "dpath>=2.0.1,<3", "shtab>=1.3.4,<2", - "rich>=3.0.5", + "rich>=9.0.0", "dictdiffer>=0.8.1", "python-benedict>=0.21.1", "pyparsing==2.4.7",