Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optuna GridSearch #630

Merged
merged 31 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7119ec5
Optuna
carraraig Jun 27, 2024
c09cb7a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2024
d4cb28e
Optuna - Categorical Distibution
carraraig Jun 27, 2024
1767d6a
Merge remote-tracking branch 'origin/Optuna' into Optuna
carraraig Jun 27, 2024
80940d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2024
bbdf225
Add optuna to dependency and regenerate poetry
carraraig Jun 27, 2024
0ce8943
Merge remote-tracking branch 'origin/Optuna' into Optuna
carraraig Jun 27, 2024
a99d6b8
Add optuna to dependency and regenerate poetry
carraraig Jun 27, 2024
d61e43a
Add optuna to dependency and regenerate poetry
carraraig Jun 27, 2024
0ad7c12
Add test on within Session
carraraig Jun 27, 2024
25e73c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 27, 2024
d38803d
Add test on within Session
carraraig Jun 28, 2024
bfa3c38
Merge remote-tracking branch 'origin/Optuna' into Optuna
carraraig Jun 28, 2024
c04205a
Add test on within Session
carraraig Jun 28, 2024
d1ba14b
ehn: common function with dict
bruAristimunha Jun 28, 2024
6efdaa4
ehn: moving function to util
bruAristimunha Jun 28, 2024
ae6be13
fix: correcting the what news file.
bruAristimunha Jul 1, 2024
c3a7daf
Merge branch 'develop' into Optuna
bruAristimunha Jul 15, 2024
6f572f5
Add test benchmark and raise an issue if the conversion didn't worked…
carraraig Jul 15, 2024
c107a2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 15, 2024
b160cf6
FIX: italian to eng
bruAristimunha Jul 15, 2024
dc09e6c
EHN: making optuna optional
bruAristimunha Jul 15, 2024
de72e46
FIX: fixing the workflow files
bruAristimunha Jul 15, 2024
c2810dc
FIX: changing the optuna file
bruAristimunha Jul 15, 2024
19aac40
Merge branch 'develop' into Optuna
bruAristimunha Jul 15, 2024
2224ceb
FIX: including optuna for the windows
bruAristimunha Jul 15, 2024
73d18db
Merge remote-tracking branch 'carraraig/Optuna' into Optuna
bruAristimunha Jul 15, 2024
f1a2262
FIX: fix the doc generation
bruAristimunha Jul 15, 2024
1399706
FIX: make sure to not include a warning in all the executions
bruAristimunha Jul 15, 2024
392cb3e
FIX: fixing the workflow file
bruAristimunha Jul 15, 2024
b3ec191
FIX: fixing the doc
bruAristimunha Jul 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ jobs:

- name: Install dependencies
if: (steps.cached-poetry-dependencies.outputs.cache-hit != 'true')
run: poetry install --no-interaction --no-root --with docs --extras deeplearning
run: poetry install --no-interaction --no-root --with docs --extras deeplearning --extras optuna

- name: Install library
run: poetry install --no-interaction --with docs --extras deeplearning
run: poetry install --no-interaction --with docs --extras deeplearning --extras optuna

- name: Build docs
run: |
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test-devel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ jobs:
if: |
(runner.os != 'Windows') &&
(steps.cached-poetry-dependencies.outputs.cache-hit != 'true')
run: poetry install --no-interaction --no-root --extras deeplearning
run: poetry install --no-interaction --no-root --extras deeplearning --extras optuna

- name: Install library (Linux/OSX)
if: ${{ runner.os != 'Windows' }}
run: poetry install --no-interaction --extras deeplearning
run: poetry install --no-interaction --extras deeplearning --extras optuna

- name: Install library (Windows)
if: ${{ runner.os == 'Windows' }}
run: poetry install --no-interaction
run: poetry install --no-interaction --extras optuna

- name: Run tests
run: |
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ jobs:
if: |
(runner.os != 'Windows') &&
(steps.cached-poetry-dependencies.outputs.cache-hit != 'true')
run: poetry install --no-interaction --no-root --extras deeplearning
run: poetry install --no-interaction --no-root --extras deeplearning --extras optuna

- name: Install library (Linux/OSX)
if: ${{ runner.os != 'Windows' }}
run: poetry install --no-interaction --extras deeplearning
run: poetry install --no-interaction --extras deeplearning --extras optuna

- name: Install library (Windows)
if: ${{ runner.os == 'Windows' }}
run: poetry install --no-interaction
run: poetry install --no-interaction --extras optuna

- name: Run tests
run: |
Expand Down
6 changes: 5 additions & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@ Develop branch

Enhancements
~~~~~~~~~~~~
- Add possibility to use OptunaGridSearch (:gh:`630` by `Igor Carrara`_)
- Add scripts to upload results on PapersWithCode (:gh:`561` by `Pierre Guetschel`_)
- Centralize dataset summary tables in CSV files (:gh:`635` by `Pierre Guetschel`_)
- Add new dataset :class:`moabb.datasets.Liu2024` dataset (:gh:`619` by `Taha Habib`_)
- Increasing the version in the pre-commit config (:gh:`631` by pre-commit bot)



Bugs
~~~~
- Fix caching in the workflows (:gh:`632` by `Pierre Guetschel`_)

API changes
~~~~~~~~~~~
- None
- Include optuna as soft-dependency in the benchmark function and in the base of evaluation (:gh:`630` by `Igor Carrara`_)



Version - 1.1.0 (Stable - PyPi)
---------------------------------
Expand Down
8 changes: 8 additions & 0 deletions moabb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def benchmark( # noqa: C901
exclude_datasets=None,
n_splits=None,
cache_config=None,
optuna=False,
):
"""Run benchmarks for selected pipelines and datasets.

Expand Down Expand Up @@ -102,6 +103,7 @@ def benchmark( # noqa: C901
and exclude_datasets are specified, raise an error.
exclude_datasets: list of str or Dataset object
Datasets to exclude from the benchmark run
optuna: Enable Optuna for the hyperparameter search

Returns
-------
Expand All @@ -110,7 +112,11 @@ def benchmark( # noqa: C901

Notes
-----
.. versionadded:: 1.1.1
Includes the possibility to use Optuna for hyperparameter search.

.. versionadded:: 0.5.0
Create the function to run the benchmark
"""
# set logs
if evaluations is None:
Expand Down Expand Up @@ -182,6 +188,7 @@ def benchmark( # noqa: C901
return_epochs=True,
n_splits=n_splits,
cache_config=cache_config,
optuna=optuna,
)
paradigm_results = context.process(
pipelines=ppl_with_epochs, param_grid=param_grid
Expand All @@ -202,6 +209,7 @@ def benchmark( # noqa: C901
overwrite=overwrite,
n_splits=n_splits,
cache_config=cache_config,
optuna=optuna,
)
paradigm_results = context.process(
pipelines=ppl_with_array, param_grid=param_grid
Expand Down
47 changes: 45 additions & 2 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import logging
from abc import ABC, abstractmethod
from warnings import warn

import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.model_selection import GridSearchCV

from moabb.analysis import Results
from moabb.datasets.base import BaseDataset
from moabb.evaluations.utils import _convert_sklearn_params_to_optuna
from moabb.paradigms.base import BaseParadigm


log = logging.getLogger(__name__)

# Making the optuna soft dependency
try:
from optuna.integration import OptunaSearchCV

optuna_available = True
except ImportError:
optuna_available = False

if optuna_available:
search_methods = {"grid": GridSearchCV, "optuna": OptunaSearchCV}
else:
search_methods = {"grid": GridSearchCV}


class BaseEvaluation(ABC):
"""Base class that defines necessary operations for an evaluation.
Expand Down Expand Up @@ -53,11 +68,19 @@ class BaseEvaluation(ABC):
Save model after training, for each fold of cross-validation if needed
cache_config: bool, default=None
Configuration for caching of datasets. See :class:`moabb.datasets.base.CacheConfig` for details.
optuna:bool, default=False
If optuna is enable it will change the GridSearch to a RandomizedGridSearch with 15 minutes of cut off time.
This option is compatible with list of entries of type None, bool, int, float and string
time_out: default=60*15
Cut off time for the optuna search expressed in seconds, the default value is 15 minutes.
Only used with optuna equal to True.

Notes
-----
.. versionadded:: 1.1.0
n_splits, save_model, cache_config parameters.
.. versionadded:: 1.1.1
optuna, time_out parameters.
"""

def __init__(
Expand All @@ -77,6 +100,8 @@ def __init__(
n_splits=None,
save_model=False,
cache_config=None,
optuna=False,
time_out=60 * 15,
):
self.random_state = random_state
self.n_jobs = n_jobs
Expand All @@ -88,6 +113,16 @@ def __init__(
self.n_splits = n_splits
self.save_model = save_model
self.cache_config = cache_config
self.optuna = optuna
self.time_out = time_out

if self.optuna and not optuna_available:
raise ImportError("Optuna is not available. Please install it first.")
if (self.time_out != 60 * 15) and not self.optuna:
warn(
"time_out parameter is only used when optuna is enabled. "
"Ignoring time_out parameter."
)
# check paradigm
if not isinstance(paradigm, BaseParadigm):
raise (ValueError("paradigm must be an Paradigm instance"))
Expand Down Expand Up @@ -261,19 +296,27 @@ def is_valid(self, dataset):
"""

def _grid_search(self, param_grid, name, grid_clf, inner_cv):
extra_params = {}
if param_grid is not None:
if name in param_grid:
search = GridSearchCV(
if self.optuna:
search = search_methods["optuna"]
param_grid[name] = _convert_sklearn_params_to_optuna(param_grid[name])
extra_params["timeout"] = self.time_out
else:
search = search_methods["grid"]

search = search(
grid_clf,
param_grid[name],
refit=True,
cv=inner_cv,
n_jobs=self.n_jobs,
scoring=self.paradigm.scoring,
return_train_score=True,
**extra_params,
)
return search

else:
return grid_clf

Expand Down
41 changes: 41 additions & 0 deletions moabb/evaluations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
from sklearn.pipeline import Pipeline


try:
from optuna.distributions import CategoricalDistribution

optuna_available = True
except ImportError:
optuna_available = False


def _check_if_is_keras_model(model):
"""Check if the model is a Keras model.

Expand Down Expand Up @@ -212,3 +220,36 @@ def create_save_path(
return str(path_save)
else:
print("No hdf5_path provided, models will not be saved.")


def _convert_sklearn_params_to_optuna(param_grid: dict) -> dict:
"""
Function to convert the parameter in Optuna format. This function will
create a categorical distribution of values from the list of values
provided in the parameter grid.

Parameters
----------
param_grid:
Dictionary with the parameters to be converted.

Returns
-------
optuna_params: dict
Dictionary with the parameters converted to Optuna format.
"""
if not optuna_available:
raise ImportError(
"Optuna is not available. Please install it optuna " "and optuna-integration."
)
else:
optuna_params = {}
for key, value in param_grid.items():
try:
if isinstance(value, list):
optuna_params[key] = CategoricalDistribution(value)
else:
optuna_params[key] = value
except Exception as e:
raise ValueError(f"Conversion failed for parameter {key}: {e}")
return optuna_params
10 changes: 10 additions & 0 deletions moabb/tests/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ def test_include_exclude(self):
overwrite=True,
)

def test_optuna(self):
res = benchmark(
pipelines=str(self.pp_dir),
evaluations=["WithinSession"],
paradigms=["FakeImageryParadigm"],
overwrite=True,
optuna=True,
)
self.assertEqual(len(res), 40)


if __name__ == "__main__":
unittest.main()
25 changes: 25 additions & 0 deletions moabb/tests/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def setUp(self):
datasets=[dataset],
hdf5_path="res_test",
save_model=True,
optuna=False,
)

def test_mne_labels(self):
Expand Down Expand Up @@ -138,6 +139,30 @@ def test_eval_grid_search(self):
# We should have 9 columns in the results data frame
self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8)

def test_eval_grid_search_optuna(self):
# Test grid search
param_grid = {"C": {"csp__metric": ["euclid", "riemann"]}}
process_pipeline = self.eval.paradigm.make_process_pipelines(dataset)[0]

self.eval.optuna = True

results = [
r
for r in self.eval.evaluate(
dataset,
pipelines,
param_grid=param_grid,
process_pipeline=process_pipeline,
)
]

self.eval.optuna = False

# We should get 4 results, 2 sessions 2 subjects
self.assertEqual(len(results), 4)
# We should have 9 columns in the results data frame
self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8)

def test_within_session_evaluation_save_model(self):
res_test_path = "./res_test"

Expand Down
Loading
Loading