Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First version of cut tuning with optuna #11

Merged
merged 1 commit into from
Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ disable=useless-object-inheritance,missing-function-docstring

[DESIGN]
max-args=10
max-locals=40
max-locals=40
max-attributes=10
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
optuna >= 2.10.0
colorama >= 0.4.3
click >= 8.0.3
click >= 8.0.3
matplotlib >= 3.5.1
psutil >= 5.9.0
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ python_requires = >=3.8

[options.packages.find]
where = src

[flake8]
max-line-length = 150
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,16 @@ def __call__(self):
# List run-time dependencies here. These will be installed by pip when your project is
# installed. For an analysis of "install_requires" vs pip's requirements files see:
# https://packaging.python.org/en/latest/requirements.html
install_requires=["optuna", "colorama", "click"],
install_requires=["optuna", "colorama", "click", "matplotlib", "psutil"],

python_requires=">=3.8",

# List additional groups of dependencies here (e.g. development dependencies). You can install
# these using the following syntax, for example:
# $ pip install -e .[dev,test]
extras_require={
"dev": []
"dev": [],
"test": ["pylint", "pytest"]
},

# Although 'package_data' is the preferred approach, in some case you may need to place data
Expand Down
44 changes: 38 additions & 6 deletions src/o2tuner/backends.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,62 @@
"""
Mainstream and custom backends should be placed here
"""
import sys
from inspect import signature


import optuna
from o2tuner.utils import load_or_create_study


class OptunaHandler(object):
"""
Handler based on Optuna backend
"""

def __init__(self) -> None:
def __init__(self, db_study_name=None, db_storage=None, workdir=None, user_config=None) -> None:
"""
Constructor
"""
# user objective function
self._objective = None
# chosen sampler (can be None, optuna will use TPE then)
self._sampler = None
# our study object
self._study = None
self._n_trials = 100
# number of trials to be done in given study
self._n_trials = None
# name of study (in database)
self.db_study_name = db_study_name
# database storage
self.db_storage = db_storage
# working directory (if any)
self.workdir = workdir
# optional user configuration that will be passed down to each call of the objective
self.user_config = user_config

def initialise(self, n_trials=100):
self._n_trials = n_trials
self._study = optuna.create_study()
self._study = load_or_create_study(self.db_study_name, self.db_storage, self._sampler)

if self.workdir:
self._study.set_user_attr("cwd", self.workdir)

def optimise(self):
if not self._n_trials or not self._objective:
print("ERROR: Not initialised: Number of trials and objective function need to be set")
return
self._study.optimize(self._objective, n_trials=self._n_trials)

def set_objective(self, objective):
self._objective = objective
sig = signature(objective)
n_params = len(sig.parameters)
if n_params > 2 or not n_params:
print("Invalid signature of objective funtion. Need either 1 argument (only trial object) or 2 arguments (trial object and user_config)")
sys.exit(1)
if n_params == 1:
self._objective = objective
else:
# Additionally pass the user config
self._objective = lambda trial: objective(trial, self.user_config)

def set_sampler(self, sampler):
self._sampler = sampler
73 changes: 73 additions & 0 deletions src/o2tuner/inspector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Mainstream and custom backends should be placed here
"""
from os.path import join
import matplotlib.pyplot as plt

import optuna.visualization.matplotlib as ovm

from o2tuner.io import parse_yaml, make_dir
from o2tuner.utils import load_or_create_study
from o2tuner.sampler import construct_sampler


class OptunaInspector: # pylint: disable=too-few-public-methods
"""
Inspect optuna study
"""

def __init__(self, config_path=None):
"""
Constructor
"""
self._study = None
self._config_path = config_path
self._config = None

def load_impl(self, config_path):
"""
Implementation of loading a study
"""
config = parse_yaml(config_path)
sampler = construct_sampler(config.get("sampler", None))
storage = config.get("study", {})
self._study = load_or_create_study(storage.get("name"), storage.get("storage"), sampler)
self._config = config

def load(self, config_path=None):
"""
Loading wrapper
"""
if not config_path and not self._config_path:
print("WARNING: Nothing to load, no configuration given")
return False
if config_path:
self.load_impl(config_path)
else:
self.load_impl(self._config_path)
return True

def visualise(self, out_dir=None):
"""
Doing a simple visualisation using optuna's built-in functionality
"""
if not self._study:
print("WARNING: No study loaded to visualise")
return

out_dir = out_dir or join(self._config.get("workdir", "./"), "visualisation")
make_dir(out_dir)

def to_figure_and_save(axis, path):
fig = plt.figure()
axis.figure = fig
fig.add_axes(axis)
fig.savefig(path)
plt.close(fig)

ax_plot_optimization_history = ovm.plot_optimization_history(self._study)
to_figure_and_save(ax_plot_optimization_history, join(out_dir, "ax_plot_optimization_history.png"))
ax_plot_parallel_coordinate = ovm.plot_parallel_coordinate(self._study)
to_figure_and_save(ax_plot_parallel_coordinate, join(out_dir, "ax_plot_parallel_coordinate.png"))
ax_plot_param_importances = ovm.plot_param_importances(self._study)
to_figure_and_save(ax_plot_param_importances, join(out_dir, "ax_plot_param_importances.png"))
149 changes: 149 additions & 0 deletions src/o2tuner/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Common I/O functionality
"""

import sys
from os.path import expanduser, isfile, isdir, exists, abspath, join
from os import makedirs, remove, listdir
from shutil import rmtree
import json
import yaml


############################
# STANDARD FILE SYSTEM I/O #
############################
def exists_file(path):
"""wrapper around python's os.path.isfile

this has the possibility to add additional functionality if needed
"""
return isfile(path)


def exists_dir(path):
"""wrapper around python's os.path.isdir

this has the possibility to add additional functionality if needed
"""
return isdir(path)


def make_dir(path):
"""create a directory
"""
if exists(path):
if exists_file(path):
# if exists and if that is actually a file instead of a directory, fail here...
print(f"Attempted to create directory {path}. However, a file seems to exist there, quitting")
sys.exit(1)
# ...otherwise just warn.
print(f"WARNING: The directory {path} already exists, not overwriting")
return
# make the whole path structure
makedirs(path)


def remove_any_impl(to_be_removed, keep=None):

if not exists(to_be_removed):
return

if isfile(to_be_removed):
remove(to_be_removed)
return

if not keep:
rmtree(to_be_removed)
return

# at this point we are sure that it is a directory and that certain things should be kept
contains = [join(to_be_removed, ld) for ld in listdir(to_be_removed)]

# figure out potential overlap between this path and things to be kept
keep_this = []
for con in contains:
for k in keep:
if not k.find(con):
# means found at index 0
keep_this.append(con)

to_be_removed = [con for con in contains if con not in keep_this]

for tbr in to_be_removed:
remove_any_impl(tbr, keep_this)


def remove_dir(to_be_removed, keep=None):
"""
remove a file or directory but keep certain things inside on request
"""
to_be_removed = abspath(to_be_removed)
if keep:
keep = [join(to_be_removed, k) for k in keep]
remove_any_impl(to_be_removed, keep)


########
# YAML #
########
class NoAliasDumperYAML(yaml.SafeDumper):
"""
Avoid refs in dumped YAML
"""
def ignore_aliases(self, data):
return True


def parse_yaml(path):
"""
Wrap YAML reading
https://stackoverflow.com/questions/13518819/avoid-references-in-pyyaml
"""
path = expanduser(path)
try:
with open(path, encoding="utf8") as in_file:
return yaml.safe_load(in_file)
except (OSError, IOError, yaml.YAMLError) as exc:
print(f"ERROR: Cannot parse YAML from {path} due to\n{exc}")
sys.exit(1)


def dump_yaml(to_yaml, path, *, no_refs=False):
"""
Wrap YAML writing
"""
path = expanduser(path)
try:
with open(path, 'w', encoding="utf8") as out_file:
if no_refs:
yaml.dump_all([to_yaml], out_file, Dumper=NoAliasDumperYAML)
else:
yaml.safe_dump(to_yaml, out_file)
except (OSError, IOError, yaml.YAMLError) as eexc:
print(f"ERROR: Cannot write YAML to {path} due to\n{eexc}")
sys.exit(1)


########
# YAML #
########
def parse_json(filepath):
"""
Wrap JSON reading, needed for interfacing with O2 cut configuration files
"""
filepath = expanduser(filepath)
if not exists_file(filepath):
print(f"ERROR: JSON file {filepath} does not exist.")
sys.exit(1)
with open(filepath, "r", encoding="utf8") as config_file:
return json.load(config_file)


def dump_json(to_json, path):
"""
Wrap JSON writing, needed for interfacing with O2 cut configuration files
"""
path = expanduser(path)
with open(path, 'w', encoding="utf8") as config_file:
json.dump(to_json, config_file, indent=2)
Loading