Skip to content

Commit

Permalink
live: Revisit output names and structure. (#322)
Browse files Browse the repository at this point in the history
* live: Revisit output names and structure.

Applied #246 (comment)

Closes #246

* Rename `log_plot` to `log_sklearn_plot`.

* Rename `plot` -> `sklearn`

* Rename `sklearn` -> `sklearn_plot`
  • Loading branch information
daavoo committed Oct 18, 2022
1 parent 2d61715 commit 876cf20
Show file tree
Hide file tree
Showing 20 changed files with 184 additions and 156 deletions.
12 changes: 9 additions & 3 deletions src/dvclive/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from .image import Image
from .plot import Calibration, ConfusionMatrix, Det, PrecisionRecall, Roc
from .scalar import Scalar
from .metric import Metric
from .sklearn_plot import (
Calibration,
ConfusionMatrix,
Det,
PrecisionRecall,
Roc,
)
from .utils import NumpyEncoder # noqa: F401

PLOTS = {
Expand All @@ -10,4 +16,4 @@
"precision_recall": PrecisionRecall,
"roc": Roc,
}
DATA_TYPES = (*PLOTS.values(), Scalar, Image)
DATA_TYPES = (*PLOTS.values(), Metric, Image)
4 changes: 2 additions & 2 deletions src/dvclive/data/scalar.py → src/dvclive/data/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from .utils import NUMPY_SCALARS


class Scalar(Data):
class Metric(Data):
suffixes = [".csv", ".tsv"]
subfolder = "scalars"
subfolder = "metrics"

@staticmethod
def could_log(val: object) -> bool:
Expand Down
14 changes: 7 additions & 7 deletions src/dvclive/data/plot.py → src/dvclive/data/sklearn_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from .base import Data


class Plot(Data):
class SKLearnPlot(Data):
suffixes = [".json"]
subfolder = "plots"
subfolder = "sklearn"

@property
def output_path(self) -> Path:
Expand Down Expand Up @@ -51,7 +51,7 @@ def get_properties():
raise NotImplementedError


class Roc(Plot):
class Roc(SKLearnPlot):
@staticmethod
def get_properties():
return {
Expand Down Expand Up @@ -79,7 +79,7 @@ def no_step_dump(self) -> None:
self.write_json(roc, self.output_path)


class PrecisionRecall(Plot):
class PrecisionRecall(SKLearnPlot):
@staticmethod
def get_properties():
return {
Expand Down Expand Up @@ -108,7 +108,7 @@ def no_step_dump(self) -> None:
self.write_json(prc, self.output_path)


class Det(Plot):
class Det(SKLearnPlot):
@staticmethod
def get_properties():
return {
Expand Down Expand Up @@ -137,7 +137,7 @@ def no_step_dump(self) -> None:
self.write_json(det, self.output_path)


class ConfusionMatrix(Plot):
class ConfusionMatrix(SKLearnPlot):
@staticmethod
def get_properties():
return {
Expand All @@ -159,7 +159,7 @@ def no_step_dump(self) -> None:
self.write_json(cm, self.output_path)


class Calibration(Plot):
class Calibration(SKLearnPlot):
@staticmethod
def get_properties():
return {
Expand Down
28 changes: 14 additions & 14 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ruamel.yaml.representer import RepresenterError

from . import env
from .data import DATA_TYPES, PLOTS, Image, NumpyEncoder, Scalar
from .data import DATA_TYPES, PLOTS, Image, Metric, NumpyEncoder
from .dvc import make_checkpoint
from .error import (
ConfigMismatchError,
Expand Down Expand Up @@ -109,10 +109,10 @@ def __init__(
def _cleanup(self):
for data_type in DATA_TYPES:
shutil.rmtree(
Path(self.dir) / data_type.subfolder, ignore_errors=True
Path(self.plots_path) / data_type.subfolder, ignore_errors=True
)

for f in (self.summary_path, self.report_path, self.params_path):
for f in (self.metrics_path, self.report_path, self.params_path):
if os.path.exists(f):
os.remove(f)

Expand Down Expand Up @@ -153,12 +153,12 @@ def params_path(self):
return os.path.join(self.dir, "params.yaml")

@property
def exists(self):
return os.path.isdir(self.dir)
def metrics_path(self):
return os.path.join(self.dir, "metrics.json")

@property
def summary_path(self):
return str(self.dir) + ".json"
def plots_path(self):
return os.path.join(self.dir, "plots")

def get_step(self) -> int:
return self._step or 0
Expand Down Expand Up @@ -194,13 +194,13 @@ def next_step(self):
self.set_step(self.get_step() + 1)

def log(self, name: str, val: Union[int, float]):
if not Scalar.could_log(val):
if not Metric.could_log(val):
raise InvalidDataTypeError(name, type(val))

if name in self._scalars:
data = self._scalars[name]
else:
data = Scalar(name, self.dir)
data = Metric(name, self.plots_path)
self._scalars[name] = data

data.dump(val, self._step)
Expand All @@ -215,19 +215,19 @@ def log_image(self, name: str, val):
if name in self._images:
data = self._images[name]
else:
data = Image(name, self.dir)
data = Image(name, self.plots_path)
self._images[name] = data

data.dump(val, self._step)
logger.debug(f"Logged {name}: {val}")

def log_plot(self, name, labels, predictions, **kwargs):
def log_sklearn_plot(self, name, labels, predictions, **kwargs):
val = (labels, predictions)

if name in self._plots:
data = self._plots[name]
elif name in PLOTS and PLOTS[name].could_log(val):
data = PLOTS[name](name, self.dir)
data = PLOTS[name](name, self.plots_path)
self._plots[name] = data
else:
raise InvalidPlotTypeError(name)
Expand Down Expand Up @@ -268,7 +268,7 @@ def make_summary(self):
for data in self._scalars.values():
summary_data = nested_update(summary_data, data.summary)

with open(self.summary_path, "w", encoding="utf-8") as f:
with open(self.metrics_path, "w", encoding="utf-8") as f:
json.dump(summary_data, f, indent=4, cls=NumpyEncoder)

def make_report(self):
Expand All @@ -287,7 +287,7 @@ def make_checkpoint(self):
make_checkpoint()

def read_step(self):
if Path(self.summary_path).exists():
if Path(self.metrics_path).exists():
latest = self.read_latest()
return latest.get("step", 0)
return 0
Expand Down
36 changes: 18 additions & 18 deletions src/dvclive/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,29 @@
from dvc_render.table import TableRenderer
from dvc_render.vega import VegaRenderer

from dvclive.data import PLOTS, Image, Scalar
from dvclive.data.plot import Plot
from dvclive.data import PLOTS, Image, Metric
from dvclive.data.sklearn_plot import SKLearnPlot
from dvclive.serialize import load_yaml
from dvclive.utils import parse_tsv

if TYPE_CHECKING:
from dvclive import Live


def get_scalar_renderers(scalars_folder):
def get_scalar_renderers(metrics_path):
renderers = []
for suffix in Scalar.suffixes:
for file in Path(scalars_folder).rglob(f"*{suffix}"):
for suffix in Metric.suffixes:
for file in metrics_path.rglob(f"*{suffix}"):
data = parse_tsv(file)
for row in data:
row["rev"] = "workspace"

y = file.relative_to(scalars_folder).with_suffix("")
y = file.relative_to(metrics_path).with_suffix("")
y = y.as_posix()

name = file.relative_to(scalars_folder.parent).with_suffix("")
name = file.relative_to(metrics_path.parent).with_suffix("")
name = name.as_posix()
name = name.replace(scalars_folder.name, "static")
name = name.replace(metrics_path.name, "static")

properties = {"x": "step", "y": y}
renderers.append(VegaRenderer(data, name, **properties))
Expand All @@ -57,7 +57,7 @@ def get_image_renderers(images_folder):

def get_plot_renderers(plots_folder):
renderers = []
for suffix in Plot.suffixes:
for suffix in SKLearnPlot.suffixes:
for file in Path(plots_folder).rglob(f"*{suffix}"):
name = file.stem
data = json.loads(file.read_text())
Expand All @@ -71,12 +71,12 @@ def get_plot_renderers(plots_folder):


def get_metrics_renderers(dvclive_summary):
summary_path = Path(dvclive_summary)
if summary_path.exists():
metrics_path = Path(dvclive_summary)
if metrics_path.exists():
return [
TableRenderer(
[json.loads(summary_path.read_text(encoding="utf-8"))],
summary_path.name,
[json.loads(metrics_path.read_text(encoding="utf-8"))],
metrics_path.name,
)
]
return []
Expand All @@ -95,14 +95,14 @@ def get_params_renderers(dvclive_params):


def make_report(dvclive: "Live"):
dvclive_path = Path(dvclive.dir)
plots_path = Path(dvclive.plots_path)

renderers = []
renderers.extend(get_params_renderers(dvclive.params_path))
renderers.extend(get_metrics_renderers(dvclive.summary_path))
renderers.extend(get_scalar_renderers(dvclive_path / Scalar.subfolder))
renderers.extend(get_image_renderers(dvclive_path / Image.subfolder))
renderers.extend(get_plot_renderers(dvclive_path / Plot.subfolder))
renderers.extend(get_metrics_renderers(dvclive.metrics_path))
renderers.extend(get_scalar_renderers(plots_path / Metric.subfolder))
renderers.extend(get_image_renderers(plots_path / Image.subfolder))
renderers.extend(get_plot_renderers(plots_path / SKLearnPlot.subfolder))

if dvclive.report_mode == "html":
render_html(renderers, dvclive.report_path, refresh_seconds=5)
Expand Down
6 changes: 3 additions & 3 deletions src/dvclive/studio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from os import getenv

from dvclive.env import STUDIO_ENDPOINT
from dvclive.utils import parse_scalars
from dvclive.utils import parse_metrics


def _get_unsent_datapoints(plot, latest_step):
Expand All @@ -28,14 +28,14 @@ def _to_dvc_format(plots):


def _get_updates(live):
plots, metrics = parse_scalars(live)
plots, metrics = parse_metrics(live)
latest_step = live._latest_studio_step # pylint: disable=protected-access

for name, plot in plots.items():
datapoints = _get_unsent_datapoints(plot, latest_step)
plots[name] = _cast_to_numbers(datapoints)

metrics = {live.summary_path: {"data": metrics}}
metrics = {live.metrics_path: {"data": metrics}}
plots = _to_dvc_format(plots)
return metrics, plots

Expand Down
12 changes: 6 additions & 6 deletions src/dvclive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ def parse_json(path):
return json.load(fd)


def parse_scalars(live):
from .data import Scalar
def parse_metrics(live):
from .data import Metric

live_dir = Path(live.dir)
plots_path = Path(live.plots_path)
history = {}
for suffix in Scalar.suffixes:
for scalar_file in live_dir.rglob(f"*{suffix}"):
for suffix in Metric.suffixes:
for scalar_file in plots_path.rglob(f"*{suffix}"):
history[str(scalar_file)] = parse_tsv(scalar_file)
latest = parse_json(live.summary_path)
latest = parse_json(live.metrics_path)
return history, latest
6 changes: 3 additions & 3 deletions tests/test_catalyst.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dvclive import Live
from dvclive.catalyst import DvcLiveCallback
from dvclive.data import Scalar
from dvclive.data import Metric

# pylint: disable=redefined-outer-name, unused-argument

Expand Down Expand Up @@ -67,8 +67,8 @@ def test_catalyst_callback(tmp_dir, runner, runner_params):

assert os.path.exists("dvclive")

train_path = tmp_dir / "dvclive" / Scalar.subfolder / "train"
valid_path = tmp_dir / "dvclive" / Scalar.subfolder / "valid"
train_path = tmp_dir / "dvclive" / "plots" / Metric.subfolder / "train"
valid_path = tmp_dir / "dvclive" / "plots" / Metric.subfolder / "valid"

assert train_path.is_dir()
assert valid_path.is_dir()
Expand Down
Loading

0 comments on commit 876cf20

Please sign in to comment.