Skip to content

Commit

Permalink
add log_param/log_params
Browse files Browse the repository at this point in the history
  • Loading branch information
dtrifiro authored and daavoo committed Sep 23, 2022
1 parent fc76857 commit 27a6c0c
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 6 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ package_dir=
packages = find:
install_requires=
dvc_render[table]>=0.0.8
ruamel.yaml>=0.17.11

[options.extras_require]
tests =
Expand Down
26 changes: 25 additions & 1 deletion src/dvclive/error.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional

if TYPE_CHECKING:
from .live import Live
Expand Down Expand Up @@ -44,3 +44,27 @@ def __init__(self, name, step):
super().__init__(
f"Data '{name}' has already been logged with step '{step}'"
)


class ParameterAlreadyLoggedError(DvcLiveError):
def __init__(
self, name: str, val: Any, previous_val: Optional[Any] = None
):
self.name = name
self.val = val
self.previous_val = previous_val
super().__init__(
f"Parameter '{name}={val}' has already been logged"
+ (
f" (previous value={self.previous_val})."
if self.previous_val is not None
and self.val != self.previous_val
else "."
)
)


class InvalidParameterTypeError(DvcLiveError):
def __init__(self, val: Any):
self.val = val
super().__init__(f"Parameter type {type(val)} is not supported.")
61 changes: 58 additions & 3 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,33 @@
from pathlib import Path
from typing import Any, Dict, Optional, Union

from ruamel.yaml.representer import RepresenterError

from . import env
from .data import DATA_TYPES, PLOTS, Image, NumpyEncoder, Scalar
from .dvc import make_checkpoint
from .error import (
ConfigMismatchError,
InvalidDataTypeError,
InvalidParameterTypeError,
InvalidPlotTypeError,
ParameterAlreadyLoggedError,
)
from .report import make_report
from .serialize import dump_yaml, load_yaml
from .utils import env2bool, nested_update, open_file_in_browser

logging.basicConfig()
logger = logging.getLogger("dvclive")
logger.setLevel(os.getenv(env.DVCLIVE_LOGLEVEL, "INFO").upper())


# Recursive type aliases are not yet supported by mypy (as of 0.971),
# so we set type: ignore for ParamLike.
# See https://github.com/python/mypy/issues/731#issuecomment-1213482527
ParamLike = Union[int, float, str, bool, Dict[str, "ParamLike"]] # type: ignore # noqa


class Live:
DEFAULT_DIR = "dvclive"

Expand Down Expand Up @@ -54,6 +65,8 @@ def __init__(
if self._path is None:
self._path = self.DEFAULT_DIR

self._params_path = os.path.join(self._path, "params.yaml")

if self._report is not None:
if not self.report_path:
self.report_path = os.path.join(self.dir, f"report.{report}")
Expand All @@ -64,23 +77,25 @@ def __init__(
self._scalars: Dict[str, Any] = OrderedDict()
self._images: Dict[str, Any] = OrderedDict()
self._plots: Dict[str, Any] = OrderedDict()
self._params: Dict[str, Any] = OrderedDict()

self._init_paths()
if self._resume:
self._read_params()
self._step = self.read_step()
if self._step != 0:
self._step += 1
logger.info(f"Resumed from step {self._step}")
else:
self._cleanup()
self._init_paths()

def _cleanup(self):
for data_type in DATA_TYPES:
shutil.rmtree(
Path(self.dir) / data_type.subfolder, ignore_errors=True
)

for f in {self.summary_path, self.report_path}:
for f in (self.summary_path, self.report_path, self.params_path):
if os.path.exists(f):
os.remove(f)

Expand Down Expand Up @@ -116,6 +131,10 @@ def init_from_env(self) -> None:
def dir(self):
return self._path

@property
def params_path(self):
return self._params_path

@property
def exists(self):
return os.path.isdir(self.dir)
Expand All @@ -130,7 +149,6 @@ def get_step(self) -> int:
def set_step(self, step: int) -> None:
if self._step is None:
self._step = 0
self._init_paths()
for data in chain(
self._scalars.values(),
self._images.values(),
Expand Down Expand Up @@ -191,6 +209,43 @@ def log_plot(self, name, labels, predictions, **kwargs):
data.dump(val, self._step, **kwargs)
logger.debug(f"Logged {name}")

def _read_params(self):
if os.path.isfile(self.params_path):
params = load_yaml(self.params_path)
self._params.update(params)

def _dump_params(self):
try:
dump_yaml(self.params_path, self._params)
except RepresenterError as exc:
raise InvalidParameterTypeError(exc.args) from exc

def log_params(self, params: Dict[str, ParamLike]):
"""Saves the given set of parameters (dict) to yaml"""
if self._resume and self.get_step():
logger.info(
"Resuming previous dvclive session, not logging params."
)
return

for param_name, param_value in params.items():
if param_name in self._params:
raise ParameterAlreadyLoggedError(
param_name, param_value, self._params[param_name]
)

self._params.update(params)
self._dump_params()
logger.debug(f"Logged {params} parameters to {self.params_path}")

def log_param(
self,
name: str,
val: ParamLike,
):
"""Saves the given parameter value to yaml"""
self.log_params({name: val})

def make_summary(self):
summary_data = {}
if self._step is not None:
Expand Down
42 changes: 42 additions & 0 deletions src/dvclive/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from collections import OrderedDict

from dvclive.error import DvcLiveError


class YAMLError(DvcLiveError):
pass


class YAMLFileCorruptedError(YAMLError):
def __init__(self, path):
super().__init__(path, "YAML file structure is corrupted")


def load_yaml(path, typ="safe"):
from ruamel.yaml import YAML
from ruamel.yaml import YAMLError as _YAMLError

yaml = YAML(typ=typ)
with open(path, encoding="utf-8") as fd:
try:
return yaml.load(fd.read())
except _YAMLError:
raise YAMLFileCorruptedError(path)


def _get_yaml():
from ruamel.yaml import YAML

yaml = YAML()
yaml.default_flow_style = False

# tell Dumper to represent OrderedDict as normal dict
yaml_repr_cls = yaml.Representer
yaml_repr_cls.add_representer(OrderedDict, yaml_repr_cls.represent_dict)
return yaml


def dump_yaml(path, data):
yaml = _get_yaml()
with open(path, "w", encoding="utf-8") as fd:
yaml.dump(data, fd)
Loading

0 comments on commit 27a6c0c

Please sign in to comment.