Skip to content

Commit

Permalink
feat: add plotting func for non-grid search 2 parameter CV
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed Aug 2, 2024
1 parent ccbd4ba commit b86ebf9
Showing 1 changed file with 111 additions and 0 deletions.
111 changes: 111 additions & 0 deletions src/invert4geom/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,117 @@ def plot_2_parameter_cv_scores(
plt.tight_layout()


def plot_2_parameter_cv_scores_uneven(
study: optuna.study.Study,
param_names: tuple[str, str],
plot_param_names: tuple[str, str] = ("Hyperparameter 1", "Hyperparameter 2"),
figsize: tuple[float, float] = (5, 3.5),
cmap: str | None = None,
) -> None:
"""
plot a scatter plot graph with x axis equal to parameter 1, y axis equal to
parameter 2, and points colored by cross-validation scores.
Parameters
----------
study : optuna.study.Study
param_names : tuple[str, str], optional
name to give for the parameters, by default "Hyperparameter"
figsize : tuple[float, float], optional
size of the figure, by default (5, 3.5)
cmap : str, optional
matplotlib colormap for scores, by default "viridis"
"""
# Check if seaborn is installed
if sns is None:
msg = "Missing optional dependency 'seaborn' required for plotting."
raise ImportError(msg)
sns.set_theme()
# Check if matplotlib is installed
if plt is None:
msg = "Missing optional dependency 'matplotlib' required for plotting."
raise ImportError(msg)

if cmap is None:
cmap = sns.color_palette("mako", as_cmap=True)

df0 = study.trials_dataframe().sort_values(by="value")

df0 = df0[[param_names[0], param_names[1], "value"]]

best = df0.iloc[0]

df = df0.set_index([param_names[0], param_names[1]])
df = df[~df.index.duplicated()]

df1 = df.reset_index()[["value", param_names[0], param_names[1]]]

plt.figure(figsize=figsize)
plt.title("Two parameter cross-validation")

x_min = df1[param_names[0]].min()
x_max = df1[param_names[0]].max()
y_min = df1[param_names[1]].min()
y_max = df1[param_names[1]].max()
x_buffer = (x_max - x_min) / 10
y_buffer = (y_max - y_min) / 10

dampings = list(np.logspace(-10, -2, num=9))
dampings.append(None)

# temporarily set Python's logging level to not get information about the
# inversion's progress
logging.disable(level=logging.INFO)

spline = utils.best_spline_cv(
coordinates=(df1[param_names[0]], df1[param_names[1]]),
data=df1.value,
dampings=dampings,
)

# reset logging level
logging.disable(level=logging.NOTSET)

region = vd.pad_region(
vd.get_region((df1[param_names[0]], df1[param_names[1]])),
(y_buffer, x_buffer),
)
grid = spline.grid(
shape=(100, 100),
region=region,
).scalars

grid.plot(
cmap=cmap,
)

plt.scatter(
df1[param_names[0]], # pylint: disable=unsubscriptable-object
df1[param_names[1]], # pylint: disable=unsubscriptable-object
marker=".",
color="gray",
label="Trials",
)

plt.plot(
best[param_names[0]],
best[param_names[1]],
"s",
markersize=10,
color=sns.color_palette()[3],
label="Minimum",
)
plt.legend(
loc="upper right",
)
plt.xlim([x_min - x_buffer, x_max + x_buffer])
plt.ylim([y_min - y_buffer, y_max + y_buffer])
plt.xlabel(plot_param_names[0])
plt.ylabel(plot_param_names[1])

plt.tight_layout()


def plot_cv_scores(
scores: list[float],
parameters: list[float],
Expand Down

0 comments on commit b86ebf9

Please sign in to comment.