Skip to content

Commit

Permalink
Merge branch 'main' into general-pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
hollandjg authored Jul 7, 2023
2 parents 2107781 + 6816a4f commit 4263d79
Showing 1 changed file with 129 additions and 119 deletions.
248 changes: 129 additions & 119 deletions src/autora/workflow/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from .protocol import SupportsControllerState

# Change default plot styles
rcParams["axes.spines.top"] = False
rcParams["axes.spines.right"] = False
rcParams["legend.frameon"] = False
controller_plotting_rc_context = {
"axes.spines.top": False,
"axes.spines.right": False,
"legend.frameon": False,
}


def _get_variable_index(
Expand Down Expand Up @@ -297,54 +299,61 @@ def plot_results_panel_2d(
shape = (1, n_cycles_to_plot)
else:
shape = (int(np.ceil(n_cycles_to_plot / wrap)), wrap)
fig, axs = plt.subplots(*shape, **d_kw["subplot_kw"])
# Place axis object in an array if plotting single panel
if shape == (1, 1):
axs = np.array([axs])

# Loop by panel
for i, ax in enumerate(axs.flat):
if i + 1 <= n_cycles_to_plot:
# Get index of cycle to plot
i_cycle = cycle_idx[i]

# ---Plot observed data---
# Independent variable values
x_vals = df_observed.loc[:, iv[0]]
# Dependent values masked by current cycle vs previous data
dv_previous = np.ma.masked_where(
df_observed["cycle"] >= i_cycle, df_observed[dv[0]]
)
dv_current = np.ma.masked_where(
df_observed["cycle"] != i_cycle, df_observed[dv[0]]
)
# Plotting scatter
ax.scatter(x_vals, dv_previous, **d_kw["scatter_previous_kw"])
ax.scatter(x_vals, dv_current, **d_kw["scatter_current_kw"])

# ---Plot Model---
conditions = condition_space[:, iv[0]]
ax.plot(conditions, l_predictions[i_cycle], **d_kw["plot_model_kw"])

# Label Panels
ax.text(
0.05, 1, f"Cycle {i_cycle}", ha="left", va="top", transform=ax.transAxes
)

else:
ax.axis("off")
with plt.rc_context(controller_plotting_rc_context):
fig, axs = plt.subplots(*shape, **d_kw["subplot_kw"])
# Place axis object in an array if plotting single panel
if shape == (1, 1):
axs = np.array([axs])

# Loop by panel
for i, ax in enumerate(axs.flat):
if i + 1 <= n_cycles_to_plot:
# Get index of cycle to plot
i_cycle = cycle_idx[i]

# ---Plot observed data---
# Independent variable values
x_vals = df_observed.loc[:, iv[0]]
# Dependent values masked by current cycle vs previous data
dv_previous = np.ma.masked_where(
df_observed["cycle"] >= i_cycle, df_observed[dv[0]]
)
dv_current = np.ma.masked_where(
df_observed["cycle"] != i_cycle, df_observed[dv[0]]
)
# Plotting scatter
ax.scatter(x_vals, dv_previous, **d_kw["scatter_previous_kw"])
ax.scatter(x_vals, dv_current, **d_kw["scatter_current_kw"])

# ---Plot Model---
conditions = condition_space[:, iv[0]]
ax.plot(conditions, l_predictions[i_cycle], **d_kw["plot_model_kw"])

# Label Panels
ax.text(
0.05,
1,
f"Cycle {i_cycle}",
ha="left",
va="top",
transform=ax.transAxes,
)

# Super Labels
fig.supxlabel(iv_label, y=0.07)
fig.supylabel(dv_label)
else:
ax.axis("off")

# Super Labels
fig.supxlabel(iv_label, y=0.07)
fig.supylabel(dv_label)

# Legend
fig.legend(
["Previous Data", "New Data", "Model"],
ncols=3,
bbox_to_anchor=(0.5, 0),
loc="lower center",
)
# Legend
fig.legend(
["Previous Data", "New Data", "Model"],
ncols=3,
bbox_to_anchor=(0.5, 0),
loc="lower center",
)

return fig

Expand Down Expand Up @@ -444,60 +453,60 @@ def plot_results_panel_3d(
shape = (1, n_cycles)
else:
shape = (int(np.ceil(n_cycles / wrap)), wrap)
with plt.rc_context(controller_plotting_rc_context):
fig, axs = plt.subplots(*shape, **d_kw["subplot_kw"])

# Loop by panel
for i, ax in enumerate(axs.flat):
if i + 1 <= n_cycles:
# ---Plot observed data---
# Independent variable values
l_x = [df_observed.loc[:, s[0]] for s in iv]
# Dependent values masked by current cycle vs previous data
dv_previous = np.ma.masked_where(
df_observed["cycle"] >= i, df_observed[dv[0]]
)
dv_current = np.ma.masked_where(
df_observed["cycle"] != i, df_observed[dv[0]]
)
# Plotting scatter
ax.scatter(*l_x, dv_previous, **d_kw["scatter_previous_kw"])
ax.scatter(*l_x, dv_current, **d_kw["scatter_current_kw"])

fig, axs = plt.subplots(*shape, **d_kw["subplot_kw"])

# Loop by panel
for i, ax in enumerate(axs.flat):
if i + 1 <= n_cycles:
# ---Plot observed data---
# Independent variable values
l_x = [df_observed.loc[:, s[0]] for s in iv]
# Dependent values masked by current cycle vs previous data
dv_previous = np.ma.masked_where(
df_observed["cycle"] >= i, df_observed[dv[0]]
)
dv_current = np.ma.masked_where(
df_observed["cycle"] != i, df_observed[dv[0]]
)
# Plotting scatter
ax.scatter(*l_x, dv_previous, **d_kw["scatter_previous_kw"])
ax.scatter(*l_x, dv_current, **d_kw["scatter_current_kw"])

# ---Plot Model---
ax.plot_surface(
x1, x2, l_predictions[i].reshape(x1.shape), **d_kw["surface_kw"]
)
# ---Labels---
# Title
ax.set_title(f"Cycle {i}")

# Axis
ax.set_xlabel(iv_labels[0])
ax.set_ylabel(iv_labels[1])
ax.set_zlabel(dv_label)

# Viewing angle
if view:
ax.view_init(*view)
# ---Plot Model---
ax.plot_surface(
x1, x2, l_predictions[i].reshape(x1.shape), **d_kw["surface_kw"]
)
# ---Labels---
# Title
ax.set_title(f"Cycle {i}")

else:
ax.axis("off")

# Legend
handles, labels = axs.flatten()[0].get_legend_handles_labels()
legend_elements = [
handles[0],
handles[1],
Patch(facecolor=handles[2].get_facecolors()[0]),
]
fig.legend(
handles=legend_elements,
labels=labels,
ncols=3,
bbox_to_anchor=(0.5, 0),
loc="lower center",
)
# Axis
ax.set_xlabel(iv_labels[0])
ax.set_ylabel(iv_labels[1])
ax.set_zlabel(dv_label)

# Viewing angle
if view:
ax.view_init(*view)

else:
ax.axis("off")

# Legend
handles, labels = axs.flatten()[0].get_legend_handles_labels()
legend_elements = [
handles[0],
handles[1],
Patch(facecolor=handles[2].get_facecolors()[0]),
]
fig.legend(
handles=legend_elements,
labels=labels,
ncols=3,
bbox_to_anchor=(0.5, 0),
loc="lower center",
)

return fig

Expand Down Expand Up @@ -590,24 +599,25 @@ def plot_cycle_score(
else:
l_scores = cycle_specified_score(scorer, state, X, y_true, **scorer_kw)

# Plotting
fig, ax = plt.subplots(figsize=figsize)
ax.plot(np.arange(len(state.models)), l_scores, **plot_kw)

# Adjusting axis limits
if ylim:
ax.set_ylim(*ylim)
if xlim:
ax.set_xlim(*xlim)

# Labeling
ax.set_xlabel(x_label)
if y_label is None:
if scorer is not None:
y_label = scorer.__name__
else:
y_label = "Score"
ax.set_ylabel(y_label)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
with plt.rc_context(controller_plotting_rc_context):
# Plotting
fig, ax = plt.subplots(figsize=figsize)
ax.plot(np.arange(len(state.models)), l_scores, **plot_kw)

# Adjusting axis limits
if ylim:
ax.set_ylim(*ylim)
if xlim:
ax.set_xlim(*xlim)

# Labeling
ax.set_xlabel(x_label)
if y_label is None:
if scorer is not None:
y_label = scorer.__name__
else:
y_label = "Score"
ax.set_ylabel(y_label)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

return fig

0 comments on commit 4263d79

Please sign in to comment.