diff --git a/dvclive/data/__init__.py b/dvclive/data/__init__.py index 5a1a8686..74964b4e 100644 --- a/dvclive/data/__init__.py +++ b/dvclive/data/__init__.py @@ -1,4 +1,12 @@ from .image import Image +from .plot import Calibration, ConfusionMatrix, Det, PrecisionRecall, Roc from .scalar import Scalar -DATA_TYPES = [Image, Scalar] +PLOTS = { + "calibration": Calibration, + "confusion_matrix": ConfusionMatrix, + "det": Det, + "precision_recall": PrecisionRecall, + "roc": Roc, +} +DATA_TYPES = list(PLOTS.values()) + [Scalar, Image] diff --git a/dvclive/data/base.py b/dvclive/data/base.py index 8b7f08f1..e71204e3 100644 --- a/dvclive/data/base.py +++ b/dvclive/data/base.py @@ -8,10 +8,11 @@ class Data(abc.ABC): def __init__(self, name: str, output_folder: str) -> None: self.name = name - self.output_folder: Path = Path(output_folder) + self.output_folder: Path = Path(output_folder) / self.subfolder self._step: Optional[int] = None self.val = None self._step_none_logged: bool = False + self._dump_kwargs = None @property def step(self) -> int: @@ -19,9 +20,7 @@ def step(self) -> int: @step.setter def step(self, val: int) -> None: - if not self._step_none_logged and val is None: - self._step_none_logged = True - elif val == self._step: + if self._step_none_logged and val == self._step: raise DataAlreadyLoggedError(self.name, val) self._step = val @@ -31,6 +30,16 @@ def step(self, val: int) -> None: def output_path(self) -> Path: pass + @property + @abc.abstractmethod + def no_step_output_path(self) -> Path: + pass + + @property + @abc.abstractmethod + def subfolder(self): + pass + @property @abc.abstractmethod def summary(self): @@ -41,6 +50,26 @@ def summary(self): def could_log(val: object) -> bool: pass - def dump(self, val, step): + def dump(self, val, step, **kwargs): self.val = val self.step = step + self._dump_kwargs = kwargs + if not self._step_none_logged and step is None: + self._step_none_logged = True + self.no_step_dump() + elif step == 0: + self.first_step_dump() + else: + self.step_dump() + + @abc.abstractmethod + def first_step_dump(self) -> None: + pass + + @abc.abstractmethod + def no_step_dump(self) -> None: + pass + + @abc.abstractmethod + def step_dump(self) -> None: + pass diff --git a/dvclive/data/image.py b/dvclive/data/image.py index 8d2ee7fb..c2f96948 100644 --- a/dvclive/data/image.py +++ b/dvclive/data/image.py @@ -5,43 +5,45 @@ class Image(Data): suffixes = [".jpg", ".jpeg", ".gif", ".png"] + subfolder = "images" - @staticmethod - def could_log(val: object) -> bool: - if val.__class__.__module__ == "PIL.Image": - return True - if val.__class__.__module__ == "numpy": - return True - return False + @property + def no_step_output_path(self) -> Path: + return self.output_folder / self.name @property def output_path(self) -> Path: - if Path(self.name).suffix not in self.suffixes: - raise ValueError( - f"Invalid image suffix '{Path(self.name).suffix}'" - f" Must be one of {self.suffixes}" - ) if self._step is None: - output_path = self.output_folder / self.name + output_path = self.no_step_output_path else: output_path = self.output_folder / f"{self._step}" / self.name output_path.parent.mkdir(exist_ok=True, parents=True) return output_path - def dump(self, val, step) -> None: - if self._step_none_logged and self._step is None: - super().dump(val, step) - step_none_path = self.output_folder / self.name - if step_none_path.exists(): - step_none_path.rename(self.output_path) - else: - super().dump(val, step) - if val.__class__.__module__ == "numpy": - from PIL import Image as ImagePIL + @staticmethod + def could_log(val: object) -> bool: + if val.__class__.__module__ == "PIL.Image": + return True + if val.__class__.__module__ == "numpy": + return True + return False + + def first_step_dump(self) -> None: + if self.no_step_output_path.exists(): + self.no_step_output_path.rename(self.output_path) + + def no_step_dump(self) -> None: + self.step_dump() - val = ImagePIL.fromarray(val) + def step_dump(self) -> None: + if self.val.__class__.__module__ == "numpy": + from PIL import Image as ImagePIL + + _val = ImagePIL.fromarray(self.val) + else: + _val = self.val - val.save(self.output_path) + _val.save(self.output_path) @property def summary(self): diff --git a/dvclive/data/plot.py b/dvclive/data/plot.py new file mode 100644 index 00000000..5a706e65 --- /dev/null +++ b/dvclive/data/plot.py @@ -0,0 +1,123 @@ +import json +from pathlib import Path + +from .base import Data + + +class Plot(Data): + suffixes = [".json"] + subfolder = "plots" + + @property + def output_path(self) -> Path: + _path = self.output_folder / self.name + _path.parent.mkdir(exist_ok=True, parents=True) + return _path.with_suffix(".json") + + @staticmethod + def could_log(val: object) -> bool: + if isinstance(val, tuple) and len(val) == 2: + return True + return False + + @property + def no_step_output_path(self) -> Path: + return super().no_step_output_path.with_suffix(".json") + + @property + def summary(self): + return {} + + @staticmethod + def write_json(content, output_file): + with open(output_file, "w") as f: + json.dump(content, f, indent=4) + + def no_step_dump(self) -> None: + raise NotImplementedError + + def first_step_dump(self) -> None: + raise NotImplementedError( + "DVCLive plots can only be used in no-step mode." + ) + + def step_dump(self) -> None: + raise NotImplementedError( + "DVCLive plots can only be used in no-step mode." + ) + + +class Roc(Plot): + def no_step_dump(self) -> int: + from sklearn import metrics + + fpr, tpr, roc_thresholds = metrics.roc_curve( + y_true=self.val[0], y_score=self.val[1], **self._dump_kwargs + ) + roc = { + "roc": [ + {"fpr": fp, "tpr": tp, "threshold": t} + for fp, tp, t in zip(fpr, tpr, roc_thresholds) + ] + } + self.write_json(roc, self.output_path) + + +class PrecisionRecall(Plot): + def no_step_dump(self) -> int: + from sklearn import metrics + + precision, recall, prc_thresholds = metrics.precision_recall_curve( + y_true=self.val[0], probas_pred=self.val[1], **self._dump_kwargs + ) + + prc = { + "prc": [ + {"precision": p, "recall": r, "threshold": t} + for p, r, t in zip(precision, recall, prc_thresholds) + ] + } + self.write_json(prc, self.output_path) + + +class Det(Plot): + def no_step_dump(self) -> int: + from sklearn import metrics + + fpr, fnr, roc_thresholds = metrics.det_curve( + y_true=self.val[0], y_score=self.val[1], **self._dump_kwargs + ) + + det = { + "det": [ + {"fpr": fp, "fnr": fn, "threshold": t} + for fp, fn, t in zip(fpr, fnr, roc_thresholds) + ] + } + self.write_json(det, self.output_path) + + +class ConfusionMatrix(Plot): + def no_step_dump(self) -> int: + cm = [ + {"actual": str(actual), "predicted": str(predicted)} + for actual, predicted in zip(self.val[0], self.val[1]) + ] + self.write_json(cm, self.output_path) + + +class Calibration(Plot): + def no_step_dump(self) -> int: + from sklearn import calibration + + prob_true, prob_pred = calibration.calibration_curve( + y_true=self.val[0], y_prob=self.val[1], **self._dump_kwargs + ) + + calibration = { + "calibration": [ + {"prob_true": pt, "prob_pred": pp} + for pt, pp in zip(prob_true, prob_pred) + ] + } + self.write_json(calibration, self.output_path) diff --git a/dvclive/data/scalar.py b/dvclive/data/scalar.py index 5a5b096f..e1075d03 100644 --- a/dvclive/data/scalar.py +++ b/dvclive/data/scalar.py @@ -11,6 +11,7 @@ class Scalar(Data): suffixes = [".csv", ".tsv"] + subfolder = "scalars" @staticmethod def could_log(val: object) -> bool: @@ -24,23 +25,30 @@ def output_path(self) -> Path: _path.parent.mkdir(exist_ok=True, parents=True) return _path.with_suffix(".tsv") - def dump(self, val, step): - super().dump(val, step) + @property + def no_step_output_path(self) -> Path: + return self.output_path + + def first_step_dump(self) -> None: + self.step_dump() + + def no_step_dump(self) -> None: + pass - if step is not None: - ts = int(time.time() * 1000) - d = OrderedDict( - [("timestamp", ts), ("step", self.step), (self.name, self.val)] - ) + def step_dump(self) -> None: + ts = int(time.time() * 1000) + d = OrderedDict( + [("timestamp", ts), ("step", self.step), (self.name, self.val)] + ) - existed = self.output_path.exists() - with open(self.output_path, "a") as fobj: - writer = csv.DictWriter(fobj, d.keys(), delimiter="\t") + existed = self.output_path.exists() + with open(self.output_path, "a") as fobj: + writer = csv.DictWriter(fobj, d.keys(), delimiter="\t") - if not existed: - writer.writeheader() + if not existed: + writer.writeheader() - writer.writerow(d) + writer.writerow(d) @property def summary(self): diff --git a/dvclive/error.py b/dvclive/error.py index 4e7d31a1..51ee7aff 100644 --- a/dvclive/error.py +++ b/dvclive/error.py @@ -26,6 +26,17 @@ def __init__(self, name, val): super().__init__(f"Data '{name}' has not supported type {val}") +class InvalidPlotTypeError(DvcLiveError): + def __init__(self, name): + from .data import PLOTS + + self.name = name + super().__init__( + f"Plot type '{name}' is not supported." + f"\nSupported types are: {list(PLOTS)}" + ) + + class DataAlreadyLoggedError(DvcLiveError): def __init__(self, name, step): self.name = name diff --git a/dvclive/live.py b/dvclive/live.py index 73f47469..0345bd20 100644 --- a/dvclive/live.py +++ b/dvclive/live.py @@ -3,12 +3,17 @@ import os import shutil from collections import OrderedDict +from itertools import chain from pathlib import Path from typing import Any, Dict, Optional, Union -from .data import DATA_TYPES +from .data import DATA_TYPES, PLOTS, Image, Scalar from .dvc import make_checkpoint, make_html -from .error import ConfigMismatchError, InvalidDataTypeError +from .error import ( + ConfigMismatchError, + InvalidDataTypeError, + InvalidPlotTypeError, +) from .utils import nested_update logger = logging.getLogger(__name__) @@ -36,7 +41,9 @@ def __init__( self._path = self.DEFAULT_DIR self._step: Optional[int] = None - self._data: Dict[str, Any] = OrderedDict() + self._scalars: Dict[str, Any] = OrderedDict() + self._images: Dict[str, Any] = OrderedDict() + self._plots: Dict[str, Any] = OrderedDict() if self._resume: self._step = self.read_step() @@ -47,11 +54,10 @@ def __init__( self._init_paths() def _cleanup(self): - for data_type in DATA_TYPES: - for suffix in data_type.suffixes: - for data_file in Path(self.dir).rglob(f"*{suffix}"): - data_file.unlink() + shutil.rmtree( + Path(self.dir) / data_type.subfolder, ignore_errors=True + ) if os.path.exists(self.summary_path): os.remove(self.summary_path) @@ -117,7 +123,11 @@ def set_step(self, step: int) -> None: if self._step is None: self._step = 0 self._init_paths() - for data in self._data.values(): + for data in chain( + self._scalars.values(), + self._images.values(), + self._plots.values(), + ): data.dump(data.val, self._step) if self._summary: self.make_summary() @@ -134,28 +144,51 @@ def next_step(self): self.set_step(self.get_step() + 1) def log(self, name: str, val: Union[int, float]): - data = None - if name in self._data: - data = self._data[name] - else: - for data_type in DATA_TYPES: - if data_type.could_log(val): - data = data_type(name, self.dir) - self._data[name] = data - if data is None: + if not Scalar.could_log(val): raise InvalidDataTypeError(name, type(val)) + if name in self._scalars: + data = self._scalars[name] + else: + data = Scalar(name, self.dir) + self._scalars[name] = data + data.dump(val, self._step) if self._summary: self.make_summary() + def log_image(self, name: str, val): + if not Image.could_log(val): + raise InvalidDataTypeError(name, type(val)) + + if name in self._images: + data = self._images[name] + else: + data = Image(name, self.dir) + self._images[name] = data + + data.dump(val, self._step) + + def log_plot(self, name, labels, predictions, **kwargs): + val = (labels, predictions) + + if name in self._plots: + data = self._plots[name] + elif name in PLOTS and PLOTS[name].could_log(val): + data = PLOTS[name](name, self.dir) + self._plots[name] = data + else: + raise InvalidPlotTypeError(name) + + data.dump(val, self._step, **kwargs) + def make_summary(self): summary_data = {} if self._step is not None: summary_data["step"] = self.get_step() - for data in self._data.values(): + for data in self._scalars.values(): summary_data = nested_update(summary_data, data.summary) with open(self.summary_path, "w") as f: diff --git a/setup.py b/setup.py index d84ef395..10d1a87d 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,8 @@ def run(self): _build_py.run(self) +image = ["pillow"] +plots = ["scikit-learn"] mmcv = ["mmcv"] tf = ["tensorflow"] xgb = ["xgboost"] @@ -44,9 +46,8 @@ def run(self): catalyst = ["catalyst"] fastai = ["fastai"] pl = ["pytorch_lightning"] -image = ["pillow"] -all_libs = mmcv + tf + xgb + lgbm + hugginface + catalyst + fastai + pl + image +all_libs = mmcv + tf + xgb + lgbm + hugginface + catalyst + fastai + pl + plots tests_requires = [ "pylint==2.5.3", @@ -56,7 +57,6 @@ def run(self): "pytest-cov>=2.12.1", "pytest-mock>=3.6.1", "pandas>=1.3.1", - "sklearn", "funcy>=1.14", "dvc>=2.0.0", ] + all_libs @@ -79,8 +79,10 @@ def run(self): "huggingface": hugginface, "catalyst": catalyst, "fastai": fastai, - "image": image, "pytorch_lightning": pl, + "sklearn": plots, + "image": image, + "plots": plots, }, keywords="data-science metrics machine-learning developer-tools ai", python_requires=">=3.6", diff --git a/tests/test_catalyst.py b/tests/test_catalyst.py index 666bb9b5..248a0ecb 100644 --- a/tests/test_catalyst.py +++ b/tests/test_catalyst.py @@ -8,6 +8,7 @@ from torch.utils.data import DataLoader from dvclive.catalyst import DvcLiveCallback +from dvclive.data import Scalar # pylint: disable=redefined-outer-name, unused-argument @@ -59,8 +60,8 @@ def test_catalyst_callback(tmp_dir, runner, loaders): assert os.path.exists("dvclive") - train_path = tmp_dir / "dvclive/train" - valid_path = tmp_dir / "dvclive/valid" + train_path = tmp_dir / "dvclive" / Scalar.subfolder / "train" + valid_path = tmp_dir / "dvclive" / Scalar.subfolder / "valid" assert train_path.is_dir() assert valid_path.is_dir() diff --git a/tests/test_data/test_image.py b/tests/test_data/test_image.py index e102b634..98b89a2c 100644 --- a/tests/test_data/test_image.py +++ b/tests/test_data/test_image.py @@ -1,56 +1,51 @@ -import os - import numpy as np import pytest from PIL import Image # pylint: disable=unused-argument from dvclive import Live -from tests.test_main import _parse_json +from dvclive.data import Image as LiveImage def test_PIL(tmp_dir): dvclive = Live() img = Image.new("RGB", (500, 500), (250, 250, 250)) - dvclive.log("image.png", img) - - assert (tmp_dir / dvclive.dir / "image.png").exists() - summary = _parse_json("dvclive.json") + dvclive.log_image("image.png", img) - assert summary["image.png"] == os.path.join(dvclive.dir, "image.png") + assert (tmp_dir / dvclive.dir / LiveImage.subfolder / "image.png").exists() def test_invalid_extension(tmp_dir): dvclive = Live() img = Image.new("RGB", (500, 500), (250, 250, 250)) with pytest.raises(ValueError): - dvclive.log("image.foo", img) + dvclive.log_image("image.foo", img) @pytest.mark.parametrize("shape", [(500, 500), (500, 500, 3), (500, 500, 4)]) def test_numpy(tmp_dir, shape): dvclive = Live() img = np.ones(shape, np.uint8) * 255 - dvclive.log("image.png", img) + dvclive.log_image("image.png", img) - assert (tmp_dir / dvclive.dir / "image.png").exists() + assert (tmp_dir / dvclive.dir / LiveImage.subfolder / "image.png").exists() def test_step_formatting(tmp_dir): dvclive = Live() img = np.ones((500, 500, 3), np.uint8) for _ in range(3): - dvclive.log("image.png", img) + dvclive.log_image("image.png", img) dvclive.next_step() for step in range(3): - assert (tmp_dir / dvclive.dir / str(step) / "image.png").exists() - - summary = _parse_json("dvclive.json") - - assert summary["image.png"] == os.path.join( - dvclive.dir, str(step), "image.png" - ) + assert ( + tmp_dir + / dvclive.dir + / LiveImage.subfolder + / str(step) + / "image.png" + ).exists() def test_step_rename(tmp_dir, mocker): @@ -59,13 +54,30 @@ def test_step_rename(tmp_dir, mocker): rename = mocker.spy(Path, "rename") dvclive = Live() img = np.ones((500, 500, 3), np.uint8) - dvclive.log("image.png", img) - assert (tmp_dir / dvclive.dir / "image.png").exists() + dvclive.log_image("image.png", img) + assert (tmp_dir / dvclive.dir / LiveImage.subfolder / "image.png").exists() dvclive.next_step() - assert not (tmp_dir / dvclive.dir / "image.png").exists() - assert (tmp_dir / dvclive.dir / "0" / "image.png").exists() + assert not ( + tmp_dir / dvclive.dir / LiveImage.subfolder / "image.png" + ).exists() + assert ( + tmp_dir / dvclive.dir / LiveImage.subfolder / "0" / "image.png" + ).exists() rename.assert_called_once_with( - Path(dvclive.dir) / "image.png", Path(dvclive.dir) / "0" / "image.png" + Path(dvclive.dir) / LiveImage.subfolder / "image.png", + Path(dvclive.dir) / LiveImage.subfolder / "0" / "image.png", ) + + +def test_cleanup(tmp_dir): + dvclive = Live() + img = np.ones((500, 500, 3), np.uint8) + dvclive.log_image("image.png", img) + + assert (tmp_dir / dvclive.dir / LiveImage.subfolder / "image.png").exists() + + Live() + + assert not (tmp_dir / dvclive.dir / LiveImage.subfolder).exists() diff --git a/tests/test_data/test_plot.py b/tests/test_data/test_plot.py new file mode 100644 index 00000000..c2401b32 --- /dev/null +++ b/tests/test_data/test_plot.py @@ -0,0 +1,139 @@ +import json + +import pytest +from sklearn import calibration, metrics + +from dvclive import Live +from dvclive.data.plot import Plot + +# pylint: disable=redefined-outer-name, unused-argument + + +@pytest.fixture +def y_true_y_pred_y_score(): + from sklearn.datasets import make_classification + from sklearn.ensemble import RandomForestClassifier + from sklearn.model_selection import train_test_split + + X, y = make_classification(random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf = RandomForestClassifier(random_state=0) + clf.fit(X_train, y_train) + + y_pred = clf.predict(X_test) + y_score = clf.predict_proba(X_test)[:, 1] + + return y_test, y_pred, y_score + + +def test_log_calibration_curve(tmp_dir, y_true_y_pred_y_score, mocker): + live = Live() + out = tmp_dir / live.dir / Plot.subfolder + + y_true, _, y_score = y_true_y_pred_y_score + + spy = mocker.spy(calibration, "calibration_curve") + + live.log_plot("calibration", y_true, y_score) + + spy.assert_called_once_with(y_true, y_score) + + assert (out / "calibration.json").exists() + + +def test_log_det_curve(tmp_dir, y_true_y_pred_y_score, mocker): + live = Live() + out = tmp_dir / live.dir / Plot.subfolder + + y_true, _, y_score = y_true_y_pred_y_score + + spy = mocker.spy(metrics, "det_curve") + + live.log_plot("det", y_true, y_score) + + spy.assert_called_once_with(y_true, y_score) + assert (out / "det.json").exists() + + +def test_log_roc_curve(tmp_dir, y_true_y_pred_y_score, mocker): + live = Live() + out = tmp_dir / live.dir / Plot.subfolder + + y_true, _, y_score = y_true_y_pred_y_score + + spy = mocker.spy(metrics, "roc_curve") + + live.log_plot("roc", y_true, y_score) + + spy.assert_called_once_with(y_true, y_score) + assert (out / "roc.json").exists() + + +def test_log_prc_curve(tmp_dir, y_true_y_pred_y_score, mocker): + live = Live() + out = tmp_dir / live.dir / Plot.subfolder + + y_true, _, y_score = y_true_y_pred_y_score + + spy = mocker.spy(metrics, "precision_recall_curve") + + live.log_plot("precision_recall", y_true, y_score) + + spy.assert_called_once_with(y_true, y_score) + assert (out / "precision_recall.json").exists() + + +def test_log_confusion_matrix(tmp_dir, y_true_y_pred_y_score, mocker): + live = Live() + out = tmp_dir / live.dir / Plot.subfolder + + y_true, y_pred, _ = y_true_y_pred_y_score + + live.log_plot("confusion_matrix", y_true, y_pred) + + cm = json.loads((out / "confusion_matrix.json").read_text()) + + assert isinstance(cm, list) + assert isinstance(cm[0], dict) + assert cm[0]["actual"] == str(y_true[0]) + assert cm[0]["predicted"] == str(y_pred[0]) + + +def test_step_exception(tmp_dir, y_true_y_pred_y_score): + live = Live() + out = tmp_dir / live.dir / Plot.subfolder + + y_true, y_pred, _ = y_true_y_pred_y_score + + live.log_plot("confusion_matrix", y_true, y_pred) + assert (out / "confusion_matrix.json").exists() + + with pytest.raises(NotImplementedError): + live.next_step() + + +def test_dump_kwargs(tmp_dir, y_true_y_pred_y_score, mocker): + live = Live() + + y_true, _, y_score = y_true_y_pred_y_score + + spy = mocker.spy(metrics, "roc_curve") + + live.log_plot("roc", y_true, y_score, drop_intermediate=True) + + spy.assert_called_once_with(y_true, y_score, drop_intermediate=True) + + +def test_cleanup(tmp_dir, y_true_y_pred_y_score): + live = Live() + out = tmp_dir / live.dir / Plot.subfolder + + y_true, y_pred, _ = y_true_y_pred_y_score + + live.log_plot("confusion_matrix", y_true, y_pred) + + assert (out / "confusion_matrix.json").exists() + + Live() + + assert not (tmp_dir / live.dir / Plot.subfolder).exists() diff --git a/tests/test_fastai.py b/tests/test_fastai.py index f43e3f14..1a466884 100644 --- a/tests/test_fastai.py +++ b/tests/test_fastai.py @@ -12,6 +12,7 @@ untar_data, ) +from dvclive.data.scalar import Scalar from dvclive.fastai import DvcLiveCallback # pylint: disable=redefined-outer-name, unused-argument @@ -46,12 +47,12 @@ def test_fastai_callback(tmp_dir, data_loader): assert os.path.exists("dvclive") - train_path = tmp_dir / "dvclive/train" - valid_path = tmp_dir / "dvclive/valid" + train_path = tmp_dir / "dvclive" / Scalar.subfolder / "train" + valid_path = tmp_dir / "dvclive" / Scalar.subfolder / "valid" assert train_path.is_dir() assert valid_path.is_dir() - assert (tmp_dir / "dvclive/accuracy.tsv").exists() + assert (tmp_dir / "dvclive" / Scalar.subfolder / "accuracy.tsv").exists() def test_fastai_model_file(tmp_dir, data_loader): diff --git a/tests/test_huggingface.py b/tests/test_huggingface.py index f5c8b620..f36ca74b 100644 --- a/tests/test_huggingface.py +++ b/tests/test_huggingface.py @@ -10,6 +10,7 @@ TrainingArguments, ) +from dvclive.data.scalar import Scalar from dvclive.huggingface import DvcLiveCallback from tests.test_main import read_logs @@ -78,7 +79,7 @@ def test_huggingface_integration(tmp_dir, model, args, data, tokenizer): assert os.path.exists("dvclive") - logs, _ = read_logs("dvclive") + logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder) assert len(logs) == 10 assert "eval_matthews_correlation" in logs diff --git a/tests/test_keras.py b/tests/test_keras.py index 8436651d..4943691e 100644 --- a/tests/test_keras.py +++ b/tests/test_keras.py @@ -2,6 +2,7 @@ import pytest +from dvclive.data.scalar import Scalar from dvclive.keras import DvcLiveCallback from tests.test_main import read_logs @@ -45,7 +46,7 @@ def test_keras_callback(tmp_dir, xor_model, capture_wrap): ) assert os.path.exists("dvclive") - logs, _ = read_logs("dvclive") + logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder) assert "accuracy" in logs diff --git a/tests/test_lgbm.py b/tests/test_lgbm.py index 812dce46..5a431d32 100644 --- a/tests/test_lgbm.py +++ b/tests/test_lgbm.py @@ -8,6 +8,7 @@ from sklearn import datasets from sklearn.model_selection import train_test_split +from dvclive.data.scalar import Scalar from dvclive.lgbm import DvcLiveCallback from tests.test_main import read_logs @@ -44,7 +45,7 @@ def test_lgbm_integration(tmp_dir, model_params, iris_data): assert os.path.exists("dvclive") - logs, _ = read_logs("dvclive") + logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder) assert len(logs) == 1 assert len(first(logs.values())) == 5 diff --git a/tests/test_lightning.py b/tests/test_lightning.py index 1fa73106..9a42d3a4 100644 --- a/tests/test_lightning.py +++ b/tests/test_lightning.py @@ -9,6 +9,7 @@ from torchvision import transforms from torchvision.datasets import MNIST +from dvclive.data.scalar import Scalar from dvclive.lightning import DvcLiveLogger from tests.test_main import read_logs @@ -84,7 +85,6 @@ def test_lightning_integration(tmp_dir): model = LitMNIST() # init logger dvclive_logger = DvcLiveLogger("test_run", path="logs") - print(dvclive_logger.version) trainer = Trainer( logger=dvclive_logger, max_epochs=1, checkpoint_callback=False ) @@ -93,7 +93,7 @@ def test_lightning_integration(tmp_dir): assert os.path.exists("logs") assert not os.path.exists("DvcLiveLogger") - logs, _ = read_logs("logs") + logs, _ = read_logs(tmp_dir / "logs" / Scalar.subfolder) assert len(logs) == 3 assert "train_loss_step" in logs diff --git a/tests/test_main.py b/tests/test_main.py index 2e39c9dc..455b0147 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,6 +7,7 @@ from funcy import last from dvclive import Live, env +from dvclive.data import Scalar # pylint: disable=unused-argument from dvclive.dvc import SIGNAL_FILE @@ -25,7 +26,7 @@ def read_logs(path: str): metric_name = str(metric_file).replace(str(path) + os.path.sep, "") metric_name = metric_name.replace(".tsv", "") history[metric_name] = _parse_tsv(metric_file) - latest = _parse_json(str(path) + ".json") + latest = _parse_json(str(path.parent) + ".json") return history, latest @@ -76,8 +77,8 @@ def test_logging_step(tmp_dir, path): dvclive = Live(path) dvclive.log("m1", 1) dvclive.next_step() - assert (tmp_dir / path).is_dir() - assert (tmp_dir / path / "m1.tsv").is_file() + assert (tmp_dir / dvclive.dir).is_dir() + assert (tmp_dir / dvclive.dir / Scalar.subfolder / "m1.tsv").is_file() assert (tmp_dir / dvclive.summary_path).is_file() s = _parse_json(dvclive.summary_path) @@ -88,16 +89,18 @@ def test_logging_step(tmp_dir, path): def test_nested_logging(tmp_dir): dvclive = Live("logs", summary=True) + out = tmp_dir / dvclive.dir / Scalar.subfolder + dvclive.log("train/m1", 1) dvclive.log("val/val_1/m1", 1) dvclive.log("val/val_1/m2", 1) dvclive.next_step() - assert (tmp_dir / "logs" / "val" / "val_1").is_dir() - assert (tmp_dir / "logs" / "train" / "m1.tsv").is_file() - assert (tmp_dir / "logs" / "val" / "val_1" / "m1.tsv").is_file() - assert (tmp_dir / "logs" / "val" / "val_1" / "m2.tsv").is_file() + assert (out / "val" / "val_1").is_dir() + assert (out / "train" / "m1.tsv").is_file() + assert (out / "val" / "val_1" / "m1.tsv").is_file() + assert (out / "val" / "val_1" / "m2.tsv").is_file() summary = _parse_json(dvclive.summary_path) @@ -146,14 +149,14 @@ def test_cleanup(tmp_dir, summary, html): (tmp_dir / "logs" / "some_user_file.txt").touch() - assert (tmp_dir / "logs" / "m1.tsv").is_file() + assert (tmp_dir / dvclive.dir / Scalar.subfolder / "m1.tsv").is_file() assert (tmp_dir / dvclive.summary_path).is_file() == summary assert html_path.is_file() == html dvclive = Live("logs", summary=summary) assert (tmp_dir / "logs" / "some_user_file.txt").is_file() - assert not (tmp_dir / "logs" / "m1.tsv").is_file() + assert not (tmp_dir / dvclive.dir / Scalar.subfolder).exists() assert (tmp_dir / dvclive.summary_path).is_file() == summary assert not (html_path).is_file() @@ -165,12 +168,14 @@ def test_cleanup(tmp_dir, summary, html): def test_continue(tmp_dir, resume, steps, metrics): dvclive = Live("logs") + out = tmp_dir / dvclive.dir / Scalar.subfolder + for metric in [0.9, 0.8]: dvclive.log("metric", metric) dvclive.next_step() - assert read_history("logs", "metric") == ([0, 1], [0.9, 0.8]) - assert read_latest("logs", "metric") == (1, 0.8) + assert read_history(out, "metric") == ([0, 1], [0.9, 0.8]) + assert read_latest(out, "metric") == (1, 0.8) dvclive = Live("logs", resume=resume) @@ -178,8 +183,8 @@ def test_continue(tmp_dir, resume, steps, metrics): dvclive.log("metric", new_metric) dvclive.next_step() - assert read_history("logs", "metric") == (steps, metrics) - assert read_latest("logs", "metric") == (last(steps), last(metrics)) + assert read_history(out, "metric") == (steps, metrics) + assert read_latest(out, "metric") == (last(steps), last(metrics)) def test_resume_on_first_init(tmp_dir): @@ -201,9 +206,11 @@ def test_require_step_update(tmp_dir, metric): dvclive.log(metric, 2.0) -def test_custom_steps(tmp_dir, mocker): +def test_custom_steps(tmp_dir): dvclive = Live("logs") + out = tmp_dir / dvclive.dir / Scalar.subfolder + steps = [0, 62, 1000] metrics = [0.9, 0.8, 0.7] @@ -211,12 +218,14 @@ def test_custom_steps(tmp_dir, mocker): dvclive.set_step(step) dvclive.log("m", metric) - assert read_history("logs", "m") == (steps, metrics) - assert read_latest("logs", "m") == (last(steps), last(metrics)) + assert read_history(out, "m") == (steps, metrics) + assert read_latest(out, "m") == (last(steps), last(metrics)) def test_log_reset_with_set_step(tmp_dir): dvclive = Live() + out = tmp_dir / dvclive.dir / Scalar.subfolder + for i in range(3): dvclive.set_step(i) dvclive.log("train_m", 1) @@ -225,10 +234,10 @@ def test_log_reset_with_set_step(tmp_dir): dvclive.set_step(i) dvclive.log("val_m", 1) - assert read_history("dvclive", "train_m") == ([0, 1, 2], [1, 1, 1]) - assert read_history("dvclive", "val_m") == ([0, 1, 2], [1, 1, 1]) - assert read_latest("dvclive", "train_m") == (2, 1) - assert read_latest("dvclive", "val_m") == (2, 1) + assert read_history(out, "train_m") == ([0, 1, 2], [1, 1, 1]) + assert read_history(out, "val_m") == ([0, 1, 2], [1, 1, 1]) + assert read_latest(out, "train_m") == (2, 1) + assert read_latest(out, "val_m") == (2, 1) @pytest.mark.parametrize("html", [True, False]) @@ -293,10 +302,12 @@ def test_get_step_custom_steps(tmp_dir): def test_get_step_control_flow(tmp_dir): dvclive = Live() + out = tmp_dir / dvclive.dir / Scalar.subfolder + while dvclive.get_step() < 10: dvclive.log("i", dvclive.get_step()) dvclive.next_step() - steps, values = read_history("dvclive", "i") + steps, values = read_history(out, "i") assert steps == list(range(10)) assert values == [float(x) for x in range(10)] diff --git a/tests/test_mmcv.py b/tests/test_mmcv.py index aa9d3628..2ca84bbf 100644 --- a/tests/test_mmcv.py +++ b/tests/test_mmcv.py @@ -4,6 +4,7 @@ from mmcv.runner import build_runner import dvclive +from dvclive.data.scalar import Scalar from tests.test_main import read_logs # pylint: disable=unused-argument @@ -31,7 +32,7 @@ def test_mmcv_hook(tmp_dir, mocker): assert set_step.call_count == 6 assert log.call_count == 12 - logs, _ = read_logs("dvclive") + logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder) assert "learning_rate" in logs assert "momentum" in logs diff --git a/tests/test_xgboost.py b/tests/test_xgboost.py index bf711534..061fefef 100644 --- a/tests/test_xgboost.py +++ b/tests/test_xgboost.py @@ -7,6 +7,7 @@ from funcy import first from sklearn import datasets +from dvclive.data.scalar import Scalar from dvclive.xgb import DvcLiveCallback from tests.test_main import read_logs @@ -37,7 +38,7 @@ def test_xgb_integration(tmp_dir, train_params, iris_data): assert os.path.exists("dvclive") - logs, _ = read_logs("dvclive") + logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder) assert len(logs) == 1 assert len(first(logs.values())) == 5