diff --git a/src/dvclive/data/base.py b/src/dvclive/data/base.py deleted file mode 100644 index df0c6fdf..00000000 --- a/src/dvclive/data/base.py +++ /dev/null @@ -1,76 +0,0 @@ -import abc -from pathlib import Path -from typing import Any, Dict, List, Optional - -from dvclive.error import DataAlreadyLoggedError - - -class Data(abc.ABC): - def __init__(self, name: str, output_folder: str) -> None: - self.name = name - self.output_folder: Path = Path(output_folder) / self.subfolder - self._step: Optional[int] = None - self.val: Optional[List[Any]] = None - self._step_none_logged: bool = False - self._dump_kwargs: Optional[Dict[str, Any]] = None - - @property - def step(self) -> Optional[int]: - return self._step - - @step.setter - def step(self, val: int) -> None: - if self._step_none_logged and val == self._step: - raise DataAlreadyLoggedError(self.name, val) - - self._step = val - - @property - @abc.abstractmethod - def output_path(self) -> Path: - pass - - @property - @abc.abstractmethod - def no_step_output_path(self) -> Path: - pass - - @property - @abc.abstractmethod - def subfolder(self): - pass - - @property - @abc.abstractmethod - def summary(self): - pass - - @staticmethod - @abc.abstractmethod - def could_log(val: object) -> bool: - pass - - def dump(self, val: List[Any], step: Optional[int], **kwargs): - assert val is not None - self.val = val - self.step = step - self._dump_kwargs = kwargs - if not self._step_none_logged and step is None: - self._step_none_logged = True - self.no_step_dump() - elif step == 0: - self.first_step_dump() - else: - self.step_dump() - - @abc.abstractmethod - def first_step_dump(self) -> None: - pass - - @abc.abstractmethod - def no_step_dump(self) -> None: - pass - - @abc.abstractmethod - def step_dump(self) -> None: - pass diff --git a/src/dvclive/data/image.py b/src/dvclive/data/image.py deleted file mode 100644 index c2f96948..00000000 --- a/src/dvclive/data/image.py +++ /dev/null @@ -1,50 +0,0 @@ -from pathlib import Path - -from .base import Data - - -class Image(Data): - suffixes = [".jpg", ".jpeg", ".gif", ".png"] - subfolder = "images" - - @property - def no_step_output_path(self) -> Path: - return self.output_folder / self.name - - @property - def output_path(self) -> Path: - if self._step is None: - output_path = self.no_step_output_path - else: - output_path = self.output_folder / f"{self._step}" / self.name - output_path.parent.mkdir(exist_ok=True, parents=True) - return output_path - - @staticmethod - def could_log(val: object) -> bool: - if val.__class__.__module__ == "PIL.Image": - return True - if val.__class__.__module__ == "numpy": - return True - return False - - def first_step_dump(self) -> None: - if self.no_step_output_path.exists(): - self.no_step_output_path.rename(self.output_path) - - def no_step_dump(self) -> None: - self.step_dump() - - def step_dump(self) -> None: - if self.val.__class__.__module__ == "numpy": - from PIL import Image as ImagePIL - - _val = ImagePIL.fromarray(self.val) - else: - _val = self.val - - _val.save(self.output_path) - - @property - def summary(self): - return {self.name: str(self.output_path)} diff --git a/src/dvclive/error.py b/src/dvclive/error.py index 9c863fe3..24c13e14 100644 --- a/src/dvclive/error.py +++ b/src/dvclive/error.py @@ -14,12 +14,12 @@ def __init__(self, name, val): class InvalidPlotTypeError(DvcLiveError): def __init__(self, name): - from .data import PLOTS + from .plots import SKLEARN_PLOTS self.name = name super().__init__( f"Plot type '{name}' is not supported." - f"\nSupported types are: {list(PLOTS)}" + f"\nSupported types are: {list(SKLEARN_PLOTS)}" ) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 61377585..133a0a4a 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -2,23 +2,21 @@ import logging import os import shutil -from collections import OrderedDict -from itertools import chain from pathlib import Path from typing import Any, Dict, List, Optional, Union from ruamel.yaml.representer import RepresenterError from . import env -from .data import DATA_TYPES, PLOTS, Image, Metric, NumpyEncoder from .dvc import make_checkpoint from .error import ( InvalidDataTypeError, InvalidParameterTypeError, InvalidPlotTypeError, ) +from .plots import PLOT_TYPES, SKLEARN_PLOTS, Image, Metric, NumpyEncoder from .report import make_report -from .serialize import dump_yaml, load_yaml +from .serialize import dump_json, dump_yaml, load_yaml from .studio import post_to_studio from .utils import ( env2bool, @@ -66,11 +64,12 @@ def __init__( self.report_mode: Optional[str] = report self.report_path = "" + self.summary: Dict[str, Any] = {} self._step: Optional[int] = None - self._scalars: Dict[str, Any] = OrderedDict() - self._images: Dict[str, Any] = OrderedDict() - self._plots: Dict[str, Any] = OrderedDict() - self._params: Dict[str, Any] = OrderedDict() + self._metrics: Dict[str, Any] = {} + self._images: Dict[str, Any] = {} + self._plots: Dict[str, Any] = {} + self._params: Dict[str, Any] = {} self._init_paths() @@ -90,7 +89,6 @@ def __init__( self._cleanup() self._latest_studio_step = self.get_step() - if self.report_mode == "studio": from scmrepo.git import Git @@ -104,9 +102,9 @@ def __init__( self.report_mode = None def _cleanup(self): - for data_type in DATA_TYPES: + for plot_type in PLOT_TYPES: shutil.rmtree( - Path(self.plots_path) / data_type.subfolder, ignore_errors=True + Path(self.plots_path) / plot_type.subfolder, ignore_errors=True ) for f in (self.metrics_path, self.report_path, self.params_path): @@ -138,12 +136,6 @@ def get_step(self) -> int: def set_step(self, step: int) -> None: if self._step is None: self._step = 0 - for data in chain( - self._scalars.values(), - self._images.values(), - self._plots.values(), - ): - data.dump(data.val, self._step) self.make_summary() if self.report_mode == "studio": @@ -165,14 +157,16 @@ def log(self, name: str, val: Union[int, float]): if not Metric.could_log(val): raise InvalidDataTypeError(name, type(val)) - if name in self._scalars: - data = self._scalars[name] + if name in self._metrics: + data = self._metrics[name] else: data = Metric(name, self.plots_path) - self._scalars[name] = data + self._metrics[name] = data - data.dump(val, self._step) + data.step = self.get_step() + data.dump(val) + self.summary = nested_update(self.summary, data.to_summary(val)) self.make_summary() logger.debug(f"Logged {name}: {val}") @@ -186,7 +180,8 @@ def log_image(self, name: str, val): data = Image(name, self.plots_path) self._images[name] = data - data.dump(val, self._step) + data.step = self.get_step() + data.dump(val) logger.debug(f"Logged {name}: {val}") def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs): @@ -195,13 +190,14 @@ def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs): name = name or kind if name in self._plots: data = self._plots[name] - elif kind in PLOTS and PLOTS[kind].could_log(val): - data = PLOTS[kind](name, self.plots_path) + elif kind in SKLEARN_PLOTS and SKLEARN_PLOTS[kind].could_log(val): + data = SKLEARN_PLOTS[kind](name, self.plots_path) self._plots[name] = data else: raise InvalidPlotTypeError(name) - data.dump(val, self._step, **kwargs) + data.step = self.get_step() + data.dump(val, **kwargs) logger.debug(f"Logged {name}") def _read_params(self): @@ -211,7 +207,7 @@ def _read_params(self): def _dump_params(self): try: - dump_yaml(self.params_path, self._params) + dump_yaml(self._params, self.params_path) except RepresenterError as exc: raise InvalidParameterTypeError(exc.args) from exc @@ -230,15 +226,9 @@ def log_param( self.log_params({name: val}) def make_summary(self): - summary_data = {} if self._step is not None: - summary_data["step"] = self.get_step() - - for data in self._scalars.values(): - summary_data = nested_update(summary_data, data.summary) - - with open(self.metrics_path, "w", encoding="utf-8") as f: - json.dump(summary_data, f, indent=4, cls=NumpyEncoder) + self.summary["step"] = self.get_step() + dump_json(self.summary, self.metrics_path, cls=NumpyEncoder) def make_report(self): if self.report_mode is not None: diff --git a/src/dvclive/data/__init__.py b/src/dvclive/plots/__init__.py similarity index 60% rename from src/dvclive/data/__init__.py rename to src/dvclive/plots/__init__.py index 0829d446..2b7ec00d 100644 --- a/src/dvclive/data/__init__.py +++ b/src/dvclive/plots/__init__.py @@ -1,19 +1,13 @@ from .image import Image from .metric import Metric -from .sklearn_plot import ( - Calibration, - ConfusionMatrix, - Det, - PrecisionRecall, - Roc, -) +from .sklearn import Calibration, ConfusionMatrix, Det, PrecisionRecall, Roc from .utils import NumpyEncoder # noqa: F401 -PLOTS = { +SKLEARN_PLOTS = { "calibration": Calibration, "confusion_matrix": ConfusionMatrix, "det": Det, "precision_recall": PrecisionRecall, "roc": Roc, } -DATA_TYPES = (*PLOTS.values(), Metric, Image) +PLOT_TYPES = (*SKLEARN_PLOTS.values(), Metric, Image) diff --git a/src/dvclive/plots/base.py b/src/dvclive/plots/base.py new file mode 100644 index 00000000..8473bacb --- /dev/null +++ b/src/dvclive/plots/base.py @@ -0,0 +1,41 @@ +import abc +from pathlib import Path +from typing import Optional + +from dvclive.error import DataAlreadyLoggedError + + +class Data(abc.ABC): + def __init__(self, name: str, output_folder: str) -> None: + self.name = name + self.output_folder: Path = Path(output_folder) / self.subfolder + self._step: Optional[int] = None + + @property + def step(self) -> Optional[int]: + return self._step + + @step.setter + def step(self, val: int) -> None: + if val == self._step: + raise DataAlreadyLoggedError(self.name, val) + self._step = val + + @property + @abc.abstractmethod + def output_path(self) -> Path: + pass + + @property + @abc.abstractmethod + def subfolder(self): + pass + + @staticmethod + @abc.abstractmethod + def could_log(val) -> bool: + pass + + @abc.abstractmethod + def dump(self, val, **kwargs): + pass diff --git a/src/dvclive/plots/image.py b/src/dvclive/plots/image.py new file mode 100644 index 00000000..2e4a14c3 --- /dev/null +++ b/src/dvclive/plots/image.py @@ -0,0 +1,31 @@ +from pathlib import Path + +from .base import Data + + +class Image(Data): + suffixes = [".jpg", ".jpeg", ".gif", ".png"] + subfolder = "images" + + @property + def output_path(self) -> Path: + _path = self.output_folder / self.name + _path.parent.mkdir(exist_ok=True, parents=True) + return _path + + @staticmethod + def could_log(val: object) -> bool: + if val.__class__.__module__ == "PIL.Image": + return True + if val.__class__.__module__ == "numpy": + return True + return False + + def dump(self, val, **kwargs) -> None: + if val.__class__.__module__ == "numpy": + from PIL import Image as ImagePIL + + pil_image = ImagePIL.fromarray(val) + else: + pil_image = val + pil_image.save(self.output_path) diff --git a/src/dvclive/data/metric.py b/src/dvclive/plots/metric.py similarity index 80% rename from src/dvclive/data/metric.py rename to src/dvclive/plots/metric.py index c1d965f6..2ac5cf16 100644 --- a/src/dvclive/data/metric.py +++ b/src/dvclive/plots/metric.py @@ -29,33 +29,19 @@ def output_path(self) -> Path: _path.parent.mkdir(exist_ok=True, parents=True) return _path - @property - def no_step_output_path(self) -> Path: - return self.output_path - - def first_step_dump(self) -> None: - self.step_dump() - - def no_step_dump(self) -> None: - pass - - def step_dump(self) -> None: + def dump(self, val, **kwargs) -> None: ts = int(time.time() * 1000) d = OrderedDict( - [("timestamp", ts), ("step", self.step), (self.name, self.val)] + [("timestamp", ts), ("step", self.step), (self.name, val)] ) - existed = self.output_path.exists() with open(self.output_path, "a", encoding="utf-8") as fobj: writer = csv.DictWriter(fobj, d.keys(), delimiter="\t") - if not existed: writer.writeheader() - writer.writerow(d) - @property - def summary(self): + def to_summary(self, val): d = {} - nested_set(d, os.path.normpath(self.name).split(os.path.sep), self.val) + nested_set(d, os.path.normpath(self.name).split(os.path.sep), val) return d diff --git a/src/dvclive/data/sklearn_plot.py b/src/dvclive/plots/sklearn.py similarity index 66% rename from src/dvclive/data/sklearn_plot.py rename to src/dvclive/plots/sklearn.py index af24cb13..1e86a24e 100644 --- a/src/dvclive/data/sklearn_plot.py +++ b/src/dvclive/plots/sklearn.py @@ -1,6 +1,7 @@ -import json from pathlib import Path +from dvclive.serialize import dump_json + from .base import Data @@ -21,32 +22,6 @@ def could_log(val: object) -> bool: return True return False - @property - def no_step_output_path(self) -> Path: - return super().no_step_output_path.with_suffix(".json") - - @property - def summary(self): - return {} - - @staticmethod - def write_json(content, output_file): - with open(output_file, "w", encoding="utf-8") as f: - json.dump(content, f, indent=4) - - def no_step_dump(self) -> None: - raise NotImplementedError - - def first_step_dump(self) -> None: - raise NotImplementedError( - "DVCLive plots can only be used in no-step mode." - ) - - def step_dump(self) -> None: - raise NotImplementedError( - "DVCLive plots can only be used in no-step mode." - ) - @staticmethod def get_properties(): raise NotImplementedError @@ -63,13 +38,11 @@ def get_properties(): "y_label": "True Positive Rate", } - def no_step_dump(self) -> None: - assert self.val is not None - + def dump(self, val, **kwargs) -> None: from sklearn import metrics fpr, tpr, roc_thresholds = metrics.roc_curve( - y_true=self.val[0], y_score=self.val[1], **self._dump_kwargs + y_true=val[0], y_score=val[1], **kwargs ) roc = { "roc": [ @@ -77,7 +50,7 @@ def no_step_dump(self) -> None: for fp, tp, t in zip(fpr, tpr, roc_thresholds) ] } - self.write_json(roc, self.output_path) + dump_json(roc, self.output_path) class PrecisionRecall(SKLearnPlot): @@ -91,13 +64,11 @@ def get_properties(): "y_label": "Precision", } - def no_step_dump(self) -> None: - assert self.val is not None - + def dump(self, val, **kwargs) -> None: from sklearn import metrics precision, recall, prc_thresholds = metrics.precision_recall_curve( - y_true=self.val[0], probas_pred=self.val[1], **self._dump_kwargs + y_true=val[0], probas_pred=val[1], **kwargs ) prc = { @@ -106,7 +77,7 @@ def no_step_dump(self) -> None: for p, r, t in zip(precision, recall, prc_thresholds) ] } - self.write_json(prc, self.output_path) + dump_json(prc, self.output_path) class Det(SKLearnPlot): @@ -120,13 +91,11 @@ def get_properties(): "y_label": "False Negative Rate", } - def no_step_dump(self) -> None: - assert self.val is not None - + def dump(self, val, **kwargs) -> None: from sklearn import metrics fpr, fnr, roc_thresholds = metrics.det_curve( - y_true=self.val[0], y_score=self.val[1], **self._dump_kwargs + y_true=val[0], y_score=val[1], **kwargs ) det = { @@ -135,7 +104,7 @@ def no_step_dump(self) -> None: for fp, fn, t in zip(fpr, fnr, roc_thresholds) ] } - self.write_json(det, self.output_path) + dump_json(det, self.output_path) class ConfusionMatrix(SKLearnPlot): @@ -150,14 +119,12 @@ def get_properties(): "y_label": "Predicted Label", } - def no_step_dump(self) -> None: - assert self.val is not None - + def dump(self, val, **kwargs) -> None: cm = [ {"actual": str(actual), "predicted": str(predicted)} - for actual, predicted in zip(self.val[0], self.val[1]) + for actual, predicted in zip(val[0], val[1]) ] - self.write_json(cm, self.output_path) + dump_json(cm, self.output_path) class Calibration(SKLearnPlot): @@ -171,19 +138,17 @@ def get_properties(): "y_label": "Fraction of Positives", } - def no_step_dump(self) -> None: - assert self.val is not None - + def dump(self, val, **kwargs) -> None: from sklearn import calibration prob_true, prob_pred = calibration.calibration_curve( - y_true=self.val[0], y_prob=self.val[1], **self._dump_kwargs + y_true=val[0], y_prob=val[1], **kwargs ) - calibration = { + _calibration = { "calibration": [ {"prob_true": pt, "prob_pred": pp} for pt, pp in zip(prob_true, prob_pred) ] } - self.write_json(calibration, self.output_path) + dump_json(_calibration, self.output_path) diff --git a/src/dvclive/data/utils.py b/src/dvclive/plots/utils.py similarity index 100% rename from src/dvclive/data/utils.py rename to src/dvclive/plots/utils.py diff --git a/src/dvclive/report.py b/src/dvclive/report.py index 54f3cc7e..0082dba9 100644 --- a/src/dvclive/report.py +++ b/src/dvclive/report.py @@ -8,8 +8,8 @@ from dvc_render.table import TableRenderer from dvc_render.vega import VegaRenderer -from dvclive.data import PLOTS, Image, Metric -from dvclive.data.sklearn_plot import SKLearnPlot +from dvclive.plots import SKLEARN_PLOTS, Image, Metric +from dvclive.plots.sklearn import SKLearnPlot from dvclive.serialize import load_yaml from dvclive.utils import parse_tsv @@ -38,12 +38,12 @@ def get_scalar_renderers(metrics_path): def get_image_renderers(images_folder): - dvclive_path = images_folder.parent + plots_path = images_folder.parent.parent renderers = [] for suffix in Image.suffixes: all_images = Path(images_folder).rglob(f"*{suffix}") for file in sorted(all_images): - src = str(file.relative_to(dvclive_path)) + src = str(file.relative_to(plots_path)) name = str(file.relative_to(images_folder)) data = [ { @@ -65,7 +65,7 @@ def get_plot_renderers(plots_folder): data = data[name] for row in data: row["rev"] = "workspace" - properties = PLOTS[name].get_properties() + properties = SKLEARN_PLOTS[name].get_properties() renderers.append(VegaRenderer(data, name, **properties)) return renderers diff --git a/src/dvclive/serialize.py b/src/dvclive/serialize.py index 3a286e5a..019923c4 100644 --- a/src/dvclive/serialize.py +++ b/src/dvclive/serialize.py @@ -1,3 +1,4 @@ +import json from collections import OrderedDict from dvclive.error import DvcLiveError @@ -36,7 +37,12 @@ def _get_yaml(): return yaml -def dump_yaml(path, data): +def dump_yaml(content, output_file): yaml = _get_yaml() - with open(path, "w", encoding="utf-8") as fd: - yaml.dump(data, fd) + with open(output_file, "w", encoding="utf-8") as fd: + yaml.dump(content, fd) + + +def dump_json(content, output_file, indent=4, **kwargs): + with open(output_file, "w", encoding="utf-8") as f: + json.dump(content, f, indent=indent, **kwargs) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 73ebc23e..0ddecb83 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -23,6 +23,7 @@ def nested_set(d, keys, value): for key in keys[:-1]: d = d.setdefault(key, {}) d[keys[-1]] = value + return d def nested_update(d, u): @@ -108,7 +109,7 @@ def parse_json(path): def parse_metrics(live): - from .data import Metric + from .plots import Metric plots_path = Path(live.plots_path) history = {} diff --git a/tests/plots/test_image.py b/tests/plots/test_image.py new file mode 100644 index 00000000..310fd127 --- /dev/null +++ b/tests/plots/test_image.py @@ -0,0 +1,64 @@ +import numpy as np +import pytest +from PIL import Image + +# pylint: disable=unused-argument +from dvclive import Live +from dvclive.plots import Image as LiveImage + + +def test_PIL(tmp_dir): + live = Live() + img = Image.new("RGB", (10, 10), (250, 250, 250)) + live.log_image("image.png", img) + + assert ( + tmp_dir / live.plots_path / LiveImage.subfolder / "image.png" + ).exists() + + +def test_invalid_extension(tmp_dir): + live = Live() + img = Image.new("RGB", (10, 10), (250, 250, 250)) + with pytest.raises(ValueError): + live.log_image("image.foo", img) + + +@pytest.mark.parametrize("shape", [(10, 10), (10, 10, 3), (10, 10, 4)]) +def test_numpy(tmp_dir, shape): + live = Live() + img = np.ones(shape, np.uint8) * 255 + live.log_image("image.png", img) + + assert ( + tmp_dir / live.plots_path / LiveImage.subfolder / "image.png" + ).exists() + + +def test_override_on_step(tmp_dir): + live = Live() + + zeros = np.zeros((2, 2, 3), np.uint8) + live.log_image("image.png", zeros) + + live.next_step() + + ones = np.ones((2, 2, 3), np.uint8) + live.log_image("image.png", ones) + + img_path = tmp_dir / live.plots_path / LiveImage.subfolder / "image.png" + assert np.array_equal(np.array(Image.open(img_path)), ones) + + +def test_cleanup(tmp_dir): + live = Live() + img = np.ones((10, 10, 3), np.uint8) + live.log_image("image.png", img) + + assert ( + tmp_dir / live.plots_path / LiveImage.subfolder / "image.png" + ).exists() + + Live() + + assert not (tmp_dir / live.plots_path / LiveImage.subfolder).exists() diff --git a/tests/test_data/test_scalar.py b/tests/plots/test_metric.py similarity index 91% rename from tests/test_data/test_scalar.py rename to tests/plots/test_metric.py index f585e91c..063882d3 100644 --- a/tests/test_data/test_scalar.py +++ b/tests/plots/test_metric.py @@ -5,8 +5,8 @@ # pylint: disable=unused-argument from dvclive import Live -from dvclive.data.metric import Metric -from dvclive.data.utils import NUMPY_INTS, NUMPY_SCALARS +from dvclive.plots.metric import Metric +from dvclive.plots.utils import NUMPY_INTS, NUMPY_SCALARS from dvclive.utils import parse_tsv diff --git a/tests/test_data/test_plot.py b/tests/plots/test_sklearn.py similarity index 89% rename from tests/test_data/test_plot.py rename to tests/plots/test_sklearn.py index 93ce9635..085916b9 100644 --- a/tests/test_data/test_plot.py +++ b/tests/plots/test_sklearn.py @@ -4,7 +4,7 @@ from sklearn import calibration, metrics from dvclive import Live -from dvclive.data.sklearn_plot import SKLearnPlot +from dvclive.plots.sklearn import SKLearnPlot # pylint: disable=redefined-outer-name, unused-argument @@ -99,19 +99,6 @@ def test_log_confusion_matrix(tmp_dir, y_true_y_pred_y_score, mocker): assert cm[0]["predicted"] == str(y_pred[0]) -def test_step_exception(tmp_dir, y_true_y_pred_y_score): - live = Live() - out = tmp_dir / live.plots_path / SKLearnPlot.subfolder - - y_true, y_pred, _ = y_true_y_pred_y_score - - live.log_sklearn_plot("confusion_matrix", y_true, y_pred) - assert (out / "confusion_matrix.json").exists() - - with pytest.raises(NotImplementedError): - live.next_step() - - def test_dump_kwargs(tmp_dir, y_true_y_pred_y_score, mocker): live = Live() @@ -124,6 +111,22 @@ def test_dump_kwargs(tmp_dir, y_true_y_pred_y_score, mocker): spy.assert_called_once_with(y_true, y_score, drop_intermediate=True) +def test_override_on_step(tmp_dir): + live = Live() + + live.log_sklearn_plot("confusion_matrix", [0, 0], [0, 0]) + live.next_step() + live.log_sklearn_plot("confusion_matrix", [0, 0], [1, 1]) + + plot_path = tmp_dir / live.plots_path / SKLearnPlot.subfolder + plot_path = plot_path / "confusion_matrix.json" + + assert json.loads(plot_path.read_text()) == [ + {"actual": "0", "predicted": "1"}, + {"actual": "0", "predicted": "1"}, + ] + + def test_cleanup(tmp_dir, y_true_y_pred_y_score): live = Live() out = tmp_dir / live.plots_path / SKLearnPlot.subfolder diff --git a/tests/test_catalyst.py b/tests/test_catalyst.py index a3bb4aac..b4703b83 100644 --- a/tests/test_catalyst.py +++ b/tests/test_catalyst.py @@ -7,7 +7,7 @@ from dvclive import Live from dvclive.catalyst import DvcLiveCallback -from dvclive.data import Metric +from dvclive.plots import Metric # pylint: disable=redefined-outer-name, unused-argument diff --git a/tests/test_data/test_image.py b/tests/test_data/test_image.py deleted file mode 100644 index 093060b4..00000000 --- a/tests/test_data/test_image.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np -import pytest -from PIL import Image - -# pylint: disable=unused-argument -from dvclive import Live -from dvclive.data import Image as LiveImage - - -def test_PIL(tmp_dir): - live = Live() - img = Image.new("RGB", (500, 500), (250, 250, 250)) - live.log_image("image.png", img) - - assert ( - tmp_dir / live.plots_path / LiveImage.subfolder / "image.png" - ).exists() - - -def test_invalid_extension(tmp_dir): - live = Live() - img = Image.new("RGB", (500, 500), (250, 250, 250)) - with pytest.raises(ValueError): - live.log_image("image.foo", img) - - -@pytest.mark.parametrize("shape", [(500, 500), (500, 500, 3), (500, 500, 4)]) -def test_numpy(tmp_dir, shape): - live = Live() - img = np.ones(shape, np.uint8) * 255 - live.log_image("image.png", img) - - assert ( - tmp_dir / live.plots_path / LiveImage.subfolder / "image.png" - ).exists() - - -def test_step_formatting(tmp_dir): - live = Live() - img = np.ones((500, 500, 3), np.uint8) - for _ in range(3): - live.log_image("image.png", img) - live.next_step() - - for step in range(3): - assert ( - tmp_dir - / live.plots_path - / LiveImage.subfolder - / str(step) - / "image.png" - ).exists() - - -def test_step_rename(tmp_dir, mocker): - from pathlib import Path - - rename = mocker.spy(Path, "rename") - live = Live() - img = np.ones((500, 500, 3), np.uint8) - live.log_image("image.png", img) - assert ( - tmp_dir / live.plots_path / LiveImage.subfolder / "image.png" - ).exists() - - live.next_step() - - assert not ( - tmp_dir / live.plots_path / LiveImage.subfolder / "image.png" - ).exists() - assert ( - tmp_dir / live.plots_path / LiveImage.subfolder / "0" / "image.png" - ).exists() - rename.assert_called_once_with( - Path(live.plots_path) / LiveImage.subfolder / "image.png", - Path(live.plots_path) / LiveImage.subfolder / "0" / "image.png", - ) - - -def test_cleanup(tmp_dir): - live = Live() - img = np.ones((500, 500, 3), np.uint8) - live.log_image("image.png", img) - - assert ( - tmp_dir / live.plots_path / LiveImage.subfolder / "image.png" - ).exists() - - Live() - - assert not (tmp_dir / live.plots_path / LiveImage.subfolder).exists() diff --git a/tests/test_fastai.py b/tests/test_fastai.py index f7cacf83..7eb0c6ce 100644 --- a/tests/test_fastai.py +++ b/tests/test_fastai.py @@ -11,8 +11,8 @@ ) from dvclive import Live -from dvclive.data.metric import Metric from dvclive.fastai import DvcLiveCallback +from dvclive.plots.metric import Metric # pylint: disable=redefined-outer-name, unused-argument diff --git a/tests/test_huggingface.py b/tests/test_huggingface.py index 2e1ce141..1ffae80e 100644 --- a/tests/test_huggingface.py +++ b/tests/test_huggingface.py @@ -12,8 +12,8 @@ ) from dvclive import Live -from dvclive.data.metric import Metric from dvclive.huggingface import DvcLiveCallback +from dvclive.plots.metric import Metric from dvclive.utils import parse_metrics # pylint: disable=redefined-outer-name, unused-argument, no-value-for-parameter diff --git a/tests/test_keras.py b/tests/test_keras.py index 33c9fe75..32b11bc8 100644 --- a/tests/test_keras.py +++ b/tests/test_keras.py @@ -3,8 +3,8 @@ import pytest from dvclive import Live -from dvclive.data.metric import Metric from dvclive.keras import DvcLiveCallback +from dvclive.plots.metric import Metric from dvclive.utils import parse_metrics # pylint: disable=unused-argument, no-name-in-module, redefined-outer-name diff --git a/tests/test_lightning.py b/tests/test_lightning.py index 52dd6b22..d24ca05d 100644 --- a/tests/test_lightning.py +++ b/tests/test_lightning.py @@ -8,8 +8,8 @@ from torch.optim import Adam from torch.utils.data import DataLoader, Dataset -from dvclive.data.metric import Metric from dvclive.lightning import DvcLiveLogger +from dvclive.plots.metric import Metric from dvclive.utils import parse_metrics # pylint: disable=redefined-outer-name, unused-argument diff --git a/tests/test_main.py b/tests/test_main.py index 80620837..05ee21a7 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,17 +1,18 @@ # pylint: disable=protected-access # pylint: disable=unused-argument +import json import os import pytest from funcy import last from dvclive import Live, env -from dvclive.data import Metric from dvclive.error import ( DataAlreadyLoggedError, InvalidDataTypeError, InvalidParameterTypeError, ) +from dvclive.plots import Metric from dvclive.serialize import load_yaml from dvclive.utils import parse_metrics @@ -239,7 +240,7 @@ def test_require_step_update(tmp_dir, metric): with pytest.raises( DataAlreadyLoggedError, - match="has already been logged with step 'None'", + match="has already been logged with step '0'", ): dvclive.log(metric, 2.0) @@ -340,3 +341,17 @@ def test_logger(tmp_dir, mocker, monkeypatch): live = Live(resume=True) logger.info.assert_called_with("Resumed from step 0") + + +def test_make_summary_without_calling_log(tmp_dir): + dvclive = Live() + + dvclive.summary["foo"] = 1.0 + dvclive.make_summary() + + assert json.loads((tmp_dir / dvclive.metrics_path).read_text()) == { + # no `step` + "foo": 1.0 + } + log_file = tmp_dir / dvclive.plots_path / Metric.subfolder / "foo.tsv" + assert not log_file.exists() diff --git a/tests/test_report.py b/tests/test_report.py index 8ed4abe5..01a235e7 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -5,10 +5,10 @@ from PIL import Image from dvclive import Live -from dvclive.data import Image as LiveImage -from dvclive.data import Metric -from dvclive.data.sklearn_plot import ConfusionMatrix, SKLearnPlot from dvclive.env import DVCLIVE_OPEN +from dvclive.plots import Image as LiveImage +from dvclive.plots import Metric +from dvclive.plots.sklearn import ConfusionMatrix, SKLearnPlot from dvclive.report import ( get_image_renderers, get_metrics_renderers, @@ -30,20 +30,18 @@ def test_get_renderers(tmp_dir, mocker): live.log_image("image.png", img) live.next_step() - live.set_step(None) live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1]) image_renderers = get_image_renderers( tmp_dir / live.plots_path / LiveImage.subfolder ) - assert len(image_renderers) == 2 - image_renderers = sorted( - image_renderers, key=lambda x: x.datapoints[0]["rev"] - ) - for n, renderer in enumerate(image_renderers): - assert renderer.datapoints == [ - {"src": mocker.ANY, "rev": os.path.join(str(n), "image.png")} - ] + assert len(image_renderers) == 1 + assert image_renderers[0].datapoints == [ + { + "src": os.path.join("plots", LiveImage.subfolder, "image.png"), + "rev": "image.png", + } + ] scalar_renderers = get_scalar_renderers( tmp_dir / live.plots_path / Metric.subfolder diff --git a/tests/test_studio.py b/tests/test_studio.py index 4c18deab..44c082f3 100644 --- a/tests/test_studio.py +++ b/tests/test_studio.py @@ -5,7 +5,7 @@ import pytest from dvclive import Live, env -from dvclive.data import Metric +from dvclive.plots import Metric @pytest.mark.studio