Skip to content

Commit

Permalink
fix: update plotting for optuna optimization results
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed Aug 2, 2024
1 parent f7a135d commit f2c60c4
Showing 1 changed file with 39 additions and 27 deletions.
66 changes: 39 additions & 27 deletions src/invert4geom/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1975,13 +1975,17 @@ def optimize_regional_filter(
resulting_grav_df = best_trial.user_attrs.get("results")

if plot is True:
if optimize_on_true_regional_misfit is True:
if study._is_multi_objective() is False: # pylint: disable=protected-access
optuna.visualization.plot_slice(study).show()
else:
optuna.visualization.plot_pareto_front(
study,
target_names=["RMS of residual at constraints", "RMS of residual"],
).show()
p = optuna.visualization.plot_pareto_front(study)
plotting.remove_df_from_hoverdata(p).show()
for i, j in enumerate(study.metric_names):
optuna.visualization.plot_slice(
study,
target=lambda t: t.values[i], # noqa: B023 # pylint: disable=cell-var-from-loop
target_name=j,
).show()
if plot_grid is True:
resulting_grav_df.set_index(["northing", "easting"]).to_xarray().reg.plot()

Expand Down Expand Up @@ -2105,13 +2109,17 @@ def optimize_regional_trend(
resulting_grav_df = best_trial.user_attrs.get("results")

if plot is True:
if optimize_on_true_regional_misfit is True:
if study._is_multi_objective() is False: # pylint: disable=protected-access
optuna.visualization.plot_slice(study).show()
else:
optuna.visualization.plot_pareto_front(
study,
target_names=["RMS of residual at constraints", "RMS of residual"],
).show()
p = optuna.visualization.plot_pareto_front(study)
plotting.remove_df_from_hoverdata(p).show()
for i, j in enumerate(study.metric_names):
optuna.visualization.plot_slice(
study,
target=lambda t: t.values[i], # noqa: B023 # pylint: disable=cell-var-from-loop
target_name=j,
).show()
if plot_grid is True:
resulting_grav_df.set_index(["northing", "easting"]).to_xarray().reg.plot()

Expand Down Expand Up @@ -2259,15 +2267,17 @@ def optimize_regional_eq_sources(
resulting_grav_df = best_trial.user_attrs.get("results")

if plot is True:
if optimize_on_true_regional_misfit is True:
if study._is_multi_objective() is False: # pylint: disable=protected-access
optuna.visualization.plot_slice(study).show()
else:
optuna.visualization.plot_pareto_front(
study,
target_names=["RMS of residual at constraints", "RMS of residual"],
).show()
optuna.visualization.plot_param_importances(study).show()

p = optuna.visualization.plot_pareto_front(study)
plotting.remove_df_from_hoverdata(p).show()
for i, j in enumerate(study.metric_names):
optuna.visualization.plot_slice(
study,
target=lambda t: t.values[i], # noqa: B023 # pylint: disable=cell-var-from-loop
target_name=j,
).show()
if plot_grid is True:
resulting_grav_df.set_index(["northing", "easting"]).to_xarray().reg.plot()

Expand Down Expand Up @@ -2469,13 +2479,17 @@ def optimize_regional_constraint_point_minimization(
resulting_grav_df = best_trial.user_attrs.get("results")

if plot is True:
if optimize_on_true_regional_misfit is True:
if study._is_multi_objective() is False: # pylint: disable=protected-access
optuna.visualization.plot_slice(study).show()
else:
optuna.visualization.plot_pareto_front(
study,
target_names=["RMS of residual at constraints", "RMS of residual"],
).show()
p = optuna.visualization.plot_pareto_front(study)
plotting.remove_df_from_hoverdata(p).show()
for i, j in enumerate(study.metric_names):
optuna.visualization.plot_slice(
study,
target=lambda t: t.values[i], # noqa: B023 # pylint: disable=cell-var-from-loop
target_name=j,
).show()
if len(study.trials[0].params) > 1:
optuna.visualization.plot_param_importances(study).show()
if isinstance(testing_df, pd.DataFrame) & (plot_grid is True):
Expand Down Expand Up @@ -2560,13 +2574,11 @@ def optimize_regional_constraint_point_minimization_kfolds(
)

if plot is True:
if kwargs.get("optimize_on_true_regional_misfit") is True:
if study._is_multi_objective() is False: # pylint: disable=protected-access
optuna.visualization.plot_slice(study).show()
else:
optuna.visualization.plot_pareto_front(
study,
target_names=["RMS of residual at constraints", "RMS of residual"],
).show()
p = optuna.visualization.plot_pareto_front(study)
plotting.remove_df_from_hoverdata(p).show()
if len(study.trials[0].params) > 1:
optuna.visualization.plot_param_importances(study).show()
if plot_grid is True:
Expand Down

0 comments on commit f2c60c4

Please sign in to comment.