Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jrapin committed Jan 28, 2020
2 parents 14180fe + 5052aa5 commit a081c65
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 102 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ Alternatively, you can clone the repository and run `pip install -e .` from insi

By default, this only installs requirements for the optimization and instrumentation subpackages. If you are also interested in the benchmarking part,
you should install with the `[benchmark]` flag (example: `pip install 'nevergrad[benchmark]'`), and if you also want the test tools, use
the `[all]` flag (example: `pip install -e '.[all]'`)
the `[all]` flag (example: `pip install -e '.[all]'`).


You can join Nevergrad users Facebook group [here](https://www.facebook.com/groups/nevergradusers/)
You can join Nevergrad users Facebook group [here](https://www.facebook.com/groups/nevergradusers/).


## Goals and structure
Expand All @@ -48,7 +48,7 @@ The structure of the package follows its goal, you will therefore find subpackag

The following README is very general, here are links to find more details on:
- [how to perform optimization](docs/optimization.md) using `nevergrad`, including using parallelization and a few recommendation on which algorithm should be used depending on the settings
- [how to parametrize](docs/parametrization.md) your problem so that the optimizers are informed of the problem to solve. This also provides a tool to instantiate a script or non-python code in order into a Python function and be able to tune some of its parameters.
- [how to parametrize](docs/parametrization.md) your problem so that the optimizers are informed of the problem to solve. This also provides a tool to instantiate a script or non-python code into a Python function and be able to tune some of its parameters.
- [how to benchmark](docs/benchmarking.md) all optimizers on various test functions.
- [benchmark results](docs/benchmarks.md) of some standard optimizers an simple test cases.
- examples of [optimization for machine learning](docs/machinelearning.md).
Expand Down
71 changes: 24 additions & 47 deletions nevergrad/benchmark/experiments.py

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions nevergrad/common/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import time
import inspect
import warnings
import itertools
import collections
Expand Down Expand Up @@ -218,3 +219,30 @@ def __iter__(self) -> tp.Iterator[X]:

def __len__(self) -> int:
return len(self._data)


def different_from_defaults(instance: tp.Any, check_mismatches: bool = False) -> tp.Dict[str, tp.Any]:
"""Checks which attributes are different from defaults arguments
Parameters
----------
instance: object
the object to change
check_mismatches: bool
checks that the attributes match the parameters
Note
----
This is convenient for short repr of data structures
"""
defaults = {
x: y.default for x, y in inspect.signature(instance.__class__.__init__).parameters.items() if x not in ["self", "__class__"]
}
if check_mismatches:
diff = set(defaults.keys()).symmetric_difference(instance.__dict__.keys())
if diff: # this is to help during development
raise RuntimeError(f"Mismatch between attributes and arguments of {instance}: {diff}")
else:
defaults = {x: y for x, y in defaults.items() if x in instance.__dict__}
# only print non defaults
return {x: instance.__dict__[x] for x, y in defaults.items() if y != instance.__dict__[x] and not x.startswith("_")}
14 changes: 3 additions & 11 deletions nevergrad/optimization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import time
import pickle
import inspect
import warnings
from pathlib import Path
from numbers import Real
Expand All @@ -14,9 +13,9 @@
from typing import Optional, Tuple, Callable, Any, Dict, List, Union, Deque, Type, Set, TypeVar
import numpy as np
from nevergrad.parametrization import parameter as p
from nevergrad.common import tools as ngtools
from ..common.typetools import ArrayLike as ArrayLike # allows reexport
from ..common.typetools import JobLike, ExecutorLike
from ..common.tools import Sleeper
from ..common.decorators import Registry
from . import utils

Expand Down Expand Up @@ -422,7 +421,7 @@ def minimize(
tmp_runnings: List[Tuple[p.Parameter, JobLike[float]]] = []
tmp_finished: Deque[Tuple[p.Parameter, JobLike[float]]] = deque()
# go
sleeper = Sleeper() # manages waiting time depending on execution time of the jobs
sleeper = ngtools.Sleeper() # manages waiting time depending on execution time of the jobs
remaining_budget = self.budget - self.num_ask
first_iteration = True
# multiobjective hack
Expand Down Expand Up @@ -586,14 +585,7 @@ class ParametrizedFamily(OptimizerFamily):
_optimizer_class: Optional[Type[Optimizer]] = None

def __init__(self) -> None:
defaults = {
x: y.default for x, y in inspect.signature(self.__class__.__init__).parameters.items() if x not in ["self", "__class__"]
}
diff = set(defaults.keys()).symmetric_difference(self.__dict__.keys())
if diff: # this is to help durring development
raise RuntimeError(f"Mismatch between attributes and arguments of ParametrizedFamily: {diff}")
# only print non defaults
different = {x: self.__dict__[x] for x, y in defaults.items() if y != self.__dict__[x] and not x.startswith("_")}
different = ngtools.different_from_defaults(self, check_mismatches=True)
super().__init__(**different)

def config(self) -> tp.Dict[str, tp.Any]:
Expand Down
136 changes: 101 additions & 35 deletions nevergrad/optimization/optimizerlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,19 +902,15 @@ def __init__(self, instrumentation: IntOrParameter, budget: Optional[int] = None
] # noqa: F405
if budget < 12 * num_workers:
self.optims = [ScrHammersleySearch(self.instrumentation, budget, num_workers)] # noqa: F405
self.who_asked: Dict[Tuple[float, ...], List[int]] = defaultdict(list)

def _internal_ask_candidate(self) -> p.Parameter:
optim_index = self._num_ask % len(self.optims)
candidate = self.optims[optim_index].ask()
data = candidate.get_standardized_data(reference=self.instrumentation)
self.who_asked[tuple(data)] += [optim_index]
candidate._meta["optim_index"] = optim_index
return candidate

def _internal_tell_candidate(self, candidate: p.Parameter, value: float) -> None:
tx = tuple(candidate.get_standardized_data(reference=self.instrumentation))
optim_index = self.who_asked[tx][0]
del self.who_asked[tx][0]
optim_index: int = candidate._meta["optim_index"]
self.optims[optim_index].tell(candidate, value)

def _internal_provide_recommendation(self) -> ArrayLike:
Expand Down Expand Up @@ -951,13 +947,11 @@ def intshare(n: int, m: int) -> Tuple[int, ...]:
SQP(self.instrumentation, 1), # noqa: F405
ScrHammersleySearch(self.instrumentation, budget=(budget // len(self.which_optim)) * nw4), # noqa: F405
]
self.who_asked: Dict[Tuple[float, ...], List[int]] = defaultdict(list)

def _internal_ask_candidate(self) -> p.Parameter:
optim_index = self.which_optim[self._num_ask % len(self.which_optim)]
candidate = self.optims[optim_index].ask()
tx = tuple(candidate.get_standardized_data(reference=self.instrumentation))
self.who_asked[tx] += [optim_index]
candidate._meta["optim_index"] = optim_index
return candidate


Expand All @@ -979,7 +973,6 @@ def __init__(self, instrumentation: IntOrParameter, budget: Optional[int] = None
self.optims += [SQP(self.instrumentation, 1)] # noqa: F405
if i > 0:
self.optims[-1].initial_guess = self._rng.normal(0, 1, self.dimension) # type: ignore
self.who_asked: Dict[Tuple[float, ...], List[int]] = defaultdict(list)


@registry.register
Expand All @@ -993,7 +986,6 @@ def __init__(self, instrumentation: IntOrParameter, budget: Optional[int] = None
CMA(self.instrumentation, budget=None, num_workers=num_workers), # share instrumentation and its rng
LhsDE(self.instrumentation, budget=None, num_workers=num_workers),
] # noqa: F405
self.who_asked: Dict[Tuple[float, ...], List[int]] = defaultdict(list)
self.budget_before_choosing = budget // 3
self.best_optim = -1

Expand All @@ -1013,8 +1005,7 @@ def _internal_ask_candidate(self) -> p.Parameter:
self.best_optim = optim_index
optim_index = self.best_optim
candidate = self.optims[optim_index].ask()
tx = tuple(candidate.get_standardized_data(reference=self.instrumentation))
self.who_asked[tx] += [optim_index]
candidate._meta["optim_index"] = optim_index
return candidate


Expand Down Expand Up @@ -1063,6 +1054,33 @@ def __init__(self, instrumentation: IntOrParameter, budget: Optional[int] = None
self.budget_before_choosing = budget // 10


@registry.register
class CMandAS3(ASCMADEthird):
"""Competence map, with algorithm selection in one of the cases (3 CMAs)."""

def __init__(self, instrumentation: IntOrParameter, budget: Optional[int] = None, num_workers: int = 1) -> None:
super().__init__(instrumentation, budget=budget, num_workers=num_workers)
self.optims = [TwoPointsDE(self.instrumentation, budget=None, num_workers=num_workers)] # noqa: F405
assert budget is not None
self.budget_before_choosing = 2 * budget
if budget < 201:
self.optims = [OnePlusOne(self.instrumentation, budget=None, num_workers=num_workers)]
if budget > 50 * self.dimension or num_workers < 30:
if num_workers == 1:
self.optims = [
chainCMAPowell(self.instrumentation, budget=None, num_workers=num_workers), # share instrumentation and its rng
chainCMAPowell(self.instrumentation, budget=None, num_workers=num_workers),
chainCMAPowell(self.instrumentation, budget=None, num_workers=num_workers),
]
else:
self.optims = [
CMA(self.instrumentation, budget=None, num_workers=num_workers), # share instrumentation and its rng
CMA(self.instrumentation, budget=None, num_workers=num_workers),
CMA(self.instrumentation, budget=None, num_workers=num_workers),
]
self.budget_before_choosing = budget // 10


@registry.register
class CMandAS(CMandAS2):
"""Competence map, with algorithm selection in one of the cases (2 CMAs)."""
Expand Down Expand Up @@ -1522,7 +1540,6 @@ class NGO(base.Optimizer):
def __init__(self, instrumentation: IntOrParameter, budget: Optional[int] = None, num_workers: int = 1) -> None:
super().__init__(instrumentation, budget=budget, num_workers=num_workers)
assert budget is not None
self.who_asked: Dict[Tuple[float, ...], List[int]] = defaultdict(list)
descr = self.instrumentation.descriptors
self.has_noise = not (descr.deterministic and descr.deterministic_function)
self.fully_continuous = descr.continuous
Expand Down Expand Up @@ -1567,15 +1584,11 @@ def __init__(self, instrumentation: IntOrParameter, budget: Optional[int] = None
def _internal_ask_candidate(self) -> p.Parameter:
optim_index = 0
candidate = self.optims[optim_index].ask()
data = candidate.get_standardized_data(reference=self.instrumentation)
self.who_asked[tuple(data)] += [optim_index]
candidate._meta["optim_index"] = optim_index
return candidate

def _internal_tell_candidate(self, candidate: p.Parameter, value: float) -> None:
data = candidate.get_standardized_data(reference=self.instrumentation)
tx = tuple(data)
optim_index = self.who_asked[tx][0]
del self.who_asked[tx][0]
optim_index = candidate._meta["optim_index"]
self.optims[optim_index].tell(candidate, value)

def _internal_provide_recommendation(self) -> ArrayLike:
Expand All @@ -1586,26 +1599,79 @@ def _internal_tell_not_asked(self, candidate: p.Parameter, value: float) -> None
raise base.TellNotAskedNotSupportedError


class EMNA_TBPSA(TBPSA):
"""Test-based population-size adaptation with EMNA.
"""

# pylint: disable=too-many-instance-attributes

def __init__(self, instrumentation: IntOrParameter, budget: Optional[int] = None, num_workers: int = 1) -> None:
super().__init__(instrumentation, budget=budget, num_workers=num_workers)
self.sigma = 1
self.mu = self.dimension
self.llambda = 4 * self.dimension
if num_workers is not None:
self.llambda = max(self.llambda, num_workers)
self.current_center: np.ndarray = np.zeros(self.dimension)
self._loss_record: List[float] = []
# population
self._evaluated_population: List[base.utils.Individual] = []

def _internal_provide_recommendation(self) -> ArrayLike:
return self.current_bests["optimistic"].x # Naive version for now

def _internal_tell_candidate(self, candidate: p.Parameter, value: float) -> None:
self._loss_record += [value]
if len(self._loss_record) >= 5 * self.llambda:
first_fifth = self._loss_record[: self.llambda]
last_fifth = self._loss_record[-self.llambda:]
means = [sum(fitnesses) / float(self.llambda) for fitnesses in [first_fifth, last_fifth]]
stds = [np.std(fitnesses) / np.sqrt(self.llambda - 1) for fitnesses in [first_fifth, last_fifth]]
z = (means[0] - means[1]) / (np.sqrt(stds[0] ** 2 + stds[1] ** 2))
if z < 2.0:
self.mu *= 2
else:
self.mu = int(self.mu * 0.84)
if self.mu < self.dimension:
self.mu = self.dimension
self.llambda = 4 * self.mu
if self.num_workers > 1:
self.llambda = max(self.llambda, self.num_workers)
self.mu = self.llambda // 4
self._loss_record = []
data = candidate.get_standardized_data(reference=self.instrumentation)
particle = base.utils.Individual(data)
particle._parameters = np.array([candidate._meta["sigma"]])
particle.value = value
self._evaluated_population.append(particle)
if len(self._evaluated_population) >= self.llambda:
# Sorting the population.
self._evaluated_population.sort(key=lambda p: p.value)
# Computing the new parent.
self.current_center = sum(p.x for p in self._evaluated_population[: self.mu]) / self.mu # type: ignore
# EMNA update
t1 = [(self._evaluated_population[i].x - self.current_center)**2 for i in range(self.mu)]
self.sigma = np.sqrt(sum(t1) / (self.mu))
imp = max(1, (np.log(self.llambda) / 2)**(1 / self.dimension))
if self.num_workers / self.dimension > 16:
self.sigma /= imp
self._evaluated_population = []


@registry.register
class JNGO(NGO):
class Shiva(NGO):
"""Nevergrad optimizer by competence map. You might modify this one for designing youe own competence map."""

def __init__(self, instrumentation: IntOrParameter, budget: Optional[int] = None, num_workers: int = 1) -> None:
super().__init__(instrumentation, budget=budget, num_workers=num_workers)
assert budget is not None
if self.has_noise and self.has_discrete_not_softmax:
self.optims = [DoubleFastGAOptimisticNoisyDiscreteOnePlusOne(self.instrumentation, budget, num_workers)]
if self.has_noise and (self.has_discrete_not_softmax or not self.instrumentation.descriptors.metrizable):
self.optims = [RecombiningPortfolioOptimisticNoisyDiscreteOnePlusOne(self.instrumentation, budget, num_workers)]
else:
if self.has_noise:
self.optims = [TBPSA(self.instrumentation, budget, num_workers)]
else:
if self.has_discrete_not_softmax:
self.optims = [DoubleFastGADiscreteOnePlusOne(self.instrumentation, budget, num_workers)]
if not self.instrumentation.descriptors.metrizable:
if self.dimension < 60:
self.optims = [NGO(self.instrumentation, budget, num_workers)]
else:
if num_workers > budget / 5:
self.optims = [TwoPointsDE(self.instrumentation, budget, num_workers)] # noqa: F405
else:
if num_workers == 1 and budget > 3000:
self.optims = [Powell(self.instrumentation, budget, num_workers)] # noqa: F405
else:
self.optims = [chainCMAwithLHSsqrt(self.instrumentation, budget, num_workers)] # noqa: F405
self.optims = [CMA(self.instrumentation, budget, num_workers)]
else:
self.optims = [NGO(self.instrumentation, budget, num_workers)]
3 changes: 2 additions & 1 deletion nevergrad/optimization/recorded_recommendations.csv
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ CM,1.0082049151,-0.9099785499,-1.025147209,1.2046460074,,,,,,,,,,,,
CMA,1.012515477,-0.9138805701,-1.029555946,1.2098418178,,,,,,,,,,,,
CMandAS,-0.3375952501,-0.5852755939,-0.1149228138,2.2419018641,,,,,,,,,,,,
CMandAS2,1.012515477,-0.9138805701,-1.029555946,1.2098418178,,,,,,,,,,,,
CMandAS3,1.0,0.0,0.0,0.0,,,,,,,,,,,,
CauchyLHSSearch,-0.527971877,1.341890246,2.6790716005,3.5963545262,,,,,,,,,,,,
CauchyOnePlusOne,0.0,0.0,0.0,0.0,,,,,,,,,,,,
CauchyRandomSearch,-0.6941119288,-0.1425497837,-0.4907358842,-0.0426447433,,,,,,,,,,,,
Expand All @@ -56,7 +57,6 @@ HaltonSearch,-0.318639364,-0.7647096738,-0.7063025628,1.0675705239,,,,,,,,,,,,
HaltonSearchPlusMiddlePoint,0.0,0.0,0.0,0.0,,,,,,,,,,,,
HammersleySearch,0.2104283942,-1.1503493804,-0.1397102989,0.8416212336,,,,,,,,,,,,
HammersleySearchPlusMiddlePoint,0.5244005127,-1.1503493804,-0.1397102989,0.8416212336,,,,,,,,,,,,
JNGO,-0.4845350361,-0.2374252885,-0.7949197153,1.4124712099,,,,,,,,,,,,
LBO,-0.09816919,-0.2262437547,-1.6286468329,-0.8183425619,,,,,,,,,,,,
LHSSearch,-0.3978418928,0.827925915,1.2070034191,1.3637174061,,,,,,,,,,,,
LargeHaltonSearch,-67.4489750196,43.0727299295,-25.3347103136,-56.5948821933,,,,,,,,,,,,
Expand Down Expand Up @@ -157,6 +157,7 @@ ScrHaltonSearch,-0.318639364,-1.2206403488,1.7506860713,0.5659488219,,,,,,,,,,,,
ScrHaltonSearchPlusMiddlePoint,-1.1503493804,1.2206403488,-0.8416212336,1.0675705239,,,,,,,,,,,,
ScrHammersleySearch,1.3829941271,-0.318639364,-1.2206403488,1.7506860713,,,,,,,,,,,,
ScrHammersleySearchPlusMiddlePoint,-1.2815515655,0.0,0.4307272993,0.8416212336,,,,,,,,,,,,
Shiva,0.0,-0.3451057176,-0.1327329683,1.9291307781,,,,,,,,,,,,
SmallHaltonSearchPlusMiddlePoint,0.0031863936,0.0076470967,-0.0175068607,0.0056594882,,,,,,,,,,,,
SmallHammersleySearchPlusMiddlePoint,0.0052440051,-0.0115034938,-0.001397103,0.0084162123,,,,,,,,,,,,
SmallScaleRandomSearchPlusMiddlePoint,0.0101251548,-0.0091386915,-0.0102953021,0.0120979645,,,,,,,,,,,,
Expand Down
5 changes: 2 additions & 3 deletions nevergrad/optimization/test_optimizerlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pandas as pd
from bayes_opt.util import acq_max
import nevergrad as ng
from .. import instrumentation as inst
from ..common.typetools import ArrayLike
from ..common import testing
from . import base
Expand Down Expand Up @@ -308,7 +307,7 @@ def test_population_pickle(name: str) -> None:

def test_bo_instrumentation_and_parameters() -> None:
# instrumentation
instrumentation = inst.Instrumentation(ng.p.Choice([True, False]))
instrumentation = ng.p.Instrumentation(ng.p.Choice([True, False]))
with pytest.warns(base.InefficientSettingsWarning):
optlib.QRBO(instrumentation, budget=10)
with pytest.warns(None) as record:
Expand All @@ -334,7 +333,7 @@ def test_chaining() -> None:


def test_instrumentation_optimizer_reproducibility() -> None:
instrumentation = inst.Instrumentation(ng.p.Array(shape=(1,)), y=ng.p.Choice(list(range(100))))
instrumentation = ng.p.Instrumentation(ng.p.Array(shape=(1,)), y=ng.p.Choice(list(range(100))))
instrumentation.random_state.seed(12)
optimizer = optlib.RandomSearch(instrumentation, budget=10)
recom = optimizer.minimize(_square)
Expand Down
2 changes: 0 additions & 2 deletions nevergrad/parametrization/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,9 @@ def _compute_descriptors(self) -> utils.Descriptors:

@property
def descriptors(self) -> utils.Descriptors:
print("Computing descriptors for", self.__class__)
if self._descriptors is None:
self._compute_descriptors()
self._descriptors = self._compute_descriptors()
print("Got", self._descriptors)
return self._descriptors


Expand Down
Loading

0 comments on commit a081c65

Please sign in to comment.