From 53fafc75fc776c9e4cf16a8802bebc4efacf690d Mon Sep 17 00:00:00 2001 From: mdtanker Date: Thu, 18 Jul 2024 15:13:22 -0600 Subject: [PATCH] feat: add functions for optuna logback to warn about best parameter values being at limits. --- src/invert4geom/optimization.py | 123 ++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/src/invert4geom/optimization.py b/src/invert4geom/optimization.py index 061a2fd8..ed9ad9d2 100644 --- a/src/invert4geom/optimization.py +++ b/src/invert4geom/optimization.py @@ -76,6 +76,129 @@ def logging_callback( ) +def warn_limits_better_than_trial_1_param( + study: optuna.study.Study, + trial: optuna.trial.FrozenTrial, +) -> None: + """ + custom optuna callback, warn if limits provide better score than current trial + + Parameters + ---------- + study : optuna.study.Study + optuna study + trial : optuna.trial.FrozenTrial + current trial + """ + # exit if one of first 2 trials (lower and upper limits) + if trial.number < 2: + return + + # get scores of lower and upper limits + # this assumes that the first two trials are the lower and upper limits set by + # study.enqueue_trial() + lower_limit_score = study.trials[0].value + upper_limit_score = study.trials[1].value + msg = ( + "Current trial (#%s, %s) has a worse score (%s) than either of the lower " + "(%s) or upper (%s) parameter value limits, it might be best to stop the " + "study and expand the limits." + ) + # if study direction is minimize + if study.direction == optuna.study.StudyDirection.MINIMIZE: + # if current trial is worse than either limit, log a warning + if trial.values[0] > max(lower_limit_score, upper_limit_score): + logging.warning( + msg, + trial.number, + trial.params, + trial.values[0], + lower_limit_score, + upper_limit_score, + ) + else: + pass + + # if study direction is maximize + if study.direction == optuna.study.StudyDirection.MAXIMIZE: + # if current trial is worse than either limit, log a warning + if trial.values[0] < min(lower_limit_score, upper_limit_score): + logging.warning( + msg, + trial.number, + trial.params, + trial.values[0], + lower_limit_score, + upper_limit_score, + ) + else: + pass + + +def warn_limits_better_than_trial_multi_params( + study: optuna.study.Study, + trial: optuna.trial.FrozenTrial, +) -> None: + """ + custom optuna callback, warn if limits provide better score than current trial for + multiple parameter optimization + + Parameters + ---------- + study : optuna.study.Study + optuna study + trial : optuna.trial.FrozenTrial + current trial + """ + + # number of parameters in the study + num_params = len(trial.params) + + # get number of combos (2 params->4 trials, 3 params->8 trials etc.) + num_combos = 2**num_params + + # exit if one of enqueued trials + if trial.number < num_combos: + return + + # get scores of combos of upper and lower limits of both parameters + # this assumes that the first four trials are set by study.enqueue_trial() + scores = [] + for i in range(num_combos): + scores.append(study.trials[i].value) + + msg = ( + "Current trial (#%s, %s) has a worse score (%s) than any of the combinations " + "of parameter value limits, it might be best to stop the study and expand the " + "limits." + ) + # if study direction is minimize + if study.direction == optuna.study.StudyDirection.MINIMIZE: + # if current trial is worse than either limit, log a warning + if trial.values[0] > max(scores): + logging.warning( + msg, + trial.number, + trial.params, + trial.values[0], + ) + else: + pass + + # if study direction is maximize + if study.direction == optuna.study.StudyDirection.MAXIMIZE: + # if current trial is worse than either limit, log a warning + if trial.values[0] < min(scores): + logging.warning( + msg, + trial.number, + trial.params, + trial.values[0], + ) + else: + pass + + def available_cpu_count() -> typing.Any: """ Number of available virtual or physical CPUs on this system, i.e.