Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Population size control #25

Merged
merged 11 commits into from
Sep 1, 2022
16 changes: 8 additions & 8 deletions modcma/modularcmaes.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,19 +457,19 @@ def evaluate_bbob(
"""
# This speeds up the import, this import is quite slow, so import it lazy here
# pylint: disable=import-outside-toplevel
from IOHexperimenter import IOH_function, IOH_logger
import ioh

evals, fopts = np.array([]), np.array([])
if seed:
np.random.seed(seed)
fitness_func = IOH_function(
fid, dim, instance, target_precision=target_precision, suite="BBOB"
fitness_func = ioh.get_problem(
fid, dimension=dim, instance=instance
)

if logging:
data_location = data_folder if os.path.isdir(data_folder) else os.getcwd()
logger = IOH_logger(data_location, f"{label}F{fid}_{dim}D")
fitness_func.add_logger(logger)
logger = ioh.logger.Analyzer(root=data_location, folder_name=f"{label}F{fid}_{dim}D")
fitness_func.attach_logger(logger)

print(
f"Optimizing function {fid} in {dim}D for target "
Expand All @@ -479,11 +479,11 @@ def evaluate_bbob(
for idx in range(iterations):
if idx > 0:
fitness_func.reset()
target = fitness_func.get_target()
target = fitness_func.objective.y + target_precision

optimizer = ModularCMAES(fitness_func, dim, target=target, **kwargs).run()
evals = np.append(evals, fitness_func.evaluations)
fopts = np.append(fopts, fitness_func.best_so_far_precision)
evals = np.append(evals, fitness_func.state.evaluations)
fopts = np.append(fopts, fitness_func.state.current_best_internal.y)

result_string = (
"FCE:\t{:10.8f}\t{:10.4f}\n"
Expand Down
25 changes: 20 additions & 5 deletions modcma/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class Parameters(AnnotatedStruct):
n_generations: int = None
lambda_: int = None
mu: int = None
sigma0: float = 0.5
sigma0: float = 0.2
a_tpa: float = 0.5
b_tpa: float = 0.0
cs: float = None
Expand Down Expand Up @@ -384,8 +384,7 @@ def init_local_restart_parameters(self) -> None:

TODO: check if we can move this to separate object.
"""
if len(self.restarts) == 0:
self.restarts.append(self.t)

self.max_iter = 100 + 50 * (self.d + 3) ** 2 / np.sqrt(self.lambda_)
self.nbin = 10 + int(np.ceil(30 * self.d / self.lambda_))
self.n_stagnation = min(int(120 + (30 * self.d / self.lambda_)), 20000)
Expand Down Expand Up @@ -463,9 +462,9 @@ def init_dynamic_parameters(self) -> None:
Examples of such parameters are the Covariance matrix C and its
eigenvectors and the learning rate sigma.
"""
self.sigma = np.float64(self.sigma0)
self.sigma = np.float64(self.sigma0) * (self.ub[0,0] - self.lb[0,0])
if hasattr(self, "m") or self.x0 is None:
self.m = np.float64(np.random.uniform(self.lb, self.ub, (self.d, 1))
self.m = np.float64(np.random.uniform(self.lb, self.ub, (self.d, 1)))
else:
self.m = np.float64(self.x0.copy())
self.m_old = np.empty((self.d, 1), dtype=np.float64)
Expand Down Expand Up @@ -628,6 +627,10 @@ def adapt_evolution_paths(self) -> None:
def perform_local_restart(self) -> None:
"""Method performing local restart, if a restart strategy is specified."""
if self.local_restart:

if len(self.restarts) == 0:
self.restarts.append(self.t)

if self.local_restart == "IPOP" and self.mu > 512:
self.mu *= self.ipop_factor
self.lambda_ *= self.ipop_factor
Expand Down Expand Up @@ -743,6 +746,7 @@ def save(self, filename: str = "parameters.pkl") -> None:

def record_statistics(self) -> None:
"""Method for recording metadata."""
# if self.local_restart or self.compute_termination_criteria:
self.flat_fitnesses.append(
self.population.f[0] == self.population.f[self.flat_fitness_index]
)
Expand Down Expand Up @@ -841,6 +845,17 @@ def update(self, parameters: dict, reset_default_modules=False):
self.init_selection_parameters()
self.init_adaptation_parameters()
self.init_local_restart_parameters()

def update_popsize(self, lambda_new):
"""Manually control the population size."""
if self.local_restart is not None:
warnings.warn("Modification of population size is disabled when local restart startegies are used")
return
self.lambda_ = lambda_new
self.mu = lambda_new//2
self.init_selection_parameters()
self.init_adaptation_parameters()
self.init_local_restart_parameters()


class BIPOPParameters(AnnotatedStruct):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
IOHexperimenter>=0.2.8
ioh>=0.3.2.8.3
numba>=0.52.0
numpy>=1.18.5
scipy>=1.4.1
41 changes: 41 additions & 0 deletions tests/create_expected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Script to re-initialize the expected function values in expected.json."""

from modcma import ModularCMAES, parameters, utils
import ioh
import numpy as np
import json

def run_bbob_function(module, value, fid):
"""Runs the specified version of ModularCMAES on the bbob-function and returns its output."""
np.random.seed(42)
dim = 2
budget = 20
iid = 1
f = ioh.get_problem(fid, dimension=dim, instance=iid)
p = parameters.Parameters(
dim, budget=budget, **{module: value}
)
ModularCMAES(f, parameters=p).run()
return f.state.current_best_internal.y

def create_expected_dict():
"""Creates the dictionary containing the expected final function values."""
BBOB_2D_PER_MODULE_20_ITER = dict()
for module in parameters.Parameters.__modules__:
m = getattr(parameters.Parameters, module)
if type(m) == utils.AnyOf:
for o in filter(None, m.options):
BBOB_2D_PER_MODULE_20_ITER[f"{module}_{o}"] = np.zeros(24)
for fid in range(1, 25):
BBOB_2D_PER_MODULE_20_ITER[f"{module}_{o}"][fid - 1] = run_bbob_function(module, o, fid)

elif type(m) == utils.InstanceOf:
BBOB_2D_PER_MODULE_20_ITER[f"{module}_{True}"] = np.zeros(24)
for fid in range(1, 25):
BBOB_2D_PER_MODULE_20_ITER[f"{module}_{True}"][fid - 1] = run_bbob_function(module, True, fid)
return BBOB_2D_PER_MODULE_20_ITER

if __name__ == "__main__":
BBOB_2D_PER_MODULE_20_ITER = create_expected_dict()
with open("expected.json", "w") as f:
json.dump(BBOB_2D_PER_MODULE_20_ITER, f)
1 change: 1 addition & 0 deletions tests/expected.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"active_True": [0.5079023784181027, 3994.4755180670527, 12.043295738621183, 4.886212069485996, 0.0, 304.44598143380546, 94.62349805255676, 36.18253759553309, 19.85070276244657, 36645.87647130221, 13315.51512952369, 2683.703847500958, 44.81379842261522, 0.4609332855389666, 9.711117955511472, 1.759139939785059, 1.3933170719608428, 20.10914118263286, 0.4559623215635096, 0.9345626367105797, 1.990063610751629, 8.05905284872558, 21.25735111797754, 10.91219903774687], "elitist_True": [0.6699833162594511, 1227.9721054463228, 16.869910137817453, 19.572392428411835, 0.0, 212.3024228554834, 95.34366236043742, 7.623121487433538, 3.161557968895952, 14355.480702885847, 29507.974482338563, 2683.703847500958, 28.411660356917345, 0.6299432344321189, 7.265400966461772, 14.434537776429282, 1.4118850095554634, 23.630235905875463, 0.4559623215635096, 2.1620882643582515, 1.990963814929167, 7.845446138574841, 7.628611542889807, 14.194368416235939], "orthogonal_True": [0.25534271814624465, 118072.30539981391, 2.0096992931251494, 10.626106984940332, 2.8786810407197265, 270686.4186762697, 95.44019339797313, 12.550406947379894, 0.33525044424635597, 25908.565659852564, 1237306.2027346096, 1702.7335324051267, 70.43902585372217, 0.6976581726574429, 12.52074418264813, 12.81291626426489, 58.44704396496336, 66.6243546212341, 2.667737341562461, 2.158758789228943, 2.7378732590422024, 2.0770051329721895, 7.955480122219026, 14.581514536704073], "sequential_True": [19.5719406285756, 7210188.745246256, 13.681793553365424, 23.313513761722504, 0.0, 52255.4145901049, 93.54091559010935, 184.20999101158142, 384.33208310135626, 23103333.812051296, 32971525.026216447, 2683.703847500958, 214.39433012072365, 57.622398804713555, 25.18091794870781, 20.16251348556511, 1.9775544663974183, 52.20467342507745, 0.4559623215635096, 259.10398673490283, 2.0601778166852416, 11.496305678387635, 8.260780485902643, 43.60153110381705], "threshold_convergence_True": [0.7347847130472607, 3621.5783803371305, 9.440988022081923, 17.588360915885023, 0.0, 387.4010412103078, 95.44019339797313, 24.543790706144733, 11.6043233309905, 419398.0095371446, 3430.404715228205, 926.4455952079577, 28.411660356917345, 0.6893117820736764, 32.31733995435986, 15.751386747152557, 1.4118850095554634, 12.429946026578554, 0.4559623215635096, 3.055334914776176, 0.8924655543998508, 1.9930860112728006, 6.085273695307758, 7.74642446736389], "step_size_adaptation_csa": [0.6699833162594511, 4858.983093703771, 13.575313361038532, 20.615929060719242, 0.0, 314.32588235229133, 95.00863388735908, 7.623121487433538, 2.40985693994124, 8829.9538591471, 75722.58537541116, 2683.703847500958, 28.411660356917345, 0.46048487751008504, 23.76532915356332, 20.16251348556511, 1.4118850095554634, 22.517169808626214, 0.4559623215635096, 1.3074796822154253, 2.011681868741659, 10.719130303012593, 17.153972345540915, 15.81948309019006], "step_size_adaptation_tpa": [0.03767014607534176, 2175.9911995416614, 13.567807397656335, 11.874870433498288, 0.0, 1071.7434978918488, 94.32587149819872, 7.124026768321718, 9.334989121212201, 266.0481801806988, 252465.61185230396, 691.4718777227441, 38.68678171583922, 0.2866364961210628, 11.326830236653606, 15.751386747152557, 3.855740332598697, 32.942464249892616, 0.4559623215635096, 1.9737590640108762, 2.5904060477232917, 0.00110904003110431, 9.77176743348317, 3.9402019732311335], "step_size_adaptation_msr": [0.21287712512225715, 295.54884244593984, 30.872441630043006, 2.438539936320332, 0.0, 291.62475583684414, 95.00863388735908, 7.124026768321718, 9.334989121212201, 1927.7199333449996, 252293.37460247002, 2683.703847500958, 38.68678171583922, 0.1609115280864648, 23.76532915356332, 15.751386747152557, 3.453551937059726, 31.36699937737331, 0.4559623215635096, 2.101822038572848, 0.13402589068051024, 1.7724620177602137, 5.498985829997645, 10.501507178287895], "step_size_adaptation_xnes": [3.685803014042311, 53903.67207319866, 5.424387513620404, 8.476988158726103, 0.0, 70.60878462167778, 93.54091559010935, 3.8843531249961387, 255.5973743918305, 216076.36001404427, 13572.632799955189, 6.099589325005635, 364.92895618522044, 4.2640726765416295, 3.5220968308719414, 15.751386747152557, 4.013689588890096, 20.10958871116135, 0.4559623215635096, 1.9622096635746784, 4.6098062526230805, 0.07660928447563826, 7.97304712403716, 6.572006292778018], "step_size_adaptation_m-xnes": [0.21287712512225715, 2175.9911995416614, 20.43694180566843, 11.874870433498288, 0.0, 34.18405821417568, 95.00863388735908, 7.124026768321718, 9.334989121212201, 1927.7199333449996, 256392.77602054115, 1447.1339175020744, 38.68678171583922, 0.5070656236522031, 23.76532915356332, 15.751386747152557, 2.8882839166644008, 31.36699937737331, 0.4559623215635096, 2.3635090154834093, 0.4682737151651404, 0.43657925301521344, 7.819215103667412, 13.559901346380887], "step_size_adaptation_lp-xnes": [1.7332064445719153, 4419.683662806729, 17.098151479760155, 14.562651567781499, 4.766267911300538, 113.70160820754681, 93.54091559010935, 4.863385150913383, 4.659487885926415, 4387.270026268735, 105506.49942948109, 2512.7600761088174, 26.127030760680967, 0.7712613615349453, 5.842531481068716, 8.687032329156176, 7.279161607717865, 61.20592146441295, 3.744194191668644, 2.4622889158859524, 2.8965324473324405, 12.262312434816277, 17.73216844338432, 3.4960227654881573], "step_size_adaptation_psr": [0.21287712512225715, 1088.5057791931495, 16.688455670695674, 5.5288647474243735, 0.0, 266.943391560345, 95.00863388735908, 7.124026768321718, 9.334989121212201, 1927.7199333449996, 207576.52536598608, 2683.703847500958, 38.68678171583922, 0.36530618515212254, 23.76532915356332, 15.751386747152557, 3.855740332598697, 31.36699937737331, 0.4559623215635096, 3.8792443115816027, 0.3441163839687159, 8.044092332563334, 3.6171907727443897, 3.655695515119866], "mirrored_mirrored": [0.7554264630844963, 30.995928352373653, 10.727147924627236, 28.35505849865668, 0.0, 73.03264789844737, 98.3482403109842, 19.205395254336498, 10.73932209031653, 306.7532919121426, 47432.05526640019, 2683.703847500958, 199.99567611651995, 1.976899305627097, 8.448239046275782, 10.619873784979125, 2.0468365502110966, 3.1760484231995507, 0.4559623215635096, 3.141781349728834, 2.5738103658736153, 2.7415047738254796, 11.184397856398496, 13.905359243715296], "mirrored_mirrored pairwise": [0.03373835261991253, 981.7390996048255, 10.727147924627236, 10.296884748016868, 0.0, 73.03264789844737, 93.54091559010935, 19.205395254336498, 10.73932209031653, 191683.40442372032, 157018.28884094275, 2683.703847500958, 199.99567611651995, 1.976899305627097, 29.38549651637714, 17.063577075453246, 6.758992843888698, 32.39714226445396, 0.4559623215635096, 4.628179175342729, 1.2275961378528137, 2.7415047738254796, 1.3838182113641495, 15.305137607192382], "base_sampler_gaussian": [0.6699833162594511, 4858.983093703771, 13.575313361038532, 20.615929060719242, 0.0, 314.32588235229133, 95.00863388735908, 7.623121487433538, 2.40985693994124, 8829.9538591471, 75722.58537541116, 2683.703847500958, 28.411660356917345, 0.46048487751008504, 23.76532915356332, 20.16251348556511, 1.4118850095554634, 22.517169808626214, 0.4559623215635096, 1.3074796822154253, 2.011681868741659, 10.719130303012593, 17.153972345540915, 15.81948309019006], "base_sampler_sobol": [1.9382794743647251, 109862.7926336296, 13.944304419608123, 11.193683455264821, 0.0, 118.86959643555977, 95.44019339797313, 16.763635893103455, 2.740797812510343, 10861.51714525837, 330815.36309132184, 2.247461255023546, 91.12343223127826, 0.17212699679736815, 15.0833956786349, 14.972757376351383, 4.563680366683454, 13.656010721207792, 0.5370518570202893, 3.7387173359156574, 4.083582516166707, 0.5337698024248725, 12.081102335875753, 16.254121099911366], "base_sampler_halton": [0.05313470065815114, 39526.522004619226, 0.8917398679522033, 14.397195194898835, 0.0, 28.54887919616731, 95.00863388735908, 2.720260852291106, 1.568529005755524, 3817.305420396278, 3037.041404951838, 2438.638535172438, 6.0599988960457, 0.7756151204183547, 10.384034775256028, 12.984958685857494, 2.65445820393611, 21.224991223612605, 0.06917463900662568, 2.6908223459686598, 6.306229527285339, 1.8896710371272878, 3.848846366265107, 23.027563725186226], "weights_option_default": [0.6699833162594511, 4858.983093703771, 13.575313361038532, 20.615929060719242, 0.0, 314.32588235229133, 95.00863388735908, 7.623121487433538, 2.40985693994124, 8829.9538591471, 75722.58537541116, 2683.703847500958, 28.411660356917345, 0.46048487751008504, 23.76532915356332, 20.16251348556511, 1.4118850095554634, 22.517169808626214, 0.4559623215635096, 1.3074796822154253, 2.011681868741659, 10.719130303012593, 17.153972345540915, 15.81948309019006], "weights_option_equal": [0.6983733129352282, 53903.67207319866, 16.490693161011574, 19.983425781178084, 0.0, 1446.9386344210945, 98.3482403109842, 1.3914507176564086, 9.829561857450923, 175618.54116703727, 297974.9703969646, 2683.703847500958, 84.70468494483656, 0.26559027142197433, 11.095075704114704, 12.546429428472235, 4.124331898352289, 8.982949573227335, 0.3552769868031991, 5.125690097532708, 2.0431895662294055, 1.9396308505454316, 6.357120407299142, 16.044865034726843], "weights_option_1/2^lambda": [0.29537552250716453, 14961.853448878097, 24.071007496709008, 15.36361770066693, 0.0, 219.60735320690065, 95.00863388735908, 12.504251100709112, 10.055671647360219, 658.2473391082603, 158260.00841298988, 2683.703847500958, 41.970757571346695, 0.3473139627356672, 18.68616700712326, 12.985714196316431, 4.074463999656349, 14.549555170427439, 0.4559623215635096, 2.5857528075054406, 3.0145388205852823, 1.8648570790235144, 16.61377903494341, 4.6556683869958135], "local_restart_IPOP": [0.6699833162594511, 4858.983093703771, 13.575313361038532, 20.615929060719242, 0.0, 314.32588235229133, 95.00863388735908, 7.623121487433538, 2.40985693994124, 8829.9538591471, 75722.58537541116, 2683.703847500958, 28.411660356917345, 0.46048487751008504, 23.76532915356332, 20.16251348556511, 1.4118850095554634, 22.517169808626214, 0.4559623215635096, 1.3074796822154253, 2.011681868741659, 10.719130303012593, 17.153972345540915, 15.81948309019006], "local_restart_BIPOP": [0.6699833162594511, 4858.983093703771, 13.575313361038532, 20.615929060719242, 0.0, 314.32588235229133, 95.00863388735908, 7.623121487433538, 2.40985693994124, 8829.9538591471, 75722.58537541116, 2683.703847500958, 28.411660356917345, 0.46048487751008504, 23.76532915356332, 20.16251348556511, 1.4118850095554634, 22.517169808626214, 0.4559623215635096, 1.3074796822154253, 2.011681868741659, 10.719130303012593, 17.153972345540915, 15.81948309019006], "bound_correction_saturate": [0.6699833162594511, 4858.983093703771, 13.575313361038532, 20.615929060719242, 0.0, 33.09501100232642, 95.00863388735908, 0.025008654234555786, 48.30190325684145, 7345.169565828286, 39778.9260263472, 99.69985510848147, 81.24353869555779, 0.21972787749128886, 23.76532915356332, 12.874535394768513, 0.9011977159517143, 11.737632986991574, 0.6914037267515631, 1.3074796822154253, 0.8684924540993583, 0.49728904266780494, 7.630280738995454, 13.415576504682278], "bound_correction_unif_resample": [0.07835166888299767, 833.1172241710973, 19.627044667914628, 9.628926397266468, 5.160304518362757, 9.813141025564207, 98.3482403109842, 1.7357869403053625, 1.0497014040510548, 41678.76392402864, 168961.6744824513, 240.67153518558044, 3.835466543661913, 0.6842126813587452, 13.958153917707243, 3.678177569860251, 2.4855552649703743, 31.962561311838957, 0.10223176862010597, 3.162540017099194, 2.0966511126841403, 0.2943473719176201, 5.120439721241175, 12.236723816743432], "bound_correction_COTN": [0.16760938470242162, 4351.211417663171, 6.047980325222314, 21.912573748869775, 3.1244128771674013, 16.390734449771912, 98.3482403109842, 5.83832165250612, 0.14735284098644652, 190.06470116101258, 401.02023186453124, 2683.703847500958, 123.96634855859521, 0.7422687242316163, 4.997713381178798, 9.253215577323422, 3.996596680370051, 0.6749752706464678, 1.593872207016755, 2.732266723801698, 2.2807913444204497, 0.0009542510237096509, 5.825774943800214, 11.291346927128165], "bound_correction_toroidal": [1.4444217300032234, 668.4603295428242, 13.575313361038532, 20.615929060719242, 2.0223088241078173, 36.79824179639371, 95.00863388735908, 13.504258712369403, 55.92343205593169, 52508.86982339246, 7782.459127523065, 1.1032079131590031, 142.56590632579102, 1.8515935221861943, 23.76532915356332, 13.68135734438805, 3.160513609285335, 12.298476154812828, 0.15035979308807157, 1.3074796822154253, 4.6098062526230805, 2.216804022881296, 4.91679849370395, 7.427104342941419], "bound_correction_mirror": [1.1660440655024709, 1651.4439597423843, 8.374450285848416, 16.89173845810526, 5.8032753382871505, 33.55789312438171, 97.45871992360821, 7.7492720281054694, 1.0486600710191158, 25873.673262665754, 900.4477757873151, 1133.3973327215838, 70.13442419670197, 1.1138637242611467, 17.042022539322453, 3.9534263555903633, 1.2140101666427907, 4.101698783919703, 0.21716792040892763, 3.0087533050106425, 1.9277425466961398, 1.8126859185350808, 10.029522452683748, 14.76909640109017]}
Loading