From 41d12d7936c7a781a30da27e7fd12faa00b34a1e Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Wed, 8 Feb 2023 11:54:17 +0100 Subject: [PATCH] report: Add `notebook` mode. (#432) Renders existing HTML report inside an IFrame and updates it on each next_step. Closes #309 --- setup.cfg | 1 + src/dvclive/lightning.py | 2 ++ src/dvclive/live.py | 35 +++++++++++++++++++++++++++++------ src/dvclive/report.py | 13 +++++++++++++ src/dvclive/utils.py | 21 ++++++++++++++++++++- tests/test_report.py | 25 +++++++++++++++++++++++++ 6 files changed, 90 insertions(+), 7 deletions(-) diff --git a/setup.cfg b/setup.cfg index 104f8099..6df35185 100644 --- a/setup.cfg +++ b/setup.cfg @@ -56,6 +56,7 @@ tests = %(plots)s %(dvc)s %(markdown)s + ipython dev = %(tests)s %(all)s diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 874fd8b1..b701a1eb 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -37,6 +37,8 @@ def __init__( self._live_init["dir"] = dir self._experiment = experiment self._version = run_name + # Force Live instantiation + self.experiment # noqa pylint: disable=pointless-statement @property def name(self): diff --git a/src/dvclive/live.py b/src/dvclive/live.py index ba7c2573..0165de8b 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -18,10 +18,16 @@ ) from .error import InvalidDataTypeError, InvalidParameterTypeError, InvalidPlotTypeError from .plots import PLOT_TYPES, SKLEARN_PLOTS, Image, Metric, NumpyEncoder -from .report import make_report +from .report import BLANK_NOTEBOOK_REPORT, make_report from .serialize import dump_json, dump_yaml, load_yaml from .studio import get_studio_updates -from .utils import env2bool, matplotlib_installed, nested_update, open_file_in_browser +from .utils import ( + env2bool, + inside_notebook, + matplotlib_installed, + nested_update, + open_file_in_browser, +) try: from dvc_studio_client.env import STUDIO_TOKEN @@ -62,6 +68,7 @@ def __init__( os.makedirs(self.dir, exist_ok=True) self._report_mode: Optional[str] = report + self._report_notebook = None self._init_report() if self._resume: @@ -176,8 +183,20 @@ def _init_report(self): self._report_mode = "md" else: self._report_mode = "html" - elif self._report_mode not in {None, "html", "md"}: - raise ValueError("`report` can only be `None`, `auto`, `html` or `md`") + elif self._report_mode == "notebook": + if inside_notebook(): + from IPython.display import HTML, display + + self._report_mode = "notebook" + self._report_notebook = display( + HTML(BLANK_NOTEBOOK_REPORT), display_id=True + ) + else: + self._report_mode = "html" + elif self._report_mode not in {None, "html", "notebook", "md"}: + raise ValueError( + "`report` can only be `None`, `auto`, `html`, `notebook` or `md`" + ) logger.debug(f"{self._report_mode=}") @property @@ -202,8 +221,12 @@ def plots_dir(self) -> str: @property def report_file(self) -> Optional[str]: - if self._report_mode in ("html", "md"): - return os.path.join(self.dir, f"report.{self._report_mode}") + if self._report_mode in ("html", "md", "notebook"): + if self._report_mode == "notebook": + suffix = "html" + else: + suffix = self._report_mode + return os.path.join(self.dir, f"report.{suffix}") return None @property diff --git a/src/dvclive/report.py b/src/dvclive/report.py index d9e7762b..88b931fa 100644 --- a/src/dvclive/report.py +++ b/src/dvclive/report.py @@ -20,6 +20,13 @@ # noqa pylint: disable=protected-access +BLANK_NOTEBOOK_REPORT = """ +
+DVCLive Report +
+""" + + def get_scalar_renderers(metrics_path): renderers = [] for suffix in Metric.suffixes: @@ -123,6 +130,12 @@ def make_report(live: "Live"): if live._report_mode == "html": render_html(renderers, live.report_file, refresh_seconds=5) + elif live._report_mode == "notebook": + from IPython.display import IFrame + + render_html(renderers, live.report_file) + if live._report_notebook is not None: + live._report_notebook.update(IFrame(live.report_file, "100%", 700)) elif live._report_mode == "md": render_markdown(renderers, live.report_file) else: diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 9db1f57e..2a47acc9 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -7,6 +7,8 @@ from pathlib import Path from platform import uname +# noqa pylint: disable=unused-import + def nested_set(d, keys, value): """Set d[keys[0]]...[keys[-1]] to `value`. @@ -117,9 +119,26 @@ def parse_metrics(live): def matplotlib_installed() -> bool: - # noqa pylint: disable=unused-import try: import matplotlib # noqa: F401 except ImportError: return False return True + + +def inside_notebook() -> bool: + try: + from google import colab # noqa: F401 + + return True + except ImportError: + pass + try: + shell = get_ipython().__class__.__name__ # type: ignore[name-defined] + except NameError: + return False + if shell == "ZMQInteractiveShell": + import IPython + + return IPython.__version__ >= "6.0.0" + return False diff --git a/tests/test_report.py b/tests/test_report.py index e7abce5f..616ef29f 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -2,6 +2,7 @@ import os import pytest +from IPython import display from PIL import Image from dvclive import Live @@ -167,3 +168,27 @@ def test_get_plot_renderers(tmp_dir, mocker): {"actual": "1", "rev": "workspace", "predicted": "1"}, ] assert plot_renderer.properties == ConfusionMatrix.get_properties() + + +def test_report_auto_doesnt_set_notebook(tmp_dir, mocker): + mocker.patch("dvclive.live.inside_notebook", return_value=True) + live = Live() + assert live._report_mode != "notebook" + + +def test_report_notebook_fallsback_to_html(tmp_dir, mocker): + mocker.patch("dvclive.live.inside_notebook", return_value=False) + spy = mocker.spy(display, "display") + live = Live(report="notebook") + assert live._report_mode == "html" + assert not spy.called + + +def test_report_notebook(tmp_dir, mocker): + mocker.patch("dvclive.live.inside_notebook", return_value=True) + mocked_display = mocker.MagicMock() + mocker.patch("IPython.display.display", return_value=mocked_display) + live = Live(report="notebook") + assert live._report_mode == "notebook" + live.make_report() + assert mocked_display.update.called