diff --git a/.gitignore b/.gitignore
index ddc38be..9d540ed 100755
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
+examples/experimental
+bbob.py
# Standard ROB excludes
.sync-config.cson
.vim-arsync
diff --git a/CHANGELOG.md b/CHANGELOG.md
index f896be6..2ec91e2 100755
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,28 +1,47 @@
### Work-in-Progress
-- [ ] Make xNES work with all optimizers (currently only GD)
- Implement more strategies
- [ ] Large-scale CMA-ES variants
- [ ] [LM-CMA](https://www.researchgate.net/publication/282612269_LM-CMA_An_alternative_to_L-BFGS_for_large-scale_black_Box_optimization)
- [ ] [VkD-CMA](https://hal.inria.fr/hal-01306551v1/document), [Code](https://gist.github.com/youheiakimoto/2fb26c0ace43c22b8f19c7796e69e108)
- - [ ] [sNES](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) (separable version of xNES)
- - [ ] [ASEBO](https://proceedings.neurips.cc/paper/2019/file/88bade49e98db8790df275fcebb37a13-Paper.pdf)
- [ ] [RBO](http://proceedings.mlr.press/v100/choromanski20a/choromanski20a.pdf)
- Encoding methods - via special reshape wrappers
- [ ] Discrete Cosine Transform
- [ ] Wavelet Based Encoding (van Steenkiste, 2016)
- - [ ] Hypernetworks (Ha - start with simple MLP)
+ - [ ] CNN Hypernetwork (Ha - start with simple MLP)
### [v0.1.0] - [TBD]
##### Added
- Adds a `total_env_steps` counter to both `GymFitness` and `BraxFitness` for easier sample efficiency comparability with RL algorithms.
+- Support for new strategies/genetic algorithms
+ - SAMR-GA (Clune et al., 2008)
+ - GESMR-GA (Kumar et al., 2022)
+ - SNES (Wierstra et al., 2014)
+ - DES (Lange et al., 2022)
+ - Guided ES (Maheswaranathan et al., 2018)
+ - ASEBO (Choromanski et al., 2019)
+ - CR-FM-NES (Nomura & Ono, 2022)
+ - MR15-GA (Rechenberg, 1978)
+- Adds full set of BBOB low-dimensional functions (`BBOBFitness`)
+- Adds 2D visualizer animating sampled points (`BBOBVisualizer`)
+- Adds `Evosax2JAXWrapper` to wrap all evosax strategies
+- Adds Adan optimizer (Xie et al., 2022)
+
+##### Changed
+
+- `ParameterReshaper` can now be directly applied from within the strategy. You simply have to provide a `pholder_params` pytree at strategy instantiation (and no `num_dims`).
+- `FitnessShaper` can also be directly applied from within the strategy. This makes it easier to track the best performing member across generations and addresses issue #32. Simply provide the fitness shaping settings as args to the strategy (`maximize`, `centered_rank`, ...)
+- Removes Brax fitness (use EvoJAX version instead)
+- Add lrate and sigma schedule to strategy instantiation
##### Fixed
- Fixed reward masking in `GymFitness`. Using `jnp.sum(dones) >= 1` for cumulative return computation zeros out the final timestep, which is wrong. That's why there were problems with sparse reward gym environments (e.g. Mountain Car).
+- Fixed PGPE sample indexing.
+- Fixed weight decay. Falsely multiplied by -1 when maximization.
### [v0.0.9] - 15/06/2022
diff --git a/README.md b/README.md
index d440cb2..67477a1 100755
--- a/README.md
+++ b/README.md
@@ -36,14 +36,14 @@ state.best_member, state.best_fitness
| OpenES | [Salimans et al. (2017)](https://arxiv.org/pdf/1703.03864.pdf) | [`OpenES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/open_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/03_cnn_mnist.ipynb)
| PGPE | [Sehnke et al. (2010)](https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=A64D1AE8313A364B814998E9E245B40A?doi=10.1.1.180.7104&rep=rep1&type=pdf) | [`PGPE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pgpe.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/02_mlp_control.ipynb)
| ARS | [Mania et al. (2018)](https://arxiv.org/pdf/1803.07055.pdf) | [`ARS`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ars.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/00_getting_started.ipynb)
-| CMA-ES | [Hansen & Ostermeier (2001)](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cmaartic.pdf) | [`CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
-| Simple Gaussian | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`SimpleES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
-| Simple Genetic | [Such et al. (2017)](https://arxiv.org/abs/1712.06567) | [`SimpleGA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
-| x-NES | [Wierstra et al. (2014)](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) | [`xNES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/xnes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
-| Particle Swarm Optimization | [Kennedy & Eberhart (1995)](https://ieeexplore.ieee.org/document/488968) | [`PSO`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pso.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
-| Differential Evolution | [Storn & Price (1997)](https://www.metabolic-economics.de/pages/seminar_theoretische_biologie_2007/literatur/schaber/Storn1997JGlobOpt11.pdf) | [`DE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/de.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| ESMC | [Merchant et al. (2021)](https://proceedings.mlr.press/v139/merchant21a.html) | [`ESMC`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/esmc.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| Persistent ES | [Vicol et al. (2021)](http://proceedings.mlr.press/v139/vicol21a.html) | [`PersistentES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/persistent_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/04_lrate_pes.ipynb)
-| Population-Based Training | [Jaderberg et al. (2017)](https://arxiv.org/abs/1711.09846) | [`PBT`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pbt.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/05_quadratic_pbt.ipynb)
+| xNES | [Wierstra et al. (2014)](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) | [`XNES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/xnes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| SNES | [Wierstra et al. (2014)](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) | [`SNES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sxnes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| CR-FM-NES | [Nomura & Ono (2022)](https://arxiv.org/abs/2201.11422) | [`CR-FM-NES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/cr_fm_nes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| Guided ES | [Maheswaranathan et al. (2018)](https://arxiv.org/abs/1806.10230) | [`GuidedES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/guided_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| ASEBO | [Choromanski et al. (2019)](https://arxiv.org/abs/1903.04268) | [`GuidedES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/asebo.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| CMA-ES | [Hansen & Ostermeier (2001)](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cmaartic.pdf) | [`CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| Sep-CMA-ES | [Ros & Hansen (2008)](https://hal.inria.fr/inria-00287367/document) | [`Sep_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sep_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| BIPOP-CMA-ES | [Hansen (2009)](https://hal.inria.fr/inria-00382093/document) | [`BIPOP_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/bipop_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/06_restart_es.ipynb)
| IPOP-CMA-ES | [Auer & Hansen (2005)](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cec2005ipopcmaes.pdf) | [`IPOP_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ipop_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/06_restart_es.ipynb)
@@ -52,8 +52,21 @@ state.best_member, state.best_fitness
| MA-ES | [Bayer & Sendhoff (2017)](https://www.honda-ri.de/pubs/pdf/3376.pdf) | [`MA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| LM-MA-ES | [Loshchilov et al. (2017)](https://arxiv.org/pdf/1705.06693.pdf) | [`LM_MA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/lm_ma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| RmES | [Li & Zhang (2017)](https://ieeexplore.ieee.org/document/8080257) | [`RmES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/rm_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| Simple Genetic | [Such et al. (2017)](https://arxiv.org/abs/1712.06567) | [`SimpleGA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| SAMR-GA | [Clune et al. (2008)](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1000187) | [`SAMR_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/samr_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| GESMR-GA | [Kumar et al. (2022)](https://arxiv.org/abs/2204.04817) | [`GESMR_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/gesmr_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| MR15-GA | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`MR15_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/mr15_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| Simple Gaussian | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`SimpleES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| DES | [Lange et al. (2022)](https://arxiv.org/abs/2211.11260) | [`DES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/des.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| Particle Swarm Optimization | [Kennedy & Eberhart (1995)](https://ieeexplore.ieee.org/document/488968) | [`PSO`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pso.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| Differential Evolution | [Storn & Price (1997)](https://www.metabolic-economics.de/pages/seminar_theoretische_biologie_2007/literatur/schaber/Storn1997JGlobOpt11.pdf) | [`DE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/de.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| GLD | [Golovin et al. (2019)](https://arxiv.org/pdf/1911.06317.pdf) | [`GLD`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/gld.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
| Simulated Annealing | [Rasdi Rere et al. (2015)](https://www.sciencedirect.com/science/article/pii/S1877050915035759) | [`SimAnneal`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sim_anneal.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb)
+| Population-Based Training | [Jaderberg et al. (2017)](https://arxiv.org/abs/1711.09846) | [`PBT`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pbt.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/05_quadratic_pbt.ipynb)
+
+
+
+
## Installation ⏳
@@ -78,12 +91,12 @@ In order to use JAX on your accelerators, you can find more details in the [JAX
* 📓 [LRateTune-PES](https://github.com/RobertTLange/evosax/blob/main/examples/04_lrate_pes.ipynb): Persistent ES on meta-learning problem as in [Vicol et al. (2021)](http://proceedings.mlr.press/v139/vicol21a.html).
* 📓 [Quadratic-PBT](https://github.com/RobertTLange/evosax/blob/main/examples/05_quadratic_pbt.ipynb): PBT on toy quadratic problem as in [Jaderberg et al. (2017)](https://arxiv.org/abs/1711.09846).
* 📓 [Restart-Wrappers](https://github.com/RobertTLange/evosax/blob/main/examples/06_restart_es.ipynb): Custom restart wrappers as e.g. used in (B)IPOP-CMA-ES.
-* 📓 [Brax Control](https://github.com/RobertTLange/evosax/blob/main/examples/07_brax_control.ipynb): Evolve Tanh MLPs on Brax tasks using the `evosax` wrapper.
-* 📓 [Indirect Encodings](https://github.com/RobertTLange/evosax/blob/main/examples/08_encodings.ipynb): Find out how many parameters we need to evolve a pendulum controller.
+* 📓 [Brax Control](https://github.com/RobertTLange/evosax/blob/main/examples/07_brax_control.ipynb): Evolve Tanh MLPs on Brax tasks using the `EvoJAX` wrapper.
+
-## Key Selling Points 💵
+## Key Features 💵
-- **Strategy Diversity**: `evosax` implements more than 10 classical and modern neuroevolution strategies. All of them follow the same simple `ask`/`eval` API and come with tailored tools such as the [ClipUp](https://arxiv.org/abs/2008.02387) optimizer, parameter reshaping into PyTrees and fitness shaping (see below).
+- **Strategy Diversity**: `evosax` implements more than 30 classical and modern neuroevolution strategies. All of them follow the same simple `ask`/`eval` API and come with tailored tools such as the [ClipUp](https://arxiv.org/abs/2008.02387) optimizer, parameter reshaping into PyTrees and fitness shaping (see below).
- **Vectorization/Parallelization of `ask`/`tell` Calls**: Both `ask` and `tell` calls can leverage `jit`, `vmap`/`pmap`. This enables vectorized/parallel rollouts of different evolution strategies.
@@ -167,78 +180,79 @@ fit_shaper = FitnessShaper(centered_rank=True,
# Shape the evaluated fitness scores
fit_shaped = fit_shaper.apply(x, fitness)
```
-
-- **Strategy Restart Wrappers**: You can also choose from a set of different restart mechanisms, which will relaunch a strategy (with e.g. new population size) based on termination criteria. Note: For all restart strategies which alter the population size the ask and tell methods will have to be re-compiled at the time of change. Note that all strategies can also be executed without explicitly providing `es_params`. In this case the default parameters will be used.
-
-```Python
-from evosax import CMA_ES
-from evosax.restarts import BIPOP_Restarter
-
-# Define a termination criterion (kwargs - fitness, state, params)
-def std_criterion(fitness, state, params):
- """Restart strategy if fitness std across population is small."""
- return fitness.std() < 0.001
-
-# Instantiate Base CMA-ES & wrap with BIPOP restarts
-# Pass strategy-specific kwargs separately (e.g. elite_ration or opt_name)
-strategy = CMA_ES(num_dims, popsize, elite_ratio)
-re_strategy = BIPOP_Restarter(
- strategy,
- stop_criteria=[std_criterion],
- strategy_kwargs={"elite_ratio": elite_ratio}
- )
-state = re_strategy.initialize(rng)
-
-# ask/tell loop - restarts are automatically handled
-rng, rng_gen, rng_eval = jax.random.split(rng, 3)
-x, state = re_strategy.ask(rng_gen, state)
-fitness = ... # Your population evaluation fct
-state = re_strategy.tell(x, fitness, state)
-```
-
-- **Batch Strategy Rollouts**: *Work-in-progress*. We are currently also working on different ways of incorporating multiple subpopulations with different communication protocols.
-
-```Python
-from evosax.experimental.subpops import BatchStrategy
-
-# Instantiates 5 CMA-ES subpops of 20 members
-strategy = BatchStrategy(
- strategy_name="CMA_ES",
- num_dims=4096,
- popsize=100,
- num_subpops=5,
- strategy_kwargs={"elite_ratio": 0.5},
- communication="best_subpop",
- )
-
-state = strategy.initialize(rng)
-# Ask for evaluation candidates of different subpopulation ES
-x, state = strategy.ask(rng_iter, state)
-fitness = ...
-state = strategy.tell(x, fitness, state)
-```
-
-- **Indirect Encodings**: *Work-in-progress*. ES can struggle with high-dimensional search spaces (e.g. due to harder estimation of covariances). One potential way to alleviate this challenge, is to use indirect parameter encodings in a lower dimensional space. So far we provide JAX-compatible encodings with random projections (Gaussian/Rademacher) and Hypernetworks for MLPs. They act as drop-in replacements for the `ParameterReshaper`:
-
-```Python
-from evosax.experimental.decodings import RandomDecoder, HyperDecoder
-
-# For arbitrary network architectures / search spaces
-num_encoding_dims = 6
-param_reshaper = RandomDecoder(num_encoding_dims, net_params)
-x_shaped = param_reshaper.reshape(x)
-
-# For MLP-based models we also support a HyperNetwork en/decoding
-reshaper = HyperDecoder(
- net_params,
- hypernet_config={
- "num_latent_units": 3, # Latent units per module kernel/bias
- "num_hidden_units": 2, # Hidden dimensionality of a_i^j embedding
- },
- )
-x_shaped = param_reshaper.reshape(x)
-```
-
+
+ Additonal Work-In-Progress
+ **Strategy Restart Wrappers**: *Work-in-progress*. You can also choose from a set of different restart mechanisms, which will relaunch a strategy (with e.g. new population size) based on termination criteria. Note: For all restart strategies which alter the population size the ask and tell methods will have to be re-compiled at the time of change. Note that all strategies can also be executed without explicitly providing `es_params`. In this case the default parameters will be used.
+
+ ```Python
+ from evosax import CMA_ES
+ from evosax.restarts import BIPOP_Restarter
+
+ # Define a termination criterion (kwargs - fitness, state, params)
+ def std_criterion(fitness, state, params):
+ """Restart strategy if fitness std across population is small."""
+ return fitness.std() < 0.001
+
+ # Instantiate Base CMA-ES & wrap with BIPOP restarts
+ # Pass strategy-specific kwargs separately (e.g. elite_ration or opt_name)
+ strategy = CMA_ES(num_dims, popsize, elite_ratio)
+ re_strategy = BIPOP_Restarter(
+ strategy,
+ stop_criteria=[std_criterion],
+ strategy_kwargs={"elite_ratio": elite_ratio}
+ )
+ state = re_strategy.initialize(rng)
+
+ # ask/tell loop - restarts are automatically handled
+ rng, rng_gen, rng_eval = jax.random.split(rng, 3)
+ x, state = re_strategy.ask(rng_gen, state)
+ fitness = ... # Your population evaluation fct
+ state = re_strategy.tell(x, fitness, state)
+ ```
+
+ - **Batch Strategy Rollouts**: *Work-in-progress*. We are currently also working on different ways of incorporating multiple subpopulations with different communication protocols.
+
+ ```Python
+ from evosax.experimental.subpops import BatchStrategy
+
+ # Instantiates 5 CMA-ES subpops of 20 members
+ strategy = BatchStrategy(
+ strategy_name="CMA_ES",
+ num_dims=4096,
+ popsize=100,
+ num_subpops=5,
+ strategy_kwargs={"elite_ratio": 0.5},
+ communication="best_subpop",
+ )
+
+ state = strategy.initialize(rng)
+ # Ask for evaluation candidates of different subpopulation ES
+ x, state = strategy.ask(rng_iter, state)
+ fitness = ...
+ state = strategy.tell(x, fitness, state)
+ ```
+
+ - **Indirect Encodings**: *Work-in-progress*. ES can struggle with high-dimensional search spaces (e.g. due to harder estimation of covariances). One potential way to alleviate this challenge, is to use indirect parameter encodings in a lower dimensional space. So far we provide JAX-compatible encodings with random projections (Gaussian/Rademacher) and Hypernetworks for MLPs. They act as drop-in replacements for the `ParameterReshaper`:
+
+ ```Python
+ from evosax.experimental.decodings import RandomDecoder, HyperDecoder
+
+ # For arbitrary network architectures / search spaces
+ num_encoding_dims = 6
+ param_reshaper = RandomDecoder(num_encoding_dims, net_params)
+ x_shaped = param_reshaper.reshape(x)
+
+ # For MLP-based models we also support a HyperNetwork en/decoding
+ reshaper = HyperDecoder(
+ net_params,
+ hypernet_config={
+ "num_latent_units": 3, # Latent units per module kernel/bias
+ "num_hidden_units": 2, # Hidden dimensionality of a_i^j embedding
+ },
+ )
+ x_shaped = param_reshaper.reshape(x)
+ ```
+
## Resources & Other Great JAX-ES Tools 📝
@@ -260,7 +274,7 @@ If you use `evosax` in your research, please cite it as follows:
}
```
-We acknowledge financial support the [Google TRC](https://sites.research.google/trc/about/) and the Deutsche
+We acknowledge financial support by the [Google TRC](https://sites.research.google/trc/about/) and the Deutsche
Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 ["Science of Intelligence"](https://www.scienceofintelligence.de/) - project number 390523135.
## Development 👷
diff --git a/evosax/__init__.py b/evosax/__init__.py
index 80d52e6..9bc366b 100755
--- a/evosax/__init__.py
+++ b/evosax/__init__.py
@@ -1,4 +1,4 @@
-from .strategy import Strategy
+from .strategy import Strategy, EvoState, EvoParams
from .strategies import (
SimpleGA,
SimpleES,
@@ -9,7 +9,6 @@
PGPE,
PBT,
PersistentES,
- xNES,
ARS,
Sep_CMA_ES,
BIPOP_CMA_ES,
@@ -21,6 +20,16 @@
RmES,
GLD,
SimAnneal,
+ SNES,
+ xNES,
+ ESMC,
+ DES,
+ SAMR_GA,
+ GESMR_GA,
+ GuidedES,
+ ASEBO,
+ CR_FM_NES,
+ MR15_GA,
)
from .utils import FitnessShaper, ParameterReshaper, ESLog
from .networks import NetworkMapper
@@ -37,7 +46,6 @@
"PGPE": PGPE,
"PBT": PBT,
"PersistentES": PersistentES,
- "xNES": xNES,
"ARS": ARS,
"Sep_CMA_ES": Sep_CMA_ES,
"BIPOP_CMA_ES": BIPOP_CMA_ES,
@@ -49,9 +57,27 @@
"RmES": RmES,
"GLD": GLD,
"SimAnneal": SimAnneal,
+ "SNES": SNES,
+ "xNES": xNES,
+ "ESMC": ESMC,
+ "DES": DES,
+ "SAMR_GA": SAMR_GA,
+ "GESMR_GA": GESMR_GA,
+ "GuidedES": GuidedES,
+ "ASEBO": ASEBO,
+ "CR_FM_NES": CR_FM_NES,
+ "MR15_GA": MR15_GA,
}
__all__ = [
+ "Strategies",
+ "EvoState",
+ "EvoParams",
+ "FitnessShaper",
+ "ParameterReshaper",
+ "ESLog",
+ "NetworkMapper",
+ "ProblemMapper",
"Strategy",
"SimpleGA",
"SimpleES",
@@ -62,7 +88,6 @@
"PGPE",
"PBT",
"PersistentES",
- "xNES",
"ARS",
"Sep_CMA_ES",
"BIPOP_CMA_ES",
@@ -74,10 +99,14 @@
"RmES",
"GLD",
"SimAnneal",
- "Strategies",
- "FitnessShaper",
- "ParameterReshaper",
- "ESLog",
- "NetworkMapper",
- "ProblemMapper",
+ "SNES",
+ "xNES",
+ "ESMC",
+ "DES",
+ "SAMR_GA",
+ "GESMR_GA",
+ "GuidedES",
+ "ASEBO",
+ "CR_FM_NES",
+ "MR15_GA",
]
diff --git a/evosax/experimental/decodings/decoder.py b/evosax/experimental/decodings/decoder.py
index f23cb5c..9a15017 100644
--- a/evosax/experimental/decodings/decoder.py
+++ b/evosax/experimental/decodings/decoder.py
@@ -8,13 +8,11 @@ def __init__(
self,
num_encoding_dims: int,
placeholder_params: Union[chex.ArrayTree, chex.Array],
- identity: bool = False,
n_devices: Optional[int] = None,
):
self.num_encoding_dims = num_encoding_dims
self.total_params = num_encoding_dims
self.placeholder_params = placeholder_params
- self.identity = identity
if n_devices is None:
self.n_devices = jax.local_device_count()
else:
diff --git a/evosax/experimental/decodings/hyper.py b/evosax/experimental/decodings/hyper.py
index 08feccf..716f71e 100644
--- a/evosax/experimental/decodings/hyper.py
+++ b/evosax/experimental/decodings/hyper.py
@@ -5,8 +5,8 @@
from flax.core import unfreeze
from typing import Union, Optional
from .decoder import Decoder
-from ...utils import ParameterReshaper
-from ...networks import HyperNetworkMLP
+from .hyper_networks import HyperNetworkMLP
+from ...utils import ParameterReshaper, ravel_pytree
class HyperDecoder(Decoder):
@@ -39,14 +39,13 @@ def __init__(
super().__init__(
hyper_reshaper.total_params,
placeholder_params,
- identity,
n_devices,
)
self.hyper_reshaper = hyper_reshaper
self.vmap_dict = self.hyper_reshaper.vmap_dict
def reshape(self, x: chex.Array) -> chex.ArrayTree:
- """Perform reshaping for random projection case."""
+ """Perform reshaping for hypernetwork case."""
# 0. Reshape genome into params for hypernetwork
x_params = self.hyper_reshaper.reshape(x)
# 1. Project parameters to raw dimensionality using hypernetwork
@@ -54,9 +53,17 @@ def reshape(self, x: chex.Array) -> chex.ArrayTree:
return hyper_x
def reshape_single(self, x: chex.Array) -> chex.ArrayTree:
- """Reshape a single flat vector using random projection matrix."""
+ """Reshape a single flat vector using hypernetwork."""
# 0. Reshape genome into params for hypernetwork
x_params = self.hyper_reshaper.reshape_single(x)
# 1. Project parameters to raw dimensionality using hypernetwork
hyper_x = jax.jit(self.hyper_network.apply)(x_params)
return hyper_x
+
+ def flatten(self, x: chex.ArrayTree) -> chex.Array:
+ """Reshaping pytree parameters into flat array."""
+ return jax.vmap(ravel_pytree)(x)
+
+ def flatten_single(self, x: chex.ArrayTree) -> chex.Array:
+ """Reshaping pytree parameters into flat array."""
+ return ravel_pytree(x)
diff --git a/evosax/networks/hyper_networks.py b/evosax/experimental/decodings/hyper_networks.py
similarity index 100%
rename from evosax/networks/hyper_networks.py
rename to evosax/experimental/decodings/hyper_networks.py
diff --git a/evosax/experimental/decodings/random.py b/evosax/experimental/decodings/random.py
index 439124a..9ba17cd 100644
--- a/evosax/experimental/decodings/random.py
+++ b/evosax/experimental/decodings/random.py
@@ -12,17 +12,14 @@ def __init__(
placeholder_params: Union[chex.ArrayTree, chex.Array],
rng: chex.PRNGKey = jax.random.PRNGKey(0),
rademacher: bool = False,
- identity: bool = False,
n_devices: Optional[int] = None,
):
"""Random Projection Decoder (Gaussian/Rademacher random matrix)."""
- super().__init__(
- num_encoding_dims, placeholder_params, identity, n_devices
- )
+ super().__init__(num_encoding_dims, placeholder_params, n_devices)
self.rademacher = rademacher
# Instantiate base reshaper class
self.base_reshaper = ParameterReshaper(
- placeholder_params, identity, n_devices, verbose=False
+ placeholder_params, n_devices, verbose=False
)
self.vmap_dict = self.base_reshaper.vmap_dict
@@ -35,6 +32,10 @@ def __init__(
self.project_matrix = jax.random.rademacher(
rng, (self.num_encoding_dims, self.base_reshaper.total_params)
)
+ print(
+ "RandomDecoder: Encoding parameters to optimize -"
+ f" {num_encoding_dims}"
+ )
def reshape(self, x: chex.Array) -> chex.ArrayTree:
"""Perform reshaping for random projection case."""
diff --git a/evosax/networks/__init__.py b/evosax/networks/__init__.py
index 1d9e6a5..4f77b1c 100644
--- a/evosax/networks/__init__.py
+++ b/evosax/networks/__init__.py
@@ -1,7 +1,6 @@
from .mlp import MLP
from .cnn import CNN, All_CNN_C
from .lstm import LSTM
-from .hyper_networks import HyperNetworkMLP
# Helper that returns model based on string name
@@ -18,5 +17,4 @@
"All_CNN_C",
"LSTM",
"NetworkMapper",
- "HyperNetworkMLP",
]
diff --git a/evosax/problems/__init__.py b/evosax/problems/__init__.py
index 8b3fdb6..120ce38 100755
--- a/evosax/problems/__init__.py
+++ b/evosax/problems/__init__.py
@@ -1,22 +1,19 @@
-from .control_brax import BraxFitness
-from .control_gym import GymFitness
+from .control_gym import GymnaxFitness
from .vision import VisionFitness
-from .classic import ClassicFitness
+from .bbob import BBOBFitness
from .sequence import SequenceFitness
ProblemMapper = {
- "Gym": GymFitness,
- "Brax": BraxFitness,
+ "Gymnax": GymnaxFitness,
"Vision": VisionFitness,
- "Classic": ClassicFitness,
+ "BBOB": BBOBFitness,
"Sequence": SequenceFitness,
}
__all__ = [
- "BraxFitness",
- "GymFitness",
+ "GymnaxFitness",
"VisionFitness",
- "ClassicFitness",
+ "BBOBFitness",
"SequenceFitness",
"ProblemMapper",
]
diff --git a/evosax/problems/classic.py b/evosax/problems/classic.py
deleted file mode 100755
index a85d294..0000000
--- a/evosax/problems/classic.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import jax
-import jax.numpy as jnp
-import chex
-from functools import partial
-
-
-class ClassicFitness(object):
- def __init__(
- self,
- fct_name: str = "rosenbrock",
- num_dims: int = 2,
- num_rollouts: int = 1,
- noise_std: float = 0.0,
- ):
- self.fct_name = fct_name
- self.num_dims = num_dims
- self.num_rollouts = num_rollouts
- # Optional - add Gaussian noise to evaluation fitness
- self.noise_std = noise_std
- assert self.num_dims >= 2
-
- # Use default settings for classic BBOB evaluation functions
- if self.fct_name == "quadratic":
- self.eval = jax.vmap(quadratic_d_dim, 0)
- elif self.fct_name == "rosenbrock":
- fn = partial(rosenbrock_d_dim, params={"a": 1, "b": 100})
- self.eval = jax.vmap(fn, 0)
- elif self.fct_name == "ackley":
- fn = partial(
- ackley_d_dim, params={"c": 20, "d": 0.2, "e": 2 * jnp.pi}
- )
- self.eval = jax.vmap(fn, 0)
- elif self.fct_name == "griewank":
- self.eval = jax.vmap(griewank_d_dim, 0)
- elif self.fct_name == "rastrigin":
- fn = partial(rastrigin_d_dim, params={"f": 10})
- self.eval = jax.vmap(fn, 0)
- elif self.fct_name == "schwefel":
- self.eval = jax.vmap(schwefel_d_dim, 0)
- elif self.fct_name == "himmelblau":
- assert self.num_dims == 2
- self.eval = jax.vmap(himmelblau_2_dim, 0)
- elif self.fct_name == "six-hump":
- assert self.num_dims == 2
- self.eval = jax.vmap(six_hump_camel_2_dim, 0)
- else:
- raise ValueError("Please provide a valid problem name.")
-
- @partial(jax.jit, static_argnums=(0,))
- def rollout(
- self, rng_input: chex.PRNGKey, eval_params: chex.Array
- ) -> chex.Array:
- """Batch evaluate the proposal points."""
- fitness = self.eval(eval_params).reshape(eval_params.shape[0], 1)
- noise = self.noise_std * jax.random.normal(
- rng_input, (eval_params.shape[0], self.num_rollouts)
- )
- return (fitness + noise).squeeze()
-
-
-def himmelblau_2_dim(x: chex.Array) -> chex.Array:
- """
- 2-dim. Himmelblau function.
- f(x*)=0 - Minima at [3, 2], [-2.81, 3.13],
- [-3.78, -3.28], [3.58, -1.85]
- """
- return (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 - 7) ** 2
-
-
-def six_hump_camel_2_dim(x: chex.Array) -> chex.Array:
- """
- 2-dim. 6-Hump Camel function.
- f(x*)=-1.0316 - Minimum at [0.0898, -0.7126], [-0.0898, 0.7126]
- """
- p1 = (4 - 2.1 * x[0] ** 2 + x[0] ** 4 / 3) * x[0] ** 2
- p2 = x[0] * x[1]
- p3 = (-4 + 4 * x[1] ** 2) * x[1] ** 2
- return p1 + p2 + p3
-
-
-def quadratic_d_dim(x: chex.Array) -> chex.Array:
- """
- Simple D-dim. quadratic function.
- f(x*)=0 - Minimum at [0.]ˆd
- """
- return jnp.sum(jnp.square(x))
-
-
-def rosenbrock_d_dim(x: chex.Array, params: dict) -> chex.Array:
- """
- D-Dim. Rosenbrock function. x_i ∈ [-32.768, 32.768] or x_i ∈ [-5, 10]
- f(x*)=0 - Minumum at x*=a
- """
- x_i, x_sq, x_p = x[:-1], x[:-1] ** 2, x[1:]
- return jnp.sum((params["a"] - x_i) ** 2 + params["b"] * (x_p - x_sq) ** 2)
-
-
-def ackley_d_dim(x: chex.Array, params: dict) -> chex.Array:
- """
- D-Dim. Ackley function. x_i ∈ [-32.768, 32.768]
- f(x*)=0 - Minimum at x*=[0,...,0]
- """
- return (
- -params["c"] * jnp.exp(-params["d"] * jnp.sqrt(jnp.mean(x ** 2)))
- - jnp.exp(jnp.mean(jnp.cos(params["e"] * x)))
- + params["c"]
- + jnp.exp(1)
- )
-
-
-def griewank_d_dim(x: chex.Array) -> chex.Array:
- """
- D-Dim. Griewank function. x_i ∈ [-600, 600]
- f(x*)=0 - Minimum at x*=[0,...,0]
- """
- return (
- jnp.sum(x ** 2 / 4000)
- - jnp.prod(jnp.cos(x / jnp.sqrt(jnp.arange(1, x.shape[0] + 1))))
- + 1
- )
-
-
-def rastrigin_d_dim(x: chex.Array, params: dict) -> chex.Array:
- """
- D-Dim. Rastrigin function. x_i ∈ [-5.12, 5.12]
- f(x*)=0 - Minimum at x*=[0,...,0]
- """
- return params["f"] * x.shape[0] + jnp.sum(
- x ** 2 - params["f"] * jnp.cos(2 * jnp.pi * x)
- )
-
-
-def schwefel_d_dim(x: chex.Array) -> chex.Array:
- """
- D-Dim. Schwefel function. x_i ∈ [-500, 500]
- f(x*)=0 - Minimum at x*=[420.9687,...,420.9687]
- """
- return 418.9829 * x.shape[0] - jnp.sum(
- x * jnp.sin(jnp.sqrt(jnp.absolute(x)))
- )
diff --git a/evosax/problems/control_brax.py b/evosax/problems/control_brax.py
deleted file mode 100644
index e99fa96..0000000
--- a/evosax/problems/control_brax.py
+++ /dev/null
@@ -1,233 +0,0 @@
-import jax
-import jax.numpy as jnp
-import chex
-from typing import Optional
-from .obs_norm import ObsNormalizer
-
-
-class BraxFitness(object):
- def __init__(
- self,
- env_name: str = "ant",
- num_env_steps: int = 1000,
- num_rollouts: int = 16,
- legacy_spring: bool = True,
- normalize: bool = False,
- modify_dict: dict = {"torso_mass": 15},
- test: bool = False,
- n_devices: Optional[int] = None,
- ):
- try:
- from brax import envs
- except ImportError:
- raise ImportError(
- "You need to install `brax` to use its fitness rollouts."
- )
- self.env_name = env_name
- self.num_env_steps = num_env_steps
- self.num_rollouts = num_rollouts
- self.steps_per_member = num_env_steps * num_rollouts
- self.test = test
-
- if self.env_name in [
- "ant",
- "halfcheetah",
- "hopper",
- "humanoid",
- "reacher",
- "walker2d",
- "fetch",
- "grasp",
- "ur5e",
- ]:
- # Define the RL environment & network forward fucntion
- self.env = envs.create(
- env_name=self.env_name,
- episode_length=num_env_steps,
- legacy_spring=legacy_spring,
- )
- elif self.env_name == "modified-ant":
- from .modified_ant import create_modified_ant_env
-
- self.env = create_modified_ant_env(modify_dict)
-
- self.action_shape = self.env.action_size
- self.input_shape = (self.env.observation_size,)
- self.obs_normalizer = ObsNormalizer(
- self.input_shape, dummy=not normalize
- )
- self.obs_params = self.obs_normalizer.get_init_params()
- if n_devices is None:
- self.n_devices = jax.local_device_count()
- else:
- self.n_devices = n_devices
-
- # Keep track of total steps executed in environment
- self.total_env_steps = 0
-
- def set_apply_fn(self, map_dict, network_apply, carry_init=None):
- """Set the network forward function."""
- self.network = network_apply
- # Set rollout function based on model architecture
- if carry_init is not None:
- self.single_rollout = self.rollout_rnn
- self.carry_init = carry_init
- else:
- self.single_rollout = self.rollout_ffw
-
- # vmap over stochastic evaluations
- self.rollout_repeats = jax.vmap(self.single_rollout, in_axes=(0, None))
- self.rollout_pop = jax.vmap(
- self.rollout_repeats, in_axes=(None, map_dict)
- )
- # pmap over popmembers if > 1 device is available - otherwise pmap
- if self.n_devices > 1:
- self.rollout_map = self.rollout_pmap
- print(
- f"BraxFitness: {self.n_devices} devices detected. Please make"
- " sure that the ES population size divides evenly across the"
- " number of devices to pmap/parallelize over."
- )
- else:
- self.rollout_map = self.rollout_pop
-
- def rollout_pmap(self, rng_input, policy_params):
- """Parallelize rollout across devices. Split keys/reshape correctly."""
- keys_pmap = jnp.tile(rng_input, (self.n_devices, 1, 1))
- rew_dev, obs_dev, masks_dev = jax.pmap(self.rollout_pop)(
- keys_pmap, policy_params
- )
- rew_re = rew_dev.reshape(-1, self.num_rollouts)
- obs_re = obs_dev.reshape(
- -1, self.num_rollouts, self.num_env_steps, self.env.observation_size
- )
- masks_re = masks_dev.reshape(
- -1, self.num_rollouts, self.num_env_steps, 1
- )
- return rew_re, obs_re, masks_re
-
- def rollout(self, rng_input, policy_params):
- """Placeholder fn call for rolling out a population for multi-evals."""
- rng_pop = jax.random.split(rng_input, self.num_rollouts)
- scores, all_obs, masks = jax.jit(self.rollout_map)(
- rng_pop, policy_params
- )
- # Update normalization parameters if train case!
- if not self.test:
- obs_re = all_obs.reshape(
- self.num_env_steps, -1, self.input_shape[0]
- )
- masks_re = masks.reshape(self.num_env_steps, -1)
- self.obs_params = self.obs_normalizer.update_normalization_params(
- obs_buffer=obs_re,
- obs_mask=masks_re,
- obs_params=self.obs_params,
- )
-
- # obs_steps = self.obs_params[0]
- # running_mean, running_var = jnp.split(self.obs_params[1:], 2)
- # print(
- # float(scores.mean()),
- # float(masks.mean()),
- # obs_steps,
- # running_mean.mean(),
- # running_var.mean() / (obs_steps + 1),
- # )
-
- # Update total step counter using only transitions before termination
- self.total_env_steps += masks_re.sum()
- return scores
-
- def rollout_ffw(
- self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree
- ) -> chex.Array:
- """Rollout a jitted brax episode with lax.scan for a feedforward policy."""
- # Reset the environment
- rng, rng_reset = jax.random.split(rng_input)
- state = self.env.reset(rng_reset)
-
- def policy_step(state_input, tmp):
- """lax.scan compatible step transition in jax env."""
- state, policy_params, rng, cum_reward, valid_mask = state_input
- rng, rng_net = jax.random.split(rng)
- org_obs = state.obs
- norm_obs = self.obs_normalizer.normalize_obs(
- org_obs, self.obs_params
- )
- action = self.network(policy_params, norm_obs, rng=rng_net)
- next_s = self.env.step(state, action)
- new_cum_reward = cum_reward + next_s.reward * valid_mask
- new_valid_mask = valid_mask * (1 - next_s.done.ravel())
- carry = [next_s, policy_params, rng, new_cum_reward, new_valid_mask]
- return carry, [new_valid_mask, org_obs]
-
- # Scan over episode step loop
- carry_out, scan_out = jax.lax.scan(
- policy_step,
- [state, policy_params, rng, jnp.array([0.0]), jnp.array([1.0])],
- (),
- self.num_env_steps,
- )
- # Return masked sum of rewards accumulated by agent in episode
- ep_mask, all_obs = scan_out[0], scan_out[1]
- cum_return = carry_out[-2].squeeze()
- return cum_return, all_obs, ep_mask
-
- def rollout_rnn(
- self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree
- ) -> chex.Array:
- """Rollout a jitted episode with lax.scan for a recurrent policy."""
- # Reset the environment
- rng, rng_reset = jax.random.split(rng_input)
- state = self.env.reset(rng_reset)
- hidden = self.carry_init()
-
- def policy_step(state_input, tmp):
- """lax.scan compatible step transition in jax env."""
- (
- state,
- policy_params,
- rng,
- hidden,
- cum_reward,
- valid_mask,
- ) = state_input
- rng, rng_net = jax.random.split(rng)
- org_obs = state.obs
- norm_obs = self.obs_normalizer.normalize_obs(
- state.obs, self.obs_params
- )
- hidden, action = self.network(
- policy_params, norm_obs, hidden, rng_net
- )
- next_s = self.env.step(state, action)
- new_cum_reward = cum_reward + next_s.reward * valid_mask
- new_valid_mask = valid_mask * (1 - next_s.done.ravel())
- carry = [
- next_s,
- policy_params,
- rng,
- hidden,
- new_cum_reward,
- new_valid_mask,
- ]
- return carry, [new_valid_mask, org_obs]
-
- # Scan over episode step loop
- carry_out, scan_out = jax.lax.scan(
- policy_step,
- [
- state,
- policy_params,
- rng,
- hidden,
- jnp.array([0.0]),
- jnp.array([1.0]),
- ],
- (),
- self.num_env_steps,
- )
- # Return masked sum of rewards accumulated by agent in episode
- ep_mask, all_obs = scan_out[0], scan_out[1]
- cum_return = carry_out[-2].squeeze()
- return cum_return, all_obs, ep_mask
diff --git a/evosax/problems/control_gym.py b/evosax/problems/control_gym.py
index 323eaad..5bc6583 100644
--- a/evosax/problems/control_gym.py
+++ b/evosax/problems/control_gym.py
@@ -1,11 +1,10 @@
import jax
import jax.numpy as jnp
-from functools import partial
from typing import Optional
import chex
-class GymFitness(object):
+class GymnaxFitness(object):
def __init__(
self,
env_name: str = "CartPole-v1",
@@ -47,7 +46,7 @@ def __init__(
# Keep track of total steps executed in environment
self.total_env_steps = 0
- def set_apply_fn(self, map_dict, network_apply, carry_init=None):
+ def set_apply_fn(self, network_apply, carry_init=None):
"""Set the network forward function."""
self.network = network_apply
# Set rollout function based on model architecture
@@ -57,9 +56,7 @@ def set_apply_fn(self, map_dict, network_apply, carry_init=None):
else:
self.single_rollout = self.rollout_ffw
self.rollout_repeats = jax.vmap(self.single_rollout, in_axes=(0, None))
- self.rollout_pop = jax.vmap(
- self.rollout_repeats, in_axes=(None, map_dict)
- )
+ self.rollout_pop = jax.vmap(self.rollout_repeats, in_axes=(None, 0))
# pmap over popmembers if > 1 device is available - otherwise pmap
if self.n_devices > 1:
self.rollout_map = self.rollout_pmap
diff --git a/evosax/problems/modified_ant.py b/evosax/problems/modified_ant.py
deleted file mode 100644
index 8fabe72..0000000
--- a/evosax/problems/modified_ant.py
+++ /dev/null
@@ -1,423 +0,0 @@
-"""Trains an ant to run in the +x direction.
-Adapted from https://raw.githubusercontent.com/google/brax/main/brax/envs/ant.py
-Increases mass of main torso body from 10 to 15.
-"""
-
-import brax
-from brax import jumpy as jp
-from brax.envs import env
-from brax.envs import wrappers
-from typing import Optional
-
-
-class ModifiedAnt(env.Env):
- """Trains an ant to run in the +x direction."""
-
- def __init__(self, config, **kwargs):
- super().__init__(config=config, **kwargs)
-
- def reset(self, rng: jp.ndarray) -> env.State:
- """Resets the environment to an initial state."""
- rng, rng1, rng2 = jp.random_split(rng, 3)
- qpos = self.sys.default_angle() + jp.random_uniform(
- rng1, (self.sys.num_joint_dof,), -0.1, 0.1
- )
- qvel = jp.random_uniform(rng2, (self.sys.num_joint_dof,), -0.1, 0.1)
- qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
- info = self.sys.info(qp)
- obs = self._get_obs(qp, info)
- reward, done, zero = jp.zeros(3)
- metrics = {
- "reward_ctrl_cost": zero,
- "reward_contact_cost": zero,
- "reward_forward": zero,
- "reward_survive": zero,
- }
- return env.State(qp, obs, reward, done, metrics)
-
- def step(self, state: env.State, action: jp.ndarray) -> env.State:
- """Run one timestep of the environment's dynamics."""
- qp, info = self.sys.step(state.qp, action)
- obs = self._get_obs(qp, info)
-
- x_before = state.qp.pos[0, 0]
- x_after = qp.pos[0, 0]
- forward_reward = (x_after - x_before) / self.sys.config.dt
- ctrl_cost = 0.5 * jp.sum(jp.square(action))
- contact_cost = (
- 0.5 * 1e-3 * jp.sum(jp.square(jp.clip(info.contact.vel, -1, 1)))
- )
- survive_reward = jp.float32(1)
- reward = forward_reward - ctrl_cost - contact_cost + survive_reward
-
- done = jp.where(qp.pos[0, 2] < 0.2, x=jp.float32(1), y=jp.float32(0))
- done = jp.where(qp.pos[0, 2] > 1.0, x=jp.float32(1), y=done)
- state.metrics.update(
- reward_ctrl_cost=ctrl_cost,
- reward_contact_cost=contact_cost,
- reward_forward=forward_reward,
- reward_survive=survive_reward,
- )
-
- return state.replace(qp=qp, obs=obs, reward=reward, done=done)
-
- def _get_obs(self, qp: brax.QP, info: brax.Info) -> jp.ndarray:
- """Observe ant body position and velocities."""
- # some pre-processing to pull joint angles and velocities
- (joint_angle,), (joint_vel,) = self.sys.joints[0].angle_vel(qp)
-
- # qpos:
- # Z of the torso (1,)
- # orientation of the torso as quaternion (4,)
- # joint angles (8,)
- qpos = [qp.pos[0, 2:], qp.rot[0], joint_angle]
-
- # qvel:
- # velocity of the torso (3,)
- # angular velocity of the torso (3,)
- # joint angle velocities (8,)
- qvel = [qp.vel[0], qp.ang[0], joint_vel]
-
- # external contact forces:
- # delta velocity (3,), delta ang (3,) * 10 bodies in the system
- # Note that mujoco has 4 extra bodies tucked inside the Torso that Brax
- # ignores
- cfrc = [
- jp.clip(info.contact.vel, -1, 1),
- jp.clip(info.contact.ang, -1, 1),
- ]
- # flatten bottom dimension
- cfrc = [jp.reshape(x, x.shape[:-2] + (-1,)) for x in cfrc]
-
- return jp.concatenate(qpos + qvel + cfrc)
-
-
-_CONFIG_MODIFIED = """
-bodies {{
- name: "$ Torso"
- colliders {{
- capsule {{
- radius: 0.25
- length: 0.5
- end: 1
- }}
- }}
- inertia {{ x: 1.0 y: 1.0 z: 1.0 }}
- mass: {torso_mass}
-}}
-bodies {{
- name: "Aux 1"
- colliders {{
- rotation {{ x: 90 y: -45 }}
- capsule {{
- radius: 0.08
- length: 0.4428427219390869
- }}
- }}
- inertia {{ x: 1.0 y: 1.0 z: 1.0 }}
- mass: 1
-}}
-bodies {{
- name: "$ Body 4"
- colliders {{
- rotation {{ x: 90 y: -45 }}
- capsule {{
- radius: 0.08
- length: 0.7256854176521301
- end: -1
- }}
- }}
- inertia {{ x: 1.0 y: 1.0 z: 1.0 }}
- mass: 1
-}}
-bodies {{
- name: "Aux 2"
- colliders {{
- rotation {{ x: 90 y: 45 }}
- capsule {{
- radius: 0.08
- length: 0.4428427219390869
- }}
- }}
- inertia {{ x: 1.0 y: 1.0 z: 1.0 }}
- mass: 1
-}}
-bodies {{
- name: "$ Body 7"
- colliders {{
- rotation {{ x: 90 y: 45 }}
- capsule {{
- radius: 0.08
- length: 0.7256854176521301
- end: -1
- }}
- }}
- inertia {{ x: 1.0 y: 1.0 z: 1.0 }}
- mass: 1
-}}
-bodies {{
- name: "Aux 3"
- colliders {{
- rotation {{ x: -90 y: 45 }}
- capsule {{
- radius: 0.08
- length: 0.4428427219390869
- }}
- }}
- inertia {{ x: 1.0 y: 1.0 z: 1.0 }}
- mass: 1
-}}
-bodies {{
- name: "$ Body 10"
- colliders {{
- rotation {{ x: -90 y: 45 }}
- capsule {{
- radius: 0.08
- length: 0.7256854176521301
- end: -1
- }}
- }}
- inertia {{ x: 1.0 y: 1.0 z: 1.0 }}
- mass: 1
-}}
-bodies {{
- name: "Aux 4"
- colliders {{
- rotation {{ x: -90 y: -45 }}
- capsule {{
- radius: 0.08
- length: 0.4428427219390869
- }}
- }}
- inertia {{ x: 1.0 y: 1.0 z: 1.0 }}
- mass: 1
-}}
-bodies {{
- name: "$ Body 13"
- colliders {{
- rotation {{ x: -90 y: -45 }}
- capsule {{
- radius: 0.08
- length: 0.7256854176521301
- end: -1
- }}
- }}
- inertia {{ x: 1.0 y: 1.0 z: 1.0 }}
- mass: 1
-}}
-bodies {{
- name: "Ground"
- colliders {{
- plane {{}}
- }}
- inertia {{ x: 1.0 y: 1.0 z: 1.0 }}
- mass: 1
- frozen {{ all: true }}
-}}
-joints {{
- name: "$ Torso_Aux 1"
- parent_offset {{ x: 0.2 y: 0.2 }}
- child_offset {{ x: -0.1 y: -0.1 }}
- parent: "$ Torso"
- child: "Aux 1"
- stiffness: 18000.0
- angular_damping: 20
- spring_damping: 80
- angle_limit {{ min: -30.0 max: 30.0 }}
- rotation {{ y: -90 }}
-}}
-joints {{
- name: "Aux 1_$ Body 4"
- parent_offset {{ x: 0.1 y: 0.1 }}
- child_offset {{ x: -0.2 y: -0.2 }}
- parent: "Aux 1"
- child: "$ Body 4"
- stiffness: 18000.0
- angular_damping: 20
- spring_damping: 80
- rotation: {{ z: 135 }}
- angle_limit {{
- min: 30.0
- max: 70.0
- }}
-}}
-joints {{
- name: "$ Torso_Aux 2"
- parent_offset {{ x: -0.2 y: 0.2 }}
- child_offset {{ x: 0.1 y: -0.1 }}
- parent: "$ Torso"
- child: "Aux 2"
- stiffness: 18000.0
- angular_damping: 20
- spring_damping: 80
- rotation {{ y: -90 }}
- angle_limit {{ min: -30.0 max: 30.0 }}
-}}
-joints {{
- name: "Aux 2_$ Body 7"
- parent_offset {{ x: -0.1 y: 0.1 }}
- child_offset {{ x: 0.2 y: -0.2 }}
- parent: "Aux 2"
- child: "$ Body 7"
- stiffness: 18000.0
- angular_damping: 20
- spring_damping: 80
- rotation {{ z: 45 }}
- angle_limit {{ min: -70.0 max: -30.0 }}
-}}
-joints {{
- name: "$ Torso_Aux 3"
- parent_offset {{ x: -0.2 y: -0.2 }}
- child_offset {{ x: 0.1 y: 0.1 }}
- parent: "$ Torso"
- child: "Aux 3"
- stiffness: 18000.0
- angular_damping: 20
- spring_damping: 80
- rotation {{ y: -90 }}
- angle_limit {{ min: -30.0 max: 30.0 }}
-}}
-joints {{
- name: "Aux 3_$ Body 10"
- parent_offset {{ x: -0.1 y: -0.1 }}
- child_offset {{
- x: 0.2
- y: 0.2
- }}
- parent: "Aux 3"
- child: "$ Body 10"
- stiffness: 18000.0
- angular_damping: 20
- spring_damping: 80
- rotation {{ z: 135 }}
- angle_limit {{ min: -70.0 max: -30.0 }}
-}}
-joints {{
- name: "$ Torso_Aux 4"
- parent_offset {{ x: 0.2 y: -0.2 }}
- child_offset {{ x: -0.1 y: 0.1 }}
- parent: "$ Torso"
- child: "Aux 4"
- stiffness: 18000.0
- angular_damping: 20
- spring_damping: 80
- rotation {{ y: -90 }}
- angle_limit {{ min: -30.0 max: 30.0 }}
-}}
-joints {{
- name: "Aux 4_$ Body 13"
- parent_offset {{ x: 0.1 y: -0.1 }}
- child_offset {{ x: -0.2 y: 0.2 }}
- parent: "Aux 4"
- child: "$ Body 13"
- stiffness: 18000.0
- angular_damping: 20
- spring_damping: 80
- rotation {{ z: 45 }}
- angle_limit {{ min: 30.0 max: 70.0 }}
-}}
-actuators {{
- name: "$ Torso_Aux 1"
- joint: "$ Torso_Aux 1"
- strength: 350.0
- torque {{}}
-}}
-actuators {{
- name: "Aux 1_$ Body 4"
- joint: "Aux 1_$ Body 4"
- strength: 350.0
- torque {{}}
-}}
-actuators {{
- name: "$ Torso_Aux 2"
- joint: "$ Torso_Aux 2"
- strength: 350.0
- torque {{}}
-}}
-actuators {{
- name: "Aux 2_$ Body 7"
- joint: "Aux 2_$ Body 7"
- strength: 350.0
- torque {{}}
-}}
-actuators {{
- name: "$ Torso_Aux 3"
- joint: "$ Torso_Aux 3"
- strength: 350.0
- torque {{}}
-}}
-actuators {{
- name: "Aux 3_$ Body 10"
- joint: "Aux 3_$ Body 10"
- strength: 350.0
- torque {{}}
-}}
-actuators {{
- name: "$ Torso_Aux 4"
- joint: "$ Torso_Aux 4"
- strength: 350.0
- torque {{}}
-}}
-actuators {{
- name: "Aux 4_$ Body 13"
- joint: "Aux 4_$ Body 13"
- strength: 350.0
- torque {{}}
-}}
-friction: 1.0
-gravity {{ z: -9.8 }}
-angular_damping: -0.05
-baumgarte_erp: 0.1
-collide_include {{
- first: "$ Torso"
- second: "Ground"
-}}
-collide_include {{
- first: "$ Body 4"
- second: "Ground"
-}}
-collide_include {{
- first: "$ Body 7"
- second: "Ground"
-}}
-collide_include {{
- first: "$ Body 10"
- second: "Ground"
-}}
-collide_include {{
- first: "$ Body 13"
- second: "Ground"
-}}
-dt: {dt}
-substeps: 10
-dynamics_mode: "legacy_spring"
-"""
-
-
-def create_modified_ant_env(
- modify_dict: dict = {},
- episode_length: int = 1000,
- action_repeat: int = 1,
- auto_reset: bool = True,
- batch_size: Optional[int] = None,
- eval_metrics: bool = False,
- **kwargs
-):
- """Creates a config modified Ant Env with a specified brax system."""
- default_settings = {"torso_mass": 15, "dt": 0.05}
- for k, v in default_settings.items():
- if k not in modify_dict.keys():
- modify_dict[k] = v
- config = _CONFIG_MODIFIED.format(**modify_dict)
-
- env = ModifiedAnt(config, **kwargs)
- if episode_length is not None:
- env = wrappers.EpisodeWrapper(env, episode_length, action_repeat)
- if batch_size:
- env = wrappers.VectorWrapper(env, batch_size)
- if auto_reset:
- env = wrappers.AutoResetWrapper(env)
- if eval_metrics:
- env = wrappers.EvalWrapper(env)
-
- return env
diff --git a/evosax/problems/obs_norm.py b/evosax/problems/obs_norm.py
deleted file mode 100644
index 2dd20df..0000000
--- a/evosax/problems/obs_norm.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# Copyright 2022 The EvoJAX Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Adapted from https://github.com/google/evojax/blob/main/evojax/obs_norm.py
-
-from typing import Tuple
-from functools import partial
-
-import jax
-import jax.numpy as jnp
-import numpy as np
-
-
-def normalize(
- obs: jnp.ndarray,
- obs_params: jnp.ndarray,
- obs_shape: Tuple,
- clip_value: float,
- std_min_value: float,
- std_max_value: float,
-) -> jnp.ndarray:
- """Normalize the given observation."""
-
- obs_steps = obs_params[0]
- running_mean, running_var = jnp.split(obs_params[1:], 2)
- running_mean = running_mean.reshape(obs_shape)
- running_var = running_var.reshape(obs_shape)
-
- variance = running_var / (obs_steps + 1.0)
- variance = jnp.clip(variance, std_min_value, std_max_value)
- return jnp.clip(
- (obs - running_mean) / jnp.sqrt(variance), -clip_value, clip_value
- )
-
-
-def update_obs_params(
- obs_buffer: jnp.ndarray, obs_mask: jnp.ndarray, obs_params: jnp.ndarray
-) -> jnp.ndarray:
- """Update observation normalization parameters."""
-
- obs_steps = obs_params[0]
- running_mean, running_var = jnp.split(obs_params[1:], 2)
- if obs_mask.ndim != obs_buffer.ndim:
- obs_mask = obs_mask.reshape(
- obs_mask.shape + (1,) * (obs_buffer.ndim - obs_mask.ndim)
- )
-
- new_steps = jnp.sum(obs_mask)
- total_steps = obs_steps + new_steps
-
- input_to_old_mean = (obs_buffer - running_mean) * obs_mask
- mean_diff = jnp.sum(input_to_old_mean / total_steps, axis=(0, 1))
- new_mean = running_mean + mean_diff
-
- input_to_new_mean = (obs_buffer - new_mean) * obs_mask
- var_diff = jnp.sum(input_to_new_mean * input_to_old_mean, axis=(0, 1))
- new_var = running_var + var_diff
-
- return jnp.concatenate([jnp.ones(1) * total_steps, new_mean, new_var])
-
-
-class ObsNormalizer(object):
- """Observation normalizer."""
-
- def __init__(
- self,
- obs_shape: Tuple,
- clip_value: float = 5.0,
- std_min_value: float = 1e-6,
- std_max_value: float = 1e6,
- dummy: bool = False,
- ):
- """Initialization.
-
- Args:
- obs_shape - Shape of the observations.
- std_min_value - Minimum standard deviation.
- std_max_value - Maximum standard deviation.
- dummy - Whether this is a dummy normalizer.
- """
-
- self._obs_shape = obs_shape
- self._obs_size = np.prod(obs_shape)
- self._std_min_value = std_min_value
- self._std_max_value = std_max_value
- self._clip_value = clip_value
- self.is_dummy = dummy
-
- @partial(jax.jit, static_argnums=(0,))
- def normalize_obs(
- self, obs: jnp.ndarray, obs_params: jnp.ndarray
- ) -> jnp.ndarray:
- """Normalize the given observation.
-
- Args:
- obs - The observation to be normalized.
- Returns:
- Normalized observation.
- """
-
- if self.is_dummy:
- return obs
- else:
- return normalize(
- obs=obs,
- obs_params=obs_params,
- obs_shape=self._obs_shape,
- clip_value=self._clip_value,
- std_min_value=self._std_min_value,
- std_max_value=self._std_max_value,
- )
-
- @partial(jax.jit, static_argnums=(0,))
- def update_normalization_params(
- self,
- obs_buffer: jnp.ndarray,
- obs_mask: jnp.ndarray,
- obs_params: jnp.ndarray,
- ) -> jnp.ndarray:
- """Update internal parameters."""
-
- if self.is_dummy:
- return jnp.zeros_like(obs_params)
- else:
- return update_obs_params(
- obs_buffer=obs_buffer,
- obs_mask=obs_mask,
- obs_params=obs_params,
- )
-
- @partial(jax.jit, static_argnums=(0,))
- def get_init_params(self) -> jnp.ndarray:
- return jnp.zeros(1 + self._obs_size * 2)
diff --git a/evosax/problems/sequence.py b/evosax/problems/sequence.py
index 5425df4..a9c2175 100644
--- a/evosax/problems/sequence.py
+++ b/evosax/problems/sequence.py
@@ -45,11 +45,11 @@ def __init__(
else:
self.n_devices = n_devices
- def set_apply_fn(self, map_dict, network, carry_init):
+ def set_apply_fn(self, network, carry_init):
"""Set the network forward function."""
self.network = network
self.carry_init = carry_init
- self.rollout_pop = jax.vmap(self.rollout_rnn, in_axes=(None, map_dict))
+ self.rollout_pop = jax.vmap(self.rollout_rnn, in_axes=(None, 0))
# pmap over popmembers if > 1 device is available - otherwise pmap
if self.n_devices > 1:
self.rollout = self.rollout_pmap
diff --git a/evosax/problems/vision.py b/evosax/problems/vision.py
index 4f4971a..72989ff 100644
--- a/evosax/problems/vision.py
+++ b/evosax/problems/vision.py
@@ -25,10 +25,10 @@ def __init__(
else:
self.n_devices = n_devices
- def set_apply_fn(self, map_dict, network):
+ def set_apply_fn(self, network):
"""Set the network forward function."""
self.network = network
- self.rollout_pop = jax.vmap(self.rollout_ffw, in_axes=(None, map_dict))
+ self.rollout_pop = jax.vmap(self.rollout_ffw, in_axes=(None, 0))
# pmap over popmembers if > 1 device is available - otherwise pmap
if self.n_devices > 1:
self.rollout = self.rollout_pmap
diff --git a/evosax/restarts/termination.py b/evosax/restarts/termination.py
index 4a15edc..35c465c 100644
--- a/evosax/restarts/termination.py
+++ b/evosax/restarts/termination.py
@@ -36,7 +36,10 @@ def cma_criterion(
dC = jnp.diag(state.strategy_state.C)
# Note: Criterion requires full covariance matrix for decomposition!
C, B, D = full_eigen_decomp(
- state.strategy_state.C, state.strategy_state.B, state.strategy_state.D
+ state.strategy_state.C,
+ state.strategy_state.B,
+ state.strategy_state.D,
+ state.strategy_state.gen_counter,
)
# Stop if std of normal distrib is smaller than tolx in all coordinates
diff --git a/evosax/strategies/__init__.py b/evosax/strategies/__init__.py
index d445fe1..b28992e 100755
--- a/evosax/strategies/__init__.py
+++ b/evosax/strategies/__init__.py
@@ -7,7 +7,6 @@
from .pgpe import PGPE
from .pbt import PBT
from .persistent_es import PersistentES
-from .xnes import xNES
from .ars import ARS
from .sep_cma_es import Sep_CMA_ES
from .bipop_cma_es import BIPOP_CMA_ES
@@ -19,6 +18,16 @@
from .rm_es import RmES
from .gld import GLD
from .sim_anneal import SimAnneal
+from .snes import SNES
+from .xnes import xNES
+from .esmc import ESMC
+from .des import DES
+from .samr_ga import SAMR_GA
+from .gesmr_ga import GESMR_GA
+from .guided_es import GuidedES
+from .asebo import ASEBO
+from .cr_fm_nes import CR_FM_NES
+from .mr15_ga import MR15_GA
__all__ = [
@@ -31,7 +40,6 @@
"PGPE",
"PBT",
"PersistentES",
- "xNES",
"ARS",
"Sep_CMA_ES",
"BIPOP_CMA_ES",
@@ -43,4 +51,14 @@
"RmES",
"GLD",
"SimAnneal",
+ "SNES",
+ "xNES",
+ "ESMC",
+ "DES",
+ "SAMR_GA",
+ "GESMR_GA",
+ "GuidedES",
+ "ASEBO",
+ "CR_FM_NES",
+ "MR15_GA",
]
diff --git a/evosax/strategies/ars.py b/evosax/strategies/ars.py
index 7e53908..0273474 100644
--- a/evosax/strategies/ars.py
+++ b/evosax/strategies/ars.py
@@ -1,9 +1,9 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
-from ..utils import GradientOptimizer, OptState, OptParams
+from ..utils import GradientOptimizer, OptState, OptParams, exp_decay
from flax import struct
@@ -32,28 +32,54 @@ class EvoParams:
class ARS(Strategy):
def __init__(
self,
- num_dims: int,
popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
elite_ratio: float = 0.1,
opt_name: str = "sgd",
+ lrate_init: float = 0.05,
+ lrate_decay: float = 1.0,
+ lrate_limit: float = 0.001,
+ sigma_init: float = 0.03,
+ sigma_decay: float = 1.0,
+ sigma_limit: float = 0.01,
+ **fitness_kwargs: Union[bool, int, float]
):
"""Augmented Random Search (Mania et al., 2018)
Reference: https://arxiv.org/pdf/1803.07055.pdf"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert not self.popsize & 1, "Population size must be even"
# ARS performs antithetic sampling & allows you to select
# "b" elite perturbation directions for the update
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize / 2 * self.elite_ratio))
- assert opt_name in ["sgd", "adam", "rmsprop", "clipup"]
+ assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
self.optimizer = GradientOptimizer[opt_name](self.num_dims)
self.strategy_name = "ARS"
+ # Set core kwargs es_params (lrate/sigma schedules)
+ self.lrate_init = lrate_init
+ self.lrate_decay = lrate_decay
+ self.lrate_limit = lrate_limit
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
- return EvoParams(opt_params=self.optimizer.default_params)
+ opt_params = self.optimizer.default_params.replace(
+ lrate_init=self.lrate_init,
+ lrate_decay=self.lrate_decay,
+ lrate_limit=self.lrate_limit,
+ )
+ return EvoParams(
+ opt_params=opt_params,
+ sigma_init=self.sigma_init,
+ sigma_decay=self.sigma_decay,
+ sigma_limit=self.sigma_limit,
+ )
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
@@ -114,8 +140,5 @@ def tell_strategy(
state.mean, theta_grad, state.opt_state, params.opt_params
)
opt_state = self.optimizer.update(opt_state, params.opt_params)
-
- # Update lrate and standard deviation based on min and decay
- sigma = state.sigma * params.sigma_decay
- sigma = jnp.maximum(sigma, params.sigma_limit)
+ sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit)
return state.replace(mean=mean, sigma=sigma, opt_state=opt_state)
diff --git a/evosax/strategies/asebo.py b/evosax/strategies/asebo.py
new file mode 100644
index 0000000..e404d88
--- /dev/null
+++ b/evosax/strategies/asebo.py
@@ -0,0 +1,202 @@
+import jax
+import jax.numpy as jnp
+import chex
+from typing import Tuple, Optional, Union
+from ..strategy import Strategy
+from ..utils import GradientOptimizer, OptState, OptParams, exp_decay
+from flax import struct
+
+
+@struct.dataclass
+class EvoState:
+ mean: chex.Array
+ sigma: float
+ opt_state: OptState
+ grad_subspace: chex.Array
+ alpha: float
+ UUT: chex.Array
+ UUT_ort: chex.Array
+ best_member: chex.Array
+ best_fitness: float = jnp.finfo(jnp.float32).max
+ gen_counter: int = 0
+
+
+@struct.dataclass
+class EvoParams:
+ opt_params: OptParams
+ sigma_init: float = 0.03
+ sigma_decay: float = 1.0
+ sigma_limit: float = 0.01
+ grad_decay: float = 0.99
+ init_min: float = 0.0
+ init_max: float = 0.0
+ clip_min: float = -jnp.finfo(jnp.float32).max
+ clip_max: float = jnp.finfo(jnp.float32).max
+
+
+class ASEBO(Strategy):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ subspace_dims: int = 50,
+ opt_name: str = "adam",
+ lrate_init: float = 0.05,
+ lrate_decay: float = 1.0,
+ lrate_limit: float = 0.001,
+ sigma_init: float = 0.03,
+ sigma_decay: float = 1.0,
+ sigma_limit: float = 0.01,
+ **fitness_kwargs: Union[bool, int, float],
+ ):
+ """ASEBO (Choromanski et al., 2019)
+ Reference: https://arxiv.org/abs/1903.04268
+ Note that there are a couple of JAX-based adaptations:
+ 1. We always sample a fixed population size per generation
+ 2. We keep a fixed archive of gradients to estimate the subspace
+ """
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
+ assert not self.popsize & 1, "Population size must be even"
+ assert opt_name in ["sgd", "adam", "rmsprop", "clipup"]
+ self.optimizer = GradientOptimizer[opt_name](self.num_dims)
+ self.subspace_dims = min(subspace_dims, self.num_dims)
+ if self.subspace_dims < subspace_dims:
+ print(
+ "Subspace has to be smaller than optimization dims. Set to"
+ f" {self.subspace_dims} instead of {subspace_dims}."
+ )
+ self.strategy_name = "ASEBO"
+
+ # Set core kwargs es_params (lrate/sigma schedules)
+ self.lrate_init = lrate_init
+ self.lrate_decay = lrate_decay
+ self.lrate_limit = lrate_limit
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
+ @property
+ def params_strategy(self) -> EvoParams:
+ """Return default parameters of evolution strategy."""
+ opt_params = self.optimizer.default_params.replace(
+ lrate_init=self.lrate_init,
+ lrate_decay=self.lrate_decay,
+ lrate_limit=self.lrate_limit,
+ )
+ return EvoParams(
+ opt_params=opt_params,
+ sigma_init=self.sigma_init,
+ sigma_decay=self.sigma_decay,
+ sigma_limit=self.sigma_limit,
+ )
+
+ def initialize_strategy(
+ self, rng: chex.PRNGKey, params: EvoParams
+ ) -> EvoState:
+ """`initialize` the evolution strategy."""
+ initialization = jax.random.uniform(
+ rng,
+ (self.num_dims,),
+ minval=params.init_min,
+ maxval=params.init_max,
+ )
+
+ grad_subspace = jnp.zeros((self.subspace_dims, self.num_dims))
+
+ state = EvoState(
+ mean=initialization,
+ sigma=params.sigma_init,
+ opt_state=self.optimizer.initialize(params.opt_params),
+ grad_subspace=grad_subspace,
+ alpha=1.0,
+ UUT=jnp.zeros((self.num_dims, self.num_dims)),
+ UUT_ort=jnp.zeros((self.num_dims, self.num_dims)),
+ best_member=initialization,
+ )
+ return state
+
+ def ask_strategy(
+ self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
+ ) -> Tuple[chex.Array, EvoState]:
+ """`ask` for new parameter candidates to evaluate next."""
+ # Antithetic sampling of noise
+ X = state.grad_subspace
+ X -= jnp.mean(X, axis=0)
+ U, S, Vt = jnp.linalg.svd(X, full_matrices=False)
+
+ def svd_flip(u, v):
+ # columns of u, rows of v
+ max_abs_cols = jnp.argmax(jnp.abs(u), axis=0)
+ signs = jnp.sign(u[max_abs_cols, jnp.arange(u.shape[1])])
+ u *= signs
+ v *= signs[:, jnp.newaxis]
+ return u, v
+
+ U, Vt = svd_flip(U, Vt)
+ U = Vt[: int(self.popsize / 2)]
+ UUT = jnp.matmul(U.T, U)
+
+ U_ort = Vt[int(self.popsize / 2) :]
+ UUT_ort = jnp.matmul(U_ort.T, U_ort)
+
+ subspace_ready = state.gen_counter > self.subspace_dims
+
+ UUT = jax.lax.select(
+ subspace_ready, UUT, jnp.zeros((self.num_dims, self.num_dims))
+ )
+ cov = (
+ state.sigma * (state.alpha / self.num_dims) * jnp.eye(self.num_dims)
+ + ((1 - state.alpha) / int(self.popsize / 2)) * UUT
+ )
+ chol = jnp.linalg.cholesky(cov)
+ noise = jax.random.normal(rng, (self.num_dims, int(self.popsize / 2)))
+ z_plus = jnp.swapaxes(chol @ noise, 0, 1)
+ z_plus /= jnp.linalg.norm(z_plus, axis=-1)[:, jnp.newaxis]
+ z = jnp.concatenate([z_plus, -1.0 * z_plus])
+ x = state.mean + z
+ return x, state.replace(UUT=UUT, UUT_ort=UUT_ort)
+
+ def tell_strategy(
+ self,
+ x: chex.Array,
+ fitness: chex.Array,
+ state: EvoState,
+ params: EvoParams,
+ ) -> EvoState:
+ """`tell` performance data for strategy state update."""
+ # Reconstruct noise from last mean/std estimates
+ noise = (x - state.mean) / state.sigma
+ noise_1 = noise[: int(self.popsize / 2)]
+ fit_1 = fitness[: int(self.popsize / 2)]
+ fit_2 = fitness[int(self.popsize / 2) :]
+ fit_diff_noise = jnp.dot(noise_1.T, fit_1 - fit_2)
+ theta_grad = 1.0 / 2.0 * fit_diff_noise
+
+ alpha = jnp.linalg.norm(
+ jnp.dot(theta_grad, state.UUT_ort)
+ ) / jnp.linalg.norm(jnp.dot(theta_grad, state.UUT))
+ subspace_ready = state.gen_counter > self.subspace_dims
+ alpha = jax.lax.select(subspace_ready, alpha, 1.0)
+
+ # Add grad FIFO-style to subspace archive (only if provided else FD)
+ grad_subspace = jnp.zeros((self.subspace_dims, self.num_dims))
+ grad_subspace = grad_subspace.at[:-1, :].set(state.grad_subspace[1:, :])
+ grad_subspace = grad_subspace.at[-1, :].set(theta_grad)
+ state = state.replace(grad_subspace=grad_subspace)
+
+ # Normalize gradients by norm / num_dims
+ theta_grad /= jnp.linalg.norm(theta_grad) / self.num_dims + 1e-8
+
+ # Grad update using optimizer instance - decay lrate if desired
+ mean, opt_state = self.optimizer.step(
+ state.mean, theta_grad, state.opt_state, params.opt_params
+ )
+ opt_state = self.optimizer.update(opt_state, params.opt_params)
+
+ # Update lrate and standard deviation based on min and decay
+ sigma = state.sigma * params.sigma_decay
+ sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit)
+ return state.replace(
+ mean=mean, sigma=sigma, opt_state=opt_state, alpha=alpha
+ )
diff --git a/evosax/strategies/bipop_cma_es.py b/evosax/strategies/bipop_cma_es.py
index 992fed1..d9d5c17 100644
--- a/evosax/strategies/bipop_cma_es.py
+++ b/evosax/strategies/bipop_cma_es.py
@@ -1,6 +1,6 @@
import jax
import chex
-from typing import Tuple, Optional
+from typing import Tuple, Optional, Union
from functools import partial
from .cma_es import CMA_ES
from ..restarts.restarter import WrapperState, WrapperParams
@@ -19,14 +19,27 @@ class RestartParams:
class BIPOP_CMA_ES(object):
- def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.5,
+ sigma_init: float = 1.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""BIPOP-CMA-ES (Hansen, 2009).
Reference: https://hal.inria.fr/inria-00382093/document
Inspired by: https://tinyurl.com/44y3ryhf"""
self.strategy_name = "BIPOP_CMA_ES"
# Instantiate base strategy & wrap it with restart wrapper
self.strategy = CMA_ES(
- num_dims=num_dims, popsize=popsize, elite_ratio=elite_ratio
+ num_dims=num_dims,
+ popsize=popsize,
+ pholder_params=pholder_params,
+ elite_ratio=elite_ratio,
+ sigma_init=sigma_init,
+ **fitness_kwargs,
)
from ..restarts import BIPOP_Restarter
from ..restarts.termination import spread_criterion, cma_criterion
diff --git a/evosax/strategies/cma_es.py b/evosax/strategies/cma_es.py
index 59b7ee3..84111e7 100755
--- a/evosax/strategies/cma_es.py
+++ b/evosax/strategies/cma_es.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple, Optional
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from ..utils.eigen_decomp import full_eigen_decomp
from flax import struct
@@ -80,16 +80,27 @@ def get_cma_elite_weights(
class CMA_ES(Strategy):
- def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.5,
+ sigma_init: float = 1.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""CMA-ES (e.g. Hansen, 2016)
Reference: https://arxiv.org/abs/1604.00772
Inspired by: https://github.com/CyberAgentAILab/cmaes"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
self.strategy_name = "CMA_ES"
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
@@ -122,6 +133,7 @@ def params_strategy(self) -> EvoParams:
d_sigma=d_sigma,
c_c=c_c,
chi_n=chi_n,
+ sigma_init=self.sigma_init,
)
return params
@@ -157,7 +169,9 @@ def ask_strategy(
self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
) -> Tuple[chex.Array, EvoState]:
"""`ask` for new parameter candidates to evaluate next."""
- C, B, D = full_eigen_decomp(state.C, state.B, state.D)
+ C, B, D = full_eigen_decomp(
+ state.C, state.B, state.D, state.gen_counter
+ )
x = sample(
rng,
state.mean,
@@ -197,6 +211,7 @@ def tell_strategy(
y_w,
params.c_sigma,
params.mu_eff,
+ state.gen_counter,
)
p_c, norm_p_sigma, h_sigma = update_p_c(
@@ -259,9 +274,10 @@ def update_p_sigma(
y_w: chex.Array,
c_sigma: float,
mu_eff: float,
+ gen_counter: int,
) -> Tuple[chex.Array, chex.Array, chex.Array, None, None]:
"""Update evolution path for covariance matrix."""
- C, B, D = full_eigen_decomp(C, B, D)
+ C, B, D = full_eigen_decomp(C, B, D, gen_counter)
C_2 = B.dot(jnp.diag(1 / D)).dot(B.T) # C^(-1/2) = B D^(-1) B^T
p_sigma_new = (1 - c_sigma) * p_sigma + jnp.sqrt(
c_sigma * (2 - c_sigma) * mu_eff
diff --git a/evosax/strategies/cr_fm_nes.py b/evosax/strategies/cr_fm_nes.py
new file mode 100644
index 0000000..70c23e3
--- /dev/null
+++ b/evosax/strategies/cr_fm_nes.py
@@ -0,0 +1,300 @@
+import jax
+import jax.numpy as jnp
+import chex
+import math
+from typing import Tuple, Optional, Union
+from ..strategy import Strategy
+from flax import struct
+
+
+@struct.dataclass
+class EvoState:
+ mean: chex.Array
+ sigma: float
+ v: chex.Array
+ D: chex.Array
+ p_sigma: chex.Array
+ p_c: chex.Array
+ w_rank_hat: chex.Array
+ w_rank: chex.Array
+ z: chex.Array
+ y: chex.Array
+ best_member: chex.Array
+ best_fitness: float = jnp.finfo(jnp.float32).max
+ gen_counter: int = 0
+
+
+@struct.dataclass
+class EvoParams:
+ mu_eff: float
+ c_s: float
+ c_c: float
+ c1: float
+ chi_N: float
+ h_inv: float
+ alpha_dist: float
+ lrate_mean: float = 1.0
+ lrate_move_sigma: float = 0.1
+ lrate_stag_sigma: float = 0.1
+ lrate_conv_sigma: float = 0.1
+ lrate_B: float = 0.1
+ sigma_init: float = 1.0
+ init_min: float = 0.0
+ init_max: float = 0.0
+ clip_min: float = -jnp.finfo(jnp.float32).max
+ clip_max: float = jnp.finfo(jnp.float32).max
+
+
+def get_recombination_weights(popsize: int) -> Tuple[chex.Array, chex.Array]:
+ """Get recombination weights for different ranks."""
+
+ def get_weight(i):
+ return jnp.log(popsize / 2 + 1) - jnp.log(i)
+
+ w_rank_hat = jax.vmap(get_weight)(jnp.arange(1, popsize + 1))
+ w_rank_hat = w_rank_hat * (w_rank_hat >= 0)
+ w_rank = w_rank_hat / sum(w_rank_hat) - (1.0 / popsize)
+ return w_rank_hat.reshape(-1, 1), w_rank.reshape(-1, 1)
+
+
+def get_h_inv(dim: int) -> float:
+ dim = min(dim, 2000)
+ f = lambda a: ((1.0 + a * a) * math.exp(a * a / 2.0) / 0.24) - 10.0 - dim
+ f_prime = lambda a: (1.0 / 0.24) * a * math.exp(a * a / 2.0) * (3.0 + a * a)
+ h_inv = 1.0
+ counter = 0
+ while abs(f(h_inv)) > 1e-10:
+ counter += 1
+ h_inv = h_inv - 0.5 * (f(h_inv) / f_prime(h_inv))
+ return h_inv
+
+
+def w_dist_hat(alpha_dist: float, z: chex.Array) -> chex.Array:
+ return jnp.exp(alpha_dist * jnp.linalg.norm(z))
+
+
+class CR_FM_NES(Strategy):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ sigma_init: float = 1.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
+ """Cost-Reduced Fast-Moving Natural ES (Nomura & Ono, 2022)
+ Reference: https://arxiv.org/abs/2201.11422
+ """
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
+ assert not self.popsize & 1, "Population size must be even"
+ self.strategy_name = "CR_FM_NES"
+
+ # Set core kwargs es_params (sigma)
+ self.sigma_init = sigma_init
+
+ @property
+ def default_params(self) -> EvoParams:
+ """Return default parameters of evolutionary strategy."""
+ w_rank_hat, w_rank = get_recombination_weights(self.popsize)
+ mueff = 1 / (
+ (w_rank + (1 / self.popsize)).T @ (w_rank + (1 / self.popsize))
+ )
+ c_s = (mueff + 2.0) / (self.num_dims + mueff + 5.0)
+ c_c = (4.0 + mueff / self.num_dims) / (
+ self.num_dims + 4.0 + 2.0 * mueff / self.num_dims
+ )
+ c1_cma = 2.0 / (jnp.power(self.num_dims + 1.3, 2) + mueff)
+ chi_N = jnp.sqrt(self.num_dims) * (
+ 1.0
+ - 1.0 / (4.0 * self.num_dims)
+ + 1.0 / (21.0 * self.num_dims * self.num_dims)
+ )
+ h_inv = get_h_inv(self.num_dims)
+ alpha_dist = h_inv * jnp.minimum(
+ 1.0, jnp.sqrt(self.popsize / self.num_dims)
+ )
+ lrate_move_sigma = 1.0
+ lrate_stag_sigma = jnp.tanh(
+ (0.024 * self.popsize + 0.7 * self.num_dims + 20.0)
+ / (self.num_dims + 12.0)
+ )
+ lrate_conv_sigma = 2.0 * jnp.tanh(
+ (0.025 * self.popsize + 0.75 * self.num_dims + 10.0)
+ / (self.num_dims + 4.0)
+ )
+ c1 = c1_cma * (self.num_dims - 5) / 6
+ lrate_B = jnp.tanh(
+ (jnp.minimum(0.02 * self.popsize, 3 * jnp.log(self.num_dims)) + 5)
+ / (0.23 * self.num_dims + 25)
+ )
+ params = EvoParams(
+ lrate_move_sigma=lrate_move_sigma,
+ lrate_stag_sigma=lrate_stag_sigma,
+ lrate_conv_sigma=lrate_conv_sigma,
+ lrate_B=lrate_B,
+ mu_eff=mueff,
+ c_s=c_s,
+ c_c=c_c,
+ c1=c1,
+ chi_N=chi_N,
+ alpha_dist=alpha_dist,
+ h_inv=h_inv,
+ sigma_init=self.sigma_init,
+ )
+ return params
+
+ def initialize_strategy(
+ self, rng: chex.PRNGKey, params: EvoParams
+ ) -> EvoState:
+ """`initialize` the evolutionary strategy."""
+ rng_init, rng_v = jax.random.split(rng)
+ initialization = jax.random.uniform(
+ rng_init,
+ (self.num_dims,),
+ minval=params.init_min,
+ maxval=params.init_max,
+ )
+ w_rank_hat, w_rank = get_recombination_weights(self.popsize)
+ state = EvoState(
+ mean=initialization,
+ sigma=params.sigma_init,
+ v=jax.random.normal(rng_v, shape=(self.num_dims, 1))
+ / jnp.sqrt(self.num_dims),
+ D=jnp.ones([self.num_dims, 1]),
+ p_sigma=jnp.zeros((self.num_dims, 1)),
+ p_c=jnp.zeros((self.num_dims, 1)),
+ z=jnp.zeros((self.popsize, self.num_dims)),
+ y=jnp.zeros((self.popsize, self.num_dims)),
+ w_rank_hat=w_rank_hat.reshape(-1, 1),
+ w_rank=w_rank,
+ best_member=initialization,
+ )
+
+ return state
+
+ def ask_strategy(
+ self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
+ ) -> Tuple[chex.Array, EvoState]:
+ """`ask` for new parameter candidates to evaluate next."""
+ z_plus = jax.random.normal(
+ rng,
+ (int(self.popsize / 2), self.num_dims),
+ )
+ z = jnp.concatenate([z_plus, -1.0 * z_plus])
+ z = jnp.swapaxes(z, 0, 1)
+ normv = jnp.linalg.norm(state.v)
+ normv2 = normv ** 2
+ vbar = state.v / normv
+
+ # Rescale/reparametrize noise
+ y = z + (jnp.sqrt(1 + normv2) - 1) * vbar @ (vbar.T @ z)
+ x = state.mean[:, None] + state.sigma * y * state.D
+ x = jnp.swapaxes(x, 0, 1)
+ return x, state.replace(z=z, y=y)
+
+ def tell_strategy(
+ self,
+ x: chex.Array,
+ fitness: chex.Array,
+ state: EvoState,
+ params: EvoParams,
+ ) -> EvoState:
+ """`tell` performance data for strategy state update."""
+ ranks = fitness.argsort()
+ z = state.z[:, ranks]
+ y = state.y[:, ranks]
+ x = jnp.swapaxes(x, 0, 1)[:, ranks]
+
+ # Update evolution path p_sigma
+ p_sigma = (1 - params.c_s) * state.p_sigma + jnp.sqrt(
+ params.c_s * (2.0 - params.c_s) * params.mu_eff
+ ) * (z @ state.w_rank)
+ p_sigma_norm = jnp.linalg.norm(p_sigma)
+
+ # Calculate distance weight
+ w_tmp = state.w_rank_hat * jax.vmap(w_dist_hat, in_axes=(None, 1))(
+ params.alpha_dist, z
+ ).reshape(-1, 1)
+ weights_dist = w_tmp / sum(w_tmp) - 1.0 / self.popsize
+
+ # switching weights and learning rate
+ p_sigma_cond = p_sigma_norm >= params.chi_N
+ weights = jax.lax.select(p_sigma_cond, weights_dist, state.w_rank)
+ lrate_sigma = jax.lax.select(
+ p_sigma_cond, params.lrate_move_sigma, params.lrate_stag_sigma
+ )
+ lrate_sigma = jax.lax.select(
+ p_sigma_norm >= 0.1 * params.chi_N,
+ lrate_sigma,
+ params.lrate_conv_sigma,
+ )
+
+ # update evolution path p_c and mean
+ wxm = (x - state.mean[:, None]) @ weights
+ p_c = (1.0 - params.c_c) * state.p_c + jnp.sqrt(
+ params.c_c * (2.0 - params.c_c) * params.mu_eff
+ ) * wxm / state.sigma
+ mean = state.mean + params.lrate_mean * wxm.squeeze()
+
+ normv = jnp.linalg.norm(state.v)
+ vbar = state.v / normv
+ normv2 = normv ** 2
+ normv4 = normv2 ** 2
+
+ exY = jnp.append(y, p_c / state.D, axis=1)
+ yy = exY * exY
+ ip_yvbar = vbar.T @ exY
+ yvbar = exY * vbar
+ gammav = 1.0 + normv2
+ vbarbar = vbar * vbar
+ alphavd = jnp.minimum(
+ 1,
+ jnp.sqrt(
+ normv4 + (2 * gammav - jnp.sqrt(gammav)) / jnp.max(vbarbar)
+ )
+ / (2 + normv2),
+ )
+ t = exY * ip_yvbar - vbar * (ip_yvbar ** 2 + gammav) / 2
+ b = -(1 - alphavd ** 2) * normv4 / gammav + 2 * alphavd ** 2
+ H = jnp.ones([self.num_dims, 1]) * 2 - (b + 2 * alphavd ** 2) * vbarbar
+ invH = H ** (-1)
+ s_step1 = (
+ yy
+ - normv2 / gammav * (yvbar * ip_yvbar)
+ - jnp.ones([self.num_dims, self.popsize + 1])
+ )
+ ip_vbart = vbar.T @ t
+ s_step2 = s_step1 - alphavd / gammav * (
+ (2 + normv2) * (t * vbar) - normv2 * vbarbar @ ip_vbart
+ )
+ invHvbarbar = invH * vbarbar
+ ip_s_step2invHvbarbar = invHvbarbar.T @ s_step2
+ s = (s_step2 * invH) - b / (
+ 1 + b * vbarbar.T @ invHvbarbar
+ ) * invHvbarbar @ ip_s_step2invHvbarbar
+ ip_svbarbar = vbarbar.T @ s
+ t = t - alphavd * ((2 + normv2) * (s * vbar) - vbar @ ip_svbarbar)
+
+ # update v, D covariance ingredients
+ exw = jnp.append(
+ params.lrate_B * weights,
+ jnp.array([params.c1]).reshape(1, 1),
+ axis=0,
+ )
+ v = state.v + (t @ exw) / normv
+ D = state.D + (s @ exw) * state.D
+ # calculate detA
+ nthrootdetA = jnp.exp(
+ jnp.sum(jnp.log(D)) / self.num_dims
+ + jnp.log(1 + v.T @ v) / (2 * self.num_dims)
+ )[0][0]
+ D = D / nthrootdetA
+ # update sigma
+ G_s = (
+ jnp.sum((z * z - jnp.ones([self.num_dims, self.popsize])) @ weights)
+ / self.num_dims
+ )
+ sigma = state.sigma * jnp.exp(lrate_sigma / 2 * G_s)
+ return state.replace(
+ p_sigma=p_sigma, mean=mean, p_c=p_c, v=v, D=D, sigma=sigma
+ )
diff --git a/evosax/strategies/de.py b/evosax/strategies/de.py
index 338adc2..0c6ff7c 100755
--- a/evosax/strategies/de.py
+++ b/evosax/strategies/de.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from flax import struct
@@ -29,11 +29,17 @@ class EvoParams:
class DE(Strategy):
- def __init__(self, num_dims: int, popsize: int):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Differential Evolution (Storn & Price, 1997)
Reference: https://tinyurl.com/4pje5a74"""
- assert popsize > 6
- super().__init__(num_dims, popsize)
+ assert popsize > 6, "DE requires popsize > 6."
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
self.strategy_name = "DE"
@property
diff --git a/evosax/strategies/des.py b/evosax/strategies/des.py
new file mode 100644
index 0000000..91ba64b
--- /dev/null
+++ b/evosax/strategies/des.py
@@ -0,0 +1,107 @@
+import jax
+import jax.numpy as jnp
+import chex
+from typing import Tuple, Optional, Union
+from ..strategy import Strategy
+from flax import struct
+from flax import linen as nn
+
+
+@struct.dataclass
+class EvoState:
+ mean: chex.Array
+ sigma: chex.Array
+ weights: chex.Array # Weights for population members
+ best_member: chex.Array
+ best_fitness: float = jnp.finfo(jnp.float32).max
+ gen_counter: int = 0
+
+
+@struct.dataclass
+class EvoParams:
+ temperature: float = 12.5 # Temperature for softmax weights
+ lrate_sigma: float = 0.1 # Learning rate for population std
+ lrate_mean: float = 1.0 # Learning rate for population mean
+ sigma_init: float = 1.0 # Standard deviation
+ init_min: float = 0.0
+ init_max: float = 0.0
+ clip_min: float = -jnp.finfo(jnp.float32).max
+ clip_max: float = jnp.finfo(jnp.float32).max
+
+
+def get_des_weights(popsize: int, temperature: float = 12.5):
+ """Compute discovered recombination weights."""
+ ranks = jnp.arange(popsize)
+ ranks /= ranks.size - 1
+ ranks = ranks - 0.5
+ sigout = nn.sigmoid(temperature * ranks)
+ weights = nn.softmax(-20 * sigout)
+ return weights
+
+
+class DES(Strategy):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ ):
+ """Discovered Evolution Strategy (Lange et al., 2022)"""
+ super().__init__(popsize, num_dims, pholder_params)
+ self.strategy_name = "DES"
+
+ @property
+ def params_strategy(self) -> EvoParams:
+ """Return default parameters of evolution strategy."""
+ return EvoParams()
+
+ def initialize_strategy(
+ self, rng: chex.PRNGKey, params: EvoParams
+ ) -> EvoState:
+ """`initialize` the evolution strategy."""
+ # Get DES discovered recombination weights.
+ weights = get_des_weights(self.popsize, params.temperature)
+ initialization = jax.random.uniform(
+ rng,
+ (self.num_dims,),
+ minval=params.init_min,
+ maxval=params.init_max,
+ )
+ state = EvoState(
+ mean=initialization,
+ sigma=params.sigma_init * jnp.ones(self.num_dims),
+ weights=weights.reshape(-1, 1),
+ best_member=initialization,
+ )
+ return state
+
+ def ask_strategy(
+ self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
+ ) -> Tuple[chex.Array, EvoState]:
+ """`ask` for new proposed candidates to evaluate next."""
+ z = jax.random.normal(rng, (self.popsize, self.num_dims)) # ~ N(0, I)
+ x = state.mean + z * state.sigma.reshape(
+ 1, self.num_dims
+ ) # ~ N(m, σ^2 I)
+ return x, state
+
+ def tell_strategy(
+ self,
+ x: chex.Array,
+ fitness: chex.Array,
+ state: EvoState,
+ params: EvoParams,
+ ) -> EvoState:
+ """`tell` update to ES state."""
+ weights = state.weights
+ x = x[fitness.argsort()]
+ # Weighted updates
+ weighted_mean = (weights * x).sum(axis=0)
+ weighted_sigma = jnp.sqrt(
+ (weights * (x - state.mean) ** 2).sum(axis=0) + 1e-06
+ )
+ mean = state.mean + params.lrate_mean * (weighted_mean - state.mean)
+ sigma = state.sigma + params.lrate_sigma * (
+ weighted_sigma - state.sigma
+ )
+ return state.replace(mean=mean, sigma=sigma)
diff --git a/evosax/strategies/esmc.py b/evosax/strategies/esmc.py
new file mode 100644
index 0000000..058fafb
--- /dev/null
+++ b/evosax/strategies/esmc.py
@@ -0,0 +1,142 @@
+import jax
+import jax.numpy as jnp
+import chex
+from typing import Tuple, Optional, Union
+from ..strategy import Strategy
+from ..utils import GradientOptimizer, OptState, OptParams, exp_decay
+from flax import struct
+
+
+@struct.dataclass
+class EvoState:
+ mean: chex.Array
+ sigma: chex.Array
+ opt_state: OptState
+ best_member: chex.Array
+ best_fitness: float = jnp.finfo(jnp.float32).max
+ gen_counter: int = 0
+
+
+@struct.dataclass
+class EvoParams:
+ opt_params: OptParams
+ sigma_init: float = 0.03
+ sigma_decay: float = 0.999
+ sigma_limit: float = 0.01
+ sigma_lrate: float = 0.2 # Learning rate for std
+ sigma_max_change: float = 0.2 # Clip adaptive sigma to 20%
+ init_min: float = 0.0
+ init_max: float = 0.0
+ clip_min: float = -jnp.finfo(jnp.float32).max
+ clip_max: float = jnp.finfo(jnp.float32).max
+
+
+class ESMC(Strategy):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ opt_name: str = "adam",
+ lrate_init: float = 0.05,
+ lrate_decay: float = 1.0,
+ lrate_limit: float = 0.001,
+ sigma_init: float = 0.03,
+ sigma_decay: float = 1.0,
+ sigma_limit: float = 0.01,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
+ """ESMC (Merchant et al., 2021)
+ Reference: https://proceedings.mlr.press/v139/merchant21a.html
+ """
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
+ assert self.popsize & 1, "Population size must be odd"
+ assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
+ self.optimizer = GradientOptimizer[opt_name](self.num_dims)
+ self.strategy_name = "ESMC"
+
+ # Set core kwargs es_params (lrate/sigma schedules)
+ self.lrate_init = lrate_init
+ self.lrate_decay = lrate_decay
+ self.lrate_limit = lrate_limit
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
+ @property
+ def params_strategy(self) -> EvoParams:
+ """Return default parameters of evolution strategy."""
+ opt_params = self.optimizer.default_params.replace(
+ lrate_init=self.lrate_init,
+ lrate_decay=self.lrate_decay,
+ lrate_limit=self.lrate_limit,
+ )
+ return EvoParams(
+ opt_params=opt_params,
+ sigma_init=self.sigma_init,
+ sigma_decay=self.sigma_decay,
+ sigma_limit=self.sigma_limit,
+ )
+
+ def initialize_strategy(
+ self, rng: chex.PRNGKey, params: EvoParams
+ ) -> EvoState:
+ """`initialize` the evolution strategy."""
+ initialization = jax.random.uniform(
+ rng,
+ (self.num_dims,),
+ minval=params.init_min,
+ maxval=params.init_max,
+ )
+ state = EvoState(
+ mean=initialization,
+ sigma=jnp.ones(self.num_dims) * params.sigma_init,
+ opt_state=self.optimizer.initialize(params.opt_params),
+ best_member=initialization,
+ )
+ return state
+
+ def ask_strategy(
+ self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
+ ) -> Tuple[chex.Array, EvoState]:
+ """`ask` for new parameter candidates to evaluate next."""
+ # Antithetic sampling of noise
+ z_plus = jax.random.normal(
+ rng,
+ (int(self.popsize / 2), self.num_dims),
+ )
+ z = jnp.concatenate(
+ [jnp.zeros((1, self.num_dims)), z_plus, -1.0 * z_plus]
+ )
+ x = state.mean + z * state.sigma.reshape(1, self.num_dims)
+ return x, state
+
+ def tell_strategy(
+ self,
+ x: chex.Array,
+ fitness: chex.Array,
+ state: EvoState,
+ params: EvoParams,
+ ) -> EvoState:
+ """Update both mean and dim.-wise isotropic Gaussian scale."""
+ # Reconstruct noise from last mean/std estimates
+ noise = (x - state.mean) / state.sigma
+ bline_fitness = fitness[0]
+ noise = noise[1:]
+ fitness = fitness[1:]
+ noise_1 = noise[: int((self.popsize - 1) / 2)]
+ fit_1 = fitness[: int((self.popsize - 1) / 2)]
+ fit_2 = fitness[int((self.popsize - 1) / 2) :]
+ fit_diff = jnp.minimum(fit_1, bline_fitness) - jnp.minimum(
+ fit_2, bline_fitness
+ )
+ fit_diff_noise = jnp.dot(noise_1.T, fit_diff)
+ theta_grad = 1.0 / int((self.popsize - 1) / 2) * fit_diff_noise
+ # Grad update using optimizer instance - decay lrate if desired
+ mean, opt_state = self.optimizer.step(
+ state.mean, theta_grad, state.opt_state, params.opt_params
+ )
+ opt_state = self.optimizer.update(opt_state, params.opt_params)
+ sigma = state.sigma * params.sigma_decay
+ sigma = jnp.maximum(sigma, params.sigma_limit)
+ return state.replace(mean=mean, sigma=sigma, opt_state=opt_state)
diff --git a/evosax/strategies/full_iamalgam.py b/evosax/strategies/full_iamalgam.py
index eb6002f..b959780 100644
--- a/evosax/strategies/full_iamalgam.py
+++ b/evosax/strategies/full_iamalgam.py
@@ -1,8 +1,9 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
+from ..utils import exp_decay
from flax import struct
@@ -29,7 +30,7 @@ class EvoParams:
delta_ams: float = 2.0
theta_sdr: float = 1.0
c_mult_init: float = 1.0
- sigma_init: float = 0.0
+ sigma_init: float = 0.1
sigma_decay: float = 0.999
sigma_limit: float = 0.0
init_min: float = 0.0
@@ -39,11 +40,21 @@ class EvoParams:
class Full_iAMaLGaM(Strategy):
- def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.35):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.35,
+ sigma_init: float = 0.0,
+ sigma_decay: float = 0.99,
+ sigma_limit: float = 0.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""(Iterative) AMaLGaM (Bosman et al., 2013) - Full Covariance
Reference: https://tinyurl.com/y9fcccx2
"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
@@ -56,6 +67,11 @@ def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.35):
self.ams_popsize = int(alpha_ams * (self.popsize - 1))
self.strategy_name = "Full_iAMaLGaM"
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
@@ -72,8 +88,13 @@ def params_strategy(self) -> EvoParams:
/ (self.num_dims ** a_2_shift)
)
- params = EvoParams(eta_sigma=eta_sigma, eta_shift=eta_shift)
- return params
+ return EvoParams(
+ eta_sigma=eta_sigma,
+ eta_shift=eta_shift,
+ sigma_init=self.sigma_init,
+ sigma_decay=self.sigma_decay,
+ sigma_limit=self.sigma_limit,
+ )
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
@@ -153,8 +174,7 @@ def tell_strategy(
C = update_cov_amalgam(members_elite, state.C, mean, params.eta_sigma)
# Decay isotropic part of Gaussian search distribution
- sigma = state.sigma * params.sigma_decay
- sigma = jnp.maximum(sigma, params.sigma_limit)
+ sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit)
return state.replace(
c_mult=c_mult,
nis_counter=nis_counter,
diff --git a/evosax/strategies/gesmr_ga.py b/evosax/strategies/gesmr_ga.py
new file mode 100644
index 0000000..1a67ee0
--- /dev/null
+++ b/evosax/strategies/gesmr_ga.py
@@ -0,0 +1,166 @@
+import jax
+import jax.numpy as jnp
+import chex
+from typing import Tuple, Optional, Union
+from ..strategy import Strategy
+from flax import struct
+
+
+@struct.dataclass
+class EvoState:
+ rng: chex.PRNGKey
+ mean: chex.Array
+ archive: chex.Array
+ fitness: chex.Array
+ sigma: chex.Array
+ best_member: chex.Array
+ best_fitness: float = jnp.finfo(jnp.float32).max
+ gen_counter: int = 0
+
+
+@struct.dataclass
+class EvoParams:
+ sigma_init: float = 0.07
+ sigma_meta: float = 2.0
+ init_min: float = 0.0
+ init_max: float = 0.0
+ clip_min: float = -jnp.finfo(jnp.float32).max
+ clip_max: float = jnp.finfo(jnp.float32).max
+
+
+class GESMR_GA(Strategy):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.5,
+ sigma_ratio: float = 0.5,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
+ """Self-Adaptation Mutation Rate GA."""
+
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
+ self.elite_ratio = elite_ratio
+ self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
+ self.num_sigma_groups = int(jnp.sqrt(self.popsize))
+ self.members_per_group = int(
+ jnp.ceil(self.popsize / self.num_sigma_groups)
+ )
+ self.sigma_ratio = sigma_ratio
+ self.sigma_popsize = max(
+ 1, int(self.num_sigma_groups * self.sigma_ratio)
+ )
+ self.strategy_name = "GESMR_GA"
+
+ @property
+ def params_strategy(self) -> EvoParams:
+ """Return default parameters of evolution strategy."""
+ return EvoParams()
+
+ def initialize_strategy(
+ self, rng: chex.PRNGKey, params: EvoParams
+ ) -> EvoState:
+ """`initialize` the differential evolution strategy."""
+ rng, rng_init = jax.random.split(rng)
+ initialization = jax.random.uniform(
+ rng_init,
+ (self.elite_popsize, self.num_dims),
+ minval=params.init_min,
+ maxval=params.init_max,
+ )
+ state = EvoState(
+ rng=rng,
+ mean=initialization[0],
+ archive=initialization,
+ fitness=jnp.zeros(self.elite_popsize) + jnp.finfo(jnp.float32).max,
+ sigma=jnp.zeros(self.num_sigma_groups) + params.sigma_init,
+ best_member=initialization[0],
+ )
+ return state
+
+ def ask_strategy(
+ self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
+ ) -> Tuple[chex.Array, EvoState]:
+ """`ask` for new proposed candidates to evaluate next."""
+ rng, rng_idx, rng_eps_x, rng_eps_s = jax.random.split(rng, 4)
+ # Sample noise for mutation of x and sigma
+ eps_x = jax.random.normal(rng_eps_x, (self.popsize, self.num_dims))
+ eps_s = jax.random.uniform(
+ rng_eps_s, (self.num_sigma_groups,), minval=-1, maxval=1
+ )
+
+ # Sample members to evaluate from parent archive
+ idx = jax.random.choice(
+ rng_idx, jnp.arange(self.elite_popsize), (self.popsize - 1,)
+ )
+ x = jnp.concatenate([state.archive[0][None, :], state.archive[idx]])
+
+ # Store fitness before perturbation (used to compute meta-fitness)
+ fitness_mem = jnp.concatenate(
+ [state.fitness[0][None], state.fitness[idx]]
+ )
+
+ # Apply sigma mutation on group level -> repeat for popmember broadcast
+ sigma_perturb = state.sigma * params.sigma_meta ** eps_s
+ sigma_repeated = jnp.repeat(sigma_perturb, self.members_per_group)[
+ : self.popsize
+ ]
+ sigma = jnp.concatenate([state.sigma[0][None], sigma_repeated[1:]])
+
+ # Apply x mutation -> scale specific to group membership
+ x += sigma[:, None] * eps_x
+ return x, state.replace(
+ archive=x, fitness=fitness_mem, sigma=sigma_perturb
+ )
+
+ def tell_strategy(
+ self,
+ x: chex.Array,
+ fitness: chex.Array,
+ state: EvoState,
+ params: EvoParams,
+ ) -> EvoState:
+ """`tell` update to ES state."""
+ # Select best x members
+ idx = jnp.argsort(fitness)[: self.elite_popsize]
+ archive = x[idx]
+
+ # Select best sigma based on function value improvement
+ group_ids = jnp.repeat(
+ jnp.arange(self.members_per_group), self.num_sigma_groups
+ )[: self.popsize]
+ delta_fitness = fitness - state.fitness
+
+ best_deltas = []
+ for k in range(self.num_sigma_groups):
+ sub_mask = group_ids == k
+ sub_delta = (
+ sub_mask * delta_fitness
+ + (1 - sub_mask) * jnp.finfo(jnp.float32).max
+ )
+ max_sub_delta = jnp.min(sub_delta)
+ best_deltas.append(max_sub_delta)
+
+ idx_select = jnp.argsort(jnp.array(best_deltas))[: self.sigma_popsize]
+ sigma_elite = state.sigma[idx_select]
+
+ # Resample sigmas with replacement
+ rng, rng_sigma = jax.random.split(state.rng)
+ idx_s = jax.random.choice(
+ rng_sigma,
+ jnp.arange(self.sigma_popsize),
+ (self.num_sigma_groups - 1,),
+ )
+ sigma = jnp.concatenate([state.sigma[0][None], sigma_elite[idx_s]])
+
+ # Set mean to best member seen so far
+ improved = fitness[0] < state.best_fitness
+ best_mean = jax.lax.select(improved, archive[0], state.best_member)
+ return state.replace(
+ rng=rng,
+ fitness=fitness[idx],
+ archive=archive,
+ sigma=sigma,
+ mean=best_mean,
+ )
diff --git a/evosax/strategies/gld.py b/evosax/strategies/gld.py
index 028bd74..ae363f3 100644
--- a/evosax/strategies/gld.py
+++ b/evosax/strategies/gld.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from flax import struct
@@ -16,7 +16,7 @@ class EvoState:
@struct.dataclass
class EvoParams:
- radius_max: float = 0.1
+ radius_max: float = 0.2
radius_min: float = 0.001
radius_decay: float = 5
init_min: float = 0.0
@@ -26,10 +26,16 @@ class EvoParams:
class GLD(Strategy):
- def __init__(self, num_dims: int, popsize: int):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Gradientless Descent (Golovin et al., 2019)
Reference: https://arxiv.org/pdf/1911.06317.pdf"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
self.strategy_name = "GLD"
@property
diff --git a/evosax/strategies/guided_es.py b/evosax/strategies/guided_es.py
new file mode 100644
index 0000000..142517d
--- /dev/null
+++ b/evosax/strategies/guided_es.py
@@ -0,0 +1,186 @@
+import jax
+import jax.numpy as jnp
+import chex
+from typing import Tuple, Optional, Union
+from functools import partial
+from ..strategy import Strategy
+from ..utils import GradientOptimizer, OptState, OptParams, exp_decay
+from flax import struct
+from evosax.utils import get_best_fitness_member
+
+
+@struct.dataclass
+class EvoState:
+ mean: chex.Array
+ sigma: float
+ opt_state: OptState
+ grad_subspace: chex.Array
+ best_member: chex.Array
+ best_fitness: float = jnp.finfo(jnp.float32).max
+ gen_counter: int = 0
+
+
+@struct.dataclass
+class EvoParams:
+ opt_params: OptParams
+ sigma_init: float = 0.03
+ sigma_decay: float = 1.0
+ sigma_limit: float = 0.01
+ alpha: float = 0.5
+ beta: float = 1.0
+ init_min: float = 0.0
+ init_max: float = 0.0
+ clip_min: float = -jnp.finfo(jnp.float32).max
+ clip_max: float = jnp.finfo(jnp.float32).max
+
+
+class GuidedES(Strategy):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ subspace_dims: int = 1, # k param in example notebook
+ opt_name: str = "sgd",
+ lrate_init: float = 0.05,
+ lrate_decay: float = 1.0,
+ lrate_limit: float = 0.001,
+ sigma_init: float = 0.03,
+ sigma_decay: float = 1.0,
+ sigma_limit: float = 0.01,
+ **fitness_kwargs: Union[bool, int, float],
+ ):
+ """Guided ES (Maheswaranathan et al., 2018)
+ Reference: https://arxiv.org/abs/1806.10230
+ Note that there are a couple of JAX-based adaptations:
+ """
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
+ assert not self.popsize & 1, "Population size must be even"
+ assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
+ assert (
+ subspace_dims <= self.num_dims
+ ), "Subspace has to be smaller than optimization dims."
+ self.optimizer = GradientOptimizer[opt_name](self.num_dims)
+ self.subspace_dims = min(subspace_dims, self.num_dims)
+ if self.subspace_dims < subspace_dims:
+ print(
+ "Subspace has to be smaller than optimization dims. Set to"
+ f" {self.subspace_dims} instead of {subspace_dims}."
+ )
+ self.strategy_name = "GuidedES"
+
+ # Set core kwargs es_params (lrate/sigma schedules)
+ self.lrate_init = lrate_init
+ self.lrate_decay = lrate_decay
+ self.lrate_limit = lrate_limit
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
+ @property
+ def params_strategy(self) -> EvoParams:
+ """Return default parameters of evolution strategy."""
+ return EvoParams(opt_params=self.optimizer.default_params)
+
+ def initialize_strategy(
+ self, rng: chex.PRNGKey, params: EvoParams
+ ) -> EvoState:
+ """`initialize` the evolution strategy."""
+ rng_init, rng_sub = jax.random.split(rng)
+ initialization = jax.random.uniform(
+ rng_init,
+ (self.num_dims,),
+ minval=params.init_min,
+ maxval=params.init_max,
+ )
+
+ grad_subspace = jax.random.normal(
+ rng_sub, (self.subspace_dims, self.num_dims)
+ )
+
+ state = EvoState(
+ mean=initialization,
+ sigma=params.sigma_init,
+ opt_state=self.optimizer.initialize(params.opt_params),
+ grad_subspace=grad_subspace,
+ best_member=initialization,
+ )
+ return state
+
+ def ask_strategy(
+ self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
+ ) -> Tuple[chex.Array, EvoState]:
+ """`ask` for new parameter candidates to evaluate next."""
+ a = state.sigma * jnp.sqrt(params.alpha / self.num_dims)
+ c = state.sigma * jnp.sqrt((1.0 - params.alpha) / self.subspace_dims)
+ key_full, key_sub = jax.random.split(rng, 2)
+ eps_full = jax.random.normal(
+ key_full, shape=(self.num_dims, int(self.popsize / 2))
+ )
+ eps_subspace = jax.random.normal(
+ key_sub, shape=(self.subspace_dims, int(self.popsize / 2))
+ )
+ Q, _ = jnp.linalg.qr(state.grad_subspace)
+ # Antithetic sampling of noise
+ z_plus = a * eps_full + c * jnp.dot(Q, eps_subspace)
+ z_plus = jnp.swapaxes(z_plus, 0, 1)
+ z = jnp.concatenate([z_plus, -1.0 * z_plus])
+ x = state.mean + z
+ return x, state
+
+ @partial(jax.jit, static_argnums=(0,))
+ def tell(
+ self,
+ x: chex.Array,
+ fitness: chex.Array,
+ state: EvoState,
+ params: Optional[EvoParams] = None,
+ gradient: Optional[chex.Array] = None,
+ ) -> EvoState:
+ """`tell` performance data for strategy state update."""
+ # Use default hyperparameters if no other settings provided
+ if params is None:
+ params = self.default_params
+
+ # Flatten params if using param reshaper for ES update
+ if self.use_param_reshaper:
+ x = self.param_reshaper.flatten(x)
+
+ # Perform fitness reshaping inside of strategy tell call (if desired)
+ fitness_re = self.fitness_shaper.apply(x, fitness)
+
+ # Reconstruct noise from last mean/std estimates
+ noise = (x - state.mean) / state.sigma
+ noise_1 = noise[: int(self.popsize / 2)]
+ fit_1 = fitness_re[: int(self.popsize / 2)]
+ fit_2 = fitness_re[int(self.popsize / 2) :]
+ fit_diff = fit_1 - fit_2
+ fit_diff_noise = jnp.dot(noise_1.T, fit_diff)
+ theta_grad = (params.beta / self.popsize) * fit_diff_noise
+
+ # Add grad FIFO-style to subspace archive (only if provided else FD)
+ grad_subspace = jnp.zeros((self.subspace_dims, self.num_dims))
+ grad_subspace = grad_subspace.at[:-1, :].set(state.grad_subspace[1:, :])
+ if gradient is not None:
+ grad_subspace = grad_subspace.at[-1, :].set(gradient)
+ else:
+ grad_subspace = grad_subspace.at[-1, :].set(theta_grad)
+ state = state.replace(grad_subspace=grad_subspace)
+
+ # Grad update using optimizer instance - decay lrate if desired
+ mean, opt_state = self.optimizer.step(
+ state.mean, theta_grad, state.opt_state, params.opt_params
+ )
+ opt_state = self.optimizer.update(opt_state, params.opt_params)
+
+ # Update lrate and standard deviation based on min and decay
+ sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit)
+ state = state.replace(mean=mean, sigma=sigma, opt_state=opt_state)
+
+ # Check if there is a new best member & update trackers
+ best_member, best_fitness = get_best_fitness_member(x, fitness, state)
+ return state.replace(
+ best_member=best_member,
+ best_fitness=best_fitness,
+ gen_counter=state.gen_counter + 1,
+ )
diff --git a/evosax/strategies/indep_iamalgam.py b/evosax/strategies/indep_iamalgam.py
index 70faed4..ef73628 100644
--- a/evosax/strategies/indep_iamalgam.py
+++ b/evosax/strategies/indep_iamalgam.py
@@ -1,8 +1,9 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
+from ..utils import exp_decay
from .full_iamalgam import (
anticipated_mean_shift,
adaptive_variance_scaling,
@@ -44,11 +45,21 @@ class EvoParams:
class Indep_iAMaLGaM(Strategy):
- def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.35):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.35,
+ sigma_init: float = 0.0,
+ sigma_decay: float = 0.99,
+ sigma_limit: float = 0.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""(Iterative) AMaLGaM (Bosman et al., 2013) - Diagonal Covariance
Reference: https://tinyurl.com/y9fcccx2
"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
@@ -61,6 +72,11 @@ def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.35):
self.ams_popsize = int(alpha_ams * (self.popsize - 1))
self.strategy_name = "Indep_iAMaLGaM"
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
@@ -77,8 +93,13 @@ def params_strategy(self) -> EvoParams:
/ (self.num_dims ** a_2_shift)
)
- params = EvoParams(eta_sigma=eta_sigma, eta_shift=eta_shift)
- return params
+ return EvoParams(
+ eta_sigma=eta_sigma,
+ eta_shift=eta_shift,
+ sigma_init=self.sigma_init,
+ sigma_decay=self.sigma_decay,
+ sigma_limit=self.sigma_limit,
+ )
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
@@ -158,8 +179,7 @@ def tell_strategy(
C = update_cov_amalgam(members_elite, state.C, mean, params.eta_sigma)
# Decay isotropic part of Gaussian search distribution
- sigma = state.sigma * params.sigma_decay
- sigma = jnp.maximum(sigma, params.sigma_limit)
+ sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit)
return state.replace(
c_mult=c_mult,
nis_counter=nis_counter,
diff --git a/evosax/strategies/ipop_cma_es.py b/evosax/strategies/ipop_cma_es.py
index bdbb071..65d6fca 100644
--- a/evosax/strategies/ipop_cma_es.py
+++ b/evosax/strategies/ipop_cma_es.py
@@ -1,6 +1,6 @@
import jax
import chex
-from typing import Tuple, Optional
+from typing import Tuple, Optional, Union
from functools import partial
from .cma_es import CMA_ES
from ..restarts.restarter import WrapperState, WrapperParams
@@ -19,14 +19,27 @@ class RestartParams:
class IPOP_CMA_ES(object):
- def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.5,
+ sigma_init: float = 1.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""IPOP-CMA-ES (Auer & Hansen, 2005).
Reference: http://www.cmap.polytechnique.fr/~nikolaus.hansen/cec2005ipopcmaes.pdf
"""
self.strategy_name = "IPOP_CMA_ES"
# Instantiate base strategy & wrap it with restart wrapper
self.strategy = CMA_ES(
- num_dims=num_dims, popsize=popsize, elite_ratio=elite_ratio
+ popsize=popsize,
+ num_dims=num_dims,
+ pholder_params=pholder_params,
+ elite_ratio=elite_ratio,
+ sigma_init=sigma_init,
+ **fitness_kwargs
)
from ..restarts import IPOP_Restarter
from ..restarts.termination import cma_criterion, spread_criterion
diff --git a/evosax/strategies/lm_ma_es.py b/evosax/strategies/lm_ma_es.py
index 7c4d033..326b299 100644
--- a/evosax/strategies/lm_ma_es.py
+++ b/evosax/strategies/lm_ma_es.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from .cma_es import get_cma_elite_weights
from flax import struct
@@ -41,21 +41,27 @@ class EvoParams:
class LM_MA_ES(Strategy):
def __init__(
self,
- num_dims: int,
popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
elite_ratio: float = 0.5,
memory_size: int = 10,
+ sigma_init: float = 1.0,
+ **fitness_kwargs: Union[bool, int, float]
):
"""Limited Memory MA-ES (Loshchilov et al., 2017)
Reference: https://arxiv.org/pdf/1705.06693.pdf
"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
self.memory_size = memory_size
self.strategy_name = "LM_MA_ES"
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
@@ -85,6 +91,7 @@ def params_strategy(self) -> EvoParams:
d_sigma=d_sigma,
chi_n=chi_n,
mu_w=mu_w,
+ sigma_init=self.sigma_init,
)
return params
diff --git a/evosax/strategies/ma_es.py b/evosax/strategies/ma_es.py
index 1e65e92..86a7cfc 100644
--- a/evosax/strategies/ma_es.py
+++ b/evosax/strategies/ma_es.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from .cma_es import get_cma_elite_weights
from flax import struct
@@ -36,16 +36,27 @@ class EvoParams:
class MA_ES(Strategy):
- def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.5,
+ sigma_init: float = 1.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""MA-ES (Bayer & Sendhoff, 2017)
Reference: https://www.honda-ri.de/pubs/pdf/3376.pdf
"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
self.strategy_name = "MA_ES"
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
@@ -75,6 +86,7 @@ def params_strategy(self) -> EvoParams:
c_sigma=c_sigma,
d_sigma=d_sigma,
chi_n=chi_n,
+ sigma_init=self.sigma_init,
)
return params
diff --git a/evosax/strategies/mr15_ga.py b/evosax/strategies/mr15_ga.py
new file mode 100644
index 0000000..067731f
--- /dev/null
+++ b/evosax/strategies/mr15_ga.py
@@ -0,0 +1,139 @@
+import jax
+import jax.numpy as jnp
+import chex
+from typing import Tuple, Optional, Union
+from ..strategy import Strategy
+from .simple_ga import single_mate
+from flax import struct
+
+
+@struct.dataclass
+class EvoState:
+ mean: chex.Array
+ archive: chex.Array
+ fitness: chex.Array
+ sigma: chex.Array
+ best_member: chex.Array
+ best_fitness: float = jnp.finfo(jnp.float32).max
+ gen_counter: int = 0
+
+
+@struct.dataclass
+class EvoParams:
+ cross_over_rate: float = 0.0
+ sigma_init: float = 0.07
+ sigma_ratio: float = 0.15
+ init_min: float = 0.0
+ init_max: float = 0.0
+ clip_min: float = -jnp.finfo(jnp.float32).max
+ clip_max: float = jnp.finfo(jnp.float32).max
+
+
+class MR15_GA(Strategy):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.0,
+ sigma_ratio: float = 0.15,
+ sigma_init: float = 0.1,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
+ """1/5 MR Genetic Algorithm (Rechenberg, 1987)
+ Reference: https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8
+ """
+
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
+ self.elite_ratio = elite_ratio
+ self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
+ self.strategy_name = "MR15_GA"
+
+ # Set core kwargs es_params
+ self.sigma_ratio = sigma_ratio # no. mutation that have to improve
+ self.sigma_init = sigma_init
+
+ @property
+ def params_strategy(self) -> EvoParams:
+ """Return default parameters of evolution strategy."""
+ return EvoParams(
+ sigma_init=self.sigma_init, sigma_ratio=self.sigma_ratio
+ )
+
+ def initialize_strategy(
+ self, rng: chex.PRNGKey, params: EvoParams
+ ) -> EvoState:
+ """`initialize` the differential evolution strategy."""
+ initialization = jax.random.uniform(
+ rng,
+ (self.elite_popsize, self.num_dims),
+ minval=params.init_min,
+ maxval=params.init_max,
+ )
+ state = EvoState(
+ mean=initialization.mean(axis=0),
+ archive=initialization,
+ fitness=jnp.zeros(self.elite_popsize) + jnp.finfo(jnp.float32).max,
+ sigma=params.sigma_init,
+ best_member=initialization.mean(axis=0),
+ )
+ return state
+
+ def ask_strategy(
+ self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
+ ) -> Tuple[chex.Array, EvoState]:
+ """
+ `ask` for new proposed candidates to evaluate next.
+ 1. For each member of elite:
+ - Sample two current elite members (a & b)
+ - Cross over all dims of a with corresponding one from b
+ if random number > co-rate
+ - Additionally add noise on top of all elite parameters
+ """
+ rng, rng_eps, rng_idx_a, rng_idx_b = jax.random.split(rng, 4)
+ rng_mate = jax.random.split(rng, self.popsize)
+ epsilon = (
+ jax.random.normal(rng_eps, (self.popsize, self.num_dims))
+ * state.sigma
+ )
+ elite_ids = jnp.arange(self.elite_popsize)
+ idx_a = jax.random.choice(rng_idx_a, elite_ids, (self.popsize,))
+ idx_b = jax.random.choice(rng_idx_b, elite_ids, (self.popsize,))
+ members_a = state.archive[idx_a]
+ members_b = state.archive[idx_b]
+ x = jax.vmap(single_mate, in_axes=(0, 0, 0, None))(
+ rng_mate, members_a, members_b, params.cross_over_rate
+ )
+ x += epsilon
+ return jnp.squeeze(x), state
+
+ def tell_strategy(
+ self,
+ x: chex.Array,
+ fitness: chex.Array,
+ state: EvoState,
+ params: EvoParams,
+ ) -> EvoState:
+ """
+ `tell` update to ES state.
+ If fitness of y <= fitness of x -> replace in population.
+ """
+ # Combine current elite and recent generation info
+ fitness = jnp.concatenate([fitness, state.fitness])
+ solution = jnp.concatenate([x, state.archive])
+ # Select top elite from total archive info
+ idx = jnp.argsort(fitness)[0 : self.elite_popsize]
+ fitness = fitness[idx]
+ archive = solution[idx]
+ # Update mutation sigma - double if more than 15% improved
+ good_mutations_ratio = jnp.mean(fitness < state.best_fitness)
+ increase_sigma = good_mutations_ratio > params.sigma_ratio
+ sigma = jax.lax.select(
+ increase_sigma, 2 * state.sigma, 0.5 * state.sigma
+ )
+ # Set mean to best member seen so far
+ improved = fitness[0] < state.best_fitness
+ best_mean = jax.lax.select(improved, archive[0], state.best_member)
+ return state.replace(
+ fitness=fitness, archive=archive, sigma=sigma, mean=best_mean
+ )
diff --git a/evosax/strategies/open_es.py b/evosax/strategies/open_es.py
index 8461950..3a54af6 100755
--- a/evosax/strategies/open_es.py
+++ b/evosax/strategies/open_es.py
@@ -1,9 +1,9 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
-from ..utils import GradientOptimizer, OptState, OptParams
+from ..utils import GradientOptimizer, OptState, OptParams, exp_decay
from flax import struct
@@ -30,20 +30,51 @@ class EvoParams:
class OpenES(Strategy):
- def __init__(self, num_dims: int, popsize: int, opt_name: str = "adam"):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ opt_name: str = "adam",
+ lrate_init: float = 0.05,
+ lrate_decay: float = 1.0,
+ lrate_limit: float = 0.001,
+ sigma_init: float = 0.03,
+ sigma_decay: float = 1.0,
+ sigma_limit: float = 0.01,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""OpenAI-ES (Salimans et al. (2017)
Reference: https://arxiv.org/pdf/1703.03864.pdf
Inspired by: https://github.com/hardmaru/estool/blob/master/es.py"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert not self.popsize & 1, "Population size must be even"
- assert opt_name in ["sgd", "adam", "rmsprop", "clipup"]
+ assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
self.optimizer = GradientOptimizer[opt_name](self.num_dims)
self.strategy_name = "OpenES"
+ # Set core kwargs es_params (lrate/sigma schedules)
+ self.lrate_init = lrate_init
+ self.lrate_decay = lrate_decay
+ self.lrate_limit = lrate_limit
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
- return EvoParams(opt_params=self.optimizer.default_params)
+ opt_params = self.optimizer.default_params.replace(
+ lrate_init=self.lrate_init,
+ lrate_decay=self.lrate_decay,
+ lrate_limit=self.lrate_limit,
+ )
+ return EvoParams(
+ opt_params=opt_params,
+ sigma_init=self.sigma_init,
+ sigma_decay=self.sigma_decay,
+ sigma_limit=self.sigma_limit,
+ )
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
@@ -95,6 +126,5 @@ def tell_strategy(
state.mean, theta_grad, state.opt_state, params.opt_params
)
opt_state = self.optimizer.update(opt_state, params.opt_params)
- sigma = state.sigma * params.sigma_decay
- sigma = jnp.maximum(sigma, params.sigma_limit)
+ sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit)
return state.replace(mean=mean, sigma=sigma, opt_state=opt_state)
diff --git a/evosax/strategies/pbt.py b/evosax/strategies/pbt.py
index 1f41cdb..41c7dd9 100755
--- a/evosax/strategies/pbt.py
+++ b/evosax/strategies/pbt.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from flax import struct
@@ -27,10 +27,16 @@ class EvoParams:
class PBT(Strategy):
- def __init__(self, num_dims: int, popsize: int):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Synchronous Population-Based Training (Jaderberg et al., 2017)
Reference: https://arxiv.org/abs/1711.09846"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
self.strategy_name = "PBT"
@property
diff --git a/evosax/strategies/persistent_es.py b/evosax/strategies/persistent_es.py
index 4d9f86f..ee50fc1 100644
--- a/evosax/strategies/persistent_es.py
+++ b/evosax/strategies/persistent_es.py
@@ -1,9 +1,9 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
-from ..utils import GradientOptimizer, OptState, OptParams
+from ..utils import GradientOptimizer, OptState, OptParams, exp_decay
from flax import struct
@@ -34,21 +34,52 @@ class EvoParams:
class PersistentES(Strategy):
- def __init__(self, num_dims: int, popsize: int, opt_name: str = "adam"):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ opt_name: str = "adam",
+ lrate_init: float = 0.05,
+ lrate_decay: float = 1.0,
+ lrate_limit: float = 0.001,
+ sigma_init: float = 0.03,
+ sigma_decay: float = 1.0,
+ sigma_limit: float = 0.01,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Persistent ES (Vicol et al., 2021).
Reference: http://proceedings.mlr.press/v139/vicol21a.html
Inspired by: http://proceedings.mlr.press/v139/vicol21a/vicol21a-supp.pdf
"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert not self.popsize & 1, "Population size must be even"
- assert opt_name in ["sgd", "adam", "rmsprop", "clipup"]
+ assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
self.optimizer = GradientOptimizer[opt_name](self.num_dims)
self.strategy_name = "PersistentES"
+ # Set core kwargs es_params (lrate/sigma schedules)
+ self.lrate_init = lrate_init
+ self.lrate_decay = lrate_decay
+ self.lrate_limit = lrate_limit
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
- return EvoParams(opt_params=self.optimizer.default_params)
+ opt_params = self.optimizer.default_params.replace(
+ lrate_init=self.lrate_init,
+ lrate_decay=self.lrate_decay,
+ lrate_limit=self.lrate_limit,
+ )
+ return EvoParams(
+ opt_params=opt_params,
+ sigma_init=self.sigma_init,
+ sigma_decay=self.sigma_decay,
+ sigma_limit=self.sigma_limit,
+ )
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
@@ -105,8 +136,7 @@ def tell_strategy(
opt_state = self.optimizer.update(opt_state, params.opt_params)
inner_step_counter = state.inner_step_counter + params.K
- sigma = state.sigma * params.sigma_decay
- sigma = jnp.maximum(sigma, params.sigma_limit)
+ sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit)
# Reset accumulated antithetic noise memory if done with inner problem
reset = inner_step_counter >= params.T
inner_step_counter = jax.lax.select(reset, 0, inner_step_counter)
diff --git a/evosax/strategies/pgpe.py b/evosax/strategies/pgpe.py
index 371da7b..6057fb7 100755
--- a/evosax/strategies/pgpe.py
+++ b/evosax/strategies/pgpe.py
@@ -1,9 +1,9 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
-from ..utils import GradientOptimizer, OptState, OptParams
+from ..utils import GradientOptimizer, OptState, OptParams, exp_decay
from flax import struct
@@ -34,28 +34,54 @@ class EvoParams:
class PGPE(Strategy):
def __init__(
self,
- num_dims: int,
popsize: int,
- elite_ratio: float = 0.1,
- opt_name: str = "sgd",
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 1.0,
+ opt_name: str = "adam",
+ lrate_init: float = 0.05,
+ lrate_decay: float = 1.0,
+ lrate_limit: float = 0.001,
+ sigma_init: float = 0.03,
+ sigma_decay: float = 1.0,
+ sigma_limit: float = 0.01,
+ **fitness_kwargs: Union[bool, int, float]
):
"""PGPE (e.g. Sehnke et al., 2010)
Reference: https://tinyurl.com/2p8bn956
Inspired by: https://github.com/hardmaru/estool/blob/master/es.py"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize / 2 * self.elite_ratio))
assert not self.popsize & 1, "Population size must be even"
- assert opt_name in ["sgd", "adam", "rmsprop", "clipup"]
+ assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"]
self.optimizer = GradientOptimizer[opt_name](self.num_dims)
self.strategy_name = "PGPE"
+ # Set core kwargs es_params (lrate/sigma schedules)
+ self.lrate_init = lrate_init
+ self.lrate_decay = lrate_decay
+ self.lrate_limit = lrate_limit
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
- return EvoParams(opt_params=self.optimizer.default_params)
+ opt_params = self.optimizer.default_params.replace(
+ lrate_init=self.lrate_init,
+ lrate_decay=self.lrate_decay,
+ lrate_limit=self.lrate_limit,
+ )
+ return EvoParams(
+ opt_params=opt_params,
+ sigma_init=self.sigma_init,
+ sigma_decay=self.sigma_decay,
+ sigma_limit=self.sigma_limit,
+ )
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
@@ -85,7 +111,7 @@ def ask_strategy(
(int(self.popsize / 2), self.num_dims),
)
z = jnp.concatenate([z_plus, -1.0 * z_plus])
- x = state.mean + z * state.sigma.reshape(1, self.num_dims)
+ x = state.mean + state.sigma * z
return x, state
def tell_strategy(
@@ -98,9 +124,9 @@ def tell_strategy(
"""Update both mean and dim.-wise isotropic Gaussian scale."""
# Reconstruct noise from last mean/std estimates
noise = (x - state.mean) / state.sigma
- noise_1 = noise[: int(self.popsize / 2)]
- fit_1 = fitness[: int(self.popsize / 2)]
- fit_2 = fitness[int(self.popsize / 2) :]
+ noise_1 = noise[::2]
+ fit_1 = fitness[::2]
+ fit_2 = fitness[1::2]
elite_idx = jnp.minimum(fit_1, fit_2).argsort()[: self.elite_popsize]
fitness_elite = jnp.concatenate([fit_1[elite_idx], fit_2[elite_idx]])
@@ -120,18 +146,18 @@ def tell_strategy(
- (state.sigma * state.sigma).reshape(1, self.num_dims)
) / state.sigma.reshape(1, self.num_dims)
rS = (fit_1 + fit_2) / 2.0 - jnp.mean(fitness_elite)
- delta_sigma = (jnp.dot(rS, S)) / self.elite_popsize
- change_sigma = params.sigma_lrate * delta_sigma
- change_sigma = jnp.minimum(
- change_sigma, params.sigma_max_change * state.sigma
- )
- change_sigma = jnp.maximum(
- change_sigma, -params.sigma_max_change * state.sigma
- )
+ delta_sigma = jnp.dot(rS, S) / (self.elite_popsize / 2)
+
+ allowed_delta = jnp.abs(state.sigma) * params.sigma_max_change
+ min_allowed = state.sigma - allowed_delta
+ max_allowed = state.sigma + allowed_delta
# adjust sigma according to the adaptive sigma calculation
# for stability, don't let sigma move more than 20% of orig value
- sigma = state.sigma - change_sigma
- sigma = sigma * params.sigma_decay
- sigma = jnp.maximum(sigma, params.sigma_limit)
+ sigma = jnp.clip(
+ state.sigma - params.sigma_lrate * delta_sigma,
+ min_allowed,
+ max_allowed,
+ )
+ sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit)
return state.replace(mean=mean, sigma=sigma, opt_state=opt_state)
diff --git a/evosax/strategies/pso.py b/evosax/strategies/pso.py
index 67f11d0..35128ce 100755
--- a/evosax/strategies/pso.py
+++ b/evosax/strategies/pso.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from flax import struct
@@ -31,10 +31,16 @@ class EvoParams:
class PSO(Strategy):
- def __init__(self, num_dims: int, popsize: int):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Particle Swarm Optimization (Kennedy & Eberhart, 1995)
Reference: https://ieeexplore.ieee.org/document/488968"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
self.strategy_name = "PSO"
@property
diff --git a/evosax/strategies/rm_es.py b/evosax/strategies/rm_es.py
index d8bf0b6..105bc39 100644
--- a/evosax/strategies/rm_es.py
+++ b/evosax/strategies/rm_es.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from flax import struct
@@ -62,21 +62,27 @@ def get_cma_elite_weights(
class RmES(Strategy):
def __init__(
self,
- num_dims: int,
popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
elite_ratio: float = 0.5,
memory_size: int = 10,
+ sigma_init: float = 1.0,
+ **fitness_kwargs: Union[bool, int, float]
):
"""Rank-m ES (Li & Zhang, 2017)
Reference: https://ieeexplore.ieee.org/document/8080257
"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
self.memory_size = memory_size # number of ranks
self.strategy_name = "RmES"
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
@@ -89,6 +95,7 @@ def params_strategy(self) -> EvoParams:
c_c=c_c,
c_sigma=jnp.minimum(2 / (self.num_dims + 7), 0.05),
mu_eff=mu_eff,
+ sigma_init=self.sigma_init,
)
return params
diff --git a/evosax/strategies/samr_ga.py b/evosax/strategies/samr_ga.py
new file mode 100644
index 0000000..c287ac1
--- /dev/null
+++ b/evosax/strategies/samr_ga.py
@@ -0,0 +1,110 @@
+import jax
+import jax.numpy as jnp
+import chex
+from typing import Tuple, Optional, Union
+from ..strategy import Strategy
+from flax import struct
+
+
+@struct.dataclass
+class EvoState:
+ mean: chex.Array
+ archive: chex.Array
+ fitness: chex.Array
+ sigma: chex.Array
+ best_member: chex.Array
+ best_fitness: float = jnp.finfo(jnp.float32).max
+ gen_counter: int = 0
+
+
+@struct.dataclass
+class EvoParams:
+ sigma_init: float = 0.07
+ sigma_meta: float = 2.0
+ sigma_best_limit: float = 0.0001
+ init_min: float = 0.0
+ init_max: float = 0.0
+ clip_min: float = -jnp.finfo(jnp.float32).max
+ clip_max: float = jnp.finfo(jnp.float32).max
+
+
+class SAMR_GA(Strategy):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
+ """Self-Adaptation Mutation Rate GA."""
+
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
+ self.elite_ratio = elite_ratio
+ self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
+ self.strategy_name = "SAMR_GA"
+
+ @property
+ def params_strategy(self) -> EvoParams:
+ """Return default parameters of evolution strategy."""
+ return EvoParams()
+
+ def initialize_strategy(
+ self, rng: chex.PRNGKey, params: EvoParams
+ ) -> EvoState:
+ """`initialize` the differential evolution strategy."""
+ initialization = jax.random.uniform(
+ rng,
+ (self.elite_popsize, self.num_dims),
+ minval=params.init_min,
+ maxval=params.init_max,
+ )
+ state = EvoState(
+ mean=initialization.mean(axis=0),
+ archive=initialization,
+ fitness=jnp.zeros(self.elite_popsize) + jnp.finfo(jnp.float32).max,
+ sigma=jnp.zeros(self.elite_popsize) + params.sigma_init,
+ best_member=initialization.mean(axis=0),
+ )
+ return state
+
+ def ask_strategy(
+ self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
+ ) -> Tuple[chex.Array, EvoState]:
+ """`ask` for new proposed candidates to evaluate next."""
+ rng, rng_idx, rng_eps_x, rng_eps_s = jax.random.split(rng, 4)
+ eps_x = jax.random.normal(rng_eps_x, (self.popsize, self.num_dims))
+ eps_s = jax.random.uniform(
+ rng_eps_s, (self.popsize,), minval=-1, maxval=1
+ )
+ idx = jax.random.choice(
+ rng_idx, jnp.arange(self.elite_popsize), (self.popsize - 1,)
+ )
+ x = jnp.concatenate([state.archive[0][None, :], state.archive[idx]])
+ sigma_0 = jnp.array(
+ [jnp.maximum(params.sigma_best_limit, state.sigma[0])]
+ )
+ sigma = jnp.concatenate([sigma_0, state.sigma[idx]])
+ sigma_gen = sigma * params.sigma_meta ** eps_s
+ x += sigma_gen[:, None] * eps_x
+ return x, state.replace(archive=x, sigma=sigma_gen)
+
+ def tell_strategy(
+ self,
+ x: chex.Array,
+ fitness: chex.Array,
+ state: EvoState,
+ params: EvoParams,
+ ) -> EvoState:
+ """`tell` update to ES state."""
+ idx = jnp.argsort(fitness)[: self.elite_popsize]
+ fitness = fitness[idx]
+ archive = x[idx]
+ sigma = state.sigma[idx]
+
+ # Set mean to best member seen so far
+ improved = fitness[0] < state.best_fitness
+ best_mean = jax.lax.select(improved, archive[0], state.best_member)
+ return state.replace(
+ fitness=fitness, archive=archive, sigma=sigma, mean=best_mean
+ )
diff --git a/evosax/strategies/sep_cma_es.py b/evosax/strategies/sep_cma_es.py
index 3fbc2c6..a80f2ab 100644
--- a/evosax/strategies/sep_cma_es.py
+++ b/evosax/strategies/sep_cma_es.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple, Optional
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from ..utils.eigen_decomp import diag_eigen_decomp
from flax import struct
@@ -57,17 +57,28 @@ def get_cma_elite_weights(
class Sep_CMA_ES(Strategy):
- def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.5,
+ sigma_init: float = 1.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Separable CMA-ES (e.g. Ros & Hansen, 2008)
Reference: https://hal.inria.fr/inria-00287367/document
Inspired by: github.com/CyberAgentAILab/cmaes/blob/main/cmaes/_sepcma.py
"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
assert 0 <= elite_ratio <= 1
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
self.strategy_name = "Sep_CMA_ES"
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
@@ -107,6 +118,7 @@ def params_strategy(self) -> EvoParams:
d_sigma=d_sigma,
c_c=c_c,
chi_n=chi_n,
+ sigma_init=self.sigma_init,
)
return params
diff --git a/evosax/strategies/sim_anneal.py b/evosax/strategies/sim_anneal.py
index 5488abc..5b73bec 100644
--- a/evosax/strategies/sim_anneal.py
+++ b/evosax/strategies/sim_anneal.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from flax import struct
@@ -33,17 +33,35 @@ class EvoParams:
class SimAnneal(Strategy):
- def __init__(self, num_dims: int, popsize: int):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ sigma_init: float = 0.03,
+ sigma_decay: float = 1.0,
+ sigma_limit: float = 0.01,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Simulated Annealing (Rasdi Rere et al., 2015)
Reference: https://www.sciencedirect.com/science/article/pii/S1877050915035759
"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
self.strategy_name = "SimAnneal"
+ # Set core kwargs es_params (lrate/sigma schedules)
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
- return EvoParams()
+ return EvoParams(
+ sigma_init=self.sigma_init,
+ sigma_decay=self.sigma_decay,
+ sigma_limit=self.sigma_limit,
+ )
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
diff --git a/evosax/strategies/simple_es.py b/evosax/strategies/simple_es.py
index 077b6b5..ac1d57d 100755
--- a/evosax/strategies/simple_es.py
+++ b/evosax/strategies/simple_es.py
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from flax import struct
@@ -28,20 +28,31 @@ class EvoParams:
class SimpleES(Strategy):
- def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.5,
+ sigma_init: float = 1.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Simple Gaussian Evolution Strategy (Rechenberg, 1975)
Reference: https://onlinelibrary.wiley.com/doi/abs/10.1002/fedr.19750860506
Inspired by: https://github.com/hardmaru/estool/blob/master/es.py"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
self.strategy_name = "SimpleES"
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
# Only parents have positive weight - equal weighting!
- return EvoParams()
+ return EvoParams(sigma_init=self.sigma_init)
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
diff --git a/evosax/strategies/simple_ga.py b/evosax/strategies/simple_ga.py
index fb4e2d3..17ec9a5 100755
--- a/evosax/strategies/simple_ga.py
+++ b/evosax/strategies/simple_ga.py
@@ -1,8 +1,9 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
+from ..utils import exp_decay
from flax import struct
@@ -19,9 +20,9 @@ class EvoState:
@struct.dataclass
class EvoParams:
- cross_over_rate: float = 0.5
+ cross_over_rate: float = 0.0
sigma_init: float = 0.07
- sigma_decay: float = 0.999
+ sigma_decay: float = 1.0
sigma_limit: float = 0.01
init_min: float = 0.0
init_max: float = 0.0
@@ -30,20 +31,39 @@ class EvoParams:
class SimpleGA(Strategy):
- def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ elite_ratio: float = 0.5,
+ sigma_init: float = 0.1,
+ sigma_decay: float = 1.0,
+ sigma_limit: float = 0.01,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Simple Genetic Algorithm (Such et al., 2017)
Reference: https://arxiv.org/abs/1712.06567
Inspired by: https://github.com/hardmaru/estool/blob/master/es.py"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
self.elite_ratio = elite_ratio
self.elite_popsize = max(1, int(self.popsize * self.elite_ratio))
self.strategy_name = "SimpleGA"
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+ self.sigma_decay = sigma_decay
+ self.sigma_limit = sigma_limit
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
- return EvoParams()
+ return EvoParams(
+ sigma_init=self.sigma_init,
+ sigma_decay=self.sigma_decay,
+ sigma_limit=self.sigma_limit,
+ )
def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
@@ -111,15 +131,12 @@ def tell_strategy(
fitness = fitness[idx]
archive = solution[idx]
# Update mutation epsilon - multiplicative decay
- sigma = jax.lax.select(
- state.sigma > params.sigma_limit,
- state.sigma * params.sigma_decay,
- state.sigma,
- )
- # Keep mean across stored archive around for evaluation protocol
- mean = archive.mean(axis=0)
+ sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit)
+ # Set mean to best member seen so far
+ improved = fitness[0] < state.best_fitness
+ best_mean = jax.lax.select(improved, archive[0], state.best_member)
return state.replace(
- fitness=fitness, archive=archive, sigma=sigma, mean=mean
+ fitness=fitness, archive=archive, sigma=sigma, mean=best_mean
)
diff --git a/evosax/strategies/snes.py b/evosax/strategies/snes.py
new file mode 100644
index 0000000..77d5b87
--- /dev/null
+++ b/evosax/strategies/snes.py
@@ -0,0 +1,135 @@
+import jax
+import jax.numpy as jnp
+import chex
+from typing import Tuple, Optional, Union
+from ..strategy import Strategy
+from flax import struct
+from flax import linen as nn
+
+
+@struct.dataclass
+class EvoState:
+ mean: chex.Array
+ sigma: chex.Array
+ weights: chex.Array
+ best_member: chex.Array
+ best_fitness: float = jnp.finfo(jnp.float32).max
+ gen_counter: int = 0
+
+
+@struct.dataclass
+class EvoParams:
+ lrate_mean: float = 1.0
+ lrate_sigma: float = 1.0
+ sigma_init: float = 1.0
+ temperature: float = 0.0
+ init_min: float = 0.0
+ init_max: float = 0.0
+ clip_min: float = -jnp.finfo(jnp.float32).max
+ clip_max: float = jnp.finfo(jnp.float32).max
+
+
+def get_recombination_weights(popsize: int, use_baseline: bool = True):
+ """Get recombination weights for different ranks."""
+
+ def get_weight(i):
+ return jnp.maximum(0, jnp.log(popsize / 2 + 1) - jnp.log(i))
+
+ weights = jax.vmap(get_weight)(jnp.arange(1, popsize + 1))
+ weights_norm = weights / jnp.sum(weights)
+ return weights_norm - use_baseline * (1 / popsize)
+
+
+def get_temp_weights(
+ popsize: int, temperature: float, use_baseline: bool = True
+):
+ """Get weights based on original discovered weights (Lange et al, 2022)."""
+ ranks = jnp.arange(popsize)
+ ranks /= ranks.size - 1
+ ranks = ranks - 0.5
+ weights = nn.softmax(-temperature * ranks)
+ return weights
+
+
+class SNES(Strategy):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ sigma_init: float = 1.0,
+ temperature: float = 0.0, # good values tend to be between 12 and 20
+ **fitness_kwargs: Union[bool, int, float]
+ ):
+ """Separable Exponential Natural ES (Wierstra et al., 2014)
+ Reference: https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
+ """
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
+ self.strategy_name = "SNES"
+
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+ self.temperature = temperature
+
+ @property
+ def params_strategy(self) -> EvoParams:
+ """Return default parameters of evolutionary strategy."""
+ lrate_sigma = (3 + jnp.log(self.num_dims)) / (
+ 5 * jnp.sqrt(self.num_dims)
+ )
+ params = EvoParams(
+ lrate_sigma=lrate_sigma,
+ sigma_init=self.sigma_init,
+ temperature=self.temperature,
+ )
+ return params
+
+ def initialize_strategy(
+ self, rng: chex.PRNGKey, params: EvoParams
+ ) -> EvoState:
+ """`initialize` the evolutionary strategy."""
+ initialization = jax.random.uniform(
+ rng,
+ (self.num_dims,),
+ minval=params.init_min,
+ maxval=params.init_max,
+ )
+ use_des_weights = params.temperature > 0.0
+ weights = jax.lax.select(
+ use_des_weights,
+ get_temp_weights(self.popsize, params.temperature),
+ get_recombination_weights(self.popsize),
+ )
+ state = EvoState(
+ mean=initialization,
+ sigma=params.sigma_init * jnp.ones(self.num_dims),
+ weights=weights.reshape(-1, 1),
+ best_member=initialization,
+ )
+
+ return state
+
+ def ask_strategy(
+ self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
+ ) -> Tuple[chex.Array, EvoState]:
+ """`ask` for new parameter candidates to evaluate next."""
+ noise = jax.random.normal(rng, (self.popsize, self.num_dims))
+ x = state.mean + noise * state.sigma.reshape(1, self.num_dims)
+ return x, state
+
+ def tell_strategy(
+ self,
+ x: chex.Array,
+ fitness: chex.Array,
+ state: EvoState,
+ params: EvoParams,
+ ) -> EvoState:
+ """`tell` performance data for strategy state update."""
+ s = (x - state.mean) / state.sigma
+ ranks = fitness.argsort()
+ sorted_noise = s[ranks]
+ grad_mean = (state.weights * sorted_noise).sum(axis=0)
+ grad_sigma = (state.weights * (sorted_noise ** 2 - 1)).sum(axis=0)
+ mean = state.mean + params.lrate_mean * state.sigma * grad_mean
+ sigma = state.sigma * jnp.exp(params.lrate_sigma / 2 * grad_sigma)
+ return state.replace(mean=mean, sigma=sigma)
diff --git a/evosax/strategies/xnes.py b/evosax/strategies/xnes.py
index 45744e3..609f8f5 100644
--- a/evosax/strategies/xnes.py
+++ b/evosax/strategies/xnes.py
@@ -1,21 +1,20 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple
+from typing import Tuple, Optional, Union
from ..strategy import Strategy
from flax import struct
+from .snes import get_recombination_weights
@struct.dataclass
class EvoState:
mean: chex.Array
sigma: float
- sigma_old: float
- amat: chex.Array
- bmat: chex.Array
+ B: chex.Array
noise: chex.Array
- eta_sigma: float
- utilities: chex.Array
+ lrate_sigma: float
+ weights: chex.Array
best_member: chex.Array
best_fitness: float = jnp.finfo(jnp.float32).max
gen_counter: int = 0
@@ -23,11 +22,13 @@ class EvoState:
@struct.dataclass
class EvoParams:
- eta_mean: float
- eta_sigma_init: float
- eta_bmat: float
- use_adaptive_sampling: bool = False
- use_fitness_shaping: bool = True
+ lrate_mean: float = 1.0
+ lrate_sigma_init: float = 0.1
+ lrate_B: float = 0.1
+ sigma_init: float = 1.0
+ use_adasam: bool = False # Adaptation sampling lrate sigma
+ rho: float = 0.5 # Significance level adaptation sampling
+ c_prime: float = 0.1 # Adaptation sampling step size
init_min: float = 0.0
init_max: float = 0.0
clip_min: float = -jnp.finfo(jnp.float32).max
@@ -35,24 +36,35 @@ class EvoParams:
class xNES(Strategy):
- def __init__(self, num_dims: int, popsize: int):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ sigma_init: float = 1.0,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Exponential Natural ES (Wierstra et al., 2014)
Reference: https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf
Inspired by: https://github.com/chanshing/xnes"""
- super().__init__(num_dims, popsize)
+ super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs)
self.strategy_name = "xNES"
+ # Set core kwargs es_params
+ self.sigma_init = sigma_init
+
@property
def params_strategy(self) -> EvoParams:
"""Return default parameters of evolutionary strategy."""
+ lrate_sigma = (9 + 3 * jnp.log(self.num_dims)) / (
+ 5 * jnp.sqrt(self.num_dims) * self.num_dims
+ )
+ rho = 0.5 - 1.0 / (3 * (self.num_dims + 1))
params = EvoParams(
- eta_mean=1.0,
- eta_sigma_init=3
- * (3 + jnp.log(self.num_dims))
- * (1.0 / (5 * self.num_dims * jnp.sqrt(self.num_dims))),
- eta_bmat=3
- * (3 + jnp.log(self.num_dims))
- * (1.0 / (5 * self.num_dims * jnp.sqrt(self.num_dims))),
+ lrate_sigma_init=lrate_sigma,
+ lrate_B=lrate_sigma,
+ rho=rho,
+ sigma_init=self.sigma_init,
)
return params
@@ -60,33 +72,20 @@ def initialize_strategy(
self, rng: chex.PRNGKey, params: EvoParams
) -> EvoState:
"""`initialize` the evolutionary strategy."""
- amat = jnp.eye(self.num_dims)
- sigma = abs(jax.scipy.linalg.det(amat)) ** (1.0 / self.num_dims)
- bmat = amat * (1.0 / sigma)
- # Utility helper for fitness shaping - doesn't work without?!
- a = jnp.log(1 + 0.5 * self.popsize)
- utilities = jnp.array(
- [jnp.maximum(0, a - jnp.log(k)) for k in range(1, self.popsize + 1)]
- )
- utilities /= jnp.sum(utilities)
- utilities -= 1.0 / self.popsize # broadcast
- utilities = utilities[::-1] # ascending order
-
initialization = jax.random.uniform(
rng,
(self.num_dims,),
minval=params.init_min,
maxval=params.init_max,
)
+ weights = get_recombination_weights(self.popsize)
state = EvoState(
mean=initialization,
- sigma=sigma,
- sigma_old=sigma,
- amat=amat,
- bmat=bmat,
+ B=jnp.eye(self.num_dims) * params.sigma_init,
+ sigma=params.sigma_init,
noise=jnp.zeros((self.popsize, self.num_dims)),
- eta_sigma=params.eta_sigma_init,
- utilities=utilities,
+ lrate_sigma=params.lrate_sigma_init,
+ weights=weights.reshape(-1, 1),
best_member=initialization,
)
@@ -97,7 +96,14 @@ def ask_strategy(
) -> Tuple[chex.Array, EvoState]:
"""`ask` for new parameter candidates to evaluate next."""
noise = jax.random.normal(rng, (self.popsize, self.num_dims))
- x = state.mean + state.sigma * jnp.dot(noise, state.bmat)
+
+ def scale_orient(n, sigma, B):
+ return sigma * B.T @ n
+
+ scaled_noise = jax.vmap(scale_orient, in_axes=(0, None, None))(
+ noise, state.sigma, state.B
+ )
+ x = state.mean + scaled_noise
return x, state.replace(noise=noise)
def tell_strategy(
@@ -108,98 +114,80 @@ def tell_strategy(
params: EvoParams,
) -> EvoState:
"""`tell` performance data for strategy state update."""
- # By default the xNES maximizes the objective
- fitness_re = -fitness
- isort = fitness_re.argsort()
- sorted_fitness = fitness_re[isort]
- sorted_noise = state.noise[isort]
- sorted_candidates = x[isort]
- fitness_shaped = jax.lax.select(
- params.use_fitness_shaping, state.utilities, sorted_fitness
- )
+ ranks = fitness.argsort()
+ sorted_noise = state.noise[ranks]
+ grad_mean = (state.weights * sorted_noise).sum(axis=0)
- use_adasam = jnp.logical_and(
- params.use_adaptive_sampling, state.gen_counter > 1
- ) # sigma_old must be available
- eta_sigma = jax.lax.select(
- use_adasam,
- self.adaptive_sampling(
- state.eta_sigma,
- state.mean,
- state.sigma,
- state.bmat,
- state.sigma_old,
- sorted_candidates,
- state.eta_sigma,
- ),
- state.eta_sigma,
- )
+ def s_grad_m(weight, noise):
+ return weight * (noise @ noise.T - jnp.eye(self.num_dims))
- dj_delta = jnp.dot(fitness_shaped, sorted_noise)
- dj_mmat = (
- jnp.dot(
- sorted_noise.T,
- sorted_noise * fitness_shaped.reshape(self.popsize, 1),
- )
- - jnp.sum(fitness_shaped) * jnp.eye(self.num_dims)
- )
- dj_sigma = jnp.trace(dj_mmat) * (1.0 / self.num_dims)
- dj_bmat = dj_mmat - dj_sigma * jnp.eye(self.num_dims)
+ grad_m = jax.vmap(s_grad_m, in_axes=(0, 0))(
+ state.weights, sorted_noise
+ ).sum(axis=0)
+ grad_sigma = jnp.trace(grad_m) / self.num_dims
+ grad_B = grad_m - grad_sigma * jnp.eye(self.num_dims)
- sigma_old = state.sigma
- mean = state.mean + (
- params.eta_mean * state.sigma * jnp.dot(state.bmat, dj_delta)
+ mean = (
+ state.mean + params.lrate_mean * state.sigma * state.B @ grad_mean
)
- sigma = sigma_old * jnp.exp(0.5 * eta_sigma * dj_sigma)
- bmat = jnp.dot(
- state.bmat,
- jax.scipy.linalg.expm(0.5 * params.eta_bmat * dj_bmat),
+ sigma = state.sigma * jnp.exp(state.lrate_sigma / 2 * grad_sigma)
+ B = state.B * jnp.exp(params.lrate_B / 2 * grad_B)
+
+ lrate_sigma = adaptation_sampling(
+ state.lrate_sigma,
+ params.lrate_sigma_init,
+ mean,
+ B,
+ sigma,
+ state.sigma,
+ sorted_noise,
+ params.c_prime,
+ params.rho,
)
- return state.replace(
- eta_sigma=eta_sigma,
- mean=mean,
- sigma=sigma,
- bmat=bmat,
- sigma_old=sigma_old,
+ lrate_sigma = jax.lax.select(
+ params.use_adasam, lrate_sigma, state.lrate_sigma
)
-
- def adaptive_sampling(
- self,
- eta_sigma: float,
- mu: chex.Array,
- sigma: float,
- bmat: chex.Array,
- sigma_old: float,
- z_try: chex.Array,
- eta_sigma_init: float,
- ) -> float:
- """Adaptation sampling."""
- c = 0.1
- rho = 0.5 - 1.0 / (3 * (self.num_dims + 1)) # empirical
-
- bbmat = jnp.dot(bmat.T, bmat)
- cov = sigma ** 2 * bbmat
- sigma_ = sigma * jnp.sqrt(sigma * (1.0 / sigma_old)) # increase by 1.5
- cov_ = sigma_ ** 2 * bbmat
-
- p0 = jax.scipy.stats.multivariate_normal.logpdf(z_try, mean=mu, cov=cov)
- p1 = jax.scipy.stats.multivariate_normal.logpdf(
- z_try, mean=mu, cov=cov_
+ return state.replace(
+ mean=mean, sigma=sigma, B=B, lrate_sigma=lrate_sigma
)
- w = jnp.exp(p1 - p0)
- # Mann-Whitney. It is assumed z_try was in ascending order.
- n_ = jnp.sum(w)
- u_ = jnp.sum(w * (jnp.arange(self.popsize) + 0.5))
- u_mu = self.popsize * n_ * 0.5
- u_sigma = jnp.sqrt(self.popsize * n_ * (self.popsize + n_ + 1) / 12.0)
- cum = jax.scipy.stats.norm.cdf(u_, loc=u_mu, scale=u_sigma)
-
- decrease = cum < rho
- eta_out = jax.lax.select(
- decrease,
- (1 - c) * eta_sigma + c * eta_sigma_init,
- jnp.minimum(1, (1 + c) * eta_sigma),
- )
- return eta_out
+def adaptation_sampling(
+ lrate_sigma: float,
+ lrate_sigma_init: float,
+ mean: chex.Array,
+ B: chex.Array,
+ sigma: float,
+ sigma_old: float,
+ sorted_noise: chex.Array,
+ c_prime: float,
+ rho: float,
+) -> float:
+ """Adaptation sampling on sigma/std learning rate."""
+ BB = B.T @ B
+ A = sigma ** 2 * BB
+ sigma_prime = sigma * jnp.sqrt(sigma / sigma_old)
+ A_prime = sigma_prime ** 2 * BB
+
+ # Probability ration and u-test - sorted order assumed for noise
+ prob_0 = jax.scipy.stats.multivariate_normal.logpdf(sorted_noise, mean, A)
+ prob_1 = jax.scipy.stats.multivariate_normal.logpdf(
+ sorted_noise, mean, A_prime
+ )
+ w = jnp.exp(prob_1 - prob_0)
+ popsize = sorted_noise.shape[0]
+ n = jnp.sum(w)
+ u = jnp.sum(w * (jnp.arange(popsize) + 0.5))
+ u_mean = popsize * n / 2
+ u_sigma = jnp.sqrt(popsize * n * (popsize + n + 1) / 12)
+ cumulative = jax.scipy.stats.norm.cdf(
+ u, loc=u_mean + 1e-10, scale=u_sigma + 1e-10
+ )
+
+ # Check test significance and update lrate
+ lrate_sigma = jax.lax.select(
+ cumulative < rho,
+ (1 - c_prime) * lrate_sigma + c_prime * lrate_sigma_init,
+ jnp.minimum(1, (1 - c_prime) * lrate_sigma),
+ )
+ return lrate_sigma
diff --git a/evosax/strategy.py b/evosax/strategy.py
index c550cd8..c6eb590 100755
--- a/evosax/strategy.py
+++ b/evosax/strategy.py
@@ -1,10 +1,10 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Tuple, Optional
+from typing import Tuple, Optional, Union
from functools import partial
from flax import struct
-from .utils import get_best_fitness_member
+from .utils import get_best_fitness_member, ParameterReshaper, FitnessShaper
@struct.dataclass
@@ -28,11 +28,30 @@ class EvoParams:
class Strategy(object):
- def __init__(self, num_dims: int, popsize: int):
+ def __init__(
+ self,
+ popsize: int,
+ num_dims: Optional[int] = None,
+ pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None,
+ **fitness_kwargs: Union[bool, int, float]
+ ):
"""Base Class for an Evolution Strategy."""
- self.num_dims = num_dims
self.popsize = popsize
+ # Setup optional parameter reshaper
+ self.use_param_reshaper = pholder_params is not None
+ if self.use_param_reshaper:
+ self.param_reshaper = ParameterReshaper(pholder_params)
+ self.num_dims = self.param_reshaper.total_params
+ else:
+ self.num_dims = num_dims
+ assert (
+ self.num_dims is not None
+ ), "Provide either num_dims or pholder_params to strategy."
+
+ # Setup optional fitness shaper
+ self.fitness_shaper = FitnessShaper(**fitness_kwargs)
+
@property
def default_params(self) -> EvoParams:
"""Return default parameters of evolution strategy."""
@@ -58,7 +77,7 @@ def ask(
rng: chex.PRNGKey,
state: EvoState,
params: Optional[EvoParams] = None,
- ) -> Tuple[chex.Array, EvoState]:
+ ) -> Tuple[Union[chex.Array, chex.ArrayTree], EvoState]:
"""`ask` for new parameter candidates to evaluate next."""
# Use default hyperparameters if no other settings provided
if params is None:
@@ -68,12 +87,18 @@ def ask(
x, state = self.ask_strategy(rng, state, params)
# Clip proposal candidates into allowed range
x_clipped = jnp.clip(jnp.squeeze(x), params.clip_min, params.clip_max)
- return x_clipped, state
+
+ # Reshape parameters into pytrees
+ if self.use_param_reshaper:
+ x_out = self.param_reshaper.reshape(x_clipped)
+ else:
+ x_out = x_clipped
+ return x_out, state
@partial(jax.jit, static_argnums=(0,))
def tell(
self,
- x: chex.Array,
+ x: Union[chex.Array, chex.ArrayTree],
fitness: chex.Array,
state: EvoState,
params: Optional[EvoParams] = None,
@@ -83,11 +108,20 @@ def tell(
if params is None:
params = self.default_params
+ # Flatten params if using param reshaper for ES update
+ if self.use_param_reshaper:
+ x = self.param_reshaper.flatten(x)
+
+ # Perform fitness reshaping inside of strategy tell call (if desired)
+ fitness_re = self.fitness_shaper.apply(x, fitness)
+
# Update the search state based on strategy-specific update
- state = self.tell_strategy(x, fitness, state, params)
+ state = self.tell_strategy(x, fitness_re, state, params)
# Check if there is a new best member & update trackers
- best_member, best_fitness = get_best_fitness_member(x, fitness, state)
+ best_member, best_fitness = get_best_fitness_member(
+ x, fitness, state, self.fitness_shaper.maximize
+ )
return state.replace(
best_member=best_member,
best_fitness=best_fitness,
@@ -115,3 +149,11 @@ def tell_strategy(
) -> EvoState:
"""Search-specific `tell` update. Returns updated state."""
raise NotImplementedError
+
+ def get_eval_params(self, state: EvoState):
+ """Return reshaped parameters to evaluate."""
+ if self.use_param_reshaper:
+ x_out = self.param_reshaper.reshape_single(state.mean)
+ else:
+ x_out = state.mean
+ return x_out
diff --git a/evosax/utils/__init__.py b/evosax/utils/__init__.py
index e3af896..3c09aab 100755
--- a/evosax/utils/__init__.py
+++ b/evosax/utils/__init__.py
@@ -2,7 +2,7 @@
from .es_logger import ESLog
# Import additional utilities for reshaping flat parameters into net dict
-from .reshape_params import ParameterReshaper
+from .reshape_params import ParameterReshaper, ravel_pytree
# Import additional utilities for reshaping fitness
from .reshape_fitness import FitnessShaper
@@ -11,13 +11,23 @@
from .helpers import get_best_fitness_member
# Import Gradient Based Optimizer step functions
-from .optimizer import SGD, Adam, RMSProp, ClipUp, OptState, OptParams
+from .optimizer import (
+ SGD,
+ Adam,
+ RMSProp,
+ ClipUp,
+ Adan,
+ OptState,
+ OptParams,
+ exp_decay,
+)
GradientOptimizer = {
"sgd": SGD,
"adam": Adam,
"rmsprop": RMSProp,
"clipup": ClipUp,
+ "adan": Adan,
}
@@ -25,12 +35,15 @@
"get_best_fitness_member",
"ESLog",
"ParameterReshaper",
+ "ravel_pytree",
"FitnessShaper",
"GradientOptimizer",
"SGD",
"Adam",
"RMSProp",
"ClipUp",
+ "Adan",
"OptState",
"OptParams",
+ "exp_decay",
]
diff --git a/evosax/utils/eigen_decomp.py b/evosax/utils/eigen_decomp.py
index 25e42f9..dc26301 100644
--- a/evosax/utils/eigen_decomp.py
+++ b/evosax/utils/eigen_decomp.py
@@ -4,11 +4,12 @@
def full_eigen_decomp(
- C: chex.Array, B: chex.Array, D: chex.Array
+ C: chex.Array, B: chex.Array, D: chex.Array, gen_counter: int
) -> Tuple[chex.Array, chex.Array, chex.Array]:
"""Perform eigendecomposition of covariance matrix."""
if B is not None and D is not None:
return C, B, D
+ C = C + 1e-10 * (gen_counter == 0)
C = (C + C.T) / 2 # Make sure matrix is symmetric
D2, B = jnp.linalg.eigh(C)
D = jnp.sqrt(jnp.where(D2 < 0, 1e-20, D2))
diff --git a/evosax/utils/evojax_wrapper.py b/evosax/utils/evojax_wrapper.py
new file mode 100644
index 0000000..1c941cb
--- /dev/null
+++ b/evosax/utils/evojax_wrapper.py
@@ -0,0 +1,53 @@
+import chex
+import jax
+import jax.numpy as jnp
+from evojax.algo.base import NEAlgorithm
+from evosax import Strategy
+
+
+class Evosax2JAX_Wrapper(NEAlgorithm):
+ """Wrapper for evosax-style ES for EvoJAX deployment."""
+
+ def __init__(
+ self,
+ evosax_strategy: Strategy,
+ param_size: int,
+ pop_size: int,
+ es_config: dict = {},
+ es_params: dict = {},
+ opt_params: dict = {},
+ seed: int = 42,
+ ):
+ self.es = evosax_strategy(
+ popsize=pop_size, num_dims=param_size, **es_config, **opt_params
+ )
+ self.es_params = self.es.default_params.replace(**es_params)
+ self.pop_size = pop_size
+ self.param_size = param_size
+ self.rand_key = jax.random.PRNGKey(seed=seed)
+ self.rand_key, init_key = jax.random.split(self.rand_key)
+ self.es_state = self.es.initialize(init_key, self.es_params)
+
+ def ask(self) -> chex.Array:
+ """Ask strategy for next set of solution candidates to evaluate."""
+ self.rand_key, ask_key = jax.random.split(self.rand_key)
+ self.params, self.es_state = self.es.ask(
+ ask_key, self.es_state, self.es_params
+ )
+ return self.params
+
+ def tell(self, fitness: chex.Array) -> None:
+ """Tell strategy about most recent fitness evaluations."""
+ self.es_state = self.es.tell(
+ self.params, fitness, self.es_state, self.es_params
+ )
+
+ @property
+ def best_params(self) -> chex.Array:
+ """Return set of mean/best parameters."""
+ return jnp.array(self.es_state.mean, copy=True)
+
+ @best_params.setter
+ def best_params(self, params: chex.Array) -> None:
+ """Update the best parameters stored internally."""
+ self.es_state = self.es_state.replace(mean=jnp.array(params, copy=True))
diff --git a/evosax/utils/helpers.py b/evosax/utils/helpers.py
index 02c3890..21dd54a 100644
--- a/evosax/utils/helpers.py
+++ b/evosax/utils/helpers.py
@@ -5,18 +5,25 @@
def get_best_fitness_member(
- x: chex.Array, fitness: chex.Array, state
+ x: chex.Array, fitness: chex.Array, state, maximize: bool = False
) -> Tuple[chex.Array, float]:
- best_in_gen = jnp.argmin(fitness)
+ """Check if fitness improved & replace in ES state."""
+ fitness_min = jax.lax.select(maximize, -1 * fitness, fitness)
+ max_and_later = maximize and state.gen_counter > 0
+ best_fit_min = jax.lax.select(
+ max_and_later, -1 * state.best_fitness, state.best_fitness
+ )
+ best_in_gen = jnp.argmin(fitness_min)
best_in_gen_fitness, best_in_gen_member = (
- fitness[best_in_gen],
+ fitness_min[best_in_gen],
x[best_in_gen],
)
- replace_best = best_in_gen_fitness < state.best_fitness
+ replace_best = best_in_gen_fitness < best_fit_min
best_fitness = jax.lax.select(
- replace_best, best_in_gen_fitness, state.best_fitness
+ replace_best, best_in_gen_fitness, best_fit_min
)
best_member = jax.lax.select(
replace_best, best_in_gen_member, state.best_member
)
+ best_fitness = jax.lax.select(maximize, -1 * best_fitness, best_fitness)
return best_member, best_fitness
diff --git a/evosax/utils/optimizer.py b/evosax/utils/optimizer.py
index 02e46de..e686c96 100644
--- a/evosax/utils/optimizer.py
+++ b/evosax/utils/optimizer.py
@@ -11,11 +11,22 @@
# "clip_value": 5,
+def exp_decay(
+ param: chex.Array, param_decay: chex.Array, param_limit: chex.Array
+) -> chex.Array:
+ """Exponentially decay parameter & clip by minimal value."""
+ param = param * param_decay
+ param = jnp.maximum(param, param_limit)
+ return param
+
+
@struct.dataclass
class OptState:
lrate: float
m: chex.Array
v: Optional[chex.Array] = None
+ n: Optional[chex.Array] = None
+ last_grads: Optional[chex.Array] = None
gen_counter: int = 0
@@ -27,6 +38,7 @@ class OptParams:
momentum: Optional[float] = None
beta_1: Optional[float] = None
beta_2: Optional[float] = None
+ beta_3: Optional[float] = None
eps: Optional[float] = None
max_speed: Optional[float] = None
@@ -57,8 +69,7 @@ def step(
def update(self, state: OptState, params: OptParams) -> OptState:
"""Exponentially decay the learning rate if desired."""
- lrate = state.lrate * params.lrate_decay
- lrate = jnp.maximum(lrate, params.lrate_limit)
+ lrate = exp_decay(state.lrate, params.lrate_decay, params.lrate_limit)
return state.replace(lrate=lrate)
@property
@@ -91,7 +102,7 @@ def __init__(self, num_dims: int):
def params_opt(self) -> Dict[str, float]:
"""Return default SGD+Momentum parameters."""
return {
- "momentum": 0.9,
+ "momentum": 0.0,
}
def initialize_opt(self, params: OptParams) -> OptState:
@@ -241,3 +252,56 @@ def clip(velocity: chex.Array, max_speed: float):
m = clip(velocity, params.max_speed)
mean_new = mean - state.lrate * m
return mean_new, state.replace(m=m)
+
+
+class Adan(Optimizer):
+ def __init__(self, num_dims: int):
+ """JAX-Compatible Adan Optimizer (Xi et al., 2022)
+ Reference: https://arxiv.org/pdf/2208.06677.pdf"""
+ super().__init__(num_dims)
+ self.opt_name = "adan"
+
+ @property
+ def params_opt(self) -> Dict[str, float]:
+ """Return default Adam parameters."""
+ return {
+ "beta_1": 0.98,
+ "beta_2": 0.92,
+ "beta_3": 0.99,
+ "eps": 1e-8,
+ }
+
+ def initialize_opt(self, params: OptParams) -> OptState:
+ """Initialize the m, v, n trace of the optimizer."""
+ return OptState(
+ m=jnp.zeros(self.num_dims),
+ v=jnp.zeros(self.num_dims),
+ n=jnp.zeros(self.num_dims),
+ last_grads=jnp.zeros(self.num_dims),
+ lrate=params.lrate_init,
+ )
+
+ def step_opt(
+ self,
+ mean: chex.Array,
+ grads: chex.Array,
+ state: OptState,
+ params: OptParams,
+ ) -> Tuple[chex.Array, OptState]:
+ """Perform a simple Adan GD step."""
+ m = (1 - params.beta_1) * grads + params.beta_1 * state.m
+ grad_diff = grads - state.last_grads
+ v = (1 - params.beta_2) * grad_diff + params.beta_2 * state.v
+ n = (1 - params.beta_3) * (
+ grads + params.beta_2 * grad_diff
+ ) ** 2 + params.beta_3 * state.n
+
+ mhat = m / (1 - params.beta_1 ** (state.gen_counter + 1))
+ vhat = v / (1 - params.beta_2 ** (state.gen_counter + 1))
+ nhat = n / (1 - params.beta_3 ** (state.gen_counter + 1))
+ mean_new = mean - state.lrate * (mhat + params.beta_2 * vhat) / (
+ jnp.sqrt(nhat) + params.eps
+ )
+ return mean_new, state.replace(
+ m=m, v=v, n=n, last_grads=grads, gen_counter=state.gen_counter + 1
+ )
diff --git a/evosax/utils/reshape_fitness.py b/evosax/utils/reshape_fitness.py
index 4d2205d..800b3c8 100755
--- a/evosax/utils/reshape_fitness.py
+++ b/evosax/utils/reshape_fitness.py
@@ -2,41 +2,57 @@
import jax.numpy as jnp
import chex
from functools import partial
+from typing import Union
class FitnessShaper(object):
def __init__(
self,
- centered_rank: bool = False,
- z_score: bool = False,
+ centered_rank: Union[bool, int] = False,
+ z_score: Union[bool, int] = False,
+ norm_range: Union[bool, int] = False,
w_decay: float = 0.0,
- maximize: bool = False,
+ maximize: Union[bool, int] = False,
):
"""JAX-compatible fitness shaping tool."""
self.w_decay = w_decay
self.centered_rank = bool(centered_rank)
self.z_score = bool(z_score)
+ self.norm_range = bool(norm_range)
self.maximize = bool(maximize)
+ # Check that only single fitness shaping transformation is used
+ num_options_on = self.centered_rank + self.z_score + self.norm_range
+ assert (
+ num_options_on < 2
+ ), "Only use one fitness shaping transformation."
+
@partial(jax.jit, static_argnums=(0,))
def apply(self, x: chex.Array, fitness: chex.Array) -> chex.Array:
"""Max objective trafo, rank shaping, z scoring & add weight decay."""
- fitness = jax.lax.select(self.maximize, -1 * fitness, fitness)
- fitness = jax.lax.select(
- self.centered_rank, compute_centered_ranks(fitness), fitness
- )
- fitness = jax.lax.select(
- self.z_score, z_score_fitness(fitness), fitness
- )
+ if self.maximize:
+ fitness = -1 * fitness
+
+ if self.centered_rank:
+ fitness = centered_rank_trafo(fitness)
+
+ if self.z_score:
+ fitness = z_score_trafo(fitness)
+
+ if self.norm_range:
+ fitness = range_norm_trafo(fitness, -1.0, 1.0)
+
+ # Apply wdecay after normalization - makes easier to tune
# "Reduce" fitness based on L2 norm of parameters
- l2_fit_red = self.w_decay * compute_weight_norm(x)
- l2_fit_red = jax.lax.select(self.maximize, -1 * l2_fit_red, l2_fit_red)
- return fitness + l2_fit_red
+ if self.w_decay > 0.0:
+ l2_fit_red = self.w_decay * compute_l2_norm(x)
+ fitness += l2_fit_red
+ return fitness
-def z_score_fitness(fitness: chex.Array) -> chex.Array:
+def z_score_trafo(arr: chex.Array) -> chex.Array:
"""Make fitness 'Gaussian' by substracting mean and dividing by std."""
- return (fitness - jnp.mean(fitness)) / jnp.std(1e-05 + fitness)
+ return (arr - jnp.mean(arr)) / (jnp.std(arr) + 1e-10)
def compute_ranks(fitness: chex.Array) -> chex.Array:
@@ -46,13 +62,28 @@ def compute_ranks(fitness: chex.Array) -> chex.Array:
return ranks
-def compute_centered_ranks(fitness: chex.Array) -> chex.Array:
+def centered_rank_trafo(fitness: chex.Array) -> chex.Array:
"""Return ~ -0.5 to 0.5 centered ranks (best to worst - min!)."""
y = compute_ranks(fitness)
y /= fitness.size - 1
return y - 0.5
-def compute_weight_norm(x: chex.Array) -> chex.Array:
+def compute_l2_norm(x: chex.Array) -> chex.Array:
"""Compute L2-norm of x_i. Assumes x to have shape (popsize, num_dims)."""
return jnp.mean(x * x, axis=1)
+
+
+def range_norm_trafo(
+ arr: chex.Array, min_val: float = -1.0, max_val: float = 1.0
+) -> chex.Array:
+ """Map scores into a min/max range."""
+ arr = jnp.clip(arr, -1e10, 1e10)
+ normalized_arr = (
+ 2
+ * max_val
+ * (arr - jnp.nanmin(arr))
+ / (jnp.nanmax(arr) - jnp.nanmin(arr) + 1e-10)
+ - min_val
+ )
+ return normalized_arr
diff --git a/evosax/utils/reshape_params.py b/evosax/utils/reshape_params.py
index aa6fb0c..bab0dfc 100755
--- a/evosax/utils/reshape_params.py
+++ b/evosax/utils/reshape_params.py
@@ -1,17 +1,29 @@
import jax
import jax.numpy as jnp
import chex
-from typing import Union, List, Optional
-from flax.core.frozen_dict import FrozenDict, unfreeze
-from flax.traverse_util import flatten_dict, unflatten_dict
-from jax.tree_util import tree_flatten, tree_unflatten
+from typing import Union, Optional
+from jax import vjp, flatten_util
+from jax.tree_util import tree_flatten
+
+
+def ravel_pytree(pytree):
+ leaves, _ = tree_flatten(pytree)
+ flat, _ = vjp(ravel_list, *leaves)
+ return flat
+
+
+def ravel_list(*lst):
+ return (
+ jnp.concatenate([jnp.ravel(elt) for elt in lst])
+ if lst
+ else jnp.array([])
+ )
class ParameterReshaper(object):
def __init__(
self,
placeholder_params: Union[chex.ArrayTree, chex.Array],
- identity: bool = False,
n_devices: Optional[int] = None,
verbose: bool = True,
):
@@ -19,19 +31,12 @@ def __init__(
# Get network shape to reshape
self.placeholder_params = placeholder_params
- leafs, treedef = jax.tree_util.tree_flatten(placeholder_params)
- self._treedef = treedef
- self.network_shape = jax.tree_map(jnp.shape, leafs)
- self.total_params = get_total_params(self.network_shape)
- self.l_id = get_layer_ids(self.network_shape)
-
- # Special case for no identity mapping (no pytree reshaping)
- if identity:
- self.reshape = jax.jit(self.reshape_identity)
- self.reshape_single = jax.jit(self.reshape_single_flat)
- else:
- self.reshape = jax.jit(self.reshape_network)
- self.reshape_single = jax.jit(self.reshape_single_net)
+ # Set total parameters depending on type of placeholder params
+ flat, self.unravel_pytree = flatten_util.ravel_pytree(
+ placeholder_params
+ )
+ self.total_params = flat.shape[0]
+ self.reshape_single = jax.jit(self.unravel_pytree)
if n_devices is None:
self.n_devices = jax.local_device_count()
@@ -50,13 +55,9 @@ def __init__(
" for optimization."
)
- def reshape_identity(self, x: chex.Array) -> chex.Array:
- """Return parameters w/o reshaping for evaluation."""
- return x
-
- def reshape_network(self, x: chex.Array) -> chex.ArrayTree:
+ def reshape(self, x: chex.Array) -> chex.ArrayTree:
"""Perform reshaping for a 2D matrix (pop_members, params)."""
- vmap_shape = jax.vmap(self.flat_to_network, in_axes=(0,))
+ vmap_shape = jax.vmap(self.reshape_single)
if self.n_devices > 1:
x = self.split_params_for_pmap(x)
map_shape = jax.pmap(vmap_shape)
@@ -64,56 +65,38 @@ def reshape_network(self, x: chex.Array) -> chex.ArrayTree:
map_shape = vmap_shape
return map_shape(x)
- def reshape_single_flat(self, x: chex.Array) -> chex.Array:
- """Perform reshaping for a 1D vector (params,)."""
- return x
-
- def reshape_single_net(self, x: chex.Array) -> chex.ArrayTree:
- """Perform reshaping for a 1D vector (params,)."""
- unsqueezed_re = self.flat_to_network(x)
- return unsqueezed_re
+ def multi_reshape(self, x: chex.Array) -> chex.ArrayTree:
+ """Reshape parameters lying already on different devices."""
+ # No reshaping required!
+ vmap_shape = jax.vmap(self.reshape_single)
+ return jax.pmap(vmap_shape)(x)
- @property
- def vmap_dict(self) -> chex.ArrayTree:
- """Get a dictionary specifying axes to vmap over."""
- vmap_dict = jax.tree_map(lambda x: 0, self.placeholder_params)
- return vmap_dict
+ def flatten(self, x: chex.ArrayTree) -> chex.Array:
+ """Reshaping pytree parameters into flat array."""
+ vmap_flat = jax.vmap(ravel_pytree)
+ if self.n_devices > 1:
+ # Flattening of pmap paramater trees to apply vmap flattening
+ def map_flat(x):
+ x_re = jax.tree_map(lambda x: x.reshape(-1, *x.shape[2:]), x)
+ return vmap_flat(x_re)
- def flat_to_network(self, flat_params: chex.Array) -> chex.ArrayTree:
- """Fill a FrozenDict with new proposed vector of params."""
- new_nn = list()
+ else:
+ map_flat = vmap_flat
+ flat = map_flat(x)
+ return flat
- # Loop over layers in network
- for i, shape in enumerate(self.network_shape):
- # Select params from flat to vector to be reshaped
- p_flat = jax.lax.dynamic_slice(
- flat_params, (self.l_id[i],), (self.l_id[i + 1] - self.l_id[i],)
- )
- # Reshape parameters into matrix/kernel/etc. shape
- p_reshaped = p_flat.reshape(shape)
- # Place reshaped params into dict and increase counter
- new_nn.append(p_reshaped)
- return tree_unflatten(self._treedef, new_nn)
+ def multi_flatten(self, x: chex.Array) -> chex.ArrayTree:
+ """Flatten parameters lying remaining on different devices."""
+ # No reshaping required!
+ vmap_flat = jax.vmap(ravel_pytree)
+ return jax.pmap(vmap_flat)(x)
def split_params_for_pmap(self, param: chex.Array) -> chex.Array:
"""Helper reshapes param (bs, #params) into (#dev, bs/#dev, #params)."""
return jnp.stack(jnp.split(param, self.n_devices))
-
-def get_total_params(params: List[chex.Array]) -> int:
- """Get total number of params in net. Loop over layer modules + params."""
- total_params = 0
- layer_keys = params
- # Loop over layers
- for l_k in layer_keys:
- total_params += jnp.prod(jnp.array(l_k))
- return total_params
-
-
-def get_layer_ids(network_shape_list: List[chex.Array]) -> List[int]:
- """Get indices to target when reshaping single flat net into dict."""
- l_id = [0]
- for shape in network_shape_list:
- add_pcount = jnp.prod(jnp.array(shape))
- l_id.append(int(l_id[-1] + add_pcount))
- return l_id
+ @property
+ def vmap_dict(self) -> chex.ArrayTree:
+ """Get a dictionary specifying axes to vmap over."""
+ vmap_dict = jax.tree_map(lambda x: 0, self.placeholder_params)
+ return vmap_dict
diff --git a/evosax/utils/visualizer_2d.py b/evosax/utils/visualizer_2d.py
new file mode 100644
index 0000000..8dcba14
--- /dev/null
+++ b/evosax/utils/visualizer_2d.py
@@ -0,0 +1,247 @@
+"""Fitness landscape visualizer and evaluation animator."""
+import chex
+import jax
+import jax.numpy as jnp
+import numpy as np
+import matplotlib.cm as cm
+import matplotlib.pyplot as plt
+import matplotlib.animation as animation
+
+cmap = cm.colors.LinearSegmentedColormap.from_list(
+ "Custom", [(0, "#2f9599"), (0.45, "#eee"), (1, "#8800ff")], N=256
+)
+
+
+class BBOBVisualizer(object):
+ """Fitness landscape visualizer and evaluation animator."""
+
+ def __init__(
+ self,
+ X: chex.Array,
+ fitness: chex.Array,
+ fn_name: str = "Rastrigin",
+ title: str = "",
+ use_3d: bool = False,
+ plot_log_fn: bool = False,
+ seed_id: int = 1,
+ ):
+ from evosax.problems.bbob import BBOB_fns, get_rotation
+
+ self.X = X
+ self.fitness = fitness
+ self.title = title
+ self.fn_name = fn_name
+ self.use_3d = use_3d
+ if not self.use_3d:
+ self.fig, self.ax = plt.subplots(figsize=(6, 5))
+ else:
+ self.fig = plt.figure(figsize=(6, 5))
+ self.ax = self.fig.add_subplot(1, 1, 1, projection="3d")
+ self.fn_name = fn_name
+ self.fn = BBOB_fns[self.fn_name]
+
+ rng = jax.random.PRNGKey(seed_id)
+ rng_q, rng_r = jax.random.split(rng)
+ self.R = get_rotation(rng_r, 2)
+ self.Q = get_rotation(rng_q, 2)
+ self.global_minima = []
+ self.plot_log_fn = plot_log_fn
+
+ # Set boundaries for evaluation range of black-box functions
+ self.x1_lower_bound, self.x1_upper_bound = -5, 5
+ self.x2_lower_bound, self.x2_upper_bound = -5, 5
+
+ # Set meta-data for rotation/azimuth
+ self.interval = 50 # Delay between frames in milliseconds.
+ try:
+ self.num_frames = X.shape[0]
+ self.static_frames = int(0.2 * self.num_frames)
+ self.azimuths = jnp.linspace(
+ 0, 90, self.num_frames - self.static_frames
+ )
+ self.angles = jnp.linspace(
+ 0, 90, self.num_frames - self.static_frames
+ )
+ except Exception:
+ pass
+
+ def animate(self, save_fname: str):
+ """Run animation for provided data."""
+ ani = animation.FuncAnimation(
+ self.fig,
+ self.update,
+ frames=self.num_frames,
+ init_func=self.init,
+ blit=False,
+ interval=self.interval,
+ )
+ ani.save(save_fname)
+
+ def init(self):
+ """Initialize the first frame for the animation."""
+ if self.use_3d:
+ self.plot_contour_3d()
+ (self.scat,) = self.ax.plot(
+ self.X[0, :, 0],
+ self.X[0, :, 1],
+ self.fitness[0, :],
+ marker="o",
+ c="r",
+ linestyle="",
+ markersize=3,
+ alpha=0.5,
+ )
+
+ else:
+ self.plot_contour_2d()
+ (self.scat,) = self.ax.plot(
+ self.X[0, :, 0],
+ self.X[0, :, 1],
+ marker="o",
+ c="r",
+ linestyle="",
+ markersize=3,
+ alpha=0.5,
+ )
+
+ return (self.scat,)
+
+ def update(self, frame):
+ """Update the frame with the solutions evaluated in generation."""
+ # Plot sample points
+ self.scat.set_data(self.X[frame, :, 0], self.X[frame, :, 1])
+ if self.use_3d:
+ self.scat.set_3d_properties(self.fitness[frame, :])
+ if frame < self.num_frames - self.static_frames:
+ self.ax.view_init(self.azimuths[frame], self.angles[frame])
+ self.ax.set_title(
+ f"{self.fn_name}: {self.title} - Generation {frame + 1}",
+ fontsize=15,
+ )
+ self.fig.tight_layout()
+ return (self.scat,)
+
+ def contour_function(self, x1, x2):
+ """Evaluate vmapped fitness landscape."""
+
+ def fn_val(x1, x2):
+ x = jnp.stack([x1, x2])
+ return self.fn(x, self.R, self.Q)
+
+ return jax.vmap(jax.vmap(fn_val, in_axes=(0, None)), in_axes=(None, 0))(
+ x1, x2
+ )
+
+ def plot_contour_2d(self, save: bool = False):
+ """Plot 2d landscape contour."""
+
+ if save:
+ self.fig, self.ax = plt.subplots(figsize=(6, 5))
+ self.ax.set_xlim(self.x1_lower_bound, self.x1_upper_bound)
+ self.ax.set_ylim(self.x2_lower_bound, self.x2_upper_bound)
+ self.ax.set_xlim(self.x1_lower_bound, self.x1_upper_bound)
+ self.ax.set_ylim(self.x2_lower_bound, self.x2_upper_bound)
+
+ # Plot local minimum value
+ for m in self.global_minima:
+ self.ax.plot(m[0], m[1], "y*", ms=10)
+ self.ax.plot(m[0], m[1], "y*", ms=10)
+
+ x1 = jnp.arange(self.x1_lower_bound, self.x1_upper_bound, 0.01)
+ x2 = jnp.arange(self.x2_lower_bound, self.x2_upper_bound, 0.01)
+ X, Y = np.meshgrid(x1, x2)
+ contour = self.contour_function(x1, x2)
+ if self.plot_log_fn:
+ contour = jnp.log(contour)
+ self.ax.contour(X, Y, contour, levels=30, linewidths=0.5, colors="#999")
+ im = self.ax.contourf(X, Y, contour, levels=30, cmap=cmap, alpha=0.7)
+ self.ax.set_title(f"{self.fn_name} Function", fontsize=15)
+ self.ax.set_xlabel(r"$x_1$")
+ self.ax.set_ylabel(r"$x_2$")
+ self.fig.colorbar(im, ax=self.ax)
+ self.fig.tight_layout()
+
+ if save:
+ plt.savefig(f"{self.fn_name}_2d.png", dpi=300)
+
+ def plot_contour_3d(self, save: bool = False):
+ """Plot 3d landscape contour."""
+ if save:
+ self.fig = plt.figure(figsize=(6, 5))
+ self.ax = self.fig.add_subplot(1, 1, 1, projection="3d")
+ x1 = jnp.arange(self.x1_lower_bound, self.x1_upper_bound, 0.01)
+ x2 = jnp.arange(self.x2_lower_bound, self.x2_upper_bound, 0.01)
+ contour = self.contour_function(x1, x2)
+ if self.plot_log_fn:
+ contour = jnp.log(contour)
+
+ X, Y = np.meshgrid(x1, x2)
+ self.ax.contour(
+ X,
+ Y,
+ contour,
+ zdir="z",
+ offset=np.min(contour),
+ levels=30,
+ cmap=cmap,
+ alpha=0.5,
+ )
+ self.ax.plot_surface(
+ X,
+ Y,
+ contour,
+ cmap=cmap,
+ linewidth=0,
+ antialiased=True,
+ alpha=0.7,
+ )
+
+ # Rmove fills and set labels
+ self.ax.xaxis.pane.fill = False
+ self.ax.yaxis.pane.fill = False
+ self.ax.zaxis.pane.fill = False
+
+ self.ax.xaxis.set_tick_params(labelsize=8)
+ self.ax.yaxis.set_tick_params(labelsize=8)
+ self.ax.zaxis.set_tick_params(labelsize=8)
+
+ self.ax.set_xlabel(r"$x_1$")
+ self.ax.set_ylabel(r"$x_2$")
+ self.ax.set_zlabel(r"$f(x)$")
+ self.ax.set_title(f"{self.fn_name} Function", fontsize=15)
+ self.fig.tight_layout()
+ if save:
+ plt.savefig(f"{self.fn_name}_3d.png", dpi=300)
+
+
+if __name__ == "__main__":
+ import jax
+ from jax.config import config
+
+ config.update("jax_enable_x64", True)
+
+ rng = jax.random.PRNGKey(42)
+
+ # for fn_name in [
+ # "BuecheRastrigin",
+ # ]: # BBOB_fns.keys():
+ # print(f"Start 2d/3d - {fn_name}")
+ # visualizer = BBOBVisualizer(None, None, fn_name, "")
+ # visualizer.plot_contour_2d(save=True)
+ # visualizer.plot_contour_3d(save=True)
+
+ # Test animations
+ # All solutions from single run (10 gens, 16 pmembers, 2 dims)
+ X = jax.random.normal(rng, shape=(50, 16, 2))
+
+ def sphere(x):
+ return jnp.sum(x ** 2)
+
+ fitness = jax.vmap(jax.vmap(sphere))(X)
+ print(fitness.shape)
+ visualizer = BBOBVisualizer(
+ X, fitness, "Sphere", "Test Strategy", use_3d=True
+ )
+ visualizer.animate("Sphere_3d.gif")
+ # visualizer = BBOBVisualizer(X, None, "Sphere", "Test Strategy", use_3d=False)
+ # visualizer.animate("Sphere_2d.gif")
diff --git a/examples/00_getting_started.ipynb b/examples/00_getting_started.ipynb
index 9a4acaa..35187d4 100644
--- a/examples/00_getting_started.ipynb
+++ b/examples/00_getting_started.ipynb
@@ -35,15 +35,26 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"id": "9dee8f4d-9ce8-4f5b-8d9a-ccde409dd873",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "EvoParams(mu_eff=DeviceArray(3.1672993, dtype=float32), c_1=DeviceArray(0.14227484, dtype=float32), c_mu=DeviceArray(0.1547454, dtype=float32), c_sigma=DeviceArray(0.50822735, dtype=float32), d_sigma=DeviceArray(1.5082273, dtype=float32), c_c=DeviceArray(0.60908335, dtype=float32), chi_n=DeviceArray(1.2542727, dtype=float32, weak_type=True), c_m=1.0, sigma_init=1.0, init_min=-3, init_max=3, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)"
+ ]
+ },
+ "execution_count": 1,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from evosax import CMA_ES\n",
- "from evosax.problems import ClassicFitness\n",
+ "from evosax.problems import BBOBFitness\n",
"\n",
"# Instantiate the evolution strategy instance\n",
"strategy = CMA_ES(num_dims=2, popsize=10)\n",
@@ -70,13 +81,13 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"id": "969621cb",
"metadata": {},
"outputs": [],
"source": [
"# Instantiate helper class for classic evolution strategies benchmarks\n",
- "evaluator = ClassicFitness(\"rosenbrock\", num_dims=2)"
+ "evaluator = BBOBFitness(\"RosenbrockOriginal\", num_dims=2)"
]
},
{
@@ -89,10 +100,25 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
"id": "e8982ce5-91b0-4ccc-ba46-2c69d4f11287",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "EvoState(p_sigma=DeviceArray([1.4210665 , 0.17011967], dtype=float32), p_c=DeviceArray([1.5021834, 0.1798304], dtype=float32), C=DeviceArray([[1.6521862, 0.2126719],\n",
+ " [0.2126719, 0.6064514]], dtype=float32), D=None, B=None, mean=DeviceArray([-0.7851852, 1.9345263], dtype=float32), sigma=DeviceArray(1.0486844, dtype=float32), weights=DeviceArray([ 0.45627266, 0.27075312, 0.16223112, 0.08523354,\n",
+ " 0.02550957, -0.09313666, -0.25813875, -0.4010702 ,\n",
+ " -0.5271447 , -0.639922 ], dtype=float32), weights_truncated=DeviceArray([0.45627266, 0.27075312, 0.16223112, 0.08523354, 0.02550957,\n",
+ " 0. , 0. , 0. , 0. , 0. ], dtype=float32), best_member=DeviceArray([0.7087054, 2.3278952], dtype=float32), best_fitness=DeviceArray(17.166704, dtype=float32), gen_counter=DeviceArray(1, dtype=int32, weak_type=True))"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"# Ask for a set of candidate solutions to evaluate\n",
"x, state = strategy.ask(rng, state, es_params)\n",
@@ -113,7 +139,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"id": "57b3e83c-e81f-48bd-aa8c-e90b7542502a",
"metadata": {},
"outputs": [],
@@ -127,10 +153,34 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"id": "791a2d65-7404-40a2-9c4b-4a1e1ac12dfa",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(