Skip to content

Commit

Permalink
Voronoi differential evolution, first draft (#1523)
Browse files Browse the repository at this point in the history
* firstdraft

* fix

* go

* oncemore

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
teytaud authored Jun 20, 2023
1 parent c84e9a0 commit 62427d8
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 7 deletions.
48 changes: 45 additions & 3 deletions nevergrad/optimization/differentialevolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,22 @@


class Crossover:
def __init__(self, random_state: np.random.RandomState, crossover: tp.Union[str, float]):
def __init__(
self,
random_state: np.random.RandomState,
crossover: tp.Union[str, float],
parameter: tp.Optional[p.Parameter] = None,
):
self.CR = 0.5
self.crossover = crossover
self.random_state = random_state
if isinstance(crossover, float):
self.CR = crossover
elif crossover == "random":
self.CR = self.random_state.uniform(0.0, 1.0)
elif crossover not in ["twopoints", "onepoint", "rotated_twopoints"]:
elif crossover not in ["twopoints", "onepoint", "rotated_twopoints", "voronoi"]:
raise ValueError(f'Unknown crossover "{crossover}"')
self.shape = np.array(parameter.value).shape if parameter is not None else None

def apply(self, donor: np.ndarray, individual: np.ndarray) -> None:
dim = donor.size
Expand All @@ -32,6 +38,8 @@ def apply(self, donor: np.ndarray, individual: np.ndarray) -> None:
return self.rotated_twopoints(donor, individual)
elif self.crossover == "onepoint" and dim >= 3:
return self.onepoint(donor, individual)
elif self.crossover == "voronoi":
return self.voronoi(donor, individual)
else:
return self.variablewise(donor, individual)

Expand Down Expand Up @@ -69,6 +77,29 @@ def rotated_twopoints(self, donor: np.ndarray, individual: np.ndarray) -> None:
assert bounds[1] < donor.size + 1
donor[bounds[0] : bounds[1]] = individual[bounds2[0] : bounds2[1]]

def voronoi(self, donor: np.ndarray, individual: np.ndarray) -> None:
shape = self.shape
if shape is None or len(shape) < 2:
warnings.warn("Voronoi DE needs a shape.")
self.twopoints(donor, individual)
return
local_donor = donor.reshape(shape)
local_individual = individual.reshape(shape)
x1 = np.array([np.random.randint(shape[i]) for i in range(len(shape))])
x2 = np.array([np.random.randint(shape[i]) for i in range(len(shape))])
x3 = np.array([np.random.randint(shape[i]) for i in range(len(shape))])
x4 = np.array([np.random.randint(shape[i]) for i in range(len(shape))])
it = np.nditer(local_donor, flags=["multi_index"])
for _ in it:
d1 = np.linalg.norm(np.array(it.multi_index) - x1)
d2 = np.linalg.norm(np.array(it.multi_index) - x2)
d3 = np.linalg.norm(np.array(it.multi_index) - x3)
d4 = np.linalg.norm(np.array(it.multi_index) - x4)
if min([d1, d2, d3]) > d4:
local_donor[it.multi_index] = local_individual[it.multi_index]
donor[:] = local_donor.flatten()[:]
individual[:] = local_individual.flatten()[:]


class _DE(base.Optimizer):
"""Differential evolution.
Expand Down Expand Up @@ -156,6 +187,13 @@ def _internal_ask_candidate(self) -> p.Parameter:
candidate = self.parametrization.sample()
elif self.sampler is not None:
candidate = self.sampler.ask()
elif self._config.crossover == "voronoi":
new_guy = (
self.scale * self._rng.normal(0, 1, self.dimension)
if len(self.population) > self.llambda / 6
else self.scale * self._rng.normal() * np.ones(self.dimension)
)
candidate = self.parametrization.spawn_child().set_standardized_data(new_guy)
else:
new_guy = self.scale * self._rng.normal(0, 1, self.dimension)
candidate = self.parametrization.spawn_child().set_standardized_data(new_guy)
Expand Down Expand Up @@ -192,7 +230,9 @@ def _internal_ask_candidate(self) -> p.Parameter:
if co == "parametrization":
candidate.recombine(self.parametrization.spawn_child().set_standardized_data(donor))
else:
crossovers = Crossover(self._rng, 1.0 / self.dimension if co == "dimension" else co)
crossovers = Crossover(
self._rng, 1.0 / self.dimension if co == "dimension" else co, self.parametrization
)
crossovers.apply(donor, data)
candidate.set_standardized_data(donor, reference=self.parametrization)
return candidate
Expand Down Expand Up @@ -322,6 +362,7 @@ def __init__(
"dimension",
"random",
"parametrization",
"voronoi",
]
self.initialization = initialization
self.scale = scale
Expand All @@ -337,6 +378,7 @@ def __init__(

DE = DifferentialEvolution().set_name("DE", register=True)
TwoPointsDE = DifferentialEvolution(crossover="twopoints").set_name("TwoPointsDE", register=True)
VoronoiDE = DifferentialEvolution(crossover="voronoi").set_name("VoronoiDE", register=True)
RotatedTwoPointsDE = DifferentialEvolution(crossover="rotated_twopoints").set_name(
"RotatedTwoPointsDE", register=True
)
Expand Down
9 changes: 5 additions & 4 deletions nevergrad/optimization/optimizerlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,10 @@ def __init__(
self._popsize = (
max(num_workers, 4 + int(self._config.popsize_factor * np.log(self.dimension)))
if pop is None
else pop
else max(pop, num_workers)
)
if self._config.elitist:
self._popsize = max(self._popsize, self.num_workers + 1)
# internal attributes
self._to_be_asked: tp.Deque[np.ndarray] = deque()
self._to_be_told: tp.List[p.Parameter] = []
Expand Down Expand Up @@ -1717,10 +1719,9 @@ def __init__(
super().__init__(parametrization, budget=budget, num_workers=num_workers)
self.frequency_ratio = frequency_ratio
self.algorithm = algorithm
elitist = self.dimension < 3
if multivariate_optimizer is None:
multivariate_optimizer = (
ParametrizedCMA(elitist=(self.dimension < 3)) if self.dimension > 1 else OnePlusOne
)
multivariate_optimizer = ParametrizedCMA(elitist=elitist) if self.dimension > 1 else OnePlusOne
self._optim = multivariate_optimizer(
self.parametrization, budget, num_workers
) # share parametrization and its rng
Expand Down
69 changes: 69 additions & 0 deletions nevergrad/optimization/test_optimizerlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import re
import sys
import time
Expand All @@ -20,6 +21,7 @@
import numpy as np
import pandas as pd
from scipy import stats
from scipy.ndimage import gaussian_filter
from bayes_opt.util import acq_max

# from bayes_opt.util import NotUniqueError
Expand Down Expand Up @@ -959,6 +961,73 @@ def test_smoother() -> None:
)


@pytest.mark.parametrize("n", [5, 10, 15, 25, 40]) # type: ignore
@pytest.mark.parametrize("b_per_dim", [1, 10, 20]) # type: ignore
def test_voronoide(n, b_per_dim) -> None:
if n < 25 or b_per_dim < 1 and not os.environ.get("CIRCLECI", False): # Outside CircleCI, only the big.
raise SkipTest("Only big things outside CircleCI.")

list_optims = ["CMA", "DE", "PSO", "RandomSearch", "TwoPointsDE", "OnePlusOne"]
if os.environ.get("CIRCLECI", False) and (n > 10 or n * b_per_dim > 100): # In CircleCI, only the small.
raise SkipTest("Topology optimization too slow in CircleCI")
if os.environ.get("CIRCLECI", False) or (n < 10 or b_per_dim < 20):
list_optims = ["CMA", "PSO", "OnePlusOne"]
if n > 20:
list_optims = ["DE", "TwoPointsDE"]
fails = {}
for o in list_optims:
fails[o] = 0
size = n * n
sqrtsize = n
b = b_per_dim * size # budget
nw = 20 # parallel workers

num_tests = 20
array = ng.p.Array(shape=(n, n), lower=-1.0, upper=1.0)
for idx in range(num_tests):
xa = idx % 3
xb = 2 - xa
xs = 1.5 * (
np.array([float(np.cos(xa * i + xb * j) < 0.0) for i in range(n) for j in range(n)]).reshape(n, n)
- 0.5
)
if (idx // 3) % 2 > 0:
xs = np.transpose(xs)
if (idx // 6) % 2 > 0:
xs = -xs

def f(x):
# return np.linalg.norm(x - xs) + np.linalg.norm(x - gaussian_filter(x, sigma=1))
return (
5.0 * np.sum(np.abs(x - xs) > 0.3) / size
+ 13.0 * np.linalg.norm(x - gaussian_filter(x, sigma=3)) / sqrtsize
)

VoronoiDE = ng.optimizers.VoronoiDE(array, budget=b, num_workers=nw)
vde = f(VoronoiDE.minimize(f).value)
for o in list_optims:
try:
other = ng.optimizers.registry[o](array, budget=b, num_workers=nw)
val = f(other.minimize(f).value)
except:
print(f"crash in {o}")
val = float(1.0e7)
# print(o, val / vde)
if val < vde:
fails[o] += 1
# Remove both lines below. TODO
# ratio = min([(idx + 1 - fails[o]) / (0.001 + fails[o]) for o in list_optims])
# print(f"temporary: {ratio}", idx + 1, fails, f"({n}-{b_per_dim})")
ratio = min([(num_tests - fails[o]) / (0.001 + fails[o]) for o in list_optims])
print(f"VoronoiDE for DO: {ratio}", num_tests, fails, f"({n}-{b_per_dim})")

for o in list_optims:
ratio = 3.0 if "DE" not in "o" else 2.0
assert (
num_tests - fails[o] > ratio * fails[o]
), f"Failure {o}: {fails[o]} / {num_tests} ({n}-{b_per_dim})"


def test_weighted_moo_de() -> None:
for _ in range(1): # Yes this is cheaper.
D = 2
Expand Down

0 comments on commit 62427d8

Please sign in to comment.