Skip to content

Commit

Permalink
Introduce public summary. Remove "no step" / "step" logic from plot…
Browse files Browse the repository at this point in the history
…s. (#331)

It was initially introduced for supporting different logging format between step and not step updates.

For `live.log_image`, "step" mode now overwrites the path instead of creating subfolder by step.

For `live.log`, the "no step" was meant to not generate the `.tsv` file but only the `.json`.
Added a public property `summary` so "no step" scenarios can work as follows:

```
live = Live()

live.summary["foo"] = 1
live.make_summary()
```

Closes #326

Apply suggestions from code review

Co-authored-by: Paweł Redzyński <[email protected]>

Co-authored-by: Paweł Redzyński <[email protected]>
  • Loading branch information
daavoo and pared committed Oct 31, 2022
1 parent 05c910c commit 9157570
Show file tree
Hide file tree
Showing 25 changed files with 255 additions and 378 deletions.
76 changes: 0 additions & 76 deletions src/dvclive/data/base.py

This file was deleted.

50 changes: 0 additions & 50 deletions src/dvclive/data/image.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/dvclive/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def __init__(self, name, val):

class InvalidPlotTypeError(DvcLiveError):
def __init__(self, name):
from .data import PLOTS
from .plots import SKLEARN_PLOTS

self.name = name
super().__init__(
f"Plot type '{name}' is not supported."
f"\nSupported types are: {list(PLOTS)}"
f"\nSupported types are: {list(SKLEARN_PLOTS)}"
)


Expand Down
58 changes: 24 additions & 34 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,21 @@
import logging
import os
import shutil
from collections import OrderedDict
from itertools import chain
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from ruamel.yaml.representer import RepresenterError

from . import env
from .data import DATA_TYPES, PLOTS, Image, Metric, NumpyEncoder
from .dvc import make_checkpoint
from .error import (
InvalidDataTypeError,
InvalidParameterTypeError,
InvalidPlotTypeError,
)
from .plots import PLOT_TYPES, SKLEARN_PLOTS, Image, Metric, NumpyEncoder
from .report import make_report
from .serialize import dump_yaml, load_yaml
from .serialize import dump_json, dump_yaml, load_yaml
from .studio import post_to_studio
from .utils import (
env2bool,
Expand Down Expand Up @@ -66,11 +64,12 @@ def __init__(
self.report_mode: Optional[str] = report
self.report_path = ""

self.summary: Dict[str, Any] = {}
self._step: Optional[int] = None
self._scalars: Dict[str, Any] = OrderedDict()
self._images: Dict[str, Any] = OrderedDict()
self._plots: Dict[str, Any] = OrderedDict()
self._params: Dict[str, Any] = OrderedDict()
self._metrics: Dict[str, Any] = {}
self._images: Dict[str, Any] = {}
self._plots: Dict[str, Any] = {}
self._params: Dict[str, Any] = {}

self._init_paths()

Expand All @@ -90,7 +89,6 @@ def __init__(
self._cleanup()

self._latest_studio_step = self.get_step()

if self.report_mode == "studio":
from scmrepo.git import Git

Expand All @@ -104,9 +102,9 @@ def __init__(
self.report_mode = None

def _cleanup(self):
for data_type in DATA_TYPES:
for plot_type in PLOT_TYPES:
shutil.rmtree(
Path(self.plots_path) / data_type.subfolder, ignore_errors=True
Path(self.plots_path) / plot_type.subfolder, ignore_errors=True
)

for f in (self.metrics_path, self.report_path, self.params_path):
Expand Down Expand Up @@ -138,12 +136,6 @@ def get_step(self) -> int:
def set_step(self, step: int) -> None:
if self._step is None:
self._step = 0
for data in chain(
self._scalars.values(),
self._images.values(),
self._plots.values(),
):
data.dump(data.val, self._step)
self.make_summary()

if self.report_mode == "studio":
Expand All @@ -165,14 +157,16 @@ def log(self, name: str, val: Union[int, float]):
if not Metric.could_log(val):
raise InvalidDataTypeError(name, type(val))

if name in self._scalars:
data = self._scalars[name]
if name in self._metrics:
data = self._metrics[name]
else:
data = Metric(name, self.plots_path)
self._scalars[name] = data
self._metrics[name] = data

data.dump(val, self._step)
data.step = self.get_step()
data.dump(val)

self.summary = nested_update(self.summary, data.to_summary(val))
self.make_summary()
logger.debug(f"Logged {name}: {val}")

Expand All @@ -186,7 +180,8 @@ def log_image(self, name: str, val):
data = Image(name, self.plots_path)
self._images[name] = data

data.dump(val, self._step)
data.step = self.get_step()
data.dump(val)
logger.debug(f"Logged {name}: {val}")

def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs):
Expand All @@ -195,13 +190,14 @@ def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs):
name = name or kind
if name in self._plots:
data = self._plots[name]
elif kind in PLOTS and PLOTS[kind].could_log(val):
data = PLOTS[kind](name, self.plots_path)
elif kind in SKLEARN_PLOTS and SKLEARN_PLOTS[kind].could_log(val):
data = SKLEARN_PLOTS[kind](name, self.plots_path)
self._plots[name] = data
else:
raise InvalidPlotTypeError(name)

data.dump(val, self._step, **kwargs)
data.step = self.get_step()
data.dump(val, **kwargs)
logger.debug(f"Logged {name}")

def _read_params(self):
Expand All @@ -211,7 +207,7 @@ def _read_params(self):

def _dump_params(self):
try:
dump_yaml(self.params_path, self._params)
dump_yaml(self._params, self.params_path)
except RepresenterError as exc:
raise InvalidParameterTypeError(exc.args) from exc

Expand All @@ -230,15 +226,9 @@ def log_param(
self.log_params({name: val})

def make_summary(self):
summary_data = {}
if self._step is not None:
summary_data["step"] = self.get_step()

for data in self._scalars.values():
summary_data = nested_update(summary_data, data.summary)

with open(self.metrics_path, "w", encoding="utf-8") as f:
json.dump(summary_data, f, indent=4, cls=NumpyEncoder)
self.summary["step"] = self.get_step()
dump_json(self.summary, self.metrics_path, cls=NumpyEncoder)

def make_report(self):
if self.report_mode is not None:
Expand Down
12 changes: 3 additions & 9 deletions src/dvclive/data/__init__.py → src/dvclive/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from .image import Image
from .metric import Metric
from .sklearn_plot import (
Calibration,
ConfusionMatrix,
Det,
PrecisionRecall,
Roc,
)
from .sklearn import Calibration, ConfusionMatrix, Det, PrecisionRecall, Roc
from .utils import NumpyEncoder # noqa: F401

PLOTS = {
SKLEARN_PLOTS = {
"calibration": Calibration,
"confusion_matrix": ConfusionMatrix,
"det": Det,
"precision_recall": PrecisionRecall,
"roc": Roc,
}
DATA_TYPES = (*PLOTS.values(), Metric, Image)
PLOT_TYPES = (*SKLEARN_PLOTS.values(), Metric, Image)
41 changes: 41 additions & 0 deletions src/dvclive/plots/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import abc
from pathlib import Path
from typing import Optional

from dvclive.error import DataAlreadyLoggedError


class Data(abc.ABC):
def __init__(self, name: str, output_folder: str) -> None:
self.name = name
self.output_folder: Path = Path(output_folder) / self.subfolder
self._step: Optional[int] = None

@property
def step(self) -> Optional[int]:
return self._step

@step.setter
def step(self, val: int) -> None:
if val == self._step:
raise DataAlreadyLoggedError(self.name, val)
self._step = val

@property
@abc.abstractmethod
def output_path(self) -> Path:
pass

@property
@abc.abstractmethod
def subfolder(self):
pass

@staticmethod
@abc.abstractmethod
def could_log(val) -> bool:
pass

@abc.abstractmethod
def dump(self, val, **kwargs):
pass
31 changes: 31 additions & 0 deletions src/dvclive/plots/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path

from .base import Data


class Image(Data):
suffixes = [".jpg", ".jpeg", ".gif", ".png"]
subfolder = "images"

@property
def output_path(self) -> Path:
_path = self.output_folder / self.name
_path.parent.mkdir(exist_ok=True, parents=True)
return _path

@staticmethod
def could_log(val: object) -> bool:
if val.__class__.__module__ == "PIL.Image":
return True
if val.__class__.__module__ == "numpy":
return True
return False

def dump(self, val, **kwargs) -> None:
if val.__class__.__module__ == "numpy":
from PIL import Image as ImagePIL

pil_image = ImagePIL.fromarray(val)
else:
pil_image = val
pil_image.save(self.output_path)
Loading

0 comments on commit 9157570

Please sign in to comment.