Skip to content

Commit

Permalink
log_sklearn_plot: Support custom name.
Browse files Browse the repository at this point in the history
Rename existing `name` argument to `kind`. This argument is used to determine the type of plot.
Add new optional `name` argument to define output file.

Closes #323
  • Loading branch information
daavoo committed Oct 25, 2022
1 parent 5a120c2 commit 1cfe551
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/dvclive/data/sklearn_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class SKLearnPlot(Data):

@property
def output_path(self) -> Path:
_path = Path(f"{self.output_folder / self.name}.json")
_name = self.name.replace(".json", "")
_path = Path(f"{self.output_folder / _name}.json")
_path.parent.mkdir(exist_ok=True, parents=True)
return _path

Expand Down
7 changes: 4 additions & 3 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,14 @@ def log_image(self, name: str, val):
data.dump(val, self._step)
logger.debug(f"Logged {name}: {val}")

def log_sklearn_plot(self, name, labels, predictions, **kwargs):
def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs):
val = (labels, predictions)

name = name or kind
if name in self._plots:
data = self._plots[name]
elif name in PLOTS and PLOTS[name].could_log(val):
data = PLOTS[name](name, self.plots_path)
elif kind in PLOTS and PLOTS[kind].could_log(val):
data = PLOTS[kind](name, self.plots_path)
self._plots[name] = data
else:
raise InvalidPlotTypeError(name)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_data/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,19 @@ def test_cleanup(tmp_dir, y_true_y_pred_y_score):
Live()

assert not (tmp_dir / live.plots_path / SKLearnPlot.subfolder).exists()


def test_custom_name(tmp_dir, y_true_y_pred_y_score):
live = Live()
out = tmp_dir / live.plots_path / SKLearnPlot.subfolder

y_true, y_pred, _ = y_true_y_pred_y_score

live.log_sklearn_plot("confusion_matrix", y_true, y_pred, name="train/cm")
live.log_sklearn_plot("confusion_matrix", y_true, y_pred, name="val/cm")
# ".json" should be stripped from the name
live.log_sklearn_plot("confusion_matrix", y_true, y_pred, name="cm.json")

assert (out / "train" / "cm.json").exists()
assert (out / "val" / "cm.json").exists()
assert (out / "cm.json").exists()

0 comments on commit 1cfe551

Please sign in to comment.