Skip to content

Commit

Permalink
feat: add functions for optuna logback to warn about best parameter v…
Browse files Browse the repository at this point in the history
…alues being at limits.
  • Loading branch information
mdtanker committed Aug 2, 2024
1 parent 89d4403 commit 53fafc7
Showing 1 changed file with 123 additions and 0 deletions.
123 changes: 123 additions & 0 deletions src/invert4geom/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,129 @@ def logging_callback(
)


def warn_limits_better_than_trial_1_param(
study: optuna.study.Study,
trial: optuna.trial.FrozenTrial,
) -> None:
"""
custom optuna callback, warn if limits provide better score than current trial
Parameters
----------
study : optuna.study.Study
optuna study
trial : optuna.trial.FrozenTrial
current trial
"""
# exit if one of first 2 trials (lower and upper limits)
if trial.number < 2:
return

# get scores of lower and upper limits
# this assumes that the first two trials are the lower and upper limits set by
# study.enqueue_trial()
lower_limit_score = study.trials[0].value
upper_limit_score = study.trials[1].value
msg = (
"Current trial (#%s, %s) has a worse score (%s) than either of the lower "
"(%s) or upper (%s) parameter value limits, it might be best to stop the "
"study and expand the limits."
)
# if study direction is minimize
if study.direction == optuna.study.StudyDirection.MINIMIZE:
# if current trial is worse than either limit, log a warning
if trial.values[0] > max(lower_limit_score, upper_limit_score):
logging.warning(
msg,
trial.number,
trial.params,
trial.values[0],
lower_limit_score,
upper_limit_score,
)
else:
pass

# if study direction is maximize
if study.direction == optuna.study.StudyDirection.MAXIMIZE:
# if current trial is worse than either limit, log a warning
if trial.values[0] < min(lower_limit_score, upper_limit_score):
logging.warning(
msg,
trial.number,
trial.params,
trial.values[0],
lower_limit_score,
upper_limit_score,
)
else:
pass


def warn_limits_better_than_trial_multi_params(
study: optuna.study.Study,
trial: optuna.trial.FrozenTrial,
) -> None:
"""
custom optuna callback, warn if limits provide better score than current trial for
multiple parameter optimization
Parameters
----------
study : optuna.study.Study
optuna study
trial : optuna.trial.FrozenTrial
current trial
"""

# number of parameters in the study
num_params = len(trial.params)

# get number of combos (2 params->4 trials, 3 params->8 trials etc.)
num_combos = 2**num_params

# exit if one of enqueued trials
if trial.number < num_combos:
return

# get scores of combos of upper and lower limits of both parameters
# this assumes that the first four trials are set by study.enqueue_trial()
scores = []
for i in range(num_combos):
scores.append(study.trials[i].value)

msg = (
"Current trial (#%s, %s) has a worse score (%s) than any of the combinations "
"of parameter value limits, it might be best to stop the study and expand the "
"limits."
)
# if study direction is minimize
if study.direction == optuna.study.StudyDirection.MINIMIZE:
# if current trial is worse than either limit, log a warning
if trial.values[0] > max(scores):
logging.warning(
msg,
trial.number,
trial.params,
trial.values[0],
)
else:
pass

# if study direction is maximize
if study.direction == optuna.study.StudyDirection.MAXIMIZE:
# if current trial is worse than either limit, log a warning
if trial.values[0] < min(scores):
logging.warning(
msg,
trial.number,
trial.params,
trial.values[0],
)
else:
pass


def available_cpu_count() -> typing.Any:
"""
Number of available virtual or physical CPUs on this system, i.e.
Expand Down

0 comments on commit 53fafc7

Please sign in to comment.