From f2c60c4f2b00c1f96f5008fc8e8ba940a8a4ac6f Mon Sep 17 00:00:00 2001 From: mdtanker Date: Mon, 22 Jul 2024 21:00:20 -0600 Subject: [PATCH] fix: update plotting for optuna optimization results --- src/invert4geom/optimization.py | 66 +++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/src/invert4geom/optimization.py b/src/invert4geom/optimization.py index 3c2569e6..0c17d3a0 100644 --- a/src/invert4geom/optimization.py +++ b/src/invert4geom/optimization.py @@ -1975,13 +1975,17 @@ def optimize_regional_filter( resulting_grav_df = best_trial.user_attrs.get("results") if plot is True: - if optimize_on_true_regional_misfit is True: + if study._is_multi_objective() is False: # pylint: disable=protected-access optuna.visualization.plot_slice(study).show() else: - optuna.visualization.plot_pareto_front( - study, - target_names=["RMS of residual at constraints", "RMS of residual"], - ).show() + p = optuna.visualization.plot_pareto_front(study) + plotting.remove_df_from_hoverdata(p).show() + for i, j in enumerate(study.metric_names): + optuna.visualization.plot_slice( + study, + target=lambda t: t.values[i], # noqa: B023 # pylint: disable=cell-var-from-loop + target_name=j, + ).show() if plot_grid is True: resulting_grav_df.set_index(["northing", "easting"]).to_xarray().reg.plot() @@ -2105,13 +2109,17 @@ def optimize_regional_trend( resulting_grav_df = best_trial.user_attrs.get("results") if plot is True: - if optimize_on_true_regional_misfit is True: + if study._is_multi_objective() is False: # pylint: disable=protected-access optuna.visualization.plot_slice(study).show() else: - optuna.visualization.plot_pareto_front( - study, - target_names=["RMS of residual at constraints", "RMS of residual"], - ).show() + p = optuna.visualization.plot_pareto_front(study) + plotting.remove_df_from_hoverdata(p).show() + for i, j in enumerate(study.metric_names): + optuna.visualization.plot_slice( + study, + target=lambda t: t.values[i], # noqa: B023 # pylint: disable=cell-var-from-loop + target_name=j, + ).show() if plot_grid is True: resulting_grav_df.set_index(["northing", "easting"]).to_xarray().reg.plot() @@ -2259,15 +2267,17 @@ def optimize_regional_eq_sources( resulting_grav_df = best_trial.user_attrs.get("results") if plot is True: - if optimize_on_true_regional_misfit is True: + if study._is_multi_objective() is False: # pylint: disable=protected-access optuna.visualization.plot_slice(study).show() else: - optuna.visualization.plot_pareto_front( - study, - target_names=["RMS of residual at constraints", "RMS of residual"], - ).show() - optuna.visualization.plot_param_importances(study).show() - + p = optuna.visualization.plot_pareto_front(study) + plotting.remove_df_from_hoverdata(p).show() + for i, j in enumerate(study.metric_names): + optuna.visualization.plot_slice( + study, + target=lambda t: t.values[i], # noqa: B023 # pylint: disable=cell-var-from-loop + target_name=j, + ).show() if plot_grid is True: resulting_grav_df.set_index(["northing", "easting"]).to_xarray().reg.plot() @@ -2469,13 +2479,17 @@ def optimize_regional_constraint_point_minimization( resulting_grav_df = best_trial.user_attrs.get("results") if plot is True: - if optimize_on_true_regional_misfit is True: + if study._is_multi_objective() is False: # pylint: disable=protected-access optuna.visualization.plot_slice(study).show() else: - optuna.visualization.plot_pareto_front( - study, - target_names=["RMS of residual at constraints", "RMS of residual"], - ).show() + p = optuna.visualization.plot_pareto_front(study) + plotting.remove_df_from_hoverdata(p).show() + for i, j in enumerate(study.metric_names): + optuna.visualization.plot_slice( + study, + target=lambda t: t.values[i], # noqa: B023 # pylint: disable=cell-var-from-loop + target_name=j, + ).show() if len(study.trials[0].params) > 1: optuna.visualization.plot_param_importances(study).show() if isinstance(testing_df, pd.DataFrame) & (plot_grid is True): @@ -2560,13 +2574,11 @@ def optimize_regional_constraint_point_minimization_kfolds( ) if plot is True: - if kwargs.get("optimize_on_true_regional_misfit") is True: + if study._is_multi_objective() is False: # pylint: disable=protected-access optuna.visualization.plot_slice(study).show() else: - optuna.visualization.plot_pareto_front( - study, - target_names=["RMS of residual at constraints", "RMS of residual"], - ).show() + p = optuna.visualization.plot_pareto_front(study) + plotting.remove_df_from_hoverdata(p).show() if len(study.trials[0].params) > 1: optuna.visualization.plot_param_importances(study).show() if plot_grid is True: