-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First version of cut tuning with optuna
* introduce various functionality, such as utils, I/O etc the structure is not cast in stone of course, first attempt * wrap multiprocessing in one function * offer the user a user_config dictionary at each trial which can be used to extract some (global) use case-specific configuration * allow 2 different signatures for objective * start making possible the usage of different samplers * cut tuning example 4 scripts provided to produce reference, optimisation, evaluation and closure (see README in respective directory src/o2tuner_run) **NOTE** The example is preliminary, might need improvements Most functionality was added according to the needs of this example. * set max-line for flake8 to 150 (same as pylint) Overall, everything preliminary and up to discussion.
- Loading branch information
Benedikt Volkel
committed
Jun 20, 2022
1 parent
494f3e9
commit 74ca78e
Showing
16 changed files
with
891 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,3 +22,6 @@ python_requires = >=3.8 | |
|
||
[options.packages.find] | ||
where = src | ||
|
||
[flake8] | ||
max-line-length = 150 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,62 @@ | ||
""" | ||
Mainstream and custom backends should be placed here | ||
""" | ||
import sys | ||
from inspect import signature | ||
|
||
|
||
import optuna | ||
from o2tuner.utils import load_or_create_study | ||
|
||
|
||
class OptunaHandler(object): | ||
""" | ||
Handler based on Optuna backend | ||
""" | ||
|
||
def __init__(self) -> None: | ||
def __init__(self, db_study_name=None, db_storage=None, workdir=None, user_config=None) -> None: | ||
""" | ||
Constructor | ||
""" | ||
# user objective function | ||
self._objective = None | ||
# chosen sampler (can be None, optuna will use TPE then) | ||
self._sampler = None | ||
# our study object | ||
self._study = None | ||
self._n_trials = 100 | ||
# number of trials to be done in given study | ||
self._n_trials = None | ||
# name of study (in database) | ||
self.db_study_name = db_study_name | ||
# database storage | ||
self.db_storage = db_storage | ||
# working directory (if any) | ||
self.workdir = workdir | ||
# optional user configuration that will be passed down to each call of the objective | ||
self.user_config = user_config | ||
|
||
def initialise(self, n_trials=100): | ||
self._n_trials = n_trials | ||
self._study = optuna.create_study() | ||
self._study = load_or_create_study(self.db_study_name, self.db_storage, self._sampler) | ||
|
||
if self.workdir: | ||
self._study.set_user_attr("cwd", self.workdir) | ||
|
||
def optimise(self): | ||
if not self._n_trials or not self._objective: | ||
print("ERROR: Not initialised: Number of trials and objective function need to be set") | ||
return | ||
self._study.optimize(self._objective, n_trials=self._n_trials) | ||
|
||
def set_objective(self, objective): | ||
self._objective = objective | ||
sig = signature(objective) | ||
n_params = len(sig.parameters) | ||
if n_params > 2 or not n_params: | ||
print("Invalid signature of objective funtion. Need either 1 argument (only trial object) or 2 arguments (trial object and user_config)") | ||
sys.exit(1) | ||
if n_params == 1: | ||
self._objective = objective | ||
else: | ||
# Additionally pass the user config | ||
self._objective = lambda trial: objective(trial, self.user_config) | ||
|
||
def set_sampler(self, sampler): | ||
self._sampler = sampler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
""" | ||
Mainstream and custom backends should be placed here | ||
""" | ||
from os.path import join | ||
import matplotlib.pyplot as plt | ||
|
||
import optuna.visualization.matplotlib as ovm | ||
|
||
from o2tuner.io import parse_yaml, make_dir | ||
from o2tuner.utils import load_or_create_study | ||
from o2tuner.sampler import construct_sampler | ||
|
||
|
||
class OptunaInspector: # pylint: disable=too-few-public-methods | ||
""" | ||
Inspect optuna study | ||
""" | ||
|
||
def __init__(self, config_path=None): | ||
""" | ||
Constructor | ||
""" | ||
self._study = None | ||
self._config_path = config_path | ||
self._config = None | ||
|
||
def load_impl(self, config_path): | ||
""" | ||
Implementation of loading a study | ||
""" | ||
config = parse_yaml(config_path) | ||
sampler = construct_sampler(config.get("sampler", None)) | ||
storage = config.get("study", {}) | ||
self._study = load_or_create_study(storage.get("name"), storage.get("storage"), sampler) | ||
self._config = config | ||
|
||
def load(self, config_path=None): | ||
""" | ||
Loading wrapper | ||
""" | ||
if not config_path and not self._config_path: | ||
print("WARNING: Nothing to load, no configuration given") | ||
return False | ||
if config_path: | ||
self.load_impl(config_path) | ||
else: | ||
self.load_impl(self._config_path) | ||
return True | ||
|
||
def visualise(self, out_dir=None): | ||
""" | ||
Doing a simple visualisation using optuna's built-in functionality | ||
""" | ||
if not self._study: | ||
print("WARNING: No study loaded to visualise") | ||
return | ||
|
||
out_dir = out_dir or join(self._config.get("workdir", "./"), "visualisation") | ||
make_dir(out_dir) | ||
|
||
def to_figure_and_save(axis, path): | ||
fig = plt.figure() | ||
axis.figure = fig | ||
fig.add_axes(axis) | ||
fig.savefig(path) | ||
plt.close(fig) | ||
|
||
ax_plot_optimization_history = ovm.plot_optimization_history(self._study) | ||
to_figure_and_save(ax_plot_optimization_history, join(out_dir, "ax_plot_optimization_history.png")) | ||
ax_plot_parallel_coordinate = ovm.plot_parallel_coordinate(self._study) | ||
to_figure_and_save(ax_plot_parallel_coordinate, join(out_dir, "ax_plot_parallel_coordinate.png")) | ||
ax_plot_param_importances = ovm.plot_param_importances(self._study) | ||
to_figure_and_save(ax_plot_param_importances, join(out_dir, "ax_plot_param_importances.png")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
""" | ||
Common I/O functionality | ||
""" | ||
|
||
import sys | ||
from os.path import expanduser, isfile, isdir, exists, abspath, join | ||
from os import makedirs, remove, listdir | ||
from shutil import rmtree | ||
import json | ||
import yaml | ||
|
||
|
||
############################ | ||
# STANDARD FILE SYSTEM I/O # | ||
############################ | ||
def exists_file(path): | ||
"""wrapper around python's os.path.isfile | ||
this has the possibility to add additional functionality if needed | ||
""" | ||
return isfile(path) | ||
|
||
|
||
def exists_dir(path): | ||
"""wrapper around python's os.path.isdir | ||
this has the possibility to add additional functionality if needed | ||
""" | ||
return isdir(path) | ||
|
||
|
||
def make_dir(path): | ||
"""create a directory | ||
""" | ||
if exists(path): | ||
if exists_file(path): | ||
# if exists and if that is actually a file instead of a directory, fail here... | ||
print(f"Attempted to create directory {path}. However, a file seems to exist there, quitting") | ||
sys.exit(1) | ||
# ...otherwise just warn. | ||
print(f"WARNING: The directory {path} already exists, not overwriting") | ||
return | ||
# make the whole path structure | ||
makedirs(path) | ||
|
||
|
||
def remove_any_impl(to_be_removed, keep=None): | ||
|
||
if not exists(to_be_removed): | ||
return | ||
|
||
if isfile(to_be_removed): | ||
remove(to_be_removed) | ||
return | ||
|
||
if not keep: | ||
rmtree(to_be_removed) | ||
return | ||
|
||
# at this point we are sure that it is a directory and that certain things should be kept | ||
contains = [join(to_be_removed, ld) for ld in listdir(to_be_removed)] | ||
|
||
# figure out potential overlap between this path and things to be kept | ||
keep_this = [] | ||
for con in contains: | ||
for k in keep: | ||
if not k.find(con): | ||
# means found at index 0 | ||
keep_this.append(con) | ||
|
||
to_be_removed = [con for con in contains if con not in keep_this] | ||
|
||
for tbr in to_be_removed: | ||
remove_any_impl(tbr, keep_this) | ||
|
||
|
||
def remove_dir(to_be_removed, keep=None): | ||
""" | ||
remove a file or directory but keep certain things inside on request | ||
""" | ||
to_be_removed = abspath(to_be_removed) | ||
if keep: | ||
keep = [join(to_be_removed, k) for k in keep] | ||
remove_any_impl(to_be_removed, keep) | ||
|
||
|
||
######## | ||
# YAML # | ||
######## | ||
class NoAliasDumperYAML(yaml.SafeDumper): | ||
""" | ||
Avoid refs in dumped YAML | ||
""" | ||
def ignore_aliases(self, data): | ||
return True | ||
|
||
|
||
def parse_yaml(path): | ||
""" | ||
Wrap YAML reading | ||
https://stackoverflow.com/questions/13518819/avoid-references-in-pyyaml | ||
""" | ||
path = expanduser(path) | ||
try: | ||
with open(path, encoding="utf8") as in_file: | ||
return yaml.safe_load(in_file) | ||
except (OSError, IOError, yaml.YAMLError) as exc: | ||
print(f"ERROR: Cannot parse YAML from {path} due to\n{exc}") | ||
sys.exit(1) | ||
|
||
|
||
def dump_yaml(to_yaml, path, *, no_refs=False): | ||
""" | ||
Wrap YAML writing | ||
""" | ||
path = expanduser(path) | ||
try: | ||
with open(path, 'w', encoding="utf8") as out_file: | ||
if no_refs: | ||
yaml.dump_all([to_yaml], out_file, Dumper=NoAliasDumperYAML) | ||
else: | ||
yaml.safe_dump(to_yaml, out_file) | ||
except (OSError, IOError, yaml.YAMLError) as eexc: | ||
print(f"ERROR: Cannot write YAML to {path} due to\n{eexc}") | ||
sys.exit(1) | ||
|
||
|
||
######## | ||
# YAML # | ||
######## | ||
def parse_json(filepath): | ||
""" | ||
Wrap JSON reading, needed for interfacing with O2 cut configuration files | ||
""" | ||
filepath = expanduser(filepath) | ||
if not exists_file(filepath): | ||
print(f"ERROR: JSON file {filepath} does not exist.") | ||
sys.exit(1) | ||
with open(filepath, "r", encoding="utf8") as config_file: | ||
return json.load(config_file) | ||
|
||
|
||
def dump_json(to_json, path): | ||
""" | ||
Wrap JSON writing, needed for interfacing with O2 cut configuration files | ||
""" | ||
path = expanduser(path) | ||
with open(path, 'w', encoding="utf8") as config_file: | ||
json.dump(to_json, config_file, indent=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
""" | ||
Optimise hook, wrapping parallelisation | ||
""" | ||
|
||
from time import sleep | ||
from math import floor | ||
from multiprocessing import Process | ||
|
||
from o2tuner.io import make_dir, parse_yaml | ||
from o2tuner.backends import OptunaHandler | ||
from o2tuner.sampler import construct_sampler | ||
|
||
|
||
def optimise_run(objective, optuna_storage_config, sampler, n_trials, workdir, user_config): | ||
""" | ||
Run one of those per job | ||
""" | ||
handler = OptunaHandler(optuna_storage_config.get("name", None), optuna_storage_config.get("storage", None), workdir, user_config) | ||
handler.initialise(n_trials) | ||
handler.set_objective(objective) | ||
handler.set_sampler(sampler) | ||
handler.optimise() | ||
return 0 | ||
|
||
|
||
def optimise(objective, optuna_config, *, user_config=None, init_func=None): | ||
""" | ||
This is the entry point function for all optimisation runs | ||
args and kwargs will be forwarded to the objective function | ||
""" | ||
|
||
# read in the configurations | ||
optuna_config = parse_yaml(optuna_config) | ||
|
||
trials = optuna_config.get("trials", 100) | ||
jobs = optuna_config.get("jobs", 1) | ||
|
||
# investigate storage properties | ||
optuna_storage_config = optuna_config.get("study", {}) | ||
|
||
if not optuna_storage_config.get("name", None) or not optuna_storage_config.get("storage", None): | ||
# we reduce the number of jobs to 1. Either missing the table name or the storage path will anyway always lead to a new study | ||
print("WARNING: No storage provided, running only one job.") | ||
jobs = 1 | ||
|
||
trials_list = floor(trials / jobs) | ||
trials_list = [trials_list] * jobs | ||
# add the left-over trials simply to the last job for now | ||
trials_list[-1] += trials - sum(trials_list) | ||
|
||
# Digest user configuration and call an initialisation if requested | ||
user_config = parse_yaml(user_config) if user_config else None | ||
if init_func: | ||
init_func(user_config) | ||
|
||
print(f"Number of jobs: {jobs}\nNumber of trials: {trials}") | ||
|
||
sampler = construct_sampler(optuna_config.get("sampler", None)) | ||
|
||
workdir = optuna_config.get("workdir", "o2tuner_optimise") | ||
make_dir(workdir) | ||
|
||
procs = [] | ||
for trial in trials_list: | ||
procs.append(Process(target=optimise_run, args=(objective, optuna_storage_config, sampler, trial, workdir, user_config))) | ||
procs[-1].start() | ||
sleep(5) | ||
|
||
while True: | ||
is_alive = any(p.is_alive() for p in procs) | ||
if not is_alive: | ||
break | ||
# We assume here that the optimisation might take some time, so we can sleep for a bit | ||
sleep(10) | ||
return 0 |
Oops, something went wrong.