diff --git a/src/autora/workflow/plotting.py b/src/autora/workflow/plotting.py index 9060eb17..11b29522 100644 --- a/src/autora/workflow/plotting.py +++ b/src/autora/workflow/plotting.py @@ -12,9 +12,11 @@ from .protocol import SupportsControllerState # Change default plot styles -rcParams["axes.spines.top"] = False -rcParams["axes.spines.right"] = False -rcParams["legend.frameon"] = False +controller_plotting_rc_context = { + "axes.spines.top": False, + "axes.spines.right": False, + "legend.frameon": False, +} def _get_variable_index( @@ -297,54 +299,61 @@ def plot_results_panel_2d( shape = (1, n_cycles_to_plot) else: shape = (int(np.ceil(n_cycles_to_plot / wrap)), wrap) - fig, axs = plt.subplots(*shape, **d_kw["subplot_kw"]) - # Place axis object in an array if plotting single panel - if shape == (1, 1): - axs = np.array([axs]) - - # Loop by panel - for i, ax in enumerate(axs.flat): - if i + 1 <= n_cycles_to_plot: - # Get index of cycle to plot - i_cycle = cycle_idx[i] - - # ---Plot observed data--- - # Independent variable values - x_vals = df_observed.loc[:, iv[0]] - # Dependent values masked by current cycle vs previous data - dv_previous = np.ma.masked_where( - df_observed["cycle"] >= i_cycle, df_observed[dv[0]] - ) - dv_current = np.ma.masked_where( - df_observed["cycle"] != i_cycle, df_observed[dv[0]] - ) - # Plotting scatter - ax.scatter(x_vals, dv_previous, **d_kw["scatter_previous_kw"]) - ax.scatter(x_vals, dv_current, **d_kw["scatter_current_kw"]) - - # ---Plot Model--- - conditions = condition_space[:, iv[0]] - ax.plot(conditions, l_predictions[i_cycle], **d_kw["plot_model_kw"]) - - # Label Panels - ax.text( - 0.05, 1, f"Cycle {i_cycle}", ha="left", va="top", transform=ax.transAxes - ) - else: - ax.axis("off") + with plt.rc_context(controller_plotting_rc_context): + fig, axs = plt.subplots(*shape, **d_kw["subplot_kw"]) + # Place axis object in an array if plotting single panel + if shape == (1, 1): + axs = np.array([axs]) + + # Loop by panel + for i, ax in enumerate(axs.flat): + if i + 1 <= n_cycles_to_plot: + # Get index of cycle to plot + i_cycle = cycle_idx[i] + + # ---Plot observed data--- + # Independent variable values + x_vals = df_observed.loc[:, iv[0]] + # Dependent values masked by current cycle vs previous data + dv_previous = np.ma.masked_where( + df_observed["cycle"] >= i_cycle, df_observed[dv[0]] + ) + dv_current = np.ma.masked_where( + df_observed["cycle"] != i_cycle, df_observed[dv[0]] + ) + # Plotting scatter + ax.scatter(x_vals, dv_previous, **d_kw["scatter_previous_kw"]) + ax.scatter(x_vals, dv_current, **d_kw["scatter_current_kw"]) + + # ---Plot Model--- + conditions = condition_space[:, iv[0]] + ax.plot(conditions, l_predictions[i_cycle], **d_kw["plot_model_kw"]) + + # Label Panels + ax.text( + 0.05, + 1, + f"Cycle {i_cycle}", + ha="left", + va="top", + transform=ax.transAxes, + ) - # Super Labels - fig.supxlabel(iv_label, y=0.07) - fig.supylabel(dv_label) + else: + ax.axis("off") + + # Super Labels + fig.supxlabel(iv_label, y=0.07) + fig.supylabel(dv_label) - # Legend - fig.legend( - ["Previous Data", "New Data", "Model"], - ncols=3, - bbox_to_anchor=(0.5, 0), - loc="lower center", - ) + # Legend + fig.legend( + ["Previous Data", "New Data", "Model"], + ncols=3, + bbox_to_anchor=(0.5, 0), + loc="lower center", + ) return fig @@ -444,60 +453,60 @@ def plot_results_panel_3d( shape = (1, n_cycles) else: shape = (int(np.ceil(n_cycles / wrap)), wrap) + with plt.rc_context(controller_plotting_rc_context): + fig, axs = plt.subplots(*shape, **d_kw["subplot_kw"]) + + # Loop by panel + for i, ax in enumerate(axs.flat): + if i + 1 <= n_cycles: + # ---Plot observed data--- + # Independent variable values + l_x = [df_observed.loc[:, s[0]] for s in iv] + # Dependent values masked by current cycle vs previous data + dv_previous = np.ma.masked_where( + df_observed["cycle"] >= i, df_observed[dv[0]] + ) + dv_current = np.ma.masked_where( + df_observed["cycle"] != i, df_observed[dv[0]] + ) + # Plotting scatter + ax.scatter(*l_x, dv_previous, **d_kw["scatter_previous_kw"]) + ax.scatter(*l_x, dv_current, **d_kw["scatter_current_kw"]) - fig, axs = plt.subplots(*shape, **d_kw["subplot_kw"]) - - # Loop by panel - for i, ax in enumerate(axs.flat): - if i + 1 <= n_cycles: - # ---Plot observed data--- - # Independent variable values - l_x = [df_observed.loc[:, s[0]] for s in iv] - # Dependent values masked by current cycle vs previous data - dv_previous = np.ma.masked_where( - df_observed["cycle"] >= i, df_observed[dv[0]] - ) - dv_current = np.ma.masked_where( - df_observed["cycle"] != i, df_observed[dv[0]] - ) - # Plotting scatter - ax.scatter(*l_x, dv_previous, **d_kw["scatter_previous_kw"]) - ax.scatter(*l_x, dv_current, **d_kw["scatter_current_kw"]) - - # ---Plot Model--- - ax.plot_surface( - x1, x2, l_predictions[i].reshape(x1.shape), **d_kw["surface_kw"] - ) - # ---Labels--- - # Title - ax.set_title(f"Cycle {i}") - - # Axis - ax.set_xlabel(iv_labels[0]) - ax.set_ylabel(iv_labels[1]) - ax.set_zlabel(dv_label) - - # Viewing angle - if view: - ax.view_init(*view) + # ---Plot Model--- + ax.plot_surface( + x1, x2, l_predictions[i].reshape(x1.shape), **d_kw["surface_kw"] + ) + # ---Labels--- + # Title + ax.set_title(f"Cycle {i}") - else: - ax.axis("off") - - # Legend - handles, labels = axs.flatten()[0].get_legend_handles_labels() - legend_elements = [ - handles[0], - handles[1], - Patch(facecolor=handles[2].get_facecolors()[0]), - ] - fig.legend( - handles=legend_elements, - labels=labels, - ncols=3, - bbox_to_anchor=(0.5, 0), - loc="lower center", - ) + # Axis + ax.set_xlabel(iv_labels[0]) + ax.set_ylabel(iv_labels[1]) + ax.set_zlabel(dv_label) + + # Viewing angle + if view: + ax.view_init(*view) + + else: + ax.axis("off") + + # Legend + handles, labels = axs.flatten()[0].get_legend_handles_labels() + legend_elements = [ + handles[0], + handles[1], + Patch(facecolor=handles[2].get_facecolors()[0]), + ] + fig.legend( + handles=legend_elements, + labels=labels, + ncols=3, + bbox_to_anchor=(0.5, 0), + loc="lower center", + ) return fig @@ -590,24 +599,25 @@ def plot_cycle_score( else: l_scores = cycle_specified_score(scorer, state, X, y_true, **scorer_kw) - # Plotting - fig, ax = plt.subplots(figsize=figsize) - ax.plot(np.arange(len(state.models)), l_scores, **plot_kw) - - # Adjusting axis limits - if ylim: - ax.set_ylim(*ylim) - if xlim: - ax.set_xlim(*xlim) - - # Labeling - ax.set_xlabel(x_label) - if y_label is None: - if scorer is not None: - y_label = scorer.__name__ - else: - y_label = "Score" - ax.set_ylabel(y_label) - ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + with plt.rc_context(controller_plotting_rc_context): + # Plotting + fig, ax = plt.subplots(figsize=figsize) + ax.plot(np.arange(len(state.models)), l_scores, **plot_kw) + + # Adjusting axis limits + if ylim: + ax.set_ylim(*ylim) + if xlim: + ax.set_xlim(*xlim) + + # Labeling + ax.set_xlabel(x_label) + if y_label is None: + if scorer is not None: + y_label = scorer.__name__ + else: + y_label = "Score" + ax.set_ylabel(y_label) + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) return fig