diff --git a/tests/plots/test_controller.py b/tests/plots/test_controller.py index 303d041f..1d3e6a98 100644 --- a/tests/plots/test_controller.py +++ b/tests/plots/test_controller.py @@ -62,9 +62,9 @@ def test_realizations_plot_representation(): realization_df, x_axis, assets.ERTSTYLE["ensemble-selector"]["color_wheel"][0] ) assert len(plots) == 20 - for plot in plots: + for idx, plot in enumerate(plots): np.testing.assert_equal(x_axis, plot.repr.x) - np.testing.assert_equal(plot.repr.y, realization_df[plot.name].values) + np.testing.assert_equal(plot.repr.y, realization_df[idx].values) def test_realizations_statistics_plot_representation(): diff --git a/webviz_ert/controllers/multi_response_controller.py b/webviz_ert/controllers/multi_response_controller.py index 81542713..3342f7dd 100644 --- a/webviz_ert/controllers/multi_response_controller.py +++ b/webviz_ert/controllers/multi_response_controller.py @@ -43,7 +43,9 @@ def _get_realizations_plots( x_axis=x_axis, y_axis=realizations_df[realization].values, text=f"Realization: {realization} Ensemble: {ensemble_name}", - name=realization, + name=ensemble_name, + legendgroup=ensemble_name, + showlegend=False if idx > 0 else True, **_style, ) realizations_data.append(plot) @@ -81,7 +83,9 @@ def _get_realizations_statistics_plots( def _get_observation_plots( - observation_df: pd.DataFrame, metadata: Optional[List[str]] = None + observation_df: pd.DataFrame, + metadata: Optional[List[str]] = None, + ensemble: str = "", ) -> PlotModel: data = observation_df["values"] stds = observation_df["std"] @@ -97,7 +101,7 @@ def _get_observation_plots( x_axis=x_axis, y_axis=data, text=attributes, - name="Observation", + name=f"Observation_{ensemble}", error_y=dict( type="data", # value of error bar given in data coordinates array=stds.values, @@ -135,7 +139,8 @@ def _create_response_plot( ) if response.observations: observations = [ - _get_observation_plots(obs.data_df()) for obs in response.observations + _get_observation_plots(obs.data_df(), ensemble=ensemble_name) + for obs in response.observations ] else: observations = [] diff --git a/webviz_ert/models/plot_model.py b/webviz_ert/models/plot_model.py index 6a342cb3..277b99b1 100644 --- a/webviz_ert/models/plot_model.py +++ b/webviz_ert/models/plot_model.py @@ -200,16 +200,18 @@ class PlotModel: def __init__(self, **kwargs: Any): self._x_axis = kwargs["x_axis"] self._y_axis = kwargs["y_axis"] - self._text = kwargs["text"] if "text" in kwargs else None + self._text = kwargs.get("text") self._name = kwargs["name"] self._mode = kwargs["mode"] self._line = kwargs["line"] self._marker = kwargs["marker"] self._error_y = kwargs.get("error_y") self._hoverlabel = kwargs.get("hoverlabel") - self._meta = kwargs["meta"] if "meta" in kwargs else None + self._meta = kwargs.get("meta") self._xaxis = kwargs.get("xaxis") self.selected = True + self.legendgroup = kwargs.get("legendgroup") + self.showlegend = kwargs.get("showlegend", True) @property def repr(self) -> Union[go.Scattergl, go.Scatter]: @@ -224,7 +226,10 @@ def repr(self) -> Union[go.Scattergl, go.Scatter]: connectgaps=True, hoverlabel=self._hoverlabel, meta=self._meta, + showlegend=self.showlegend, ) + if self.legendgroup: + repr_dict["legendgroup"] = self.legendgroup if self._line: repr_dict["line"] = self._line if self._marker: