diff --git a/src/dvclive/data/sklearn_plot.py b/src/dvclive/data/sklearn_plot.py index e227e8cd..af24cb13 100644 --- a/src/dvclive/data/sklearn_plot.py +++ b/src/dvclive/data/sklearn_plot.py @@ -10,7 +10,8 @@ class SKLearnPlot(Data): @property def output_path(self) -> Path: - _path = Path(f"{self.output_folder / self.name}.json") + _name = self.name.replace(".json", "") + _path = Path(f"{self.output_folder / _name}.json") _path.parent.mkdir(exist_ok=True, parents=True) return _path diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 36f064d2..dd11cab3 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -221,13 +221,14 @@ def log_image(self, name: str, val): data.dump(val, self._step) logger.debug(f"Logged {name}: {val}") - def log_sklearn_plot(self, name, labels, predictions, **kwargs): + def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs): val = (labels, predictions) + name = name or kind if name in self._plots: data = self._plots[name] - elif name in PLOTS and PLOTS[name].could_log(val): - data = PLOTS[name](name, self.plots_path) + elif kind in PLOTS and PLOTS[kind].could_log(val): + data = PLOTS[kind](name, self.plots_path) self._plots[name] = data else: raise InvalidPlotTypeError(name) diff --git a/tests/test_data/test_plot.py b/tests/test_data/test_plot.py index d00a5d6e..93ce9635 100644 --- a/tests/test_data/test_plot.py +++ b/tests/test_data/test_plot.py @@ -137,3 +137,19 @@ def test_cleanup(tmp_dir, y_true_y_pred_y_score): Live() assert not (tmp_dir / live.plots_path / SKLearnPlot.subfolder).exists() + + +def test_custom_name(tmp_dir, y_true_y_pred_y_score): + live = Live() + out = tmp_dir / live.plots_path / SKLearnPlot.subfolder + + y_true, y_pred, _ = y_true_y_pred_y_score + + live.log_sklearn_plot("confusion_matrix", y_true, y_pred, name="train/cm") + live.log_sklearn_plot("confusion_matrix", y_true, y_pred, name="val/cm") + # ".json" should be stripped from the name + live.log_sklearn_plot("confusion_matrix", y_true, y_pred, name="cm.json") + + assert (out / "train" / "cm.json").exists() + assert (out / "val" / "cm.json").exists() + assert (out / "cm.json").exists()