Skip to content

Commit

Permalink
Optuna GridSearch (#630)
Browse files Browse the repository at this point in the history
* Optuna

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Optuna - Categorical Distibution

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Add optuna to dependency and regenerate poetry

* Add optuna to dependency and regenerate poetry

* Add optuna to dependency and regenerate poetry

* Add test on within Session

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Add test on within Session

* Add test on within Session

* ehn: common function with dict

* ehn: moving function to util

* fix: correcting the what news file.

* Add test benchmark and raise an issue if the conversion didn't worked and exposed time_out parameter

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* FIX: italian to eng

* EHN: making optuna optional

* FIX: fixing the workflow files

* FIX: changing the optuna file

* FIX: including optuna for the windows

---------

Signed-off-by: Bru <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: bruAristimunha <[email protected]>
  • Loading branch information
3 people authored Jul 15, 2024
1 parent 81129f5 commit a840f7d
Show file tree
Hide file tree
Showing 11 changed files with 949 additions and 557 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ jobs:

- name: Install dependencies
if: (steps.cached-poetry-dependencies.outputs.cache-hit != 'true')
run: poetry install --no-interaction --no-root --with docs --extras deeplearning
run: poetry install --no-interaction --no-root --with docs --extras deeplearning --extras optuna

- name: Install library
run: poetry install --no-interaction --with docs --extras deeplearning
run: poetry install --no-interaction --with docs --extras deeplearning --extras optuna

- name: Build docs
run: |
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test-devel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ jobs:
if: |
(runner.os != 'Windows') &&
(steps.cached-poetry-dependencies.outputs.cache-hit != 'true')
run: poetry install --no-interaction --no-root --extras deeplearning
run: poetry install --no-interaction --no-root --extras deeplearning --extras optuna

- name: Install library (Linux/OSX)
if: ${{ runner.os != 'Windows' }}
run: poetry install --no-interaction --extras deeplearning
run: poetry install --no-interaction --extras deeplearning --extras optuna

- name: Install library (Windows)
if: ${{ runner.os == 'Windows' }}
run: poetry install --no-interaction
run: poetry install --no-interaction --extras optuna

- name: Run tests
run: |
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ jobs:
if: |
(runner.os != 'Windows') &&
(steps.cached-poetry-dependencies.outputs.cache-hit != 'true')
run: poetry install --no-interaction --no-root --extras deeplearning
run: poetry install --no-interaction --no-root --extras deeplearning --extras optuna

- name: Install library (Linux/OSX)
if: ${{ runner.os != 'Windows' }}
run: poetry install --no-interaction --extras deeplearning
run: poetry install --no-interaction --extras deeplearning --extras optuna

- name: Install library (Windows)
if: ${{ runner.os == 'Windows' }}
run: poetry install --no-interaction
run: poetry install --no-interaction --extras optuna

- name: Run tests
run: |
Expand Down
6 changes: 5 additions & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@ Develop branch

Enhancements
~~~~~~~~~~~~
- Add possibility to use OptunaGridSearch (:gh:`630` by `Igor Carrara`_)
- Add scripts to upload results on PapersWithCode (:gh:`561` by `Pierre Guetschel`_)
- Centralize dataset summary tables in CSV files (:gh:`635` by `Pierre Guetschel`_)
- Add new dataset :class:`moabb.datasets.Liu2024` dataset (:gh:`619` by `Taha Habib`_)
- Increasing the version in the pre-commit config (:gh:`631` by pre-commit bot)



Bugs
~~~~
- Fix caching in the workflows (:gh:`632` by `Pierre Guetschel`_)

API changes
~~~~~~~~~~~
- None
- Include optuna as soft-dependency in the benchmark function and in the base of evaluation (:gh:`630` by `Igor Carrara`_)



Version - 1.1.0 (Stable - PyPi)
---------------------------------
Expand Down
8 changes: 8 additions & 0 deletions moabb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def benchmark( # noqa: C901
exclude_datasets=None,
n_splits=None,
cache_config=None,
optuna=False,
):
"""Run benchmarks for selected pipelines and datasets.
Expand Down Expand Up @@ -102,6 +103,7 @@ def benchmark( # noqa: C901
and exclude_datasets are specified, raise an error.
exclude_datasets: list of str or Dataset object
Datasets to exclude from the benchmark run
optuna: Enable Optuna for the hyperparameter search
Returns
-------
Expand All @@ -110,7 +112,11 @@ def benchmark( # noqa: C901
Notes
-----
.. versionadded:: 1.1.1
Includes the possibility to use Optuna for hyperparameter search.
.. versionadded:: 0.5.0
Create the function to run the benchmark
"""
# set logs
if evaluations is None:
Expand Down Expand Up @@ -182,6 +188,7 @@ def benchmark( # noqa: C901
return_epochs=True,
n_splits=n_splits,
cache_config=cache_config,
optuna=optuna,
)
paradigm_results = context.process(
pipelines=ppl_with_epochs, param_grid=param_grid
Expand All @@ -202,6 +209,7 @@ def benchmark( # noqa: C901
overwrite=overwrite,
n_splits=n_splits,
cache_config=cache_config,
optuna=optuna,
)
paradigm_results = context.process(
pipelines=ppl_with_array, param_grid=param_grid
Expand Down
47 changes: 45 additions & 2 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import logging
from abc import ABC, abstractmethod
from warnings import warn

import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.model_selection import GridSearchCV

from moabb.analysis import Results
from moabb.datasets.base import BaseDataset
from moabb.evaluations.utils import _convert_sklearn_params_to_optuna
from moabb.paradigms.base import BaseParadigm


log = logging.getLogger(__name__)

# Making the optuna soft dependency
try:
from optuna.integration import OptunaSearchCV

optuna_available = True
except ImportError:
optuna_available = False

if optuna_available:
search_methods = {"grid": GridSearchCV, "optuna": OptunaSearchCV}
else:
search_methods = {"grid": GridSearchCV}


class BaseEvaluation(ABC):
"""Base class that defines necessary operations for an evaluation.
Expand Down Expand Up @@ -53,11 +68,19 @@ class BaseEvaluation(ABC):
Save model after training, for each fold of cross-validation if needed
cache_config: bool, default=None
Configuration for caching of datasets. See :class:`moabb.datasets.base.CacheConfig` for details.
optuna:bool, default=False
If optuna is enable it will change the GridSearch to a RandomizedGridSearch with 15 minutes of cut off time.
This option is compatible with list of entries of type None, bool, int, float and string
time_out: default=60*15
Cut off time for the optuna search expressed in seconds, the default value is 15 minutes.
Only used with optuna equal to True.
Notes
-----
.. versionadded:: 1.1.0
n_splits, save_model, cache_config parameters.
.. versionadded:: 1.1.1
optuna, time_out parameters.
"""

def __init__(
Expand All @@ -77,6 +100,8 @@ def __init__(
n_splits=None,
save_model=False,
cache_config=None,
optuna=False,
time_out=60 * 15,
):
self.random_state = random_state
self.n_jobs = n_jobs
Expand All @@ -88,6 +113,16 @@ def __init__(
self.n_splits = n_splits
self.save_model = save_model
self.cache_config = cache_config
self.optuna = optuna
self.time_out = time_out

if self.optuna and not optuna_available:
raise ImportError("Optuna is not available. Please install it first.")
if (self.time_out != 60 * 15) and not self.optuna:
warn(
"time_out parameter is only used when optuna is enabled. "
"Ignoring time_out parameter."
)
# check paradigm
if not isinstance(paradigm, BaseParadigm):
raise (ValueError("paradigm must be an Paradigm instance"))
Expand Down Expand Up @@ -261,19 +296,27 @@ def is_valid(self, dataset):
"""

def _grid_search(self, param_grid, name, grid_clf, inner_cv):
extra_params = {}
if param_grid is not None:
if name in param_grid:
search = GridSearchCV(
if self.optuna:
search = search_methods["optuna"]
param_grid[name] = _convert_sklearn_params_to_optuna(param_grid[name])
extra_params["timeout"] = self.time_out
else:
search = search_methods["grid"]

search = search(
grid_clf,
param_grid[name],
refit=True,
cv=inner_cv,
n_jobs=self.n_jobs,
scoring=self.paradigm.scoring,
return_train_score=True,
**extra_params,
)
return search

else:
return grid_clf

Expand Down
41 changes: 41 additions & 0 deletions moabb/evaluations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
from sklearn.pipeline import Pipeline


try:
from optuna.distributions import CategoricalDistribution

optuna_available = True
except ImportError:
optuna_available = False


def _check_if_is_keras_model(model):
"""Check if the model is a Keras model.
Expand Down Expand Up @@ -212,3 +220,36 @@ def create_save_path(
return str(path_save)
else:
print("No hdf5_path provided, models will not be saved.")


def _convert_sklearn_params_to_optuna(param_grid: dict) -> dict:
"""
Function to convert the parameter in Optuna format. This function will
create a categorical distribution of values from the list of values
provided in the parameter grid.
Parameters
----------
param_grid:
Dictionary with the parameters to be converted.
Returns
-------
optuna_params: dict
Dictionary with the parameters converted to Optuna format.
"""
if not optuna_available:
raise ImportError(
"Optuna is not available. Please install it optuna " "and optuna-integration."
)
else:
optuna_params = {}
for key, value in param_grid.items():
try:
if isinstance(value, list):
optuna_params[key] = CategoricalDistribution(value)
else:
optuna_params[key] = value
except Exception as e:
raise ValueError(f"Conversion failed for parameter {key}: {e}")
return optuna_params
10 changes: 10 additions & 0 deletions moabb/tests/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ def test_include_exclude(self):
overwrite=True,
)

def test_optuna(self):
res = benchmark(
pipelines=str(self.pp_dir),
evaluations=["WithinSession"],
paradigms=["FakeImageryParadigm"],
overwrite=True,
optuna=True,
)
self.assertEqual(len(res), 40)


if __name__ == "__main__":
unittest.main()
25 changes: 25 additions & 0 deletions moabb/tests/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def setUp(self):
datasets=[dataset],
hdf5_path="res_test",
save_model=True,
optuna=False,
)

def test_mne_labels(self):
Expand Down Expand Up @@ -138,6 +139,30 @@ def test_eval_grid_search(self):
# We should have 9 columns in the results data frame
self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8)

def test_eval_grid_search_optuna(self):
# Test grid search
param_grid = {"C": {"csp__metric": ["euclid", "riemann"]}}
process_pipeline = self.eval.paradigm.make_process_pipelines(dataset)[0]

self.eval.optuna = True

results = [
r
for r in self.eval.evaluate(
dataset,
pipelines,
param_grid=param_grid,
process_pipeline=process_pipeline,
)
]

self.eval.optuna = False

# We should get 4 results, 2 sessions 2 subjects
self.assertEqual(len(results), 4)
# We should have 9 columns in the results data frame
self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8)

def test_within_session_evaluation_save_model(self):
res_test_path = "./res_test"

Expand Down
Loading

0 comments on commit a840f7d

Please sign in to comment.