From 0f5715b9ad96e6ccfcd0710bdf62c24f7b3e8f80 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Mon, 5 Aug 2024 08:40:43 -0400 Subject: [PATCH] support lists for log_plot y val (#837) --- src/dvclive/live.py | 5 +++-- src/dvclive/plots/custom.py | 4 ++-- tests/plots/test_custom.py | 27 +++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 8f49fc99..c0b4aa81 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -562,7 +562,7 @@ def log_plot( name: str, datapoints: Union[pd.DataFrame, np.ndarray, List[Dict]], x: str, - y: str, + y: Union[str, list[str]], template: Optional[str] = "linear", title: Optional[str] = None, x_label: Optional[str] = None, @@ -579,7 +579,8 @@ def log_plot( datapoints (pd.DataFrame | np.ndarray | List[Dict]): Pandas DataFrame, Numpy Array or List of dictionaries containing the data for the plot. x (str): name of the key (present in the dictionaries) to use as the x axis. - y (str): name of the key (present in the dictionaries) to use the y axis. + y (str | list[str]): name of the key or keys (present in the + dictionaries) to use the y axis. template (str): name of the `DVC plots template` to use. Defaults to `"linear"`. title (str): title to be displayed. Defaults to diff --git a/src/dvclive/plots/custom.py b/src/dvclive/plots/custom.py index f7e55563..0ea15272 100644 --- a/src/dvclive/plots/custom.py +++ b/src/dvclive/plots/custom.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Optional, Union from dvclive.serialize import dump_json @@ -15,7 +15,7 @@ def __init__( name: str, output_folder: str, x: str, - y: str, + y: Union[str, list[str]], template: Optional[str], title: Optional[str] = None, x_label: Optional[str] = None, diff --git a/tests/plots/test_custom.py b/tests/plots/test_custom.py index 17b18f2b..349c726a 100644 --- a/tests/plots/test_custom.py +++ b/tests/plots/test_custom.py @@ -29,3 +29,30 @@ def test_log_custom_plot(tmp_dir): "x_label": "x_label", "y_label": "y_label", } + + +def test_log_custom_plot_multi_y(tmp_dir): + live = Live() + out = tmp_dir / live.plots_dir / CustomPlot.subfolder + + datapoints = [{"x": 1, "y1": 2, "y2": 3}, {"x": 4, "y1": 5, "y2": 6}] + live.log_plot( + "custom_linear", + datapoints, + x="x", + y=["y1", "y2"], + template="linear", + title="custom_title", + x_label="x_label", + y_label="y_label", + ) + + assert json.loads((out / "custom_linear.json").read_text()) == datapoints + assert live._plots["custom_linear"].plot_config == { + "template": "linear", + "title": "custom_title", + "x": "x", + "y": ["y1", "y2"], + "x_label": "x_label", + "y_label": "y_label", + }