diff --git a/pyproject.toml b/pyproject.toml index 9617acec..93f7a557 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ addopts = "-ra" [tool.coverage.run] branch = true -source = ["dvclive", "tests"] +source = ["src/dvclive", "tests"] [tool.coverage.paths] source = ["src", "*/site-packages"] diff --git a/setup.cfg b/setup.cfg index 1093ff5b..bfc01418 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,6 +29,7 @@ package_dir= packages = find: install_requires= dvc_render[table]>=0.0.8 + ruamel.yaml>=0.17.11 [options.extras_require] tests = diff --git a/src/dvclive/error.py b/src/dvclive/error.py index 0e4c3452..4429ecbd 100644 --- a/src/dvclive/error.py +++ b/src/dvclive/error.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from .live import Live @@ -44,3 +44,27 @@ def __init__(self, name, step): super().__init__( f"Data '{name}' has already been logged with step '{step}'" ) + + +class ParameterAlreadyLoggedError(DvcLiveError): + def __init__( + self, name: str, val: Any, previous_val: Optional[Any] = None + ): + self.name = name + self.val = val + self.previous_val = previous_val + super().__init__( + f"Parameter '{name}={val}' has already been logged" + + ( + f" (previous value={self.previous_val})." + if self.previous_val is not None + and self.val != self.previous_val + else "." + ) + ) + + +class InvalidParameterTypeError(DvcLiveError): + def __init__(self, val: Any): + self.val = val + super().__init__(f"Parameter type {type(val)} is not supported.") diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 4aae343b..f9a74036 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -7,15 +7,20 @@ from pathlib import Path from typing import Any, Dict, Optional, Union +from ruamel.yaml.representer import RepresenterError + from . import env from .data import DATA_TYPES, PLOTS, Image, NumpyEncoder, Scalar from .dvc import make_checkpoint from .error import ( ConfigMismatchError, InvalidDataTypeError, + InvalidParameterTypeError, InvalidPlotTypeError, + ParameterAlreadyLoggedError, ) from .report import make_report +from .serialize import dump_yaml, load_yaml from .utils import env2bool, nested_update, open_file_in_browser logging.basicConfig() @@ -23,6 +28,12 @@ logger.setLevel(os.getenv(env.DVCLIVE_LOGLEVEL, "INFO").upper()) +# Recursive type aliases are not yet supported by mypy (as of 0.971), +# so we set type: ignore for ParamLike. +# See https://github.com/python/mypy/issues/731#issuecomment-1213482527 +ParamLike = Union[int, float, str, bool, Dict[str, "ParamLike"]] # type: ignore # noqa + + class Live: DEFAULT_DIR = "dvclive" @@ -54,6 +65,8 @@ def __init__( if self._path is None: self._path = self.DEFAULT_DIR + self._params_path = os.path.join(self._path, "params.yaml") + if self._report is not None: if not self.report_path: self.report_path = os.path.join(self.dir, f"report.{report}") @@ -64,15 +77,17 @@ def __init__( self._scalars: Dict[str, Any] = OrderedDict() self._images: Dict[str, Any] = OrderedDict() self._plots: Dict[str, Any] = OrderedDict() + self._params: Dict[str, Any] = OrderedDict() + self._init_paths() if self._resume: + self._read_params() self._step = self.read_step() if self._step != 0: self._step += 1 logger.info(f"Resumed from step {self._step}") else: self._cleanup() - self._init_paths() def _cleanup(self): for data_type in DATA_TYPES: @@ -80,7 +95,7 @@ def _cleanup(self): Path(self.dir) / data_type.subfolder, ignore_errors=True ) - for f in {self.summary_path, self.report_path}: + for f in (self.summary_path, self.report_path, self.params_path): if os.path.exists(f): os.remove(f) @@ -116,6 +131,10 @@ def init_from_env(self) -> None: def dir(self): return self._path + @property + def params_path(self): + return self._params_path + @property def exists(self): return os.path.isdir(self.dir) @@ -130,7 +149,6 @@ def get_step(self) -> int: def set_step(self, step: int) -> None: if self._step is None: self._step = 0 - self._init_paths() for data in chain( self._scalars.values(), self._images.values(), @@ -191,6 +209,43 @@ def log_plot(self, name, labels, predictions, **kwargs): data.dump(val, self._step, **kwargs) logger.debug(f"Logged {name}") + def _read_params(self): + if os.path.isfile(self.params_path): + params = load_yaml(self.params_path) + self._params.update(params) + + def _dump_params(self): + try: + dump_yaml(self.params_path, self._params) + except RepresenterError as exc: + raise InvalidParameterTypeError(exc.args) from exc + + def log_params(self, params: Dict[str, ParamLike]): + """Saves the given set of parameters (dict) to yaml""" + if self._resume and self.get_step(): + logger.info( + "Resuming previous dvclive session, not logging params." + ) + return + + for param_name, param_value in params.items(): + if param_name in self._params: + raise ParameterAlreadyLoggedError( + param_name, param_value, self._params[param_name] + ) + + self._params.update(params) + self._dump_params() + logger.debug(f"Logged {params} parameters to {self.params_path}") + + def log_param( + self, + name: str, + val: ParamLike, + ): + """Saves the given parameter value to yaml""" + self.log_params({name: val}) + def make_summary(self): summary_data = {} if self._step is not None: diff --git a/src/dvclive/serialize.py b/src/dvclive/serialize.py new file mode 100644 index 00000000..3a286e5a --- /dev/null +++ b/src/dvclive/serialize.py @@ -0,0 +1,42 @@ +from collections import OrderedDict + +from dvclive.error import DvcLiveError + + +class YAMLError(DvcLiveError): + pass + + +class YAMLFileCorruptedError(YAMLError): + def __init__(self, path): + super().__init__(path, "YAML file structure is corrupted") + + +def load_yaml(path, typ="safe"): + from ruamel.yaml import YAML + from ruamel.yaml import YAMLError as _YAMLError + + yaml = YAML(typ=typ) + with open(path, encoding="utf-8") as fd: + try: + return yaml.load(fd.read()) + except _YAMLError: + raise YAMLFileCorruptedError(path) + + +def _get_yaml(): + from ruamel.yaml import YAML + + yaml = YAML() + yaml.default_flow_style = False + + # tell Dumper to represent OrderedDict as normal dict + yaml_repr_cls = yaml.Representer + yaml_repr_cls.add_representer(OrderedDict, yaml_repr_cls.represent_dict) + return yaml + + +def dump_yaml(path, data): + yaml = _get_yaml() + with open(path, "w", encoding="utf-8") as fd: + yaml.dump(data, fd) diff --git a/tests/test_main.py b/tests/test_main.py index d0c8201e..d57f0363 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,7 @@ # pylint: disable=protected-access import json import os +import re from pathlib import Path import pytest @@ -14,7 +15,10 @@ ConfigMismatchError, DataAlreadyLoggedError, InvalidDataTypeError, + InvalidParameterTypeError, + ParameterAlreadyLoggedError, ) +from dvclive.serialize import load_yaml from dvclive.utils import parse_tsv @@ -63,6 +67,149 @@ def test_logging_no_step(tmp_dir): assert "step" not in s +@pytest.mark.parametrize( + "param_name,param_value", + [ + ("param_string", "value"), + ("param_int", 42), + ("param_float", 42.0), + ("param_bool_true", True), + ("param_bool_false", False), + ("param_list", [1, 2, 3]), + ( + "param_dict_simple", + {"str": "value", "int": 42, "bool": True, "list": [1, 2, 3]}, + ), + ( + "param_dict_nested", + { + "str": "value", + "int": 42, + "bool": True, + "list": [1, 2, 3], + "dict": {"nested-str": "value", "nested-int": 42}, + }, + ), + ], +) +def test_log_param(tmp_dir, param_name, param_value): + dvclive = Live() + + dvclive.log_param(param_name, param_value) + + s = load_yaml(dvclive.params_path) + assert s[param_name] == param_value + + +def test_log_param_already_logged(tmp_dir): + dvclive = Live() + + dvclive.log_param("param", 42) + with pytest.raises( + ParameterAlreadyLoggedError, + match="Parameter 'param=42' has already been logged.", + ): + dvclive.log_param("param", 42) + + with pytest.raises( + ParameterAlreadyLoggedError, + match=re.escape( + "Parameter 'param=1' has already been logged (previous value=42)." + ), + ): + dvclive.log_param("param", 1) + + +def test_log_params(tmp_dir): + dvclive = Live() + params = { + "param_string": "value", + "param_int": 42, + "param_float": 42.0, + "param_bool_true": True, + "param_bool_false": False, + } + + dvclive.log_params(params) + + s = load_yaml(dvclive.params_path) + assert s == params + + +@pytest.mark.parametrize("resume", (False, True)) +def test_log_params_resume(tmp_dir, resume): + dvclive = Live(resume=resume) + dvclive.log_param("param", 42) + + dvclive = Live(resume=resume) + assert ("param" in dvclive._params) == resume + + if resume: + with pytest.raises(ParameterAlreadyLoggedError): + dvclive.log_param("param", 42) + + +@pytest.mark.parametrize("resume", (False, True)) +def test_log_params_resume_log(tmp_dir, resume, caplog, mocker): + dvclive = Live(resume=resume) + dvclive.log_param("param", 42) + + dvclive = Live(resume=resume) + assert ("param" in dvclive._params) == resume + dvclive.next_step() + + if resume: + spy = mocker.spy(dvclive, "_dump_params") + dvclive.log_param("param", 42) + assert ( + caplog.messages[-1] + == "Resuming previous dvclive session, not logging params." + ) + assert not spy.called + + +def test_log_params_already_logged(tmp_dir): + dvclive = Live() + params = { + "param_string": "string_value", + "param_int": 42, + "param_float": 42.0, + "param_bool_true": True, + "param_bool_false": False, + } + + dvclive.log_params(params) + assert os.path.isfile(dvclive.params_path) + + with pytest.raises( + ParameterAlreadyLoggedError, + match="Parameter 'param_string=string_value' has already been logged.", + ): + + dvclive.log_param("param_string", "string_value") + + with pytest.raises( + ParameterAlreadyLoggedError, + match=re.escape( + "Parameter 'param_string=other_value' has already been logged (previous value=string_value)." # noqa + ), + ): + + dvclive.log_param("param_string", "other_value") + + +def test_log_param_custom_obj(tmp_dir): + dvclive = Live("logs") + + class Dummy: + val = 42 + + param_value = Dummy() + + with pytest.raises(InvalidParameterTypeError): + dvclive.log_param("param_complex", param_value) + + @pytest.mark.parametrize("path", ["logs", os.path.join("subdir", "logs")]) def test_logging_step(tmp_dir, path): dvclive = Live(path) @@ -72,7 +219,7 @@ def test_logging_step(tmp_dir, path): assert (tmp_dir / dvclive.dir / Scalar.subfolder / "m1.tsv").is_file() assert (tmp_dir / dvclive.summary_path).is_file() - s = _parse_json(dvclive.summary_path) + s = load_yaml(dvclive.summary_path) assert s["m1"] == 1 assert s["step"] == 0 @@ -93,7 +240,7 @@ def test_nested_logging(tmp_dir): assert (out / "val" / "val_1" / "m1.tsv").is_file() assert (out / "val" / "val_1" / "m2.tsv").is_file() - summary = _parse_json(dvclive.summary_path) + summary = load_yaml(dvclive.summary_path) assert summary["train"]["m1"] == 1 assert summary["val"]["val_1"]["m1"] == 1 @@ -127,6 +274,16 @@ def test_cleanup(tmp_dir, html): assert not (html_path).is_file() +def test_cleanup_params(tmp_dir): + dvclive = Live("logs") + dvclive.log_param("param", 42) + + assert os.path.isfile(dvclive.params_path) + + dvclive = Live("logs") + assert not os.path.exists(dvclive.params_path) + + @pytest.mark.parametrize( "resume, steps, metrics", [(True, [0, 1, 2, 3], [0.9, 0.8, 0.7, 0.6]), (False, [0, 1], [0.7, 0.6])],