Skip to content

Commit

Permalink
First version of cut tuning with optuna
Browse files Browse the repository at this point in the history
* introduce various functionality, such as utils, I/O etc
  the structure is not cast in stone of course, first attempt
* wrap multiprocessing in one function
* offer the user a user_config dictionary at each trial which can be
  used to extract some (global) use case-specific configuration
* allow 2 different signatures for objective
* start making possible the usage of different samplers

* cut tuning example
  4 scripts provided to produce reference, optimisation, evaluation and
  closure (see README in respective directory src/o2tuner_run)

  **NOTE** The example is preliminary, might need improvements

  Most functionality was added according to the needs of this example.

* set max-line for flake8 to 150 (same as pylint)

Overall, everything preliminary and up to discussion.
  • Loading branch information
Benedikt Volkel committed Jun 20, 2022
1 parent 494f3e9 commit 74ca78e
Show file tree
Hide file tree
Showing 16 changed files with 891 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ disable=useless-object-inheritance,missing-function-docstring

[DESIGN]
max-args=10
max-locals=40
max-locals=40
max-attributes=10
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ python_requires = >=3.8

[options.packages.find]
where = src

[flake8]
max-line-length = 150
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def __call__(self):
# these using the following syntax, for example:
# $ pip install -e .[dev,test]
extras_require={
"dev": []
"dev": [],
"test": ["pylint", "pytest"]
},

# Although 'package_data' is the preferred approach, in some case you may need to place data
Expand Down
44 changes: 38 additions & 6 deletions src/o2tuner/backends.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,62 @@
"""
Mainstream and custom backends should be placed here
"""
import sys
from inspect import signature


import optuna
from o2tuner.utils import load_or_create_study


class OptunaHandler(object):
"""
Handler based on Optuna backend
"""

def __init__(self) -> None:
def __init__(self, db_study_name=None, db_storage=None, workdir=None, user_config=None) -> None:
"""
Constructor
"""
# user objective function
self._objective = None
# chosen sampler (can be None, optuna will use TPE then)
self._sampler = None
# our study object
self._study = None
self._n_trials = 100
# number of trials to be done in given study
self._n_trials = None
# name of study (in database)
self.db_study_name = db_study_name
# database storage
self.db_storage = db_storage
# working directory (if any)
self.workdir = workdir
# optional user configuration that will be passed down to each call of the objective
self.user_config = user_config

def initialise(self, n_trials=100):
self._n_trials = n_trials
self._study = optuna.create_study()
self._study = load_or_create_study(self.db_study_name, self.db_storage, self._sampler)

if self.workdir:
self._study.set_user_attr("cwd", self.workdir)

def optimise(self):
if not self._n_trials or not self._objective:
print("ERROR: Not initialised: Number of trials and objective function need to be set")
return
self._study.optimize(self._objective, n_trials=self._n_trials)

def set_objective(self, objective):
self._objective = objective
sig = signature(objective)
n_params = len(sig.parameters)
if n_params > 2 or not n_params:
print("Invalid signature of objective funtion. Need either 1 argument (only trial object) or 2 arguments (trial object and user_config)")
sys.exit(1)
if n_params == 1:
self._objective = objective
else:
# Additionally pass the user config
self._objective = lambda trial: objective(trial, self.user_config)

def set_sampler(self, sampler):
self._sampler = sampler
73 changes: 73 additions & 0 deletions src/o2tuner/inspector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Mainstream and custom backends should be placed here
"""
from os.path import join
import matplotlib.pyplot as plt

import optuna.visualization.matplotlib as ovm

from o2tuner.io import parse_yaml, make_dir
from o2tuner.utils import load_or_create_study
from o2tuner.sampler import construct_sampler


class OptunaInspector: # pylint: disable=too-few-public-methods
"""
Inspect optuna study
"""

def __init__(self, config_path=None):
"""
Constructor
"""
self._study = None
self._config_path = config_path
self._config = None

def load_impl(self, config_path):
"""
Implementation of loading a study
"""
config = parse_yaml(config_path)
sampler = construct_sampler(config.get("sampler", None))
storage = config.get("study", {})
self._study = load_or_create_study(storage.get("name"), storage.get("storage"), sampler)
self._config = config

def load(self, config_path=None):
"""
Loading wrapper
"""
if not config_path and not self._config_path:
print("WARNING: Nothing to load, no configuration given")
return False
if config_path:
self.load_impl(config_path)
else:
self.load_impl(self._config_path)
return True

def visualise(self, out_dir=None):
"""
Doing a simple visualisation using optuna's built-in functionality
"""
if not self._study:
print("WARNING: No study loaded to visualise")
return

out_dir = out_dir or join(self._config.get("workdir", "./"), "visualisation")
make_dir(out_dir)

def to_figure_and_save(axis, path):
fig = plt.figure()
axis.figure = fig
fig.add_axes(axis)
fig.savefig(path)
plt.close(fig)

ax_plot_optimization_history = ovm.plot_optimization_history(self._study)
to_figure_and_save(ax_plot_optimization_history, join(out_dir, "ax_plot_optimization_history.png"))
ax_plot_parallel_coordinate = ovm.plot_parallel_coordinate(self._study)
to_figure_and_save(ax_plot_parallel_coordinate, join(out_dir, "ax_plot_parallel_coordinate.png"))
ax_plot_param_importances = ovm.plot_param_importances(self._study)
to_figure_and_save(ax_plot_param_importances, join(out_dir, "ax_plot_param_importances.png"))
149 changes: 149 additions & 0 deletions src/o2tuner/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Common I/O functionality
"""

import sys
from os.path import expanduser, isfile, isdir, exists, abspath, join
from os import makedirs, remove, listdir
from shutil import rmtree
import json
import yaml


############################
# STANDARD FILE SYSTEM I/O #
############################
def exists_file(path):
"""wrapper around python's os.path.isfile
this has the possibility to add additional functionality if needed
"""
return isfile(path)


def exists_dir(path):
"""wrapper around python's os.path.isdir
this has the possibility to add additional functionality if needed
"""
return isdir(path)


def make_dir(path):
"""create a directory
"""
if exists(path):
if exists_file(path):
# if exists and if that is actually a file instead of a directory, fail here...
print(f"Attempted to create directory {path}. However, a file seems to exist there, quitting")
sys.exit(1)
# ...otherwise just warn.
print(f"WARNING: The directory {path} already exists, not overwriting")
return
# make the whole path structure
makedirs(path)


def remove_any_impl(to_be_removed, keep=None):

if not exists(to_be_removed):
return

if isfile(to_be_removed):
remove(to_be_removed)
return

if not keep:
rmtree(to_be_removed)
return

# at this point we are sure that it is a directory and that certain things should be kept
contains = [join(to_be_removed, ld) for ld in listdir(to_be_removed)]

# figure out potential overlap between this path and things to be kept
keep_this = []
for con in contains:
for k in keep:
if not k.find(con):
# means found at index 0
keep_this.append(con)

to_be_removed = [con for con in contains if con not in keep_this]

for tbr in to_be_removed:
remove_any_impl(tbr, keep_this)


def remove_dir(to_be_removed, keep=None):
"""
remove a file or directory but keep certain things inside on request
"""
to_be_removed = abspath(to_be_removed)
if keep:
keep = [join(to_be_removed, k) for k in keep]
remove_any_impl(to_be_removed, keep)


########
# YAML #
########
class NoAliasDumperYAML(yaml.SafeDumper):
"""
Avoid refs in dumped YAML
"""
def ignore_aliases(self, data):
return True


def parse_yaml(path):
"""
Wrap YAML reading
https://stackoverflow.com/questions/13518819/avoid-references-in-pyyaml
"""
path = expanduser(path)
try:
with open(path, encoding="utf8") as in_file:
return yaml.safe_load(in_file)
except (OSError, IOError, yaml.YAMLError) as exc:
print(f"ERROR: Cannot parse YAML from {path} due to\n{exc}")
sys.exit(1)


def dump_yaml(to_yaml, path, *, no_refs=False):
"""
Wrap YAML writing
"""
path = expanduser(path)
try:
with open(path, 'w', encoding="utf8") as out_file:
if no_refs:
yaml.dump_all([to_yaml], out_file, Dumper=NoAliasDumperYAML)
else:
yaml.safe_dump(to_yaml, out_file)
except (OSError, IOError, yaml.YAMLError) as eexc:
print(f"ERROR: Cannot write YAML to {path} due to\n{eexc}")
sys.exit(1)


########
# YAML #
########
def parse_json(filepath):
"""
Wrap JSON reading, needed for interfacing with O2 cut configuration files
"""
filepath = expanduser(filepath)
if not exists_file(filepath):
print(f"ERROR: JSON file {filepath} does not exist.")
sys.exit(1)
with open(filepath, "r", encoding="utf8") as config_file:
return json.load(config_file)


def dump_json(to_json, path):
"""
Wrap JSON writing, needed for interfacing with O2 cut configuration files
"""
path = expanduser(path)
with open(path, 'w', encoding="utf8") as config_file:
json.dump(to_json, config_file, indent=2)
76 changes: 76 additions & 0 deletions src/o2tuner/optimise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Optimise hook, wrapping parallelisation
"""

from time import sleep
from math import floor
from multiprocessing import Process

from o2tuner.io import make_dir, parse_yaml
from o2tuner.backends import OptunaHandler
from o2tuner.sampler import construct_sampler


def optimise_run(objective, optuna_storage_config, sampler, n_trials, workdir, user_config):
"""
Run one of those per job
"""
handler = OptunaHandler(optuna_storage_config.get("name", None), optuna_storage_config.get("storage", None), workdir, user_config)
handler.initialise(n_trials)
handler.set_objective(objective)
handler.set_sampler(sampler)
handler.optimise()
return 0


def optimise(objective, optuna_config, *, user_config=None, init_func=None):
"""
This is the entry point function for all optimisation runs
args and kwargs will be forwarded to the objective function
"""

# read in the configurations
optuna_config = parse_yaml(optuna_config)

trials = optuna_config.get("trials", 100)
jobs = optuna_config.get("jobs", 1)

# investigate storage properties
optuna_storage_config = optuna_config.get("study", {})

if not optuna_storage_config.get("name", None) or not optuna_storage_config.get("storage", None):
# we reduce the number of jobs to 1. Either missing the table name or the storage path will anyway always lead to a new study
print("WARNING: No storage provided, running only one job.")
jobs = 1

trials_list = floor(trials / jobs)
trials_list = [trials_list] * jobs
# add the left-over trials simply to the last job for now
trials_list[-1] += trials - sum(trials_list)

# Digest user configuration and call an initialisation if requested
user_config = parse_yaml(user_config) if user_config else None
if init_func:
init_func(user_config)

print(f"Number of jobs: {jobs}\nNumber of trials: {trials}")

sampler = construct_sampler(optuna_config.get("sampler", None))

workdir = optuna_config.get("workdir", "o2tuner_optimise")
make_dir(workdir)

procs = []
for trial in trials_list:
procs.append(Process(target=optimise_run, args=(objective, optuna_storage_config, sampler, trial, workdir, user_config)))
procs[-1].start()
sleep(5)

while True:
is_alive = any(p.is_alive() for p in procs)
if not is_alive:
break
# We assume here that the optimisation might take some time, so we can sleep for a bit
sleep(10)
return 0
Loading

0 comments on commit 74ca78e

Please sign in to comment.