Skip to content

Commit

Permalink
live: Add log_image and log_plot.
Browse files Browse the repository at this point in the history
Decouple data type logging into separated methods.

Use subfolders for each data type.

Raise NotImplementedError in `log_plot` when using steps.
  • Loading branch information
daavoo committed Feb 2, 2022
1 parent 4aa81dd commit c69af09
Show file tree
Hide file tree
Showing 19 changed files with 508 additions and 123 deletions.
10 changes: 9 additions & 1 deletion dvclive/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from .image import Image
from .plot import Calibration, ConfusionMatrix, Det, PrecisionRecall, Roc
from .scalar import Scalar

DATA_TYPES = [Image, Scalar]
PLOTS = {
"calibration": Calibration,
"confusion_matrix": ConfusionMatrix,
"det": Det,
"precision_recall": PrecisionRecall,
"roc": Roc,
}
DATA_TYPES = list(PLOTS.values()) + [Scalar, Image]
39 changes: 34 additions & 5 deletions dvclive/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,19 @@
class Data(abc.ABC):
def __init__(self, name: str, output_folder: str) -> None:
self.name = name
self.output_folder: Path = Path(output_folder)
self.output_folder: Path = Path(output_folder) / self.subfolder
self._step: Optional[int] = None
self.val = None
self._step_none_logged: bool = False
self._dump_kwargs = None

@property
def step(self) -> int:
return self._step

@step.setter
def step(self, val: int) -> None:
if not self._step_none_logged and val is None:
self._step_none_logged = True
elif val == self._step:
if self._step_none_logged and val == self._step:
raise DataAlreadyLoggedError(self.name, val)

self._step = val
Expand All @@ -31,6 +30,16 @@ def step(self, val: int) -> None:
def output_path(self) -> Path:
pass

@property
@abc.abstractmethod
def no_step_output_path(self) -> Path:
pass

@property
@abc.abstractmethod
def subfolder(self):
pass

@property
@abc.abstractmethod
def summary(self):
Expand All @@ -41,6 +50,26 @@ def summary(self):
def could_log(val: object) -> bool:
pass

def dump(self, val, step):
def dump(self, val, step, **kwargs):
self.val = val
self.step = step
self._dump_kwargs = kwargs
if not self._step_none_logged and step is None:
self._step_none_logged = True
self.no_step_dump()
elif step == 0:
self.first_step_dump()
else:
self.step_dump()

@abc.abstractmethod
def first_step_dump(self) -> None:
pass

@abc.abstractmethod
def no_step_dump(self) -> None:
pass

@abc.abstractmethod
def step_dump(self) -> None:
pass
52 changes: 27 additions & 25 deletions dvclive/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,45 @@

class Image(Data):
suffixes = [".jpg", ".jpeg", ".gif", ".png"]
subfolder = "images"

@staticmethod
def could_log(val: object) -> bool:
if val.__class__.__module__ == "PIL.Image":
return True
if val.__class__.__module__ == "numpy":
return True
return False
@property
def no_step_output_path(self) -> Path:
return self.output_folder / self.name

@property
def output_path(self) -> Path:
if Path(self.name).suffix not in self.suffixes:
raise ValueError(
f"Invalid image suffix '{Path(self.name).suffix}'"
f" Must be one of {self.suffixes}"
)
if self._step is None:
output_path = self.output_folder / self.name
output_path = self.no_step_output_path
else:
output_path = self.output_folder / f"{self._step}" / self.name
output_path.parent.mkdir(exist_ok=True, parents=True)
return output_path

def dump(self, val, step) -> None:
if self._step_none_logged and self._step is None:
super().dump(val, step)
step_none_path = self.output_folder / self.name
if step_none_path.exists():
step_none_path.rename(self.output_path)
else:
super().dump(val, step)
if val.__class__.__module__ == "numpy":
from PIL import Image as ImagePIL
@staticmethod
def could_log(val: object) -> bool:
if val.__class__.__module__ == "PIL.Image":
return True
if val.__class__.__module__ == "numpy":
return True
return False

def first_step_dump(self) -> None:
if self.no_step_output_path.exists():
self.no_step_output_path.rename(self.output_path)

def no_step_dump(self) -> None:
self.step_dump()

val = ImagePIL.fromarray(val)
def step_dump(self) -> None:
if self.val.__class__.__module__ == "numpy":
from PIL import Image as ImagePIL

_val = ImagePIL.fromarray(self.val)
else:
_val = self.val

val.save(self.output_path)
_val.save(self.output_path)

@property
def summary(self):
Expand Down
123 changes: 123 additions & 0 deletions dvclive/data/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import json
from pathlib import Path

from .base import Data


class Plot(Data):
suffixes = [".json"]
subfolder = "plots"

@property
def output_path(self) -> Path:
_path = self.output_folder / self.name
_path.parent.mkdir(exist_ok=True, parents=True)
return _path.with_suffix(".json")

@staticmethod
def could_log(val: object) -> bool:
if isinstance(val, tuple) and len(val) == 2:
return True
return False

@property
def no_step_output_path(self) -> Path:
return super().no_step_output_path.with_suffix(".json")

@property
def summary(self):
return {}

@staticmethod
def write_json(content, output_file):
with open(output_file, "w") as f:
json.dump(content, f, indent=4)

def no_step_dump(self) -> None:
raise NotImplementedError

def first_step_dump(self) -> None:
raise NotImplementedError(
"DVCLive plots can only be used in no-step mode."
)

def step_dump(self) -> None:
raise NotImplementedError(
"DVCLive plots can only be used in no-step mode."
)


class Roc(Plot):
def no_step_dump(self) -> int:
from sklearn import metrics

fpr, tpr, roc_thresholds = metrics.roc_curve(
y_true=self.val[0], y_score=self.val[1], **self._dump_kwargs
)
roc = {
"roc": [
{"fpr": fp, "tpr": tp, "threshold": t}
for fp, tp, t in zip(fpr, tpr, roc_thresholds)
]
}
self.write_json(roc, self.output_path)


class PrecisionRecall(Plot):
def no_step_dump(self) -> int:
from sklearn import metrics

precision, recall, prc_thresholds = metrics.precision_recall_curve(
y_true=self.val[0], probas_pred=self.val[1], **self._dump_kwargs
)

prc = {
"prc": [
{"precision": p, "recall": r, "threshold": t}
for p, r, t in zip(precision, recall, prc_thresholds)
]
}
self.write_json(prc, self.output_path)


class Det(Plot):
def no_step_dump(self) -> int:
from sklearn import metrics

fpr, fnr, roc_thresholds = metrics.det_curve(
y_true=self.val[0], y_score=self.val[1], **self._dump_kwargs
)

det = {
"det": [
{"fpr": fp, "fnr": fn, "threshold": t}
for fp, fn, t in zip(fpr, fnr, roc_thresholds)
]
}
self.write_json(det, self.output_path)


class ConfusionMatrix(Plot):
def no_step_dump(self) -> int:
cm = [
{"actual": str(actual), "predicted": str(predicted)}
for actual, predicted in zip(self.val[0], self.val[1])
]
self.write_json(cm, self.output_path)


class Calibration(Plot):
def no_step_dump(self) -> int:
from sklearn import calibration

prob_true, prob_pred = calibration.calibration_curve(
y_true=self.val[0], y_prob=self.val[1], **self._dump_kwargs
)

calibration = {
"calibration": [
{"prob_true": pt, "prob_pred": pp}
for pt, pp in zip(prob_true, prob_pred)
]
}
self.write_json(calibration, self.output_path)
34 changes: 21 additions & 13 deletions dvclive/data/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

class Scalar(Data):
suffixes = [".csv", ".tsv"]
subfolder = "scalars"

@staticmethod
def could_log(val: object) -> bool:
Expand All @@ -24,23 +25,30 @@ def output_path(self) -> Path:
_path.parent.mkdir(exist_ok=True, parents=True)
return _path.with_suffix(".tsv")

def dump(self, val, step):
super().dump(val, step)
@property
def no_step_output_path(self) -> Path:
return self.output_path

def first_step_dump(self) -> None:
self.step_dump()

def no_step_dump(self) -> None:
pass

if step is not None:
ts = int(time.time() * 1000)
d = OrderedDict(
[("timestamp", ts), ("step", self.step), (self.name, self.val)]
)
def step_dump(self) -> None:
ts = int(time.time() * 1000)
d = OrderedDict(
[("timestamp", ts), ("step", self.step), (self.name, self.val)]
)

existed = self.output_path.exists()
with open(self.output_path, "a") as fobj:
writer = csv.DictWriter(fobj, d.keys(), delimiter="\t")
existed = self.output_path.exists()
with open(self.output_path, "a") as fobj:
writer = csv.DictWriter(fobj, d.keys(), delimiter="\t")

if not existed:
writer.writeheader()
if not existed:
writer.writeheader()

writer.writerow(d)
writer.writerow(d)

@property
def summary(self):
Expand Down
11 changes: 11 additions & 0 deletions dvclive/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ def __init__(self, name, val):
super().__init__(f"Data '{name}' has not supported type {val}")


class InvalidPlotTypeError(DvcLiveError):
def __init__(self, name):
from .data import PLOTS

self.name = name
super().__init__(
f"Plot type '{name}' is not supported."
f"\nSupported types are: {list(PLOTS)}"
)


class DataAlreadyLoggedError(DvcLiveError):
def __init__(self, name, step):
self.name = name
Expand Down
Loading

0 comments on commit c69af09

Please sign in to comment.