Skip to content

Commit

Permalink
Avoid duplicate execution, resolved an issue which sometimes led to h…
Browse files Browse the repository at this point in the history
…yperparamtuner tuning returning more results than the searchspace size
  • Loading branch information
fjwillemsen committed Oct 22, 2024
1 parent e04906e commit 6062e50
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
1 change: 1 addition & 0 deletions kernel_tuner/backends/hypertuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, iterations):
self.observers = [ScoreObserver(self)]
self.name = platform.processor()
self.max_threads = 1024
self.last_score = None

# set the environment options
env = dict()
Expand Down
29 changes: 26 additions & 3 deletions kernel_tuner/hyper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
"""Module for functions related to hyperparameter optimization."""


from pathlib import Path
from random import randint

import kernel_tuner


def get_random_unique_filename(prefix = '', suffix=''):
"""Get a random, unique filename that does not yet exist."""
def randpath():
return Path(f"{prefix}{randint(1000, 9999)}{suffix}")
path = randpath()
while path.exists():
path = randpath()
return path

def tune_hyper_params(target_strategy: str, hyper_params: dict, *args, **kwargs):
"""Tune hyperparameters for a given strategy and kernel.
Expand Down Expand Up @@ -46,8 +57,10 @@ def tune_hyper_params(target_strategy: str, hyper_params: dict, *args, **kwargs)
if "iterations" in kwargs:
iterations = kwargs['iterations']
del kwargs['iterations']
if "cache" in kwargs:
del kwargs['cache']

# pass a temporary cache file to avoid duplicate execution
cachefile = get_random_unique_filename('temp_', '.json')
kwargs['cache'] = str(cachefile)

def put_if_not_present(target_dict, key, value):
target_dict[key] = value if key not in target_dict else target_dict[key]
Expand All @@ -59,8 +72,18 @@ def put_if_not_present(target_dict, key, value):
kwargs['verify'] = None
arguments = [target_strategy]

return kernel_tuner.tune_kernel('hyperparamtuning', None, [], arguments, hyper_params, *args, lang='Hypertuner',
# execute the hyperparameter tuning
result, env = kernel_tuner.tune_kernel('hyperparamtuning', None, [], arguments, hyper_params, *args, lang='Hypertuner',
objective='score', objective_higher_is_better=True, iterations=iterations, **kwargs)

# remove the temporary cachefile and return only unique results in order
cachefile.unlink()
result_unique = dict()
for r in result:
config_id = ",".join(str(r[k]) for k in hyper_params.keys())
if config_id not in result_unique:
result_unique[config_id] = r
return list(result_unique.values()), env

if __name__ == "__main__": # TODO remove in production
# hyperparams = {
Expand Down
2 changes: 1 addition & 1 deletion kernel_tuner/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def __deepcopy__(self, _):
All strategies support the following two options:
1. "max_fevals": the maximum number of unique valid function evaluations (i.e. compiling and
benchmarking a kernel configuration the strategy is allowed to perform as part of the optimization.
benchmarking a kernel configuration) the strategy is allowed to perform as part of the optimization.
Note that some strategies implement a default max_fevals of 100.
2. "time_limit": the maximum amount of time in seconds the strategy is allowed to spent on trying to
Expand Down
4 changes: 2 additions & 2 deletions test/test_hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ def test_hyper(env):

target_strategy = "genetic_algorithm"

result, env = tune_hyper_params(target_strategy, hyper_params, iterations=1, verbose=True, cache=cache_filename)
assert len(result) >= 2 # Look into why the hyperparamtuner returns more results than the searchspace size
result, env = tune_hyper_params(target_strategy, hyper_params, iterations=1, verbose=True)
assert len(result) == 2
assert 'best_config' in env

0 comments on commit 6062e50

Please sign in to comment.