Skip to content

Commit

Permalink
Make RichProgressBar visible for both light and dark background (#20260)
Browse files Browse the repository at this point in the history
  • Loading branch information
tshu-w authored Sep 30, 2024
1 parent 5be58f6 commit 474bdd0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 45 deletions.
37 changes: 6 additions & 31 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,14 @@ class RichProgressBarTheme:
"""

description: Union[str, "Style"] = "white"
description: Union[str, "Style"] = ""
progress_bar: Union[str, "Style"] = "#6206E0"
progress_bar_finished: Union[str, "Style"] = "#6206E0"
progress_bar_pulse: Union[str, "Style"] = "#6206E0"
batch_progress: Union[str, "Style"] = "white"
time: Union[str, "Style"] = "grey54"
processing_speed: Union[str, "Style"] = "grey70"
metrics: Union[str, "Style"] = "white"
batch_progress: Union[str, "Style"] = ""
time: Union[str, "Style"] = "dim"
processing_speed: Union[str, "Style"] = "dim underline"
metrics: Union[str, "Style"] = "italic"
metrics_text_delimiter: str = " "
metrics_format: str = ".3f"

Expand Down Expand Up @@ -280,7 +280,6 @@ def __init__(
self._metric_component: Optional[MetricsTextColumn] = None
self._progress_stopped: bool = False
self.theme = theme
self._update_for_light_colab_theme()

@property
def refresh_rate(self) -> float:
Expand Down Expand Up @@ -318,13 +317,6 @@ def test_progress_bar(self) -> "Task":
assert self.test_progress_bar_id is not None
return self.progress.tasks[self.test_progress_bar_id]

def _update_for_light_colab_theme(self) -> None:
if _detect_light_colab_theme():
attributes = ["description", "batch_progress", "metrics"]
for attr in attributes:
if getattr(self.theme, attr) == "white":
setattr(self.theme, attr, "black")

@override
def disable(self) -> None:
self._enabled = False
Expand Down Expand Up @@ -449,7 +441,7 @@ def on_validation_batch_start(
def _add_task(self, total_batches: Union[int, float], description: str, visible: bool = True) -> "TaskID":
assert self.progress is not None
return self.progress.add_task(
f"[{self.theme.description}]{description}",
f"[{self.theme.description}]{description}" if self.theme.description else description,
total=total_batches,
visible=visible,
)
Expand Down Expand Up @@ -656,20 +648,3 @@ def __getstate__(self) -> Dict:
state["progress"] = None
state["_console"] = None
return state


def _detect_light_colab_theme() -> bool:
"""Detect if it's light theme in Colab."""
try:
import get_ipython
except (NameError, ModuleNotFoundError):
return False
ipython = get_ipython()
if "google.colab" in str(ipython.__class__):
try:
from google.colab import output

return output.eval_js('document.documentElement.matches("[theme=light]")')
except ModuleNotFoundError:
return False
return False
14 changes: 0 additions & 14 deletions tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,20 +308,6 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmp_path):
assert val_bar.total == 4


@RunIf(rich=True)
@mock.patch("lightning.pytorch.callbacks.progress.rich_progress._detect_light_colab_theme", return_value=True)
def test_rich_progress_bar_colab_light_theme_update(*_):
theme = RichProgressBar().theme
assert theme.description == "black"
assert theme.batch_progress == "black"
assert theme.metrics == "black"

theme = RichProgressBar(theme=RichProgressBarTheme(description="blue", metrics="red")).theme
assert theme.description == "blue"
assert theme.batch_progress == "black"
assert theme.metrics == "red"


@RunIf(rich=True)
def test_rich_progress_bar_metric_display_task_id(tmp_path):
class CustomModel(BoringModel):
Expand Down

0 comments on commit 474bdd0

Please sign in to comment.