diff --git a/.gitignore b/.gitignore index ddc38be..9d540ed 100755 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +examples/experimental +bbob.py # Standard ROB excludes .sync-config.cson .vim-arsync diff --git a/CHANGELOG.md b/CHANGELOG.md index f896be6..2ec91e2 100755 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,28 +1,47 @@ ### Work-in-Progress -- [ ] Make xNES work with all optimizers (currently only GD) - Implement more strategies - [ ] Large-scale CMA-ES variants - [ ] [LM-CMA](https://www.researchgate.net/publication/282612269_LM-CMA_An_alternative_to_L-BFGS_for_large-scale_black_Box_optimization) - [ ] [VkD-CMA](https://hal.inria.fr/hal-01306551v1/document), [Code](https://gist.github.com/youheiakimoto/2fb26c0ace43c22b8f19c7796e69e108) - - [ ] [sNES](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) (separable version of xNES) - - [ ] [ASEBO](https://proceedings.neurips.cc/paper/2019/file/88bade49e98db8790df275fcebb37a13-Paper.pdf) - [ ] [RBO](http://proceedings.mlr.press/v100/choromanski20a/choromanski20a.pdf) - Encoding methods - via special reshape wrappers - [ ] Discrete Cosine Transform - [ ] Wavelet Based Encoding (van Steenkiste, 2016) - - [ ] Hypernetworks (Ha - start with simple MLP) + - [ ] CNN Hypernetwork (Ha - start with simple MLP) ### [v0.1.0] - [TBD] ##### Added - Adds a `total_env_steps` counter to both `GymFitness` and `BraxFitness` for easier sample efficiency comparability with RL algorithms. +- Support for new strategies/genetic algorithms + - SAMR-GA (Clune et al., 2008) + - GESMR-GA (Kumar et al., 2022) + - SNES (Wierstra et al., 2014) + - DES (Lange et al., 2022) + - Guided ES (Maheswaranathan et al., 2018) + - ASEBO (Choromanski et al., 2019) + - CR-FM-NES (Nomura & Ono, 2022) + - MR15-GA (Rechenberg, 1978) +- Adds full set of BBOB low-dimensional functions (`BBOBFitness`) +- Adds 2D visualizer animating sampled points (`BBOBVisualizer`) +- Adds `Evosax2JAXWrapper` to wrap all evosax strategies +- Adds Adan optimizer (Xie et al., 2022) + +##### Changed + +- `ParameterReshaper` can now be directly applied from within the strategy. You simply have to provide a `pholder_params` pytree at strategy instantiation (and no `num_dims`). +- `FitnessShaper` can also be directly applied from within the strategy. This makes it easier to track the best performing member across generations and addresses issue #32. Simply provide the fitness shaping settings as args to the strategy (`maximize`, `centered_rank`, ...) +- Removes Brax fitness (use EvoJAX version instead) +- Add lrate and sigma schedule to strategy instantiation ##### Fixed - Fixed reward masking in `GymFitness`. Using `jnp.sum(dones) >= 1` for cumulative return computation zeros out the final timestep, which is wrong. That's why there were problems with sparse reward gym environments (e.g. Mountain Car). +- Fixed PGPE sample indexing. +- Fixed weight decay. Falsely multiplied by -1 when maximization. ### [v0.0.9] - 15/06/2022 diff --git a/README.md b/README.md index d440cb2..67477a1 100755 --- a/README.md +++ b/README.md @@ -36,14 +36,14 @@ state.best_member, state.best_fitness | OpenES | [Salimans et al. (2017)](https://arxiv.org/pdf/1703.03864.pdf) | [`OpenES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/open_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/03_cnn_mnist.ipynb) | PGPE | [Sehnke et al. (2010)](https://citeseerx.ist.psu.edu/viewdoc/download;jsessionid=A64D1AE8313A364B814998E9E245B40A?doi=10.1.1.180.7104&rep=rep1&type=pdf) | [`PGPE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pgpe.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/02_mlp_control.ipynb) | ARS | [Mania et al. (2018)](https://arxiv.org/pdf/1803.07055.pdf) | [`ARS`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ars.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/00_getting_started.ipynb) -| CMA-ES | [Hansen & Ostermeier (2001)](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cmaartic.pdf) | [`CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Simple Gaussian | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`SimpleES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Simple Genetic | [Such et al. (2017)](https://arxiv.org/abs/1712.06567) | [`SimpleGA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| x-NES | [Wierstra et al. (2014)](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) | [`xNES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/xnes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Particle Swarm Optimization | [Kennedy & Eberhart (1995)](https://ieeexplore.ieee.org/document/488968) | [`PSO`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pso.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) -| Differential Evolution | [Storn & Price (1997)](https://www.metabolic-economics.de/pages/seminar_theoretische_biologie_2007/literatur/schaber/Storn1997JGlobOpt11.pdf) | [`DE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/de.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| ESMC | [Merchant et al. (2021)](https://proceedings.mlr.press/v139/merchant21a.html) | [`ESMC`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/esmc.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) | Persistent ES | [Vicol et al. (2021)](http://proceedings.mlr.press/v139/vicol21a.html) | [`PersistentES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/persistent_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/04_lrate_pes.ipynb) -| Population-Based Training | [Jaderberg et al. (2017)](https://arxiv.org/abs/1711.09846) | [`PBT`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pbt.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/05_quadratic_pbt.ipynb) +| xNES | [Wierstra et al. (2014)](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) | [`XNES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/xnes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| SNES | [Wierstra et al. (2014)](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) | [`SNES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sxnes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| CR-FM-NES | [Nomura & Ono (2022)](https://arxiv.org/abs/2201.11422) | [`CR-FM-NES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/cr_fm_nes.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Guided ES | [Maheswaranathan et al. (2018)](https://arxiv.org/abs/1806.10230) | [`GuidedES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/guided_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| ASEBO | [Choromanski et al. (2019)](https://arxiv.org/abs/1903.04268) | [`GuidedES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/asebo.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| CMA-ES | [Hansen & Ostermeier (2001)](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cmaartic.pdf) | [`CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) | Sep-CMA-ES | [Ros & Hansen (2008)](https://hal.inria.fr/inria-00287367/document) | [`Sep_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sep_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) | BIPOP-CMA-ES | [Hansen (2009)](https://hal.inria.fr/inria-00382093/document) | [`BIPOP_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/bipop_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/06_restart_es.ipynb) | IPOP-CMA-ES | [Auer & Hansen (2005)](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cec2005ipopcmaes.pdf) | [`IPOP_CMA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ipop_cma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/06_restart_es.ipynb) @@ -52,8 +52,21 @@ state.best_member, state.best_fitness | MA-ES | [Bayer & Sendhoff (2017)](https://www.honda-ri.de/pubs/pdf/3376.pdf) | [`MA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/ma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) | LM-MA-ES | [Loshchilov et al. (2017)](https://arxiv.org/pdf/1705.06693.pdf) | [`LM_MA_ES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/lm_ma_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) | RmES | [Li & Zhang (2017)](https://ieeexplore.ieee.org/document/8080257) | [`RmES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/rm_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Simple Genetic | [Such et al. (2017)](https://arxiv.org/abs/1712.06567) | [`SimpleGA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| SAMR-GA | [Clune et al. (2008)](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1000187) | [`SAMR_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/samr_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| GESMR-GA | [Kumar et al. (2022)](https://arxiv.org/abs/2204.04817) | [`GESMR_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/gesmr_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| MR15-GA | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`MR15_GA`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/mr15_ga.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Simple Gaussian | [Rechenberg (1978)](https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8) | [`SimpleES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/simple_es.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| DES | [Lange et al. (2022)](https://arxiv.org/abs/2211.11260) | [`DES`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/des.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Particle Swarm Optimization | [Kennedy & Eberhart (1995)](https://ieeexplore.ieee.org/document/488968) | [`PSO`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pso.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Differential Evolution | [Storn & Price (1997)](https://www.metabolic-economics.de/pages/seminar_theoretische_biologie_2007/literatur/schaber/Storn1997JGlobOpt11.pdf) | [`DE`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/de.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) | GLD | [Golovin et al. (2019)](https://arxiv.org/pdf/1911.06317.pdf) | [`GLD`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/gld.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) | Simulated Annealing | [Rasdi Rere et al. (2015)](https://www.sciencedirect.com/science/article/pii/S1877050915035759) | [`SimAnneal`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/sim_anneal.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/01_classic_benchmark.ipynb) +| Population-Based Training | [Jaderberg et al. (2017)](https://arxiv.org/abs/1711.09846) | [`PBT`](https://github.com/RobertTLange/evosax/tree/main/evosax/strategies/pbt.py) | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/05_quadratic_pbt.ipynb) + + + + ## Installation ⏳ @@ -78,12 +91,12 @@ In order to use JAX on your accelerators, you can find more details in the [JAX * 📓 [LRateTune-PES](https://github.com/RobertTLange/evosax/blob/main/examples/04_lrate_pes.ipynb): Persistent ES on meta-learning problem as in [Vicol et al. (2021)](http://proceedings.mlr.press/v139/vicol21a.html). * 📓 [Quadratic-PBT](https://github.com/RobertTLange/evosax/blob/main/examples/05_quadratic_pbt.ipynb): PBT on toy quadratic problem as in [Jaderberg et al. (2017)](https://arxiv.org/abs/1711.09846). * 📓 [Restart-Wrappers](https://github.com/RobertTLange/evosax/blob/main/examples/06_restart_es.ipynb): Custom restart wrappers as e.g. used in (B)IPOP-CMA-ES. -* 📓 [Brax Control](https://github.com/RobertTLange/evosax/blob/main/examples/07_brax_control.ipynb): Evolve Tanh MLPs on Brax tasks using the `evosax` wrapper. -* 📓 [Indirect Encodings](https://github.com/RobertTLange/evosax/blob/main/examples/08_encodings.ipynb): Find out how many parameters we need to evolve a pendulum controller. +* 📓 [Brax Control](https://github.com/RobertTLange/evosax/blob/main/examples/07_brax_control.ipynb): Evolve Tanh MLPs on Brax tasks using the `EvoJAX` wrapper. + -## Key Selling Points 💵 +## Key Features 💵 -- **Strategy Diversity**: `evosax` implements more than 10 classical and modern neuroevolution strategies. All of them follow the same simple `ask`/`eval` API and come with tailored tools such as the [ClipUp](https://arxiv.org/abs/2008.02387) optimizer, parameter reshaping into PyTrees and fitness shaping (see below). +- **Strategy Diversity**: `evosax` implements more than 30 classical and modern neuroevolution strategies. All of them follow the same simple `ask`/`eval` API and come with tailored tools such as the [ClipUp](https://arxiv.org/abs/2008.02387) optimizer, parameter reshaping into PyTrees and fitness shaping (see below). - **Vectorization/Parallelization of `ask`/`tell` Calls**: Both `ask` and `tell` calls can leverage `jit`, `vmap`/`pmap`. This enables vectorized/parallel rollouts of different evolution strategies. @@ -167,78 +180,79 @@ fit_shaper = FitnessShaper(centered_rank=True, # Shape the evaluated fitness scores fit_shaped = fit_shaper.apply(x, fitness) ``` - -- **Strategy Restart Wrappers**: You can also choose from a set of different restart mechanisms, which will relaunch a strategy (with e.g. new population size) based on termination criteria. Note: For all restart strategies which alter the population size the ask and tell methods will have to be re-compiled at the time of change. Note that all strategies can also be executed without explicitly providing `es_params`. In this case the default parameters will be used. - -```Python -from evosax import CMA_ES -from evosax.restarts import BIPOP_Restarter - -# Define a termination criterion (kwargs - fitness, state, params) -def std_criterion(fitness, state, params): - """Restart strategy if fitness std across population is small.""" - return fitness.std() < 0.001 - -# Instantiate Base CMA-ES & wrap with BIPOP restarts -# Pass strategy-specific kwargs separately (e.g. elite_ration or opt_name) -strategy = CMA_ES(num_dims, popsize, elite_ratio) -re_strategy = BIPOP_Restarter( - strategy, - stop_criteria=[std_criterion], - strategy_kwargs={"elite_ratio": elite_ratio} - ) -state = re_strategy.initialize(rng) - -# ask/tell loop - restarts are automatically handled -rng, rng_gen, rng_eval = jax.random.split(rng, 3) -x, state = re_strategy.ask(rng_gen, state) -fitness = ... # Your population evaluation fct -state = re_strategy.tell(x, fitness, state) -``` - -- **Batch Strategy Rollouts**: *Work-in-progress*. We are currently also working on different ways of incorporating multiple subpopulations with different communication protocols. - -```Python -from evosax.experimental.subpops import BatchStrategy - -# Instantiates 5 CMA-ES subpops of 20 members -strategy = BatchStrategy( - strategy_name="CMA_ES", - num_dims=4096, - popsize=100, - num_subpops=5, - strategy_kwargs={"elite_ratio": 0.5}, - communication="best_subpop", - ) - -state = strategy.initialize(rng) -# Ask for evaluation candidates of different subpopulation ES -x, state = strategy.ask(rng_iter, state) -fitness = ... -state = strategy.tell(x, fitness, state) -``` - -- **Indirect Encodings**: *Work-in-progress*. ES can struggle with high-dimensional search spaces (e.g. due to harder estimation of covariances). One potential way to alleviate this challenge, is to use indirect parameter encodings in a lower dimensional space. So far we provide JAX-compatible encodings with random projections (Gaussian/Rademacher) and Hypernetworks for MLPs. They act as drop-in replacements for the `ParameterReshaper`: - -```Python -from evosax.experimental.decodings import RandomDecoder, HyperDecoder - -# For arbitrary network architectures / search spaces -num_encoding_dims = 6 -param_reshaper = RandomDecoder(num_encoding_dims, net_params) -x_shaped = param_reshaper.reshape(x) - -# For MLP-based models we also support a HyperNetwork en/decoding -reshaper = HyperDecoder( - net_params, - hypernet_config={ - "num_latent_units": 3, # Latent units per module kernel/bias - "num_hidden_units": 2, # Hidden dimensionality of a_i^j embedding - }, - ) -x_shaped = param_reshaper.reshape(x) -``` - +
+ Additonal Work-In-Progress + **Strategy Restart Wrappers**: *Work-in-progress*. You can also choose from a set of different restart mechanisms, which will relaunch a strategy (with e.g. new population size) based on termination criteria. Note: For all restart strategies which alter the population size the ask and tell methods will have to be re-compiled at the time of change. Note that all strategies can also be executed without explicitly providing `es_params`. In this case the default parameters will be used. + + ```Python + from evosax import CMA_ES + from evosax.restarts import BIPOP_Restarter + + # Define a termination criterion (kwargs - fitness, state, params) + def std_criterion(fitness, state, params): + """Restart strategy if fitness std across population is small.""" + return fitness.std() < 0.001 + + # Instantiate Base CMA-ES & wrap with BIPOP restarts + # Pass strategy-specific kwargs separately (e.g. elite_ration or opt_name) + strategy = CMA_ES(num_dims, popsize, elite_ratio) + re_strategy = BIPOP_Restarter( + strategy, + stop_criteria=[std_criterion], + strategy_kwargs={"elite_ratio": elite_ratio} + ) + state = re_strategy.initialize(rng) + + # ask/tell loop - restarts are automatically handled + rng, rng_gen, rng_eval = jax.random.split(rng, 3) + x, state = re_strategy.ask(rng_gen, state) + fitness = ... # Your population evaluation fct + state = re_strategy.tell(x, fitness, state) + ``` + + - **Batch Strategy Rollouts**: *Work-in-progress*. We are currently also working on different ways of incorporating multiple subpopulations with different communication protocols. + + ```Python + from evosax.experimental.subpops import BatchStrategy + + # Instantiates 5 CMA-ES subpops of 20 members + strategy = BatchStrategy( + strategy_name="CMA_ES", + num_dims=4096, + popsize=100, + num_subpops=5, + strategy_kwargs={"elite_ratio": 0.5}, + communication="best_subpop", + ) + + state = strategy.initialize(rng) + # Ask for evaluation candidates of different subpopulation ES + x, state = strategy.ask(rng_iter, state) + fitness = ... + state = strategy.tell(x, fitness, state) + ``` + + - **Indirect Encodings**: *Work-in-progress*. ES can struggle with high-dimensional search spaces (e.g. due to harder estimation of covariances). One potential way to alleviate this challenge, is to use indirect parameter encodings in a lower dimensional space. So far we provide JAX-compatible encodings with random projections (Gaussian/Rademacher) and Hypernetworks for MLPs. They act as drop-in replacements for the `ParameterReshaper`: + + ```Python + from evosax.experimental.decodings import RandomDecoder, HyperDecoder + + # For arbitrary network architectures / search spaces + num_encoding_dims = 6 + param_reshaper = RandomDecoder(num_encoding_dims, net_params) + x_shaped = param_reshaper.reshape(x) + + # For MLP-based models we also support a HyperNetwork en/decoding + reshaper = HyperDecoder( + net_params, + hypernet_config={ + "num_latent_units": 3, # Latent units per module kernel/bias + "num_hidden_units": 2, # Hidden dimensionality of a_i^j embedding + }, + ) + x_shaped = param_reshaper.reshape(x) + ``` +
## Resources & Other Great JAX-ES Tools 📝 @@ -260,7 +274,7 @@ If you use `evosax` in your research, please cite it as follows: } ``` -We acknowledge financial support the [Google TRC](https://sites.research.google/trc/about/) and the Deutsche +We acknowledge financial support by the [Google TRC](https://sites.research.google/trc/about/) and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 ["Science of Intelligence"](https://www.scienceofintelligence.de/) - project number 390523135. ## Development 👷 diff --git a/evosax/__init__.py b/evosax/__init__.py index 80d52e6..9bc366b 100755 --- a/evosax/__init__.py +++ b/evosax/__init__.py @@ -1,4 +1,4 @@ -from .strategy import Strategy +from .strategy import Strategy, EvoState, EvoParams from .strategies import ( SimpleGA, SimpleES, @@ -9,7 +9,6 @@ PGPE, PBT, PersistentES, - xNES, ARS, Sep_CMA_ES, BIPOP_CMA_ES, @@ -21,6 +20,16 @@ RmES, GLD, SimAnneal, + SNES, + xNES, + ESMC, + DES, + SAMR_GA, + GESMR_GA, + GuidedES, + ASEBO, + CR_FM_NES, + MR15_GA, ) from .utils import FitnessShaper, ParameterReshaper, ESLog from .networks import NetworkMapper @@ -37,7 +46,6 @@ "PGPE": PGPE, "PBT": PBT, "PersistentES": PersistentES, - "xNES": xNES, "ARS": ARS, "Sep_CMA_ES": Sep_CMA_ES, "BIPOP_CMA_ES": BIPOP_CMA_ES, @@ -49,9 +57,27 @@ "RmES": RmES, "GLD": GLD, "SimAnneal": SimAnneal, + "SNES": SNES, + "xNES": xNES, + "ESMC": ESMC, + "DES": DES, + "SAMR_GA": SAMR_GA, + "GESMR_GA": GESMR_GA, + "GuidedES": GuidedES, + "ASEBO": ASEBO, + "CR_FM_NES": CR_FM_NES, + "MR15_GA": MR15_GA, } __all__ = [ + "Strategies", + "EvoState", + "EvoParams", + "FitnessShaper", + "ParameterReshaper", + "ESLog", + "NetworkMapper", + "ProblemMapper", "Strategy", "SimpleGA", "SimpleES", @@ -62,7 +88,6 @@ "PGPE", "PBT", "PersistentES", - "xNES", "ARS", "Sep_CMA_ES", "BIPOP_CMA_ES", @@ -74,10 +99,14 @@ "RmES", "GLD", "SimAnneal", - "Strategies", - "FitnessShaper", - "ParameterReshaper", - "ESLog", - "NetworkMapper", - "ProblemMapper", + "SNES", + "xNES", + "ESMC", + "DES", + "SAMR_GA", + "GESMR_GA", + "GuidedES", + "ASEBO", + "CR_FM_NES", + "MR15_GA", ] diff --git a/evosax/experimental/decodings/decoder.py b/evosax/experimental/decodings/decoder.py index f23cb5c..9a15017 100644 --- a/evosax/experimental/decodings/decoder.py +++ b/evosax/experimental/decodings/decoder.py @@ -8,13 +8,11 @@ def __init__( self, num_encoding_dims: int, placeholder_params: Union[chex.ArrayTree, chex.Array], - identity: bool = False, n_devices: Optional[int] = None, ): self.num_encoding_dims = num_encoding_dims self.total_params = num_encoding_dims self.placeholder_params = placeholder_params - self.identity = identity if n_devices is None: self.n_devices = jax.local_device_count() else: diff --git a/evosax/experimental/decodings/hyper.py b/evosax/experimental/decodings/hyper.py index 08feccf..716f71e 100644 --- a/evosax/experimental/decodings/hyper.py +++ b/evosax/experimental/decodings/hyper.py @@ -5,8 +5,8 @@ from flax.core import unfreeze from typing import Union, Optional from .decoder import Decoder -from ...utils import ParameterReshaper -from ...networks import HyperNetworkMLP +from .hyper_networks import HyperNetworkMLP +from ...utils import ParameterReshaper, ravel_pytree class HyperDecoder(Decoder): @@ -39,14 +39,13 @@ def __init__( super().__init__( hyper_reshaper.total_params, placeholder_params, - identity, n_devices, ) self.hyper_reshaper = hyper_reshaper self.vmap_dict = self.hyper_reshaper.vmap_dict def reshape(self, x: chex.Array) -> chex.ArrayTree: - """Perform reshaping for random projection case.""" + """Perform reshaping for hypernetwork case.""" # 0. Reshape genome into params for hypernetwork x_params = self.hyper_reshaper.reshape(x) # 1. Project parameters to raw dimensionality using hypernetwork @@ -54,9 +53,17 @@ def reshape(self, x: chex.Array) -> chex.ArrayTree: return hyper_x def reshape_single(self, x: chex.Array) -> chex.ArrayTree: - """Reshape a single flat vector using random projection matrix.""" + """Reshape a single flat vector using hypernetwork.""" # 0. Reshape genome into params for hypernetwork x_params = self.hyper_reshaper.reshape_single(x) # 1. Project parameters to raw dimensionality using hypernetwork hyper_x = jax.jit(self.hyper_network.apply)(x_params) return hyper_x + + def flatten(self, x: chex.ArrayTree) -> chex.Array: + """Reshaping pytree parameters into flat array.""" + return jax.vmap(ravel_pytree)(x) + + def flatten_single(self, x: chex.ArrayTree) -> chex.Array: + """Reshaping pytree parameters into flat array.""" + return ravel_pytree(x) diff --git a/evosax/networks/hyper_networks.py b/evosax/experimental/decodings/hyper_networks.py similarity index 100% rename from evosax/networks/hyper_networks.py rename to evosax/experimental/decodings/hyper_networks.py diff --git a/evosax/experimental/decodings/random.py b/evosax/experimental/decodings/random.py index 439124a..9ba17cd 100644 --- a/evosax/experimental/decodings/random.py +++ b/evosax/experimental/decodings/random.py @@ -12,17 +12,14 @@ def __init__( placeholder_params: Union[chex.ArrayTree, chex.Array], rng: chex.PRNGKey = jax.random.PRNGKey(0), rademacher: bool = False, - identity: bool = False, n_devices: Optional[int] = None, ): """Random Projection Decoder (Gaussian/Rademacher random matrix).""" - super().__init__( - num_encoding_dims, placeholder_params, identity, n_devices - ) + super().__init__(num_encoding_dims, placeholder_params, n_devices) self.rademacher = rademacher # Instantiate base reshaper class self.base_reshaper = ParameterReshaper( - placeholder_params, identity, n_devices, verbose=False + placeholder_params, n_devices, verbose=False ) self.vmap_dict = self.base_reshaper.vmap_dict @@ -35,6 +32,10 @@ def __init__( self.project_matrix = jax.random.rademacher( rng, (self.num_encoding_dims, self.base_reshaper.total_params) ) + print( + "RandomDecoder: Encoding parameters to optimize -" + f" {num_encoding_dims}" + ) def reshape(self, x: chex.Array) -> chex.ArrayTree: """Perform reshaping for random projection case.""" diff --git a/evosax/networks/__init__.py b/evosax/networks/__init__.py index 1d9e6a5..4f77b1c 100644 --- a/evosax/networks/__init__.py +++ b/evosax/networks/__init__.py @@ -1,7 +1,6 @@ from .mlp import MLP from .cnn import CNN, All_CNN_C from .lstm import LSTM -from .hyper_networks import HyperNetworkMLP # Helper that returns model based on string name @@ -18,5 +17,4 @@ "All_CNN_C", "LSTM", "NetworkMapper", - "HyperNetworkMLP", ] diff --git a/evosax/problems/__init__.py b/evosax/problems/__init__.py index 8b3fdb6..120ce38 100755 --- a/evosax/problems/__init__.py +++ b/evosax/problems/__init__.py @@ -1,22 +1,19 @@ -from .control_brax import BraxFitness -from .control_gym import GymFitness +from .control_gym import GymnaxFitness from .vision import VisionFitness -from .classic import ClassicFitness +from .bbob import BBOBFitness from .sequence import SequenceFitness ProblemMapper = { - "Gym": GymFitness, - "Brax": BraxFitness, + "Gymnax": GymnaxFitness, "Vision": VisionFitness, - "Classic": ClassicFitness, + "BBOB": BBOBFitness, "Sequence": SequenceFitness, } __all__ = [ - "BraxFitness", - "GymFitness", + "GymnaxFitness", "VisionFitness", - "ClassicFitness", + "BBOBFitness", "SequenceFitness", "ProblemMapper", ] diff --git a/evosax/problems/classic.py b/evosax/problems/classic.py deleted file mode 100755 index a85d294..0000000 --- a/evosax/problems/classic.py +++ /dev/null @@ -1,140 +0,0 @@ -import jax -import jax.numpy as jnp -import chex -from functools import partial - - -class ClassicFitness(object): - def __init__( - self, - fct_name: str = "rosenbrock", - num_dims: int = 2, - num_rollouts: int = 1, - noise_std: float = 0.0, - ): - self.fct_name = fct_name - self.num_dims = num_dims - self.num_rollouts = num_rollouts - # Optional - add Gaussian noise to evaluation fitness - self.noise_std = noise_std - assert self.num_dims >= 2 - - # Use default settings for classic BBOB evaluation functions - if self.fct_name == "quadratic": - self.eval = jax.vmap(quadratic_d_dim, 0) - elif self.fct_name == "rosenbrock": - fn = partial(rosenbrock_d_dim, params={"a": 1, "b": 100}) - self.eval = jax.vmap(fn, 0) - elif self.fct_name == "ackley": - fn = partial( - ackley_d_dim, params={"c": 20, "d": 0.2, "e": 2 * jnp.pi} - ) - self.eval = jax.vmap(fn, 0) - elif self.fct_name == "griewank": - self.eval = jax.vmap(griewank_d_dim, 0) - elif self.fct_name == "rastrigin": - fn = partial(rastrigin_d_dim, params={"f": 10}) - self.eval = jax.vmap(fn, 0) - elif self.fct_name == "schwefel": - self.eval = jax.vmap(schwefel_d_dim, 0) - elif self.fct_name == "himmelblau": - assert self.num_dims == 2 - self.eval = jax.vmap(himmelblau_2_dim, 0) - elif self.fct_name == "six-hump": - assert self.num_dims == 2 - self.eval = jax.vmap(six_hump_camel_2_dim, 0) - else: - raise ValueError("Please provide a valid problem name.") - - @partial(jax.jit, static_argnums=(0,)) - def rollout( - self, rng_input: chex.PRNGKey, eval_params: chex.Array - ) -> chex.Array: - """Batch evaluate the proposal points.""" - fitness = self.eval(eval_params).reshape(eval_params.shape[0], 1) - noise = self.noise_std * jax.random.normal( - rng_input, (eval_params.shape[0], self.num_rollouts) - ) - return (fitness + noise).squeeze() - - -def himmelblau_2_dim(x: chex.Array) -> chex.Array: - """ - 2-dim. Himmelblau function. - f(x*)=0 - Minima at [3, 2], [-2.81, 3.13], - [-3.78, -3.28], [3.58, -1.85] - """ - return (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 - 7) ** 2 - - -def six_hump_camel_2_dim(x: chex.Array) -> chex.Array: - """ - 2-dim. 6-Hump Camel function. - f(x*)=-1.0316 - Minimum at [0.0898, -0.7126], [-0.0898, 0.7126] - """ - p1 = (4 - 2.1 * x[0] ** 2 + x[0] ** 4 / 3) * x[0] ** 2 - p2 = x[0] * x[1] - p3 = (-4 + 4 * x[1] ** 2) * x[1] ** 2 - return p1 + p2 + p3 - - -def quadratic_d_dim(x: chex.Array) -> chex.Array: - """ - Simple D-dim. quadratic function. - f(x*)=0 - Minimum at [0.]ˆd - """ - return jnp.sum(jnp.square(x)) - - -def rosenbrock_d_dim(x: chex.Array, params: dict) -> chex.Array: - """ - D-Dim. Rosenbrock function. x_i ∈ [-32.768, 32.768] or x_i ∈ [-5, 10] - f(x*)=0 - Minumum at x*=a - """ - x_i, x_sq, x_p = x[:-1], x[:-1] ** 2, x[1:] - return jnp.sum((params["a"] - x_i) ** 2 + params["b"] * (x_p - x_sq) ** 2) - - -def ackley_d_dim(x: chex.Array, params: dict) -> chex.Array: - """ - D-Dim. Ackley function. x_i ∈ [-32.768, 32.768] - f(x*)=0 - Minimum at x*=[0,...,0] - """ - return ( - -params["c"] * jnp.exp(-params["d"] * jnp.sqrt(jnp.mean(x ** 2))) - - jnp.exp(jnp.mean(jnp.cos(params["e"] * x))) - + params["c"] - + jnp.exp(1) - ) - - -def griewank_d_dim(x: chex.Array) -> chex.Array: - """ - D-Dim. Griewank function. x_i ∈ [-600, 600] - f(x*)=0 - Minimum at x*=[0,...,0] - """ - return ( - jnp.sum(x ** 2 / 4000) - - jnp.prod(jnp.cos(x / jnp.sqrt(jnp.arange(1, x.shape[0] + 1)))) - + 1 - ) - - -def rastrigin_d_dim(x: chex.Array, params: dict) -> chex.Array: - """ - D-Dim. Rastrigin function. x_i ∈ [-5.12, 5.12] - f(x*)=0 - Minimum at x*=[0,...,0] - """ - return params["f"] * x.shape[0] + jnp.sum( - x ** 2 - params["f"] * jnp.cos(2 * jnp.pi * x) - ) - - -def schwefel_d_dim(x: chex.Array) -> chex.Array: - """ - D-Dim. Schwefel function. x_i ∈ [-500, 500] - f(x*)=0 - Minimum at x*=[420.9687,...,420.9687] - """ - return 418.9829 * x.shape[0] - jnp.sum( - x * jnp.sin(jnp.sqrt(jnp.absolute(x))) - ) diff --git a/evosax/problems/control_brax.py b/evosax/problems/control_brax.py deleted file mode 100644 index e99fa96..0000000 --- a/evosax/problems/control_brax.py +++ /dev/null @@ -1,233 +0,0 @@ -import jax -import jax.numpy as jnp -import chex -from typing import Optional -from .obs_norm import ObsNormalizer - - -class BraxFitness(object): - def __init__( - self, - env_name: str = "ant", - num_env_steps: int = 1000, - num_rollouts: int = 16, - legacy_spring: bool = True, - normalize: bool = False, - modify_dict: dict = {"torso_mass": 15}, - test: bool = False, - n_devices: Optional[int] = None, - ): - try: - from brax import envs - except ImportError: - raise ImportError( - "You need to install `brax` to use its fitness rollouts." - ) - self.env_name = env_name - self.num_env_steps = num_env_steps - self.num_rollouts = num_rollouts - self.steps_per_member = num_env_steps * num_rollouts - self.test = test - - if self.env_name in [ - "ant", - "halfcheetah", - "hopper", - "humanoid", - "reacher", - "walker2d", - "fetch", - "grasp", - "ur5e", - ]: - # Define the RL environment & network forward fucntion - self.env = envs.create( - env_name=self.env_name, - episode_length=num_env_steps, - legacy_spring=legacy_spring, - ) - elif self.env_name == "modified-ant": - from .modified_ant import create_modified_ant_env - - self.env = create_modified_ant_env(modify_dict) - - self.action_shape = self.env.action_size - self.input_shape = (self.env.observation_size,) - self.obs_normalizer = ObsNormalizer( - self.input_shape, dummy=not normalize - ) - self.obs_params = self.obs_normalizer.get_init_params() - if n_devices is None: - self.n_devices = jax.local_device_count() - else: - self.n_devices = n_devices - - # Keep track of total steps executed in environment - self.total_env_steps = 0 - - def set_apply_fn(self, map_dict, network_apply, carry_init=None): - """Set the network forward function.""" - self.network = network_apply - # Set rollout function based on model architecture - if carry_init is not None: - self.single_rollout = self.rollout_rnn - self.carry_init = carry_init - else: - self.single_rollout = self.rollout_ffw - - # vmap over stochastic evaluations - self.rollout_repeats = jax.vmap(self.single_rollout, in_axes=(0, None)) - self.rollout_pop = jax.vmap( - self.rollout_repeats, in_axes=(None, map_dict) - ) - # pmap over popmembers if > 1 device is available - otherwise pmap - if self.n_devices > 1: - self.rollout_map = self.rollout_pmap - print( - f"BraxFitness: {self.n_devices} devices detected. Please make" - " sure that the ES population size divides evenly across the" - " number of devices to pmap/parallelize over." - ) - else: - self.rollout_map = self.rollout_pop - - def rollout_pmap(self, rng_input, policy_params): - """Parallelize rollout across devices. Split keys/reshape correctly.""" - keys_pmap = jnp.tile(rng_input, (self.n_devices, 1, 1)) - rew_dev, obs_dev, masks_dev = jax.pmap(self.rollout_pop)( - keys_pmap, policy_params - ) - rew_re = rew_dev.reshape(-1, self.num_rollouts) - obs_re = obs_dev.reshape( - -1, self.num_rollouts, self.num_env_steps, self.env.observation_size - ) - masks_re = masks_dev.reshape( - -1, self.num_rollouts, self.num_env_steps, 1 - ) - return rew_re, obs_re, masks_re - - def rollout(self, rng_input, policy_params): - """Placeholder fn call for rolling out a population for multi-evals.""" - rng_pop = jax.random.split(rng_input, self.num_rollouts) - scores, all_obs, masks = jax.jit(self.rollout_map)( - rng_pop, policy_params - ) - # Update normalization parameters if train case! - if not self.test: - obs_re = all_obs.reshape( - self.num_env_steps, -1, self.input_shape[0] - ) - masks_re = masks.reshape(self.num_env_steps, -1) - self.obs_params = self.obs_normalizer.update_normalization_params( - obs_buffer=obs_re, - obs_mask=masks_re, - obs_params=self.obs_params, - ) - - # obs_steps = self.obs_params[0] - # running_mean, running_var = jnp.split(self.obs_params[1:], 2) - # print( - # float(scores.mean()), - # float(masks.mean()), - # obs_steps, - # running_mean.mean(), - # running_var.mean() / (obs_steps + 1), - # ) - - # Update total step counter using only transitions before termination - self.total_env_steps += masks_re.sum() - return scores - - def rollout_ffw( - self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree - ) -> chex.Array: - """Rollout a jitted brax episode with lax.scan for a feedforward policy.""" - # Reset the environment - rng, rng_reset = jax.random.split(rng_input) - state = self.env.reset(rng_reset) - - def policy_step(state_input, tmp): - """lax.scan compatible step transition in jax env.""" - state, policy_params, rng, cum_reward, valid_mask = state_input - rng, rng_net = jax.random.split(rng) - org_obs = state.obs - norm_obs = self.obs_normalizer.normalize_obs( - org_obs, self.obs_params - ) - action = self.network(policy_params, norm_obs, rng=rng_net) - next_s = self.env.step(state, action) - new_cum_reward = cum_reward + next_s.reward * valid_mask - new_valid_mask = valid_mask * (1 - next_s.done.ravel()) - carry = [next_s, policy_params, rng, new_cum_reward, new_valid_mask] - return carry, [new_valid_mask, org_obs] - - # Scan over episode step loop - carry_out, scan_out = jax.lax.scan( - policy_step, - [state, policy_params, rng, jnp.array([0.0]), jnp.array([1.0])], - (), - self.num_env_steps, - ) - # Return masked sum of rewards accumulated by agent in episode - ep_mask, all_obs = scan_out[0], scan_out[1] - cum_return = carry_out[-2].squeeze() - return cum_return, all_obs, ep_mask - - def rollout_rnn( - self, rng_input: chex.PRNGKey, policy_params: chex.ArrayTree - ) -> chex.Array: - """Rollout a jitted episode with lax.scan for a recurrent policy.""" - # Reset the environment - rng, rng_reset = jax.random.split(rng_input) - state = self.env.reset(rng_reset) - hidden = self.carry_init() - - def policy_step(state_input, tmp): - """lax.scan compatible step transition in jax env.""" - ( - state, - policy_params, - rng, - hidden, - cum_reward, - valid_mask, - ) = state_input - rng, rng_net = jax.random.split(rng) - org_obs = state.obs - norm_obs = self.obs_normalizer.normalize_obs( - state.obs, self.obs_params - ) - hidden, action = self.network( - policy_params, norm_obs, hidden, rng_net - ) - next_s = self.env.step(state, action) - new_cum_reward = cum_reward + next_s.reward * valid_mask - new_valid_mask = valid_mask * (1 - next_s.done.ravel()) - carry = [ - next_s, - policy_params, - rng, - hidden, - new_cum_reward, - new_valid_mask, - ] - return carry, [new_valid_mask, org_obs] - - # Scan over episode step loop - carry_out, scan_out = jax.lax.scan( - policy_step, - [ - state, - policy_params, - rng, - hidden, - jnp.array([0.0]), - jnp.array([1.0]), - ], - (), - self.num_env_steps, - ) - # Return masked sum of rewards accumulated by agent in episode - ep_mask, all_obs = scan_out[0], scan_out[1] - cum_return = carry_out[-2].squeeze() - return cum_return, all_obs, ep_mask diff --git a/evosax/problems/control_gym.py b/evosax/problems/control_gym.py index 323eaad..5bc6583 100644 --- a/evosax/problems/control_gym.py +++ b/evosax/problems/control_gym.py @@ -1,11 +1,10 @@ import jax import jax.numpy as jnp -from functools import partial from typing import Optional import chex -class GymFitness(object): +class GymnaxFitness(object): def __init__( self, env_name: str = "CartPole-v1", @@ -47,7 +46,7 @@ def __init__( # Keep track of total steps executed in environment self.total_env_steps = 0 - def set_apply_fn(self, map_dict, network_apply, carry_init=None): + def set_apply_fn(self, network_apply, carry_init=None): """Set the network forward function.""" self.network = network_apply # Set rollout function based on model architecture @@ -57,9 +56,7 @@ def set_apply_fn(self, map_dict, network_apply, carry_init=None): else: self.single_rollout = self.rollout_ffw self.rollout_repeats = jax.vmap(self.single_rollout, in_axes=(0, None)) - self.rollout_pop = jax.vmap( - self.rollout_repeats, in_axes=(None, map_dict) - ) + self.rollout_pop = jax.vmap(self.rollout_repeats, in_axes=(None, 0)) # pmap over popmembers if > 1 device is available - otherwise pmap if self.n_devices > 1: self.rollout_map = self.rollout_pmap diff --git a/evosax/problems/modified_ant.py b/evosax/problems/modified_ant.py deleted file mode 100644 index 8fabe72..0000000 --- a/evosax/problems/modified_ant.py +++ /dev/null @@ -1,423 +0,0 @@ -"""Trains an ant to run in the +x direction. -Adapted from https://raw.githubusercontent.com/google/brax/main/brax/envs/ant.py -Increases mass of main torso body from 10 to 15. -""" - -import brax -from brax import jumpy as jp -from brax.envs import env -from brax.envs import wrappers -from typing import Optional - - -class ModifiedAnt(env.Env): - """Trains an ant to run in the +x direction.""" - - def __init__(self, config, **kwargs): - super().__init__(config=config, **kwargs) - - def reset(self, rng: jp.ndarray) -> env.State: - """Resets the environment to an initial state.""" - rng, rng1, rng2 = jp.random_split(rng, 3) - qpos = self.sys.default_angle() + jp.random_uniform( - rng1, (self.sys.num_joint_dof,), -0.1, 0.1 - ) - qvel = jp.random_uniform(rng2, (self.sys.num_joint_dof,), -0.1, 0.1) - qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel) - info = self.sys.info(qp) - obs = self._get_obs(qp, info) - reward, done, zero = jp.zeros(3) - metrics = { - "reward_ctrl_cost": zero, - "reward_contact_cost": zero, - "reward_forward": zero, - "reward_survive": zero, - } - return env.State(qp, obs, reward, done, metrics) - - def step(self, state: env.State, action: jp.ndarray) -> env.State: - """Run one timestep of the environment's dynamics.""" - qp, info = self.sys.step(state.qp, action) - obs = self._get_obs(qp, info) - - x_before = state.qp.pos[0, 0] - x_after = qp.pos[0, 0] - forward_reward = (x_after - x_before) / self.sys.config.dt - ctrl_cost = 0.5 * jp.sum(jp.square(action)) - contact_cost = ( - 0.5 * 1e-3 * jp.sum(jp.square(jp.clip(info.contact.vel, -1, 1))) - ) - survive_reward = jp.float32(1) - reward = forward_reward - ctrl_cost - contact_cost + survive_reward - - done = jp.where(qp.pos[0, 2] < 0.2, x=jp.float32(1), y=jp.float32(0)) - done = jp.where(qp.pos[0, 2] > 1.0, x=jp.float32(1), y=done) - state.metrics.update( - reward_ctrl_cost=ctrl_cost, - reward_contact_cost=contact_cost, - reward_forward=forward_reward, - reward_survive=survive_reward, - ) - - return state.replace(qp=qp, obs=obs, reward=reward, done=done) - - def _get_obs(self, qp: brax.QP, info: brax.Info) -> jp.ndarray: - """Observe ant body position and velocities.""" - # some pre-processing to pull joint angles and velocities - (joint_angle,), (joint_vel,) = self.sys.joints[0].angle_vel(qp) - - # qpos: - # Z of the torso (1,) - # orientation of the torso as quaternion (4,) - # joint angles (8,) - qpos = [qp.pos[0, 2:], qp.rot[0], joint_angle] - - # qvel: - # velocity of the torso (3,) - # angular velocity of the torso (3,) - # joint angle velocities (8,) - qvel = [qp.vel[0], qp.ang[0], joint_vel] - - # external contact forces: - # delta velocity (3,), delta ang (3,) * 10 bodies in the system - # Note that mujoco has 4 extra bodies tucked inside the Torso that Brax - # ignores - cfrc = [ - jp.clip(info.contact.vel, -1, 1), - jp.clip(info.contact.ang, -1, 1), - ] - # flatten bottom dimension - cfrc = [jp.reshape(x, x.shape[:-2] + (-1,)) for x in cfrc] - - return jp.concatenate(qpos + qvel + cfrc) - - -_CONFIG_MODIFIED = """ -bodies {{ - name: "$ Torso" - colliders {{ - capsule {{ - radius: 0.25 - length: 0.5 - end: 1 - }} - }} - inertia {{ x: 1.0 y: 1.0 z: 1.0 }} - mass: {torso_mass} -}} -bodies {{ - name: "Aux 1" - colliders {{ - rotation {{ x: 90 y: -45 }} - capsule {{ - radius: 0.08 - length: 0.4428427219390869 - }} - }} - inertia {{ x: 1.0 y: 1.0 z: 1.0 }} - mass: 1 -}} -bodies {{ - name: "$ Body 4" - colliders {{ - rotation {{ x: 90 y: -45 }} - capsule {{ - radius: 0.08 - length: 0.7256854176521301 - end: -1 - }} - }} - inertia {{ x: 1.0 y: 1.0 z: 1.0 }} - mass: 1 -}} -bodies {{ - name: "Aux 2" - colliders {{ - rotation {{ x: 90 y: 45 }} - capsule {{ - radius: 0.08 - length: 0.4428427219390869 - }} - }} - inertia {{ x: 1.0 y: 1.0 z: 1.0 }} - mass: 1 -}} -bodies {{ - name: "$ Body 7" - colliders {{ - rotation {{ x: 90 y: 45 }} - capsule {{ - radius: 0.08 - length: 0.7256854176521301 - end: -1 - }} - }} - inertia {{ x: 1.0 y: 1.0 z: 1.0 }} - mass: 1 -}} -bodies {{ - name: "Aux 3" - colliders {{ - rotation {{ x: -90 y: 45 }} - capsule {{ - radius: 0.08 - length: 0.4428427219390869 - }} - }} - inertia {{ x: 1.0 y: 1.0 z: 1.0 }} - mass: 1 -}} -bodies {{ - name: "$ Body 10" - colliders {{ - rotation {{ x: -90 y: 45 }} - capsule {{ - radius: 0.08 - length: 0.7256854176521301 - end: -1 - }} - }} - inertia {{ x: 1.0 y: 1.0 z: 1.0 }} - mass: 1 -}} -bodies {{ - name: "Aux 4" - colliders {{ - rotation {{ x: -90 y: -45 }} - capsule {{ - radius: 0.08 - length: 0.4428427219390869 - }} - }} - inertia {{ x: 1.0 y: 1.0 z: 1.0 }} - mass: 1 -}} -bodies {{ - name: "$ Body 13" - colliders {{ - rotation {{ x: -90 y: -45 }} - capsule {{ - radius: 0.08 - length: 0.7256854176521301 - end: -1 - }} - }} - inertia {{ x: 1.0 y: 1.0 z: 1.0 }} - mass: 1 -}} -bodies {{ - name: "Ground" - colliders {{ - plane {{}} - }} - inertia {{ x: 1.0 y: 1.0 z: 1.0 }} - mass: 1 - frozen {{ all: true }} -}} -joints {{ - name: "$ Torso_Aux 1" - parent_offset {{ x: 0.2 y: 0.2 }} - child_offset {{ x: -0.1 y: -0.1 }} - parent: "$ Torso" - child: "Aux 1" - stiffness: 18000.0 - angular_damping: 20 - spring_damping: 80 - angle_limit {{ min: -30.0 max: 30.0 }} - rotation {{ y: -90 }} -}} -joints {{ - name: "Aux 1_$ Body 4" - parent_offset {{ x: 0.1 y: 0.1 }} - child_offset {{ x: -0.2 y: -0.2 }} - parent: "Aux 1" - child: "$ Body 4" - stiffness: 18000.0 - angular_damping: 20 - spring_damping: 80 - rotation: {{ z: 135 }} - angle_limit {{ - min: 30.0 - max: 70.0 - }} -}} -joints {{ - name: "$ Torso_Aux 2" - parent_offset {{ x: -0.2 y: 0.2 }} - child_offset {{ x: 0.1 y: -0.1 }} - parent: "$ Torso" - child: "Aux 2" - stiffness: 18000.0 - angular_damping: 20 - spring_damping: 80 - rotation {{ y: -90 }} - angle_limit {{ min: -30.0 max: 30.0 }} -}} -joints {{ - name: "Aux 2_$ Body 7" - parent_offset {{ x: -0.1 y: 0.1 }} - child_offset {{ x: 0.2 y: -0.2 }} - parent: "Aux 2" - child: "$ Body 7" - stiffness: 18000.0 - angular_damping: 20 - spring_damping: 80 - rotation {{ z: 45 }} - angle_limit {{ min: -70.0 max: -30.0 }} -}} -joints {{ - name: "$ Torso_Aux 3" - parent_offset {{ x: -0.2 y: -0.2 }} - child_offset {{ x: 0.1 y: 0.1 }} - parent: "$ Torso" - child: "Aux 3" - stiffness: 18000.0 - angular_damping: 20 - spring_damping: 80 - rotation {{ y: -90 }} - angle_limit {{ min: -30.0 max: 30.0 }} -}} -joints {{ - name: "Aux 3_$ Body 10" - parent_offset {{ x: -0.1 y: -0.1 }} - child_offset {{ - x: 0.2 - y: 0.2 - }} - parent: "Aux 3" - child: "$ Body 10" - stiffness: 18000.0 - angular_damping: 20 - spring_damping: 80 - rotation {{ z: 135 }} - angle_limit {{ min: -70.0 max: -30.0 }} -}} -joints {{ - name: "$ Torso_Aux 4" - parent_offset {{ x: 0.2 y: -0.2 }} - child_offset {{ x: -0.1 y: 0.1 }} - parent: "$ Torso" - child: "Aux 4" - stiffness: 18000.0 - angular_damping: 20 - spring_damping: 80 - rotation {{ y: -90 }} - angle_limit {{ min: -30.0 max: 30.0 }} -}} -joints {{ - name: "Aux 4_$ Body 13" - parent_offset {{ x: 0.1 y: -0.1 }} - child_offset {{ x: -0.2 y: 0.2 }} - parent: "Aux 4" - child: "$ Body 13" - stiffness: 18000.0 - angular_damping: 20 - spring_damping: 80 - rotation {{ z: 45 }} - angle_limit {{ min: 30.0 max: 70.0 }} -}} -actuators {{ - name: "$ Torso_Aux 1" - joint: "$ Torso_Aux 1" - strength: 350.0 - torque {{}} -}} -actuators {{ - name: "Aux 1_$ Body 4" - joint: "Aux 1_$ Body 4" - strength: 350.0 - torque {{}} -}} -actuators {{ - name: "$ Torso_Aux 2" - joint: "$ Torso_Aux 2" - strength: 350.0 - torque {{}} -}} -actuators {{ - name: "Aux 2_$ Body 7" - joint: "Aux 2_$ Body 7" - strength: 350.0 - torque {{}} -}} -actuators {{ - name: "$ Torso_Aux 3" - joint: "$ Torso_Aux 3" - strength: 350.0 - torque {{}} -}} -actuators {{ - name: "Aux 3_$ Body 10" - joint: "Aux 3_$ Body 10" - strength: 350.0 - torque {{}} -}} -actuators {{ - name: "$ Torso_Aux 4" - joint: "$ Torso_Aux 4" - strength: 350.0 - torque {{}} -}} -actuators {{ - name: "Aux 4_$ Body 13" - joint: "Aux 4_$ Body 13" - strength: 350.0 - torque {{}} -}} -friction: 1.0 -gravity {{ z: -9.8 }} -angular_damping: -0.05 -baumgarte_erp: 0.1 -collide_include {{ - first: "$ Torso" - second: "Ground" -}} -collide_include {{ - first: "$ Body 4" - second: "Ground" -}} -collide_include {{ - first: "$ Body 7" - second: "Ground" -}} -collide_include {{ - first: "$ Body 10" - second: "Ground" -}} -collide_include {{ - first: "$ Body 13" - second: "Ground" -}} -dt: {dt} -substeps: 10 -dynamics_mode: "legacy_spring" -""" - - -def create_modified_ant_env( - modify_dict: dict = {}, - episode_length: int = 1000, - action_repeat: int = 1, - auto_reset: bool = True, - batch_size: Optional[int] = None, - eval_metrics: bool = False, - **kwargs -): - """Creates a config modified Ant Env with a specified brax system.""" - default_settings = {"torso_mass": 15, "dt": 0.05} - for k, v in default_settings.items(): - if k not in modify_dict.keys(): - modify_dict[k] = v - config = _CONFIG_MODIFIED.format(**modify_dict) - - env = ModifiedAnt(config, **kwargs) - if episode_length is not None: - env = wrappers.EpisodeWrapper(env, episode_length, action_repeat) - if batch_size: - env = wrappers.VectorWrapper(env, batch_size) - if auto_reset: - env = wrappers.AutoResetWrapper(env) - if eval_metrics: - env = wrappers.EvalWrapper(env) - - return env diff --git a/evosax/problems/obs_norm.py b/evosax/problems/obs_norm.py deleted file mode 100644 index 2dd20df..0000000 --- a/evosax/problems/obs_norm.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2022 The EvoJAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Adapted from https://github.com/google/evojax/blob/main/evojax/obs_norm.py - -from typing import Tuple -from functools import partial - -import jax -import jax.numpy as jnp -import numpy as np - - -def normalize( - obs: jnp.ndarray, - obs_params: jnp.ndarray, - obs_shape: Tuple, - clip_value: float, - std_min_value: float, - std_max_value: float, -) -> jnp.ndarray: - """Normalize the given observation.""" - - obs_steps = obs_params[0] - running_mean, running_var = jnp.split(obs_params[1:], 2) - running_mean = running_mean.reshape(obs_shape) - running_var = running_var.reshape(obs_shape) - - variance = running_var / (obs_steps + 1.0) - variance = jnp.clip(variance, std_min_value, std_max_value) - return jnp.clip( - (obs - running_mean) / jnp.sqrt(variance), -clip_value, clip_value - ) - - -def update_obs_params( - obs_buffer: jnp.ndarray, obs_mask: jnp.ndarray, obs_params: jnp.ndarray -) -> jnp.ndarray: - """Update observation normalization parameters.""" - - obs_steps = obs_params[0] - running_mean, running_var = jnp.split(obs_params[1:], 2) - if obs_mask.ndim != obs_buffer.ndim: - obs_mask = obs_mask.reshape( - obs_mask.shape + (1,) * (obs_buffer.ndim - obs_mask.ndim) - ) - - new_steps = jnp.sum(obs_mask) - total_steps = obs_steps + new_steps - - input_to_old_mean = (obs_buffer - running_mean) * obs_mask - mean_diff = jnp.sum(input_to_old_mean / total_steps, axis=(0, 1)) - new_mean = running_mean + mean_diff - - input_to_new_mean = (obs_buffer - new_mean) * obs_mask - var_diff = jnp.sum(input_to_new_mean * input_to_old_mean, axis=(0, 1)) - new_var = running_var + var_diff - - return jnp.concatenate([jnp.ones(1) * total_steps, new_mean, new_var]) - - -class ObsNormalizer(object): - """Observation normalizer.""" - - def __init__( - self, - obs_shape: Tuple, - clip_value: float = 5.0, - std_min_value: float = 1e-6, - std_max_value: float = 1e6, - dummy: bool = False, - ): - """Initialization. - - Args: - obs_shape - Shape of the observations. - std_min_value - Minimum standard deviation. - std_max_value - Maximum standard deviation. - dummy - Whether this is a dummy normalizer. - """ - - self._obs_shape = obs_shape - self._obs_size = np.prod(obs_shape) - self._std_min_value = std_min_value - self._std_max_value = std_max_value - self._clip_value = clip_value - self.is_dummy = dummy - - @partial(jax.jit, static_argnums=(0,)) - def normalize_obs( - self, obs: jnp.ndarray, obs_params: jnp.ndarray - ) -> jnp.ndarray: - """Normalize the given observation. - - Args: - obs - The observation to be normalized. - Returns: - Normalized observation. - """ - - if self.is_dummy: - return obs - else: - return normalize( - obs=obs, - obs_params=obs_params, - obs_shape=self._obs_shape, - clip_value=self._clip_value, - std_min_value=self._std_min_value, - std_max_value=self._std_max_value, - ) - - @partial(jax.jit, static_argnums=(0,)) - def update_normalization_params( - self, - obs_buffer: jnp.ndarray, - obs_mask: jnp.ndarray, - obs_params: jnp.ndarray, - ) -> jnp.ndarray: - """Update internal parameters.""" - - if self.is_dummy: - return jnp.zeros_like(obs_params) - else: - return update_obs_params( - obs_buffer=obs_buffer, - obs_mask=obs_mask, - obs_params=obs_params, - ) - - @partial(jax.jit, static_argnums=(0,)) - def get_init_params(self) -> jnp.ndarray: - return jnp.zeros(1 + self._obs_size * 2) diff --git a/evosax/problems/sequence.py b/evosax/problems/sequence.py index 5425df4..a9c2175 100644 --- a/evosax/problems/sequence.py +++ b/evosax/problems/sequence.py @@ -45,11 +45,11 @@ def __init__( else: self.n_devices = n_devices - def set_apply_fn(self, map_dict, network, carry_init): + def set_apply_fn(self, network, carry_init): """Set the network forward function.""" self.network = network self.carry_init = carry_init - self.rollout_pop = jax.vmap(self.rollout_rnn, in_axes=(None, map_dict)) + self.rollout_pop = jax.vmap(self.rollout_rnn, in_axes=(None, 0)) # pmap over popmembers if > 1 device is available - otherwise pmap if self.n_devices > 1: self.rollout = self.rollout_pmap diff --git a/evosax/problems/vision.py b/evosax/problems/vision.py index 4f4971a..72989ff 100644 --- a/evosax/problems/vision.py +++ b/evosax/problems/vision.py @@ -25,10 +25,10 @@ def __init__( else: self.n_devices = n_devices - def set_apply_fn(self, map_dict, network): + def set_apply_fn(self, network): """Set the network forward function.""" self.network = network - self.rollout_pop = jax.vmap(self.rollout_ffw, in_axes=(None, map_dict)) + self.rollout_pop = jax.vmap(self.rollout_ffw, in_axes=(None, 0)) # pmap over popmembers if > 1 device is available - otherwise pmap if self.n_devices > 1: self.rollout = self.rollout_pmap diff --git a/evosax/restarts/termination.py b/evosax/restarts/termination.py index 4a15edc..35c465c 100644 --- a/evosax/restarts/termination.py +++ b/evosax/restarts/termination.py @@ -36,7 +36,10 @@ def cma_criterion( dC = jnp.diag(state.strategy_state.C) # Note: Criterion requires full covariance matrix for decomposition! C, B, D = full_eigen_decomp( - state.strategy_state.C, state.strategy_state.B, state.strategy_state.D + state.strategy_state.C, + state.strategy_state.B, + state.strategy_state.D, + state.strategy_state.gen_counter, ) # Stop if std of normal distrib is smaller than tolx in all coordinates diff --git a/evosax/strategies/__init__.py b/evosax/strategies/__init__.py index d445fe1..b28992e 100755 --- a/evosax/strategies/__init__.py +++ b/evosax/strategies/__init__.py @@ -7,7 +7,6 @@ from .pgpe import PGPE from .pbt import PBT from .persistent_es import PersistentES -from .xnes import xNES from .ars import ARS from .sep_cma_es import Sep_CMA_ES from .bipop_cma_es import BIPOP_CMA_ES @@ -19,6 +18,16 @@ from .rm_es import RmES from .gld import GLD from .sim_anneal import SimAnneal +from .snes import SNES +from .xnes import xNES +from .esmc import ESMC +from .des import DES +from .samr_ga import SAMR_GA +from .gesmr_ga import GESMR_GA +from .guided_es import GuidedES +from .asebo import ASEBO +from .cr_fm_nes import CR_FM_NES +from .mr15_ga import MR15_GA __all__ = [ @@ -31,7 +40,6 @@ "PGPE", "PBT", "PersistentES", - "xNES", "ARS", "Sep_CMA_ES", "BIPOP_CMA_ES", @@ -43,4 +51,14 @@ "RmES", "GLD", "SimAnneal", + "SNES", + "xNES", + "ESMC", + "DES", + "SAMR_GA", + "GESMR_GA", + "GuidedES", + "ASEBO", + "CR_FM_NES", + "MR15_GA", ] diff --git a/evosax/strategies/ars.py b/evosax/strategies/ars.py index 7e53908..0273474 100644 --- a/evosax/strategies/ars.py +++ b/evosax/strategies/ars.py @@ -1,9 +1,9 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy -from ..utils import GradientOptimizer, OptState, OptParams +from ..utils import GradientOptimizer, OptState, OptParams, exp_decay from flax import struct @@ -32,28 +32,54 @@ class EvoParams: class ARS(Strategy): def __init__( self, - num_dims: int, popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, elite_ratio: float = 0.1, opt_name: str = "sgd", + lrate_init: float = 0.05, + lrate_decay: float = 1.0, + lrate_limit: float = 0.001, + sigma_init: float = 0.03, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + **fitness_kwargs: Union[bool, int, float] ): """Augmented Random Search (Mania et al., 2018) Reference: https://arxiv.org/pdf/1803.07055.pdf""" - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert not self.popsize & 1, "Population size must be even" # ARS performs antithetic sampling & allows you to select # "b" elite perturbation directions for the update assert 0 <= elite_ratio <= 1 self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize / 2 * self.elite_ratio)) - assert opt_name in ["sgd", "adam", "rmsprop", "clipup"] + assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"] self.optimizer = GradientOptimizer[opt_name](self.num_dims) self.strategy_name = "ARS" + # Set core kwargs es_params (lrate/sigma schedules) + self.lrate_init = lrate_init + self.lrate_decay = lrate_decay + self.lrate_limit = lrate_limit + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" - return EvoParams(opt_params=self.optimizer.default_params) + opt_params = self.optimizer.default_params.replace( + lrate_init=self.lrate_init, + lrate_decay=self.lrate_decay, + lrate_limit=self.lrate_limit, + ) + return EvoParams( + opt_params=opt_params, + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) def initialize_strategy( self, rng: chex.PRNGKey, params: EvoParams @@ -114,8 +140,5 @@ def tell_strategy( state.mean, theta_grad, state.opt_state, params.opt_params ) opt_state = self.optimizer.update(opt_state, params.opt_params) - - # Update lrate and standard deviation based on min and decay - sigma = state.sigma * params.sigma_decay - sigma = jnp.maximum(sigma, params.sigma_limit) + sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit) return state.replace(mean=mean, sigma=sigma, opt_state=opt_state) diff --git a/evosax/strategies/asebo.py b/evosax/strategies/asebo.py new file mode 100644 index 0000000..e404d88 --- /dev/null +++ b/evosax/strategies/asebo.py @@ -0,0 +1,202 @@ +import jax +import jax.numpy as jnp +import chex +from typing import Tuple, Optional, Union +from ..strategy import Strategy +from ..utils import GradientOptimizer, OptState, OptParams, exp_decay +from flax import struct + + +@struct.dataclass +class EvoState: + mean: chex.Array + sigma: float + opt_state: OptState + grad_subspace: chex.Array + alpha: float + UUT: chex.Array + UUT_ort: chex.Array + best_member: chex.Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + + +@struct.dataclass +class EvoParams: + opt_params: OptParams + sigma_init: float = 0.03 + sigma_decay: float = 1.0 + sigma_limit: float = 0.01 + grad_decay: float = 0.99 + init_min: float = 0.0 + init_max: float = 0.0 + clip_min: float = -jnp.finfo(jnp.float32).max + clip_max: float = jnp.finfo(jnp.float32).max + + +class ASEBO(Strategy): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + subspace_dims: int = 50, + opt_name: str = "adam", + lrate_init: float = 0.05, + lrate_decay: float = 1.0, + lrate_limit: float = 0.001, + sigma_init: float = 0.03, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + **fitness_kwargs: Union[bool, int, float], + ): + """ASEBO (Choromanski et al., 2019) + Reference: https://arxiv.org/abs/1903.04268 + Note that there are a couple of JAX-based adaptations: + 1. We always sample a fixed population size per generation + 2. We keep a fixed archive of gradients to estimate the subspace + """ + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) + assert not self.popsize & 1, "Population size must be even" + assert opt_name in ["sgd", "adam", "rmsprop", "clipup"] + self.optimizer = GradientOptimizer[opt_name](self.num_dims) + self.subspace_dims = min(subspace_dims, self.num_dims) + if self.subspace_dims < subspace_dims: + print( + "Subspace has to be smaller than optimization dims. Set to" + f" {self.subspace_dims} instead of {subspace_dims}." + ) + self.strategy_name = "ASEBO" + + # Set core kwargs es_params (lrate/sigma schedules) + self.lrate_init = lrate_init + self.lrate_decay = lrate_decay + self.lrate_limit = lrate_limit + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + + @property + def params_strategy(self) -> EvoParams: + """Return default parameters of evolution strategy.""" + opt_params = self.optimizer.default_params.replace( + lrate_init=self.lrate_init, + lrate_decay=self.lrate_decay, + lrate_limit=self.lrate_limit, + ) + return EvoParams( + opt_params=opt_params, + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) + + def initialize_strategy( + self, rng: chex.PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the evolution strategy.""" + initialization = jax.random.uniform( + rng, + (self.num_dims,), + minval=params.init_min, + maxval=params.init_max, + ) + + grad_subspace = jnp.zeros((self.subspace_dims, self.num_dims)) + + state = EvoState( + mean=initialization, + sigma=params.sigma_init, + opt_state=self.optimizer.initialize(params.opt_params), + grad_subspace=grad_subspace, + alpha=1.0, + UUT=jnp.zeros((self.num_dims, self.num_dims)), + UUT_ort=jnp.zeros((self.num_dims, self.num_dims)), + best_member=initialization, + ) + return state + + def ask_strategy( + self, rng: chex.PRNGKey, state: EvoState, params: EvoParams + ) -> Tuple[chex.Array, EvoState]: + """`ask` for new parameter candidates to evaluate next.""" + # Antithetic sampling of noise + X = state.grad_subspace + X -= jnp.mean(X, axis=0) + U, S, Vt = jnp.linalg.svd(X, full_matrices=False) + + def svd_flip(u, v): + # columns of u, rows of v + max_abs_cols = jnp.argmax(jnp.abs(u), axis=0) + signs = jnp.sign(u[max_abs_cols, jnp.arange(u.shape[1])]) + u *= signs + v *= signs[:, jnp.newaxis] + return u, v + + U, Vt = svd_flip(U, Vt) + U = Vt[: int(self.popsize / 2)] + UUT = jnp.matmul(U.T, U) + + U_ort = Vt[int(self.popsize / 2) :] + UUT_ort = jnp.matmul(U_ort.T, U_ort) + + subspace_ready = state.gen_counter > self.subspace_dims + + UUT = jax.lax.select( + subspace_ready, UUT, jnp.zeros((self.num_dims, self.num_dims)) + ) + cov = ( + state.sigma * (state.alpha / self.num_dims) * jnp.eye(self.num_dims) + + ((1 - state.alpha) / int(self.popsize / 2)) * UUT + ) + chol = jnp.linalg.cholesky(cov) + noise = jax.random.normal(rng, (self.num_dims, int(self.popsize / 2))) + z_plus = jnp.swapaxes(chol @ noise, 0, 1) + z_plus /= jnp.linalg.norm(z_plus, axis=-1)[:, jnp.newaxis] + z = jnp.concatenate([z_plus, -1.0 * z_plus]) + x = state.mean + z + return x, state.replace(UUT=UUT, UUT_ort=UUT_ort) + + def tell_strategy( + self, + x: chex.Array, + fitness: chex.Array, + state: EvoState, + params: EvoParams, + ) -> EvoState: + """`tell` performance data for strategy state update.""" + # Reconstruct noise from last mean/std estimates + noise = (x - state.mean) / state.sigma + noise_1 = noise[: int(self.popsize / 2)] + fit_1 = fitness[: int(self.popsize / 2)] + fit_2 = fitness[int(self.popsize / 2) :] + fit_diff_noise = jnp.dot(noise_1.T, fit_1 - fit_2) + theta_grad = 1.0 / 2.0 * fit_diff_noise + + alpha = jnp.linalg.norm( + jnp.dot(theta_grad, state.UUT_ort) + ) / jnp.linalg.norm(jnp.dot(theta_grad, state.UUT)) + subspace_ready = state.gen_counter > self.subspace_dims + alpha = jax.lax.select(subspace_ready, alpha, 1.0) + + # Add grad FIFO-style to subspace archive (only if provided else FD) + grad_subspace = jnp.zeros((self.subspace_dims, self.num_dims)) + grad_subspace = grad_subspace.at[:-1, :].set(state.grad_subspace[1:, :]) + grad_subspace = grad_subspace.at[-1, :].set(theta_grad) + state = state.replace(grad_subspace=grad_subspace) + + # Normalize gradients by norm / num_dims + theta_grad /= jnp.linalg.norm(theta_grad) / self.num_dims + 1e-8 + + # Grad update using optimizer instance - decay lrate if desired + mean, opt_state = self.optimizer.step( + state.mean, theta_grad, state.opt_state, params.opt_params + ) + opt_state = self.optimizer.update(opt_state, params.opt_params) + + # Update lrate and standard deviation based on min and decay + sigma = state.sigma * params.sigma_decay + sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit) + return state.replace( + mean=mean, sigma=sigma, opt_state=opt_state, alpha=alpha + ) diff --git a/evosax/strategies/bipop_cma_es.py b/evosax/strategies/bipop_cma_es.py index 992fed1..d9d5c17 100644 --- a/evosax/strategies/bipop_cma_es.py +++ b/evosax/strategies/bipop_cma_es.py @@ -1,6 +1,6 @@ import jax import chex -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from functools import partial from .cma_es import CMA_ES from ..restarts.restarter import WrapperState, WrapperParams @@ -19,14 +19,27 @@ class RestartParams: class BIPOP_CMA_ES(object): - def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.5, + sigma_init: float = 1.0, + **fitness_kwargs: Union[bool, int, float] + ): """BIPOP-CMA-ES (Hansen, 2009). Reference: https://hal.inria.fr/inria-00382093/document Inspired by: https://tinyurl.com/44y3ryhf""" self.strategy_name = "BIPOP_CMA_ES" # Instantiate base strategy & wrap it with restart wrapper self.strategy = CMA_ES( - num_dims=num_dims, popsize=popsize, elite_ratio=elite_ratio + num_dims=num_dims, + popsize=popsize, + pholder_params=pholder_params, + elite_ratio=elite_ratio, + sigma_init=sigma_init, + **fitness_kwargs, ) from ..restarts import BIPOP_Restarter from ..restarts.termination import spread_criterion, cma_criterion diff --git a/evosax/strategies/cma_es.py b/evosax/strategies/cma_es.py index 59b7ee3..84111e7 100755 --- a/evosax/strategies/cma_es.py +++ b/evosax/strategies/cma_es.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from ..strategy import Strategy from ..utils.eigen_decomp import full_eigen_decomp from flax import struct @@ -80,16 +80,27 @@ def get_cma_elite_weights( class CMA_ES(Strategy): - def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.5, + sigma_init: float = 1.0, + **fitness_kwargs: Union[bool, int, float] + ): """CMA-ES (e.g. Hansen, 2016) Reference: https://arxiv.org/abs/1604.00772 Inspired by: https://github.com/CyberAgentAILab/cmaes""" - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert 0 <= elite_ratio <= 1 self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) self.strategy_name = "CMA_ES" + # Set core kwargs es_params + self.sigma_init = sigma_init + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" @@ -122,6 +133,7 @@ def params_strategy(self) -> EvoParams: d_sigma=d_sigma, c_c=c_c, chi_n=chi_n, + sigma_init=self.sigma_init, ) return params @@ -157,7 +169,9 @@ def ask_strategy( self, rng: chex.PRNGKey, state: EvoState, params: EvoParams ) -> Tuple[chex.Array, EvoState]: """`ask` for new parameter candidates to evaluate next.""" - C, B, D = full_eigen_decomp(state.C, state.B, state.D) + C, B, D = full_eigen_decomp( + state.C, state.B, state.D, state.gen_counter + ) x = sample( rng, state.mean, @@ -197,6 +211,7 @@ def tell_strategy( y_w, params.c_sigma, params.mu_eff, + state.gen_counter, ) p_c, norm_p_sigma, h_sigma = update_p_c( @@ -259,9 +274,10 @@ def update_p_sigma( y_w: chex.Array, c_sigma: float, mu_eff: float, + gen_counter: int, ) -> Tuple[chex.Array, chex.Array, chex.Array, None, None]: """Update evolution path for covariance matrix.""" - C, B, D = full_eigen_decomp(C, B, D) + C, B, D = full_eigen_decomp(C, B, D, gen_counter) C_2 = B.dot(jnp.diag(1 / D)).dot(B.T) # C^(-1/2) = B D^(-1) B^T p_sigma_new = (1 - c_sigma) * p_sigma + jnp.sqrt( c_sigma * (2 - c_sigma) * mu_eff diff --git a/evosax/strategies/cr_fm_nes.py b/evosax/strategies/cr_fm_nes.py new file mode 100644 index 0000000..70c23e3 --- /dev/null +++ b/evosax/strategies/cr_fm_nes.py @@ -0,0 +1,300 @@ +import jax +import jax.numpy as jnp +import chex +import math +from typing import Tuple, Optional, Union +from ..strategy import Strategy +from flax import struct + + +@struct.dataclass +class EvoState: + mean: chex.Array + sigma: float + v: chex.Array + D: chex.Array + p_sigma: chex.Array + p_c: chex.Array + w_rank_hat: chex.Array + w_rank: chex.Array + z: chex.Array + y: chex.Array + best_member: chex.Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + + +@struct.dataclass +class EvoParams: + mu_eff: float + c_s: float + c_c: float + c1: float + chi_N: float + h_inv: float + alpha_dist: float + lrate_mean: float = 1.0 + lrate_move_sigma: float = 0.1 + lrate_stag_sigma: float = 0.1 + lrate_conv_sigma: float = 0.1 + lrate_B: float = 0.1 + sigma_init: float = 1.0 + init_min: float = 0.0 + init_max: float = 0.0 + clip_min: float = -jnp.finfo(jnp.float32).max + clip_max: float = jnp.finfo(jnp.float32).max + + +def get_recombination_weights(popsize: int) -> Tuple[chex.Array, chex.Array]: + """Get recombination weights for different ranks.""" + + def get_weight(i): + return jnp.log(popsize / 2 + 1) - jnp.log(i) + + w_rank_hat = jax.vmap(get_weight)(jnp.arange(1, popsize + 1)) + w_rank_hat = w_rank_hat * (w_rank_hat >= 0) + w_rank = w_rank_hat / sum(w_rank_hat) - (1.0 / popsize) + return w_rank_hat.reshape(-1, 1), w_rank.reshape(-1, 1) + + +def get_h_inv(dim: int) -> float: + dim = min(dim, 2000) + f = lambda a: ((1.0 + a * a) * math.exp(a * a / 2.0) / 0.24) - 10.0 - dim + f_prime = lambda a: (1.0 / 0.24) * a * math.exp(a * a / 2.0) * (3.0 + a * a) + h_inv = 1.0 + counter = 0 + while abs(f(h_inv)) > 1e-10: + counter += 1 + h_inv = h_inv - 0.5 * (f(h_inv) / f_prime(h_inv)) + return h_inv + + +def w_dist_hat(alpha_dist: float, z: chex.Array) -> chex.Array: + return jnp.exp(alpha_dist * jnp.linalg.norm(z)) + + +class CR_FM_NES(Strategy): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + sigma_init: float = 1.0, + **fitness_kwargs: Union[bool, int, float] + ): + """Cost-Reduced Fast-Moving Natural ES (Nomura & Ono, 2022) + Reference: https://arxiv.org/abs/2201.11422 + """ + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) + assert not self.popsize & 1, "Population size must be even" + self.strategy_name = "CR_FM_NES" + + # Set core kwargs es_params (sigma) + self.sigma_init = sigma_init + + @property + def default_params(self) -> EvoParams: + """Return default parameters of evolutionary strategy.""" + w_rank_hat, w_rank = get_recombination_weights(self.popsize) + mueff = 1 / ( + (w_rank + (1 / self.popsize)).T @ (w_rank + (1 / self.popsize)) + ) + c_s = (mueff + 2.0) / (self.num_dims + mueff + 5.0) + c_c = (4.0 + mueff / self.num_dims) / ( + self.num_dims + 4.0 + 2.0 * mueff / self.num_dims + ) + c1_cma = 2.0 / (jnp.power(self.num_dims + 1.3, 2) + mueff) + chi_N = jnp.sqrt(self.num_dims) * ( + 1.0 + - 1.0 / (4.0 * self.num_dims) + + 1.0 / (21.0 * self.num_dims * self.num_dims) + ) + h_inv = get_h_inv(self.num_dims) + alpha_dist = h_inv * jnp.minimum( + 1.0, jnp.sqrt(self.popsize / self.num_dims) + ) + lrate_move_sigma = 1.0 + lrate_stag_sigma = jnp.tanh( + (0.024 * self.popsize + 0.7 * self.num_dims + 20.0) + / (self.num_dims + 12.0) + ) + lrate_conv_sigma = 2.0 * jnp.tanh( + (0.025 * self.popsize + 0.75 * self.num_dims + 10.0) + / (self.num_dims + 4.0) + ) + c1 = c1_cma * (self.num_dims - 5) / 6 + lrate_B = jnp.tanh( + (jnp.minimum(0.02 * self.popsize, 3 * jnp.log(self.num_dims)) + 5) + / (0.23 * self.num_dims + 25) + ) + params = EvoParams( + lrate_move_sigma=lrate_move_sigma, + lrate_stag_sigma=lrate_stag_sigma, + lrate_conv_sigma=lrate_conv_sigma, + lrate_B=lrate_B, + mu_eff=mueff, + c_s=c_s, + c_c=c_c, + c1=c1, + chi_N=chi_N, + alpha_dist=alpha_dist, + h_inv=h_inv, + sigma_init=self.sigma_init, + ) + return params + + def initialize_strategy( + self, rng: chex.PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the evolutionary strategy.""" + rng_init, rng_v = jax.random.split(rng) + initialization = jax.random.uniform( + rng_init, + (self.num_dims,), + minval=params.init_min, + maxval=params.init_max, + ) + w_rank_hat, w_rank = get_recombination_weights(self.popsize) + state = EvoState( + mean=initialization, + sigma=params.sigma_init, + v=jax.random.normal(rng_v, shape=(self.num_dims, 1)) + / jnp.sqrt(self.num_dims), + D=jnp.ones([self.num_dims, 1]), + p_sigma=jnp.zeros((self.num_dims, 1)), + p_c=jnp.zeros((self.num_dims, 1)), + z=jnp.zeros((self.popsize, self.num_dims)), + y=jnp.zeros((self.popsize, self.num_dims)), + w_rank_hat=w_rank_hat.reshape(-1, 1), + w_rank=w_rank, + best_member=initialization, + ) + + return state + + def ask_strategy( + self, rng: chex.PRNGKey, state: EvoState, params: EvoParams + ) -> Tuple[chex.Array, EvoState]: + """`ask` for new parameter candidates to evaluate next.""" + z_plus = jax.random.normal( + rng, + (int(self.popsize / 2), self.num_dims), + ) + z = jnp.concatenate([z_plus, -1.0 * z_plus]) + z = jnp.swapaxes(z, 0, 1) + normv = jnp.linalg.norm(state.v) + normv2 = normv ** 2 + vbar = state.v / normv + + # Rescale/reparametrize noise + y = z + (jnp.sqrt(1 + normv2) - 1) * vbar @ (vbar.T @ z) + x = state.mean[:, None] + state.sigma * y * state.D + x = jnp.swapaxes(x, 0, 1) + return x, state.replace(z=z, y=y) + + def tell_strategy( + self, + x: chex.Array, + fitness: chex.Array, + state: EvoState, + params: EvoParams, + ) -> EvoState: + """`tell` performance data for strategy state update.""" + ranks = fitness.argsort() + z = state.z[:, ranks] + y = state.y[:, ranks] + x = jnp.swapaxes(x, 0, 1)[:, ranks] + + # Update evolution path p_sigma + p_sigma = (1 - params.c_s) * state.p_sigma + jnp.sqrt( + params.c_s * (2.0 - params.c_s) * params.mu_eff + ) * (z @ state.w_rank) + p_sigma_norm = jnp.linalg.norm(p_sigma) + + # Calculate distance weight + w_tmp = state.w_rank_hat * jax.vmap(w_dist_hat, in_axes=(None, 1))( + params.alpha_dist, z + ).reshape(-1, 1) + weights_dist = w_tmp / sum(w_tmp) - 1.0 / self.popsize + + # switching weights and learning rate + p_sigma_cond = p_sigma_norm >= params.chi_N + weights = jax.lax.select(p_sigma_cond, weights_dist, state.w_rank) + lrate_sigma = jax.lax.select( + p_sigma_cond, params.lrate_move_sigma, params.lrate_stag_sigma + ) + lrate_sigma = jax.lax.select( + p_sigma_norm >= 0.1 * params.chi_N, + lrate_sigma, + params.lrate_conv_sigma, + ) + + # update evolution path p_c and mean + wxm = (x - state.mean[:, None]) @ weights + p_c = (1.0 - params.c_c) * state.p_c + jnp.sqrt( + params.c_c * (2.0 - params.c_c) * params.mu_eff + ) * wxm / state.sigma + mean = state.mean + params.lrate_mean * wxm.squeeze() + + normv = jnp.linalg.norm(state.v) + vbar = state.v / normv + normv2 = normv ** 2 + normv4 = normv2 ** 2 + + exY = jnp.append(y, p_c / state.D, axis=1) + yy = exY * exY + ip_yvbar = vbar.T @ exY + yvbar = exY * vbar + gammav = 1.0 + normv2 + vbarbar = vbar * vbar + alphavd = jnp.minimum( + 1, + jnp.sqrt( + normv4 + (2 * gammav - jnp.sqrt(gammav)) / jnp.max(vbarbar) + ) + / (2 + normv2), + ) + t = exY * ip_yvbar - vbar * (ip_yvbar ** 2 + gammav) / 2 + b = -(1 - alphavd ** 2) * normv4 / gammav + 2 * alphavd ** 2 + H = jnp.ones([self.num_dims, 1]) * 2 - (b + 2 * alphavd ** 2) * vbarbar + invH = H ** (-1) + s_step1 = ( + yy + - normv2 / gammav * (yvbar * ip_yvbar) + - jnp.ones([self.num_dims, self.popsize + 1]) + ) + ip_vbart = vbar.T @ t + s_step2 = s_step1 - alphavd / gammav * ( + (2 + normv2) * (t * vbar) - normv2 * vbarbar @ ip_vbart + ) + invHvbarbar = invH * vbarbar + ip_s_step2invHvbarbar = invHvbarbar.T @ s_step2 + s = (s_step2 * invH) - b / ( + 1 + b * vbarbar.T @ invHvbarbar + ) * invHvbarbar @ ip_s_step2invHvbarbar + ip_svbarbar = vbarbar.T @ s + t = t - alphavd * ((2 + normv2) * (s * vbar) - vbar @ ip_svbarbar) + + # update v, D covariance ingredients + exw = jnp.append( + params.lrate_B * weights, + jnp.array([params.c1]).reshape(1, 1), + axis=0, + ) + v = state.v + (t @ exw) / normv + D = state.D + (s @ exw) * state.D + # calculate detA + nthrootdetA = jnp.exp( + jnp.sum(jnp.log(D)) / self.num_dims + + jnp.log(1 + v.T @ v) / (2 * self.num_dims) + )[0][0] + D = D / nthrootdetA + # update sigma + G_s = ( + jnp.sum((z * z - jnp.ones([self.num_dims, self.popsize])) @ weights) + / self.num_dims + ) + sigma = state.sigma * jnp.exp(lrate_sigma / 2 * G_s) + return state.replace( + p_sigma=p_sigma, mean=mean, p_c=p_c, v=v, D=D, sigma=sigma + ) diff --git a/evosax/strategies/de.py b/evosax/strategies/de.py index 338adc2..0c6ff7c 100755 --- a/evosax/strategies/de.py +++ b/evosax/strategies/de.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy from flax import struct @@ -29,11 +29,17 @@ class EvoParams: class DE(Strategy): - def __init__(self, num_dims: int, popsize: int): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + **fitness_kwargs: Union[bool, int, float] + ): """Differential Evolution (Storn & Price, 1997) Reference: https://tinyurl.com/4pje5a74""" - assert popsize > 6 - super().__init__(num_dims, popsize) + assert popsize > 6, "DE requires popsize > 6." + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) self.strategy_name = "DE" @property diff --git a/evosax/strategies/des.py b/evosax/strategies/des.py new file mode 100644 index 0000000..91ba64b --- /dev/null +++ b/evosax/strategies/des.py @@ -0,0 +1,107 @@ +import jax +import jax.numpy as jnp +import chex +from typing import Tuple, Optional, Union +from ..strategy import Strategy +from flax import struct +from flax import linen as nn + + +@struct.dataclass +class EvoState: + mean: chex.Array + sigma: chex.Array + weights: chex.Array # Weights for population members + best_member: chex.Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + + +@struct.dataclass +class EvoParams: + temperature: float = 12.5 # Temperature for softmax weights + lrate_sigma: float = 0.1 # Learning rate for population std + lrate_mean: float = 1.0 # Learning rate for population mean + sigma_init: float = 1.0 # Standard deviation + init_min: float = 0.0 + init_max: float = 0.0 + clip_min: float = -jnp.finfo(jnp.float32).max + clip_max: float = jnp.finfo(jnp.float32).max + + +def get_des_weights(popsize: int, temperature: float = 12.5): + """Compute discovered recombination weights.""" + ranks = jnp.arange(popsize) + ranks /= ranks.size - 1 + ranks = ranks - 0.5 + sigout = nn.sigmoid(temperature * ranks) + weights = nn.softmax(-20 * sigout) + return weights + + +class DES(Strategy): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + ): + """Discovered Evolution Strategy (Lange et al., 2022)""" + super().__init__(popsize, num_dims, pholder_params) + self.strategy_name = "DES" + + @property + def params_strategy(self) -> EvoParams: + """Return default parameters of evolution strategy.""" + return EvoParams() + + def initialize_strategy( + self, rng: chex.PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the evolution strategy.""" + # Get DES discovered recombination weights. + weights = get_des_weights(self.popsize, params.temperature) + initialization = jax.random.uniform( + rng, + (self.num_dims,), + minval=params.init_min, + maxval=params.init_max, + ) + state = EvoState( + mean=initialization, + sigma=params.sigma_init * jnp.ones(self.num_dims), + weights=weights.reshape(-1, 1), + best_member=initialization, + ) + return state + + def ask_strategy( + self, rng: chex.PRNGKey, state: EvoState, params: EvoParams + ) -> Tuple[chex.Array, EvoState]: + """`ask` for new proposed candidates to evaluate next.""" + z = jax.random.normal(rng, (self.popsize, self.num_dims)) # ~ N(0, I) + x = state.mean + z * state.sigma.reshape( + 1, self.num_dims + ) # ~ N(m, σ^2 I) + return x, state + + def tell_strategy( + self, + x: chex.Array, + fitness: chex.Array, + state: EvoState, + params: EvoParams, + ) -> EvoState: + """`tell` update to ES state.""" + weights = state.weights + x = x[fitness.argsort()] + # Weighted updates + weighted_mean = (weights * x).sum(axis=0) + weighted_sigma = jnp.sqrt( + (weights * (x - state.mean) ** 2).sum(axis=0) + 1e-06 + ) + mean = state.mean + params.lrate_mean * (weighted_mean - state.mean) + sigma = state.sigma + params.lrate_sigma * ( + weighted_sigma - state.sigma + ) + return state.replace(mean=mean, sigma=sigma) diff --git a/evosax/strategies/esmc.py b/evosax/strategies/esmc.py new file mode 100644 index 0000000..058fafb --- /dev/null +++ b/evosax/strategies/esmc.py @@ -0,0 +1,142 @@ +import jax +import jax.numpy as jnp +import chex +from typing import Tuple, Optional, Union +from ..strategy import Strategy +from ..utils import GradientOptimizer, OptState, OptParams, exp_decay +from flax import struct + + +@struct.dataclass +class EvoState: + mean: chex.Array + sigma: chex.Array + opt_state: OptState + best_member: chex.Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + + +@struct.dataclass +class EvoParams: + opt_params: OptParams + sigma_init: float = 0.03 + sigma_decay: float = 0.999 + sigma_limit: float = 0.01 + sigma_lrate: float = 0.2 # Learning rate for std + sigma_max_change: float = 0.2 # Clip adaptive sigma to 20% + init_min: float = 0.0 + init_max: float = 0.0 + clip_min: float = -jnp.finfo(jnp.float32).max + clip_max: float = jnp.finfo(jnp.float32).max + + +class ESMC(Strategy): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + opt_name: str = "adam", + lrate_init: float = 0.05, + lrate_decay: float = 1.0, + lrate_limit: float = 0.001, + sigma_init: float = 0.03, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + **fitness_kwargs: Union[bool, int, float] + ): + """ESMC (Merchant et al., 2021) + Reference: https://proceedings.mlr.press/v139/merchant21a.html + """ + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) + assert self.popsize & 1, "Population size must be odd" + assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"] + self.optimizer = GradientOptimizer[opt_name](self.num_dims) + self.strategy_name = "ESMC" + + # Set core kwargs es_params (lrate/sigma schedules) + self.lrate_init = lrate_init + self.lrate_decay = lrate_decay + self.lrate_limit = lrate_limit + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + + @property + def params_strategy(self) -> EvoParams: + """Return default parameters of evolution strategy.""" + opt_params = self.optimizer.default_params.replace( + lrate_init=self.lrate_init, + lrate_decay=self.lrate_decay, + lrate_limit=self.lrate_limit, + ) + return EvoParams( + opt_params=opt_params, + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) + + def initialize_strategy( + self, rng: chex.PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the evolution strategy.""" + initialization = jax.random.uniform( + rng, + (self.num_dims,), + minval=params.init_min, + maxval=params.init_max, + ) + state = EvoState( + mean=initialization, + sigma=jnp.ones(self.num_dims) * params.sigma_init, + opt_state=self.optimizer.initialize(params.opt_params), + best_member=initialization, + ) + return state + + def ask_strategy( + self, rng: chex.PRNGKey, state: EvoState, params: EvoParams + ) -> Tuple[chex.Array, EvoState]: + """`ask` for new parameter candidates to evaluate next.""" + # Antithetic sampling of noise + z_plus = jax.random.normal( + rng, + (int(self.popsize / 2), self.num_dims), + ) + z = jnp.concatenate( + [jnp.zeros((1, self.num_dims)), z_plus, -1.0 * z_plus] + ) + x = state.mean + z * state.sigma.reshape(1, self.num_dims) + return x, state + + def tell_strategy( + self, + x: chex.Array, + fitness: chex.Array, + state: EvoState, + params: EvoParams, + ) -> EvoState: + """Update both mean and dim.-wise isotropic Gaussian scale.""" + # Reconstruct noise from last mean/std estimates + noise = (x - state.mean) / state.sigma + bline_fitness = fitness[0] + noise = noise[1:] + fitness = fitness[1:] + noise_1 = noise[: int((self.popsize - 1) / 2)] + fit_1 = fitness[: int((self.popsize - 1) / 2)] + fit_2 = fitness[int((self.popsize - 1) / 2) :] + fit_diff = jnp.minimum(fit_1, bline_fitness) - jnp.minimum( + fit_2, bline_fitness + ) + fit_diff_noise = jnp.dot(noise_1.T, fit_diff) + theta_grad = 1.0 / int((self.popsize - 1) / 2) * fit_diff_noise + # Grad update using optimizer instance - decay lrate if desired + mean, opt_state = self.optimizer.step( + state.mean, theta_grad, state.opt_state, params.opt_params + ) + opt_state = self.optimizer.update(opt_state, params.opt_params) + sigma = state.sigma * params.sigma_decay + sigma = jnp.maximum(sigma, params.sigma_limit) + return state.replace(mean=mean, sigma=sigma, opt_state=opt_state) diff --git a/evosax/strategies/full_iamalgam.py b/evosax/strategies/full_iamalgam.py index eb6002f..b959780 100644 --- a/evosax/strategies/full_iamalgam.py +++ b/evosax/strategies/full_iamalgam.py @@ -1,8 +1,9 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy +from ..utils import exp_decay from flax import struct @@ -29,7 +30,7 @@ class EvoParams: delta_ams: float = 2.0 theta_sdr: float = 1.0 c_mult_init: float = 1.0 - sigma_init: float = 0.0 + sigma_init: float = 0.1 sigma_decay: float = 0.999 sigma_limit: float = 0.0 init_min: float = 0.0 @@ -39,11 +40,21 @@ class EvoParams: class Full_iAMaLGaM(Strategy): - def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.35): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.35, + sigma_init: float = 0.0, + sigma_decay: float = 0.99, + sigma_limit: float = 0.0, + **fitness_kwargs: Union[bool, int, float] + ): """(Iterative) AMaLGaM (Bosman et al., 2013) - Full Covariance Reference: https://tinyurl.com/y9fcccx2 """ - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert 0 <= elite_ratio <= 1 self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) @@ -56,6 +67,11 @@ def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.35): self.ams_popsize = int(alpha_ams * (self.popsize - 1)) self.strategy_name = "Full_iAMaLGaM" + # Set core kwargs es_params + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" @@ -72,8 +88,13 @@ def params_strategy(self) -> EvoParams: / (self.num_dims ** a_2_shift) ) - params = EvoParams(eta_sigma=eta_sigma, eta_shift=eta_shift) - return params + return EvoParams( + eta_sigma=eta_sigma, + eta_shift=eta_shift, + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) def initialize_strategy( self, rng: chex.PRNGKey, params: EvoParams @@ -153,8 +174,7 @@ def tell_strategy( C = update_cov_amalgam(members_elite, state.C, mean, params.eta_sigma) # Decay isotropic part of Gaussian search distribution - sigma = state.sigma * params.sigma_decay - sigma = jnp.maximum(sigma, params.sigma_limit) + sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit) return state.replace( c_mult=c_mult, nis_counter=nis_counter, diff --git a/evosax/strategies/gesmr_ga.py b/evosax/strategies/gesmr_ga.py new file mode 100644 index 0000000..1a67ee0 --- /dev/null +++ b/evosax/strategies/gesmr_ga.py @@ -0,0 +1,166 @@ +import jax +import jax.numpy as jnp +import chex +from typing import Tuple, Optional, Union +from ..strategy import Strategy +from flax import struct + + +@struct.dataclass +class EvoState: + rng: chex.PRNGKey + mean: chex.Array + archive: chex.Array + fitness: chex.Array + sigma: chex.Array + best_member: chex.Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + + +@struct.dataclass +class EvoParams: + sigma_init: float = 0.07 + sigma_meta: float = 2.0 + init_min: float = 0.0 + init_max: float = 0.0 + clip_min: float = -jnp.finfo(jnp.float32).max + clip_max: float = jnp.finfo(jnp.float32).max + + +class GESMR_GA(Strategy): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.5, + sigma_ratio: float = 0.5, + **fitness_kwargs: Union[bool, int, float] + ): + """Self-Adaptation Mutation Rate GA.""" + + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) + self.elite_ratio = elite_ratio + self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) + self.num_sigma_groups = int(jnp.sqrt(self.popsize)) + self.members_per_group = int( + jnp.ceil(self.popsize / self.num_sigma_groups) + ) + self.sigma_ratio = sigma_ratio + self.sigma_popsize = max( + 1, int(self.num_sigma_groups * self.sigma_ratio) + ) + self.strategy_name = "GESMR_GA" + + @property + def params_strategy(self) -> EvoParams: + """Return default parameters of evolution strategy.""" + return EvoParams() + + def initialize_strategy( + self, rng: chex.PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the differential evolution strategy.""" + rng, rng_init = jax.random.split(rng) + initialization = jax.random.uniform( + rng_init, + (self.elite_popsize, self.num_dims), + minval=params.init_min, + maxval=params.init_max, + ) + state = EvoState( + rng=rng, + mean=initialization[0], + archive=initialization, + fitness=jnp.zeros(self.elite_popsize) + jnp.finfo(jnp.float32).max, + sigma=jnp.zeros(self.num_sigma_groups) + params.sigma_init, + best_member=initialization[0], + ) + return state + + def ask_strategy( + self, rng: chex.PRNGKey, state: EvoState, params: EvoParams + ) -> Tuple[chex.Array, EvoState]: + """`ask` for new proposed candidates to evaluate next.""" + rng, rng_idx, rng_eps_x, rng_eps_s = jax.random.split(rng, 4) + # Sample noise for mutation of x and sigma + eps_x = jax.random.normal(rng_eps_x, (self.popsize, self.num_dims)) + eps_s = jax.random.uniform( + rng_eps_s, (self.num_sigma_groups,), minval=-1, maxval=1 + ) + + # Sample members to evaluate from parent archive + idx = jax.random.choice( + rng_idx, jnp.arange(self.elite_popsize), (self.popsize - 1,) + ) + x = jnp.concatenate([state.archive[0][None, :], state.archive[idx]]) + + # Store fitness before perturbation (used to compute meta-fitness) + fitness_mem = jnp.concatenate( + [state.fitness[0][None], state.fitness[idx]] + ) + + # Apply sigma mutation on group level -> repeat for popmember broadcast + sigma_perturb = state.sigma * params.sigma_meta ** eps_s + sigma_repeated = jnp.repeat(sigma_perturb, self.members_per_group)[ + : self.popsize + ] + sigma = jnp.concatenate([state.sigma[0][None], sigma_repeated[1:]]) + + # Apply x mutation -> scale specific to group membership + x += sigma[:, None] * eps_x + return x, state.replace( + archive=x, fitness=fitness_mem, sigma=sigma_perturb + ) + + def tell_strategy( + self, + x: chex.Array, + fitness: chex.Array, + state: EvoState, + params: EvoParams, + ) -> EvoState: + """`tell` update to ES state.""" + # Select best x members + idx = jnp.argsort(fitness)[: self.elite_popsize] + archive = x[idx] + + # Select best sigma based on function value improvement + group_ids = jnp.repeat( + jnp.arange(self.members_per_group), self.num_sigma_groups + )[: self.popsize] + delta_fitness = fitness - state.fitness + + best_deltas = [] + for k in range(self.num_sigma_groups): + sub_mask = group_ids == k + sub_delta = ( + sub_mask * delta_fitness + + (1 - sub_mask) * jnp.finfo(jnp.float32).max + ) + max_sub_delta = jnp.min(sub_delta) + best_deltas.append(max_sub_delta) + + idx_select = jnp.argsort(jnp.array(best_deltas))[: self.sigma_popsize] + sigma_elite = state.sigma[idx_select] + + # Resample sigmas with replacement + rng, rng_sigma = jax.random.split(state.rng) + idx_s = jax.random.choice( + rng_sigma, + jnp.arange(self.sigma_popsize), + (self.num_sigma_groups - 1,), + ) + sigma = jnp.concatenate([state.sigma[0][None], sigma_elite[idx_s]]) + + # Set mean to best member seen so far + improved = fitness[0] < state.best_fitness + best_mean = jax.lax.select(improved, archive[0], state.best_member) + return state.replace( + rng=rng, + fitness=fitness[idx], + archive=archive, + sigma=sigma, + mean=best_mean, + ) diff --git a/evosax/strategies/gld.py b/evosax/strategies/gld.py index 028bd74..ae363f3 100644 --- a/evosax/strategies/gld.py +++ b/evosax/strategies/gld.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy from flax import struct @@ -16,7 +16,7 @@ class EvoState: @struct.dataclass class EvoParams: - radius_max: float = 0.1 + radius_max: float = 0.2 radius_min: float = 0.001 radius_decay: float = 5 init_min: float = 0.0 @@ -26,10 +26,16 @@ class EvoParams: class GLD(Strategy): - def __init__(self, num_dims: int, popsize: int): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + **fitness_kwargs: Union[bool, int, float] + ): """Gradientless Descent (Golovin et al., 2019) Reference: https://arxiv.org/pdf/1911.06317.pdf""" - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) self.strategy_name = "GLD" @property diff --git a/evosax/strategies/guided_es.py b/evosax/strategies/guided_es.py new file mode 100644 index 0000000..142517d --- /dev/null +++ b/evosax/strategies/guided_es.py @@ -0,0 +1,186 @@ +import jax +import jax.numpy as jnp +import chex +from typing import Tuple, Optional, Union +from functools import partial +from ..strategy import Strategy +from ..utils import GradientOptimizer, OptState, OptParams, exp_decay +from flax import struct +from evosax.utils import get_best_fitness_member + + +@struct.dataclass +class EvoState: + mean: chex.Array + sigma: float + opt_state: OptState + grad_subspace: chex.Array + best_member: chex.Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + + +@struct.dataclass +class EvoParams: + opt_params: OptParams + sigma_init: float = 0.03 + sigma_decay: float = 1.0 + sigma_limit: float = 0.01 + alpha: float = 0.5 + beta: float = 1.0 + init_min: float = 0.0 + init_max: float = 0.0 + clip_min: float = -jnp.finfo(jnp.float32).max + clip_max: float = jnp.finfo(jnp.float32).max + + +class GuidedES(Strategy): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + subspace_dims: int = 1, # k param in example notebook + opt_name: str = "sgd", + lrate_init: float = 0.05, + lrate_decay: float = 1.0, + lrate_limit: float = 0.001, + sigma_init: float = 0.03, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + **fitness_kwargs: Union[bool, int, float], + ): + """Guided ES (Maheswaranathan et al., 2018) + Reference: https://arxiv.org/abs/1806.10230 + Note that there are a couple of JAX-based adaptations: + """ + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) + assert not self.popsize & 1, "Population size must be even" + assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"] + assert ( + subspace_dims <= self.num_dims + ), "Subspace has to be smaller than optimization dims." + self.optimizer = GradientOptimizer[opt_name](self.num_dims) + self.subspace_dims = min(subspace_dims, self.num_dims) + if self.subspace_dims < subspace_dims: + print( + "Subspace has to be smaller than optimization dims. Set to" + f" {self.subspace_dims} instead of {subspace_dims}." + ) + self.strategy_name = "GuidedES" + + # Set core kwargs es_params (lrate/sigma schedules) + self.lrate_init = lrate_init + self.lrate_decay = lrate_decay + self.lrate_limit = lrate_limit + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + + @property + def params_strategy(self) -> EvoParams: + """Return default parameters of evolution strategy.""" + return EvoParams(opt_params=self.optimizer.default_params) + + def initialize_strategy( + self, rng: chex.PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the evolution strategy.""" + rng_init, rng_sub = jax.random.split(rng) + initialization = jax.random.uniform( + rng_init, + (self.num_dims,), + minval=params.init_min, + maxval=params.init_max, + ) + + grad_subspace = jax.random.normal( + rng_sub, (self.subspace_dims, self.num_dims) + ) + + state = EvoState( + mean=initialization, + sigma=params.sigma_init, + opt_state=self.optimizer.initialize(params.opt_params), + grad_subspace=grad_subspace, + best_member=initialization, + ) + return state + + def ask_strategy( + self, rng: chex.PRNGKey, state: EvoState, params: EvoParams + ) -> Tuple[chex.Array, EvoState]: + """`ask` for new parameter candidates to evaluate next.""" + a = state.sigma * jnp.sqrt(params.alpha / self.num_dims) + c = state.sigma * jnp.sqrt((1.0 - params.alpha) / self.subspace_dims) + key_full, key_sub = jax.random.split(rng, 2) + eps_full = jax.random.normal( + key_full, shape=(self.num_dims, int(self.popsize / 2)) + ) + eps_subspace = jax.random.normal( + key_sub, shape=(self.subspace_dims, int(self.popsize / 2)) + ) + Q, _ = jnp.linalg.qr(state.grad_subspace) + # Antithetic sampling of noise + z_plus = a * eps_full + c * jnp.dot(Q, eps_subspace) + z_plus = jnp.swapaxes(z_plus, 0, 1) + z = jnp.concatenate([z_plus, -1.0 * z_plus]) + x = state.mean + z + return x, state + + @partial(jax.jit, static_argnums=(0,)) + def tell( + self, + x: chex.Array, + fitness: chex.Array, + state: EvoState, + params: Optional[EvoParams] = None, + gradient: Optional[chex.Array] = None, + ) -> EvoState: + """`tell` performance data for strategy state update.""" + # Use default hyperparameters if no other settings provided + if params is None: + params = self.default_params + + # Flatten params if using param reshaper for ES update + if self.use_param_reshaper: + x = self.param_reshaper.flatten(x) + + # Perform fitness reshaping inside of strategy tell call (if desired) + fitness_re = self.fitness_shaper.apply(x, fitness) + + # Reconstruct noise from last mean/std estimates + noise = (x - state.mean) / state.sigma + noise_1 = noise[: int(self.popsize / 2)] + fit_1 = fitness_re[: int(self.popsize / 2)] + fit_2 = fitness_re[int(self.popsize / 2) :] + fit_diff = fit_1 - fit_2 + fit_diff_noise = jnp.dot(noise_1.T, fit_diff) + theta_grad = (params.beta / self.popsize) * fit_diff_noise + + # Add grad FIFO-style to subspace archive (only if provided else FD) + grad_subspace = jnp.zeros((self.subspace_dims, self.num_dims)) + grad_subspace = grad_subspace.at[:-1, :].set(state.grad_subspace[1:, :]) + if gradient is not None: + grad_subspace = grad_subspace.at[-1, :].set(gradient) + else: + grad_subspace = grad_subspace.at[-1, :].set(theta_grad) + state = state.replace(grad_subspace=grad_subspace) + + # Grad update using optimizer instance - decay lrate if desired + mean, opt_state = self.optimizer.step( + state.mean, theta_grad, state.opt_state, params.opt_params + ) + opt_state = self.optimizer.update(opt_state, params.opt_params) + + # Update lrate and standard deviation based on min and decay + sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit) + state = state.replace(mean=mean, sigma=sigma, opt_state=opt_state) + + # Check if there is a new best member & update trackers + best_member, best_fitness = get_best_fitness_member(x, fitness, state) + return state.replace( + best_member=best_member, + best_fitness=best_fitness, + gen_counter=state.gen_counter + 1, + ) diff --git a/evosax/strategies/indep_iamalgam.py b/evosax/strategies/indep_iamalgam.py index 70faed4..ef73628 100644 --- a/evosax/strategies/indep_iamalgam.py +++ b/evosax/strategies/indep_iamalgam.py @@ -1,8 +1,9 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy +from ..utils import exp_decay from .full_iamalgam import ( anticipated_mean_shift, adaptive_variance_scaling, @@ -44,11 +45,21 @@ class EvoParams: class Indep_iAMaLGaM(Strategy): - def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.35): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.35, + sigma_init: float = 0.0, + sigma_decay: float = 0.99, + sigma_limit: float = 0.0, + **fitness_kwargs: Union[bool, int, float] + ): """(Iterative) AMaLGaM (Bosman et al., 2013) - Diagonal Covariance Reference: https://tinyurl.com/y9fcccx2 """ - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert 0 <= elite_ratio <= 1 self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) @@ -61,6 +72,11 @@ def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.35): self.ams_popsize = int(alpha_ams * (self.popsize - 1)) self.strategy_name = "Indep_iAMaLGaM" + # Set core kwargs es_params + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" @@ -77,8 +93,13 @@ def params_strategy(self) -> EvoParams: / (self.num_dims ** a_2_shift) ) - params = EvoParams(eta_sigma=eta_sigma, eta_shift=eta_shift) - return params + return EvoParams( + eta_sigma=eta_sigma, + eta_shift=eta_shift, + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) def initialize_strategy( self, rng: chex.PRNGKey, params: EvoParams @@ -158,8 +179,7 @@ def tell_strategy( C = update_cov_amalgam(members_elite, state.C, mean, params.eta_sigma) # Decay isotropic part of Gaussian search distribution - sigma = state.sigma * params.sigma_decay - sigma = jnp.maximum(sigma, params.sigma_limit) + sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit) return state.replace( c_mult=c_mult, nis_counter=nis_counter, diff --git a/evosax/strategies/ipop_cma_es.py b/evosax/strategies/ipop_cma_es.py index bdbb071..65d6fca 100644 --- a/evosax/strategies/ipop_cma_es.py +++ b/evosax/strategies/ipop_cma_es.py @@ -1,6 +1,6 @@ import jax import chex -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from functools import partial from .cma_es import CMA_ES from ..restarts.restarter import WrapperState, WrapperParams @@ -19,14 +19,27 @@ class RestartParams: class IPOP_CMA_ES(object): - def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.5, + sigma_init: float = 1.0, + **fitness_kwargs: Union[bool, int, float] + ): """IPOP-CMA-ES (Auer & Hansen, 2005). Reference: http://www.cmap.polytechnique.fr/~nikolaus.hansen/cec2005ipopcmaes.pdf """ self.strategy_name = "IPOP_CMA_ES" # Instantiate base strategy & wrap it with restart wrapper self.strategy = CMA_ES( - num_dims=num_dims, popsize=popsize, elite_ratio=elite_ratio + popsize=popsize, + num_dims=num_dims, + pholder_params=pholder_params, + elite_ratio=elite_ratio, + sigma_init=sigma_init, + **fitness_kwargs ) from ..restarts import IPOP_Restarter from ..restarts.termination import cma_criterion, spread_criterion diff --git a/evosax/strategies/lm_ma_es.py b/evosax/strategies/lm_ma_es.py index 7c4d033..326b299 100644 --- a/evosax/strategies/lm_ma_es.py +++ b/evosax/strategies/lm_ma_es.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy from .cma_es import get_cma_elite_weights from flax import struct @@ -41,21 +41,27 @@ class EvoParams: class LM_MA_ES(Strategy): def __init__( self, - num_dims: int, popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, elite_ratio: float = 0.5, memory_size: int = 10, + sigma_init: float = 1.0, + **fitness_kwargs: Union[bool, int, float] ): """Limited Memory MA-ES (Loshchilov et al., 2017) Reference: https://arxiv.org/pdf/1705.06693.pdf """ - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert 0 <= elite_ratio <= 1 self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) self.memory_size = memory_size self.strategy_name = "LM_MA_ES" + # Set core kwargs es_params + self.sigma_init = sigma_init + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" @@ -85,6 +91,7 @@ def params_strategy(self) -> EvoParams: d_sigma=d_sigma, chi_n=chi_n, mu_w=mu_w, + sigma_init=self.sigma_init, ) return params diff --git a/evosax/strategies/ma_es.py b/evosax/strategies/ma_es.py index 1e65e92..86a7cfc 100644 --- a/evosax/strategies/ma_es.py +++ b/evosax/strategies/ma_es.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy from .cma_es import get_cma_elite_weights from flax import struct @@ -36,16 +36,27 @@ class EvoParams: class MA_ES(Strategy): - def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.5, + sigma_init: float = 1.0, + **fitness_kwargs: Union[bool, int, float] + ): """MA-ES (Bayer & Sendhoff, 2017) Reference: https://www.honda-ri.de/pubs/pdf/3376.pdf """ - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert 0 <= elite_ratio <= 1 self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) self.strategy_name = "MA_ES" + # Set core kwargs es_params + self.sigma_init = sigma_init + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" @@ -75,6 +86,7 @@ def params_strategy(self) -> EvoParams: c_sigma=c_sigma, d_sigma=d_sigma, chi_n=chi_n, + sigma_init=self.sigma_init, ) return params diff --git a/evosax/strategies/mr15_ga.py b/evosax/strategies/mr15_ga.py new file mode 100644 index 0000000..067731f --- /dev/null +++ b/evosax/strategies/mr15_ga.py @@ -0,0 +1,139 @@ +import jax +import jax.numpy as jnp +import chex +from typing import Tuple, Optional, Union +from ..strategy import Strategy +from .simple_ga import single_mate +from flax import struct + + +@struct.dataclass +class EvoState: + mean: chex.Array + archive: chex.Array + fitness: chex.Array + sigma: chex.Array + best_member: chex.Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + + +@struct.dataclass +class EvoParams: + cross_over_rate: float = 0.0 + sigma_init: float = 0.07 + sigma_ratio: float = 0.15 + init_min: float = 0.0 + init_max: float = 0.0 + clip_min: float = -jnp.finfo(jnp.float32).max + clip_max: float = jnp.finfo(jnp.float32).max + + +class MR15_GA(Strategy): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.0, + sigma_ratio: float = 0.15, + sigma_init: float = 0.1, + **fitness_kwargs: Union[bool, int, float] + ): + """1/5 MR Genetic Algorithm (Rechenberg, 1987) + Reference: https://link.springer.com/chapter/10.1007/978-3-642-81283-5_8 + """ + + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) + self.elite_ratio = elite_ratio + self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) + self.strategy_name = "MR15_GA" + + # Set core kwargs es_params + self.sigma_ratio = sigma_ratio # no. mutation that have to improve + self.sigma_init = sigma_init + + @property + def params_strategy(self) -> EvoParams: + """Return default parameters of evolution strategy.""" + return EvoParams( + sigma_init=self.sigma_init, sigma_ratio=self.sigma_ratio + ) + + def initialize_strategy( + self, rng: chex.PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the differential evolution strategy.""" + initialization = jax.random.uniform( + rng, + (self.elite_popsize, self.num_dims), + minval=params.init_min, + maxval=params.init_max, + ) + state = EvoState( + mean=initialization.mean(axis=0), + archive=initialization, + fitness=jnp.zeros(self.elite_popsize) + jnp.finfo(jnp.float32).max, + sigma=params.sigma_init, + best_member=initialization.mean(axis=0), + ) + return state + + def ask_strategy( + self, rng: chex.PRNGKey, state: EvoState, params: EvoParams + ) -> Tuple[chex.Array, EvoState]: + """ + `ask` for new proposed candidates to evaluate next. + 1. For each member of elite: + - Sample two current elite members (a & b) + - Cross over all dims of a with corresponding one from b + if random number > co-rate + - Additionally add noise on top of all elite parameters + """ + rng, rng_eps, rng_idx_a, rng_idx_b = jax.random.split(rng, 4) + rng_mate = jax.random.split(rng, self.popsize) + epsilon = ( + jax.random.normal(rng_eps, (self.popsize, self.num_dims)) + * state.sigma + ) + elite_ids = jnp.arange(self.elite_popsize) + idx_a = jax.random.choice(rng_idx_a, elite_ids, (self.popsize,)) + idx_b = jax.random.choice(rng_idx_b, elite_ids, (self.popsize,)) + members_a = state.archive[idx_a] + members_b = state.archive[idx_b] + x = jax.vmap(single_mate, in_axes=(0, 0, 0, None))( + rng_mate, members_a, members_b, params.cross_over_rate + ) + x += epsilon + return jnp.squeeze(x), state + + def tell_strategy( + self, + x: chex.Array, + fitness: chex.Array, + state: EvoState, + params: EvoParams, + ) -> EvoState: + """ + `tell` update to ES state. + If fitness of y <= fitness of x -> replace in population. + """ + # Combine current elite and recent generation info + fitness = jnp.concatenate([fitness, state.fitness]) + solution = jnp.concatenate([x, state.archive]) + # Select top elite from total archive info + idx = jnp.argsort(fitness)[0 : self.elite_popsize] + fitness = fitness[idx] + archive = solution[idx] + # Update mutation sigma - double if more than 15% improved + good_mutations_ratio = jnp.mean(fitness < state.best_fitness) + increase_sigma = good_mutations_ratio > params.sigma_ratio + sigma = jax.lax.select( + increase_sigma, 2 * state.sigma, 0.5 * state.sigma + ) + # Set mean to best member seen so far + improved = fitness[0] < state.best_fitness + best_mean = jax.lax.select(improved, archive[0], state.best_member) + return state.replace( + fitness=fitness, archive=archive, sigma=sigma, mean=best_mean + ) diff --git a/evosax/strategies/open_es.py b/evosax/strategies/open_es.py index 8461950..3a54af6 100755 --- a/evosax/strategies/open_es.py +++ b/evosax/strategies/open_es.py @@ -1,9 +1,9 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy -from ..utils import GradientOptimizer, OptState, OptParams +from ..utils import GradientOptimizer, OptState, OptParams, exp_decay from flax import struct @@ -30,20 +30,51 @@ class EvoParams: class OpenES(Strategy): - def __init__(self, num_dims: int, popsize: int, opt_name: str = "adam"): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + opt_name: str = "adam", + lrate_init: float = 0.05, + lrate_decay: float = 1.0, + lrate_limit: float = 0.001, + sigma_init: float = 0.03, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + **fitness_kwargs: Union[bool, int, float] + ): """OpenAI-ES (Salimans et al. (2017) Reference: https://arxiv.org/pdf/1703.03864.pdf Inspired by: https://github.com/hardmaru/estool/blob/master/es.py""" - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert not self.popsize & 1, "Population size must be even" - assert opt_name in ["sgd", "adam", "rmsprop", "clipup"] + assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"] self.optimizer = GradientOptimizer[opt_name](self.num_dims) self.strategy_name = "OpenES" + # Set core kwargs es_params (lrate/sigma schedules) + self.lrate_init = lrate_init + self.lrate_decay = lrate_decay + self.lrate_limit = lrate_limit + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" - return EvoParams(opt_params=self.optimizer.default_params) + opt_params = self.optimizer.default_params.replace( + lrate_init=self.lrate_init, + lrate_decay=self.lrate_decay, + lrate_limit=self.lrate_limit, + ) + return EvoParams( + opt_params=opt_params, + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) def initialize_strategy( self, rng: chex.PRNGKey, params: EvoParams @@ -95,6 +126,5 @@ def tell_strategy( state.mean, theta_grad, state.opt_state, params.opt_params ) opt_state = self.optimizer.update(opt_state, params.opt_params) - sigma = state.sigma * params.sigma_decay - sigma = jnp.maximum(sigma, params.sigma_limit) + sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit) return state.replace(mean=mean, sigma=sigma, opt_state=opt_state) diff --git a/evosax/strategies/pbt.py b/evosax/strategies/pbt.py index 1f41cdb..41c7dd9 100755 --- a/evosax/strategies/pbt.py +++ b/evosax/strategies/pbt.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy from flax import struct @@ -27,10 +27,16 @@ class EvoParams: class PBT(Strategy): - def __init__(self, num_dims: int, popsize: int): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + **fitness_kwargs: Union[bool, int, float] + ): """Synchronous Population-Based Training (Jaderberg et al., 2017) Reference: https://arxiv.org/abs/1711.09846""" - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) self.strategy_name = "PBT" @property diff --git a/evosax/strategies/persistent_es.py b/evosax/strategies/persistent_es.py index 4d9f86f..ee50fc1 100644 --- a/evosax/strategies/persistent_es.py +++ b/evosax/strategies/persistent_es.py @@ -1,9 +1,9 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy -from ..utils import GradientOptimizer, OptState, OptParams +from ..utils import GradientOptimizer, OptState, OptParams, exp_decay from flax import struct @@ -34,21 +34,52 @@ class EvoParams: class PersistentES(Strategy): - def __init__(self, num_dims: int, popsize: int, opt_name: str = "adam"): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + opt_name: str = "adam", + lrate_init: float = 0.05, + lrate_decay: float = 1.0, + lrate_limit: float = 0.001, + sigma_init: float = 0.03, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + **fitness_kwargs: Union[bool, int, float] + ): """Persistent ES (Vicol et al., 2021). Reference: http://proceedings.mlr.press/v139/vicol21a.html Inspired by: http://proceedings.mlr.press/v139/vicol21a/vicol21a-supp.pdf """ - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert not self.popsize & 1, "Population size must be even" - assert opt_name in ["sgd", "adam", "rmsprop", "clipup"] + assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"] self.optimizer = GradientOptimizer[opt_name](self.num_dims) self.strategy_name = "PersistentES" + # Set core kwargs es_params (lrate/sigma schedules) + self.lrate_init = lrate_init + self.lrate_decay = lrate_decay + self.lrate_limit = lrate_limit + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" - return EvoParams(opt_params=self.optimizer.default_params) + opt_params = self.optimizer.default_params.replace( + lrate_init=self.lrate_init, + lrate_decay=self.lrate_decay, + lrate_limit=self.lrate_limit, + ) + return EvoParams( + opt_params=opt_params, + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) def initialize_strategy( self, rng: chex.PRNGKey, params: EvoParams @@ -105,8 +136,7 @@ def tell_strategy( opt_state = self.optimizer.update(opt_state, params.opt_params) inner_step_counter = state.inner_step_counter + params.K - sigma = state.sigma * params.sigma_decay - sigma = jnp.maximum(sigma, params.sigma_limit) + sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit) # Reset accumulated antithetic noise memory if done with inner problem reset = inner_step_counter >= params.T inner_step_counter = jax.lax.select(reset, 0, inner_step_counter) diff --git a/evosax/strategies/pgpe.py b/evosax/strategies/pgpe.py index 371da7b..6057fb7 100755 --- a/evosax/strategies/pgpe.py +++ b/evosax/strategies/pgpe.py @@ -1,9 +1,9 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy -from ..utils import GradientOptimizer, OptState, OptParams +from ..utils import GradientOptimizer, OptState, OptParams, exp_decay from flax import struct @@ -34,28 +34,54 @@ class EvoParams: class PGPE(Strategy): def __init__( self, - num_dims: int, popsize: int, - elite_ratio: float = 0.1, - opt_name: str = "sgd", + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 1.0, + opt_name: str = "adam", + lrate_init: float = 0.05, + lrate_decay: float = 1.0, + lrate_limit: float = 0.001, + sigma_init: float = 0.03, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + **fitness_kwargs: Union[bool, int, float] ): """PGPE (e.g. Sehnke et al., 2010) Reference: https://tinyurl.com/2p8bn956 Inspired by: https://github.com/hardmaru/estool/blob/master/es.py""" - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert 0 <= elite_ratio <= 1 self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize / 2 * self.elite_ratio)) assert not self.popsize & 1, "Population size must be even" - assert opt_name in ["sgd", "adam", "rmsprop", "clipup"] + assert opt_name in ["sgd", "adam", "rmsprop", "clipup", "adan"] self.optimizer = GradientOptimizer[opt_name](self.num_dims) self.strategy_name = "PGPE" + # Set core kwargs es_params (lrate/sigma schedules) + self.lrate_init = lrate_init + self.lrate_decay = lrate_decay + self.lrate_limit = lrate_limit + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" - return EvoParams(opt_params=self.optimizer.default_params) + opt_params = self.optimizer.default_params.replace( + lrate_init=self.lrate_init, + lrate_decay=self.lrate_decay, + lrate_limit=self.lrate_limit, + ) + return EvoParams( + opt_params=opt_params, + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) def initialize_strategy( self, rng: chex.PRNGKey, params: EvoParams @@ -85,7 +111,7 @@ def ask_strategy( (int(self.popsize / 2), self.num_dims), ) z = jnp.concatenate([z_plus, -1.0 * z_plus]) - x = state.mean + z * state.sigma.reshape(1, self.num_dims) + x = state.mean + state.sigma * z return x, state def tell_strategy( @@ -98,9 +124,9 @@ def tell_strategy( """Update both mean and dim.-wise isotropic Gaussian scale.""" # Reconstruct noise from last mean/std estimates noise = (x - state.mean) / state.sigma - noise_1 = noise[: int(self.popsize / 2)] - fit_1 = fitness[: int(self.popsize / 2)] - fit_2 = fitness[int(self.popsize / 2) :] + noise_1 = noise[::2] + fit_1 = fitness[::2] + fit_2 = fitness[1::2] elite_idx = jnp.minimum(fit_1, fit_2).argsort()[: self.elite_popsize] fitness_elite = jnp.concatenate([fit_1[elite_idx], fit_2[elite_idx]]) @@ -120,18 +146,18 @@ def tell_strategy( - (state.sigma * state.sigma).reshape(1, self.num_dims) ) / state.sigma.reshape(1, self.num_dims) rS = (fit_1 + fit_2) / 2.0 - jnp.mean(fitness_elite) - delta_sigma = (jnp.dot(rS, S)) / self.elite_popsize - change_sigma = params.sigma_lrate * delta_sigma - change_sigma = jnp.minimum( - change_sigma, params.sigma_max_change * state.sigma - ) - change_sigma = jnp.maximum( - change_sigma, -params.sigma_max_change * state.sigma - ) + delta_sigma = jnp.dot(rS, S) / (self.elite_popsize / 2) + + allowed_delta = jnp.abs(state.sigma) * params.sigma_max_change + min_allowed = state.sigma - allowed_delta + max_allowed = state.sigma + allowed_delta # adjust sigma according to the adaptive sigma calculation # for stability, don't let sigma move more than 20% of orig value - sigma = state.sigma - change_sigma - sigma = sigma * params.sigma_decay - sigma = jnp.maximum(sigma, params.sigma_limit) + sigma = jnp.clip( + state.sigma - params.sigma_lrate * delta_sigma, + min_allowed, + max_allowed, + ) + sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit) return state.replace(mean=mean, sigma=sigma, opt_state=opt_state) diff --git a/evosax/strategies/pso.py b/evosax/strategies/pso.py index 67f11d0..35128ce 100755 --- a/evosax/strategies/pso.py +++ b/evosax/strategies/pso.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy from flax import struct @@ -31,10 +31,16 @@ class EvoParams: class PSO(Strategy): - def __init__(self, num_dims: int, popsize: int): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + **fitness_kwargs: Union[bool, int, float] + ): """Particle Swarm Optimization (Kennedy & Eberhart, 1995) Reference: https://ieeexplore.ieee.org/document/488968""" - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) self.strategy_name = "PSO" @property diff --git a/evosax/strategies/rm_es.py b/evosax/strategies/rm_es.py index d8bf0b6..105bc39 100644 --- a/evosax/strategies/rm_es.py +++ b/evosax/strategies/rm_es.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy from flax import struct @@ -62,21 +62,27 @@ def get_cma_elite_weights( class RmES(Strategy): def __init__( self, - num_dims: int, popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, elite_ratio: float = 0.5, memory_size: int = 10, + sigma_init: float = 1.0, + **fitness_kwargs: Union[bool, int, float] ): """Rank-m ES (Li & Zhang, 2017) Reference: https://ieeexplore.ieee.org/document/8080257 """ - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert 0 <= elite_ratio <= 1 self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) self.memory_size = memory_size # number of ranks self.strategy_name = "RmES" + # Set core kwargs es_params + self.sigma_init = sigma_init + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" @@ -89,6 +95,7 @@ def params_strategy(self) -> EvoParams: c_c=c_c, c_sigma=jnp.minimum(2 / (self.num_dims + 7), 0.05), mu_eff=mu_eff, + sigma_init=self.sigma_init, ) return params diff --git a/evosax/strategies/samr_ga.py b/evosax/strategies/samr_ga.py new file mode 100644 index 0000000..c287ac1 --- /dev/null +++ b/evosax/strategies/samr_ga.py @@ -0,0 +1,110 @@ +import jax +import jax.numpy as jnp +import chex +from typing import Tuple, Optional, Union +from ..strategy import Strategy +from flax import struct + + +@struct.dataclass +class EvoState: + mean: chex.Array + archive: chex.Array + fitness: chex.Array + sigma: chex.Array + best_member: chex.Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + + +@struct.dataclass +class EvoParams: + sigma_init: float = 0.07 + sigma_meta: float = 2.0 + sigma_best_limit: float = 0.0001 + init_min: float = 0.0 + init_max: float = 0.0 + clip_min: float = -jnp.finfo(jnp.float32).max + clip_max: float = jnp.finfo(jnp.float32).max + + +class SAMR_GA(Strategy): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.0, + **fitness_kwargs: Union[bool, int, float] + ): + """Self-Adaptation Mutation Rate GA.""" + + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) + self.elite_ratio = elite_ratio + self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) + self.strategy_name = "SAMR_GA" + + @property + def params_strategy(self) -> EvoParams: + """Return default parameters of evolution strategy.""" + return EvoParams() + + def initialize_strategy( + self, rng: chex.PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the differential evolution strategy.""" + initialization = jax.random.uniform( + rng, + (self.elite_popsize, self.num_dims), + minval=params.init_min, + maxval=params.init_max, + ) + state = EvoState( + mean=initialization.mean(axis=0), + archive=initialization, + fitness=jnp.zeros(self.elite_popsize) + jnp.finfo(jnp.float32).max, + sigma=jnp.zeros(self.elite_popsize) + params.sigma_init, + best_member=initialization.mean(axis=0), + ) + return state + + def ask_strategy( + self, rng: chex.PRNGKey, state: EvoState, params: EvoParams + ) -> Tuple[chex.Array, EvoState]: + """`ask` for new proposed candidates to evaluate next.""" + rng, rng_idx, rng_eps_x, rng_eps_s = jax.random.split(rng, 4) + eps_x = jax.random.normal(rng_eps_x, (self.popsize, self.num_dims)) + eps_s = jax.random.uniform( + rng_eps_s, (self.popsize,), minval=-1, maxval=1 + ) + idx = jax.random.choice( + rng_idx, jnp.arange(self.elite_popsize), (self.popsize - 1,) + ) + x = jnp.concatenate([state.archive[0][None, :], state.archive[idx]]) + sigma_0 = jnp.array( + [jnp.maximum(params.sigma_best_limit, state.sigma[0])] + ) + sigma = jnp.concatenate([sigma_0, state.sigma[idx]]) + sigma_gen = sigma * params.sigma_meta ** eps_s + x += sigma_gen[:, None] * eps_x + return x, state.replace(archive=x, sigma=sigma_gen) + + def tell_strategy( + self, + x: chex.Array, + fitness: chex.Array, + state: EvoState, + params: EvoParams, + ) -> EvoState: + """`tell` update to ES state.""" + idx = jnp.argsort(fitness)[: self.elite_popsize] + fitness = fitness[idx] + archive = x[idx] + sigma = state.sigma[idx] + + # Set mean to best member seen so far + improved = fitness[0] < state.best_fitness + best_mean = jax.lax.select(improved, archive[0], state.best_member) + return state.replace( + fitness=fitness, archive=archive, sigma=sigma, mean=best_mean + ) diff --git a/evosax/strategies/sep_cma_es.py b/evosax/strategies/sep_cma_es.py index 3fbc2c6..a80f2ab 100644 --- a/evosax/strategies/sep_cma_es.py +++ b/evosax/strategies/sep_cma_es.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from ..strategy import Strategy from ..utils.eigen_decomp import diag_eigen_decomp from flax import struct @@ -57,17 +57,28 @@ def get_cma_elite_weights( class Sep_CMA_ES(Strategy): - def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.5, + sigma_init: float = 1.0, + **fitness_kwargs: Union[bool, int, float] + ): """Separable CMA-ES (e.g. Ros & Hansen, 2008) Reference: https://hal.inria.fr/inria-00287367/document Inspired by: github.com/CyberAgentAILab/cmaes/blob/main/cmaes/_sepcma.py """ - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) assert 0 <= elite_ratio <= 1 self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) self.strategy_name = "Sep_CMA_ES" + # Set core kwargs es_params + self.sigma_init = sigma_init + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" @@ -107,6 +118,7 @@ def params_strategy(self) -> EvoParams: d_sigma=d_sigma, c_c=c_c, chi_n=chi_n, + sigma_init=self.sigma_init, ) return params diff --git a/evosax/strategies/sim_anneal.py b/evosax/strategies/sim_anneal.py index 5488abc..5b73bec 100644 --- a/evosax/strategies/sim_anneal.py +++ b/evosax/strategies/sim_anneal.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy from flax import struct @@ -33,17 +33,35 @@ class EvoParams: class SimAnneal(Strategy): - def __init__(self, num_dims: int, popsize: int): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + sigma_init: float = 0.03, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + **fitness_kwargs: Union[bool, int, float] + ): """Simulated Annealing (Rasdi Rere et al., 2015) Reference: https://www.sciencedirect.com/science/article/pii/S1877050915035759 """ - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) self.strategy_name = "SimAnneal" + # Set core kwargs es_params (lrate/sigma schedules) + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" - return EvoParams() + return EvoParams( + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) def initialize_strategy( self, rng: chex.PRNGKey, params: EvoParams diff --git a/evosax/strategies/simple_es.py b/evosax/strategies/simple_es.py index 077b6b5..ac1d57d 100755 --- a/evosax/strategies/simple_es.py +++ b/evosax/strategies/simple_es.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy from flax import struct @@ -28,20 +28,31 @@ class EvoParams: class SimpleES(Strategy): - def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.5, + sigma_init: float = 1.0, + **fitness_kwargs: Union[bool, int, float] + ): """Simple Gaussian Evolution Strategy (Rechenberg, 1975) Reference: https://onlinelibrary.wiley.com/doi/abs/10.1002/fedr.19750860506 Inspired by: https://github.com/hardmaru/estool/blob/master/es.py""" - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) self.strategy_name = "SimpleES" + # Set core kwargs es_params + self.sigma_init = sigma_init + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" # Only parents have positive weight - equal weighting! - return EvoParams() + return EvoParams(sigma_init=self.sigma_init) def initialize_strategy( self, rng: chex.PRNGKey, params: EvoParams diff --git a/evosax/strategies/simple_ga.py b/evosax/strategies/simple_ga.py index fb4e2d3..17ec9a5 100755 --- a/evosax/strategies/simple_ga.py +++ b/evosax/strategies/simple_ga.py @@ -1,8 +1,9 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy +from ..utils import exp_decay from flax import struct @@ -19,9 +20,9 @@ class EvoState: @struct.dataclass class EvoParams: - cross_over_rate: float = 0.5 + cross_over_rate: float = 0.0 sigma_init: float = 0.07 - sigma_decay: float = 0.999 + sigma_decay: float = 1.0 sigma_limit: float = 0.01 init_min: float = 0.0 init_max: float = 0.0 @@ -30,20 +31,39 @@ class EvoParams: class SimpleGA(Strategy): - def __init__(self, num_dims: int, popsize: int, elite_ratio: float = 0.5): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + elite_ratio: float = 0.5, + sigma_init: float = 0.1, + sigma_decay: float = 1.0, + sigma_limit: float = 0.01, + **fitness_kwargs: Union[bool, int, float] + ): """Simple Genetic Algorithm (Such et al., 2017) Reference: https://arxiv.org/abs/1712.06567 Inspired by: https://github.com/hardmaru/estool/blob/master/es.py""" - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) self.elite_ratio = elite_ratio self.elite_popsize = max(1, int(self.popsize * self.elite_ratio)) self.strategy_name = "SimpleGA" + # Set core kwargs es_params + self.sigma_init = sigma_init + self.sigma_decay = sigma_decay + self.sigma_limit = sigma_limit + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolution strategy.""" - return EvoParams() + return EvoParams( + sigma_init=self.sigma_init, + sigma_decay=self.sigma_decay, + sigma_limit=self.sigma_limit, + ) def initialize_strategy( self, rng: chex.PRNGKey, params: EvoParams @@ -111,15 +131,12 @@ def tell_strategy( fitness = fitness[idx] archive = solution[idx] # Update mutation epsilon - multiplicative decay - sigma = jax.lax.select( - state.sigma > params.sigma_limit, - state.sigma * params.sigma_decay, - state.sigma, - ) - # Keep mean across stored archive around for evaluation protocol - mean = archive.mean(axis=0) + sigma = exp_decay(state.sigma, params.sigma_decay, params.sigma_limit) + # Set mean to best member seen so far + improved = fitness[0] < state.best_fitness + best_mean = jax.lax.select(improved, archive[0], state.best_member) return state.replace( - fitness=fitness, archive=archive, sigma=sigma, mean=mean + fitness=fitness, archive=archive, sigma=sigma, mean=best_mean ) diff --git a/evosax/strategies/snes.py b/evosax/strategies/snes.py new file mode 100644 index 0000000..77d5b87 --- /dev/null +++ b/evosax/strategies/snes.py @@ -0,0 +1,135 @@ +import jax +import jax.numpy as jnp +import chex +from typing import Tuple, Optional, Union +from ..strategy import Strategy +from flax import struct +from flax import linen as nn + + +@struct.dataclass +class EvoState: + mean: chex.Array + sigma: chex.Array + weights: chex.Array + best_member: chex.Array + best_fitness: float = jnp.finfo(jnp.float32).max + gen_counter: int = 0 + + +@struct.dataclass +class EvoParams: + lrate_mean: float = 1.0 + lrate_sigma: float = 1.0 + sigma_init: float = 1.0 + temperature: float = 0.0 + init_min: float = 0.0 + init_max: float = 0.0 + clip_min: float = -jnp.finfo(jnp.float32).max + clip_max: float = jnp.finfo(jnp.float32).max + + +def get_recombination_weights(popsize: int, use_baseline: bool = True): + """Get recombination weights for different ranks.""" + + def get_weight(i): + return jnp.maximum(0, jnp.log(popsize / 2 + 1) - jnp.log(i)) + + weights = jax.vmap(get_weight)(jnp.arange(1, popsize + 1)) + weights_norm = weights / jnp.sum(weights) + return weights_norm - use_baseline * (1 / popsize) + + +def get_temp_weights( + popsize: int, temperature: float, use_baseline: bool = True +): + """Get weights based on original discovered weights (Lange et al, 2022).""" + ranks = jnp.arange(popsize) + ranks /= ranks.size - 1 + ranks = ranks - 0.5 + weights = nn.softmax(-temperature * ranks) + return weights + + +class SNES(Strategy): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + sigma_init: float = 1.0, + temperature: float = 0.0, # good values tend to be between 12 and 20 + **fitness_kwargs: Union[bool, int, float] + ): + """Separable Exponential Natural ES (Wierstra et al., 2014) + Reference: https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf + """ + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) + self.strategy_name = "SNES" + + # Set core kwargs es_params + self.sigma_init = sigma_init + self.temperature = temperature + + @property + def params_strategy(self) -> EvoParams: + """Return default parameters of evolutionary strategy.""" + lrate_sigma = (3 + jnp.log(self.num_dims)) / ( + 5 * jnp.sqrt(self.num_dims) + ) + params = EvoParams( + lrate_sigma=lrate_sigma, + sigma_init=self.sigma_init, + temperature=self.temperature, + ) + return params + + def initialize_strategy( + self, rng: chex.PRNGKey, params: EvoParams + ) -> EvoState: + """`initialize` the evolutionary strategy.""" + initialization = jax.random.uniform( + rng, + (self.num_dims,), + minval=params.init_min, + maxval=params.init_max, + ) + use_des_weights = params.temperature > 0.0 + weights = jax.lax.select( + use_des_weights, + get_temp_weights(self.popsize, params.temperature), + get_recombination_weights(self.popsize), + ) + state = EvoState( + mean=initialization, + sigma=params.sigma_init * jnp.ones(self.num_dims), + weights=weights.reshape(-1, 1), + best_member=initialization, + ) + + return state + + def ask_strategy( + self, rng: chex.PRNGKey, state: EvoState, params: EvoParams + ) -> Tuple[chex.Array, EvoState]: + """`ask` for new parameter candidates to evaluate next.""" + noise = jax.random.normal(rng, (self.popsize, self.num_dims)) + x = state.mean + noise * state.sigma.reshape(1, self.num_dims) + return x, state + + def tell_strategy( + self, + x: chex.Array, + fitness: chex.Array, + state: EvoState, + params: EvoParams, + ) -> EvoState: + """`tell` performance data for strategy state update.""" + s = (x - state.mean) / state.sigma + ranks = fitness.argsort() + sorted_noise = s[ranks] + grad_mean = (state.weights * sorted_noise).sum(axis=0) + grad_sigma = (state.weights * (sorted_noise ** 2 - 1)).sum(axis=0) + mean = state.mean + params.lrate_mean * state.sigma * grad_mean + sigma = state.sigma * jnp.exp(params.lrate_sigma / 2 * grad_sigma) + return state.replace(mean=mean, sigma=sigma) diff --git a/evosax/strategies/xnes.py b/evosax/strategies/xnes.py index 45744e3..609f8f5 100644 --- a/evosax/strategies/xnes.py +++ b/evosax/strategies/xnes.py @@ -1,21 +1,20 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple +from typing import Tuple, Optional, Union from ..strategy import Strategy from flax import struct +from .snes import get_recombination_weights @struct.dataclass class EvoState: mean: chex.Array sigma: float - sigma_old: float - amat: chex.Array - bmat: chex.Array + B: chex.Array noise: chex.Array - eta_sigma: float - utilities: chex.Array + lrate_sigma: float + weights: chex.Array best_member: chex.Array best_fitness: float = jnp.finfo(jnp.float32).max gen_counter: int = 0 @@ -23,11 +22,13 @@ class EvoState: @struct.dataclass class EvoParams: - eta_mean: float - eta_sigma_init: float - eta_bmat: float - use_adaptive_sampling: bool = False - use_fitness_shaping: bool = True + lrate_mean: float = 1.0 + lrate_sigma_init: float = 0.1 + lrate_B: float = 0.1 + sigma_init: float = 1.0 + use_adasam: bool = False # Adaptation sampling lrate sigma + rho: float = 0.5 # Significance level adaptation sampling + c_prime: float = 0.1 # Adaptation sampling step size init_min: float = 0.0 init_max: float = 0.0 clip_min: float = -jnp.finfo(jnp.float32).max @@ -35,24 +36,35 @@ class EvoParams: class xNES(Strategy): - def __init__(self, num_dims: int, popsize: int): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + sigma_init: float = 1.0, + **fitness_kwargs: Union[bool, int, float] + ): """Exponential Natural ES (Wierstra et al., 2014) Reference: https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf Inspired by: https://github.com/chanshing/xnes""" - super().__init__(num_dims, popsize) + super().__init__(popsize, num_dims, pholder_params, **fitness_kwargs) self.strategy_name = "xNES" + # Set core kwargs es_params + self.sigma_init = sigma_init + @property def params_strategy(self) -> EvoParams: """Return default parameters of evolutionary strategy.""" + lrate_sigma = (9 + 3 * jnp.log(self.num_dims)) / ( + 5 * jnp.sqrt(self.num_dims) * self.num_dims + ) + rho = 0.5 - 1.0 / (3 * (self.num_dims + 1)) params = EvoParams( - eta_mean=1.0, - eta_sigma_init=3 - * (3 + jnp.log(self.num_dims)) - * (1.0 / (5 * self.num_dims * jnp.sqrt(self.num_dims))), - eta_bmat=3 - * (3 + jnp.log(self.num_dims)) - * (1.0 / (5 * self.num_dims * jnp.sqrt(self.num_dims))), + lrate_sigma_init=lrate_sigma, + lrate_B=lrate_sigma, + rho=rho, + sigma_init=self.sigma_init, ) return params @@ -60,33 +72,20 @@ def initialize_strategy( self, rng: chex.PRNGKey, params: EvoParams ) -> EvoState: """`initialize` the evolutionary strategy.""" - amat = jnp.eye(self.num_dims) - sigma = abs(jax.scipy.linalg.det(amat)) ** (1.0 / self.num_dims) - bmat = amat * (1.0 / sigma) - # Utility helper for fitness shaping - doesn't work without?! - a = jnp.log(1 + 0.5 * self.popsize) - utilities = jnp.array( - [jnp.maximum(0, a - jnp.log(k)) for k in range(1, self.popsize + 1)] - ) - utilities /= jnp.sum(utilities) - utilities -= 1.0 / self.popsize # broadcast - utilities = utilities[::-1] # ascending order - initialization = jax.random.uniform( rng, (self.num_dims,), minval=params.init_min, maxval=params.init_max, ) + weights = get_recombination_weights(self.popsize) state = EvoState( mean=initialization, - sigma=sigma, - sigma_old=sigma, - amat=amat, - bmat=bmat, + B=jnp.eye(self.num_dims) * params.sigma_init, + sigma=params.sigma_init, noise=jnp.zeros((self.popsize, self.num_dims)), - eta_sigma=params.eta_sigma_init, - utilities=utilities, + lrate_sigma=params.lrate_sigma_init, + weights=weights.reshape(-1, 1), best_member=initialization, ) @@ -97,7 +96,14 @@ def ask_strategy( ) -> Tuple[chex.Array, EvoState]: """`ask` for new parameter candidates to evaluate next.""" noise = jax.random.normal(rng, (self.popsize, self.num_dims)) - x = state.mean + state.sigma * jnp.dot(noise, state.bmat) + + def scale_orient(n, sigma, B): + return sigma * B.T @ n + + scaled_noise = jax.vmap(scale_orient, in_axes=(0, None, None))( + noise, state.sigma, state.B + ) + x = state.mean + scaled_noise return x, state.replace(noise=noise) def tell_strategy( @@ -108,98 +114,80 @@ def tell_strategy( params: EvoParams, ) -> EvoState: """`tell` performance data for strategy state update.""" - # By default the xNES maximizes the objective - fitness_re = -fitness - isort = fitness_re.argsort() - sorted_fitness = fitness_re[isort] - sorted_noise = state.noise[isort] - sorted_candidates = x[isort] - fitness_shaped = jax.lax.select( - params.use_fitness_shaping, state.utilities, sorted_fitness - ) + ranks = fitness.argsort() + sorted_noise = state.noise[ranks] + grad_mean = (state.weights * sorted_noise).sum(axis=0) - use_adasam = jnp.logical_and( - params.use_adaptive_sampling, state.gen_counter > 1 - ) # sigma_old must be available - eta_sigma = jax.lax.select( - use_adasam, - self.adaptive_sampling( - state.eta_sigma, - state.mean, - state.sigma, - state.bmat, - state.sigma_old, - sorted_candidates, - state.eta_sigma, - ), - state.eta_sigma, - ) + def s_grad_m(weight, noise): + return weight * (noise @ noise.T - jnp.eye(self.num_dims)) - dj_delta = jnp.dot(fitness_shaped, sorted_noise) - dj_mmat = ( - jnp.dot( - sorted_noise.T, - sorted_noise * fitness_shaped.reshape(self.popsize, 1), - ) - - jnp.sum(fitness_shaped) * jnp.eye(self.num_dims) - ) - dj_sigma = jnp.trace(dj_mmat) * (1.0 / self.num_dims) - dj_bmat = dj_mmat - dj_sigma * jnp.eye(self.num_dims) + grad_m = jax.vmap(s_grad_m, in_axes=(0, 0))( + state.weights, sorted_noise + ).sum(axis=0) + grad_sigma = jnp.trace(grad_m) / self.num_dims + grad_B = grad_m - grad_sigma * jnp.eye(self.num_dims) - sigma_old = state.sigma - mean = state.mean + ( - params.eta_mean * state.sigma * jnp.dot(state.bmat, dj_delta) + mean = ( + state.mean + params.lrate_mean * state.sigma * state.B @ grad_mean ) - sigma = sigma_old * jnp.exp(0.5 * eta_sigma * dj_sigma) - bmat = jnp.dot( - state.bmat, - jax.scipy.linalg.expm(0.5 * params.eta_bmat * dj_bmat), + sigma = state.sigma * jnp.exp(state.lrate_sigma / 2 * grad_sigma) + B = state.B * jnp.exp(params.lrate_B / 2 * grad_B) + + lrate_sigma = adaptation_sampling( + state.lrate_sigma, + params.lrate_sigma_init, + mean, + B, + sigma, + state.sigma, + sorted_noise, + params.c_prime, + params.rho, ) - return state.replace( - eta_sigma=eta_sigma, - mean=mean, - sigma=sigma, - bmat=bmat, - sigma_old=sigma_old, + lrate_sigma = jax.lax.select( + params.use_adasam, lrate_sigma, state.lrate_sigma ) - - def adaptive_sampling( - self, - eta_sigma: float, - mu: chex.Array, - sigma: float, - bmat: chex.Array, - sigma_old: float, - z_try: chex.Array, - eta_sigma_init: float, - ) -> float: - """Adaptation sampling.""" - c = 0.1 - rho = 0.5 - 1.0 / (3 * (self.num_dims + 1)) # empirical - - bbmat = jnp.dot(bmat.T, bmat) - cov = sigma ** 2 * bbmat - sigma_ = sigma * jnp.sqrt(sigma * (1.0 / sigma_old)) # increase by 1.5 - cov_ = sigma_ ** 2 * bbmat - - p0 = jax.scipy.stats.multivariate_normal.logpdf(z_try, mean=mu, cov=cov) - p1 = jax.scipy.stats.multivariate_normal.logpdf( - z_try, mean=mu, cov=cov_ + return state.replace( + mean=mean, sigma=sigma, B=B, lrate_sigma=lrate_sigma ) - w = jnp.exp(p1 - p0) - # Mann-Whitney. It is assumed z_try was in ascending order. - n_ = jnp.sum(w) - u_ = jnp.sum(w * (jnp.arange(self.popsize) + 0.5)) - u_mu = self.popsize * n_ * 0.5 - u_sigma = jnp.sqrt(self.popsize * n_ * (self.popsize + n_ + 1) / 12.0) - cum = jax.scipy.stats.norm.cdf(u_, loc=u_mu, scale=u_sigma) - - decrease = cum < rho - eta_out = jax.lax.select( - decrease, - (1 - c) * eta_sigma + c * eta_sigma_init, - jnp.minimum(1, (1 + c) * eta_sigma), - ) - return eta_out +def adaptation_sampling( + lrate_sigma: float, + lrate_sigma_init: float, + mean: chex.Array, + B: chex.Array, + sigma: float, + sigma_old: float, + sorted_noise: chex.Array, + c_prime: float, + rho: float, +) -> float: + """Adaptation sampling on sigma/std learning rate.""" + BB = B.T @ B + A = sigma ** 2 * BB + sigma_prime = sigma * jnp.sqrt(sigma / sigma_old) + A_prime = sigma_prime ** 2 * BB + + # Probability ration and u-test - sorted order assumed for noise + prob_0 = jax.scipy.stats.multivariate_normal.logpdf(sorted_noise, mean, A) + prob_1 = jax.scipy.stats.multivariate_normal.logpdf( + sorted_noise, mean, A_prime + ) + w = jnp.exp(prob_1 - prob_0) + popsize = sorted_noise.shape[0] + n = jnp.sum(w) + u = jnp.sum(w * (jnp.arange(popsize) + 0.5)) + u_mean = popsize * n / 2 + u_sigma = jnp.sqrt(popsize * n * (popsize + n + 1) / 12) + cumulative = jax.scipy.stats.norm.cdf( + u, loc=u_mean + 1e-10, scale=u_sigma + 1e-10 + ) + + # Check test significance and update lrate + lrate_sigma = jax.lax.select( + cumulative < rho, + (1 - c_prime) * lrate_sigma + c_prime * lrate_sigma_init, + jnp.minimum(1, (1 - c_prime) * lrate_sigma), + ) + return lrate_sigma diff --git a/evosax/strategy.py b/evosax/strategy.py index c550cd8..c6eb590 100755 --- a/evosax/strategy.py +++ b/evosax/strategy.py @@ -1,10 +1,10 @@ import jax import jax.numpy as jnp import chex -from typing import Tuple, Optional +from typing import Tuple, Optional, Union from functools import partial from flax import struct -from .utils import get_best_fitness_member +from .utils import get_best_fitness_member, ParameterReshaper, FitnessShaper @struct.dataclass @@ -28,11 +28,30 @@ class EvoParams: class Strategy(object): - def __init__(self, num_dims: int, popsize: int): + def __init__( + self, + popsize: int, + num_dims: Optional[int] = None, + pholder_params: Optional[Union[chex.ArrayTree, chex.Array]] = None, + **fitness_kwargs: Union[bool, int, float] + ): """Base Class for an Evolution Strategy.""" - self.num_dims = num_dims self.popsize = popsize + # Setup optional parameter reshaper + self.use_param_reshaper = pholder_params is not None + if self.use_param_reshaper: + self.param_reshaper = ParameterReshaper(pholder_params) + self.num_dims = self.param_reshaper.total_params + else: + self.num_dims = num_dims + assert ( + self.num_dims is not None + ), "Provide either num_dims or pholder_params to strategy." + + # Setup optional fitness shaper + self.fitness_shaper = FitnessShaper(**fitness_kwargs) + @property def default_params(self) -> EvoParams: """Return default parameters of evolution strategy.""" @@ -58,7 +77,7 @@ def ask( rng: chex.PRNGKey, state: EvoState, params: Optional[EvoParams] = None, - ) -> Tuple[chex.Array, EvoState]: + ) -> Tuple[Union[chex.Array, chex.ArrayTree], EvoState]: """`ask` for new parameter candidates to evaluate next.""" # Use default hyperparameters if no other settings provided if params is None: @@ -68,12 +87,18 @@ def ask( x, state = self.ask_strategy(rng, state, params) # Clip proposal candidates into allowed range x_clipped = jnp.clip(jnp.squeeze(x), params.clip_min, params.clip_max) - return x_clipped, state + + # Reshape parameters into pytrees + if self.use_param_reshaper: + x_out = self.param_reshaper.reshape(x_clipped) + else: + x_out = x_clipped + return x_out, state @partial(jax.jit, static_argnums=(0,)) def tell( self, - x: chex.Array, + x: Union[chex.Array, chex.ArrayTree], fitness: chex.Array, state: EvoState, params: Optional[EvoParams] = None, @@ -83,11 +108,20 @@ def tell( if params is None: params = self.default_params + # Flatten params if using param reshaper for ES update + if self.use_param_reshaper: + x = self.param_reshaper.flatten(x) + + # Perform fitness reshaping inside of strategy tell call (if desired) + fitness_re = self.fitness_shaper.apply(x, fitness) + # Update the search state based on strategy-specific update - state = self.tell_strategy(x, fitness, state, params) + state = self.tell_strategy(x, fitness_re, state, params) # Check if there is a new best member & update trackers - best_member, best_fitness = get_best_fitness_member(x, fitness, state) + best_member, best_fitness = get_best_fitness_member( + x, fitness, state, self.fitness_shaper.maximize + ) return state.replace( best_member=best_member, best_fitness=best_fitness, @@ -115,3 +149,11 @@ def tell_strategy( ) -> EvoState: """Search-specific `tell` update. Returns updated state.""" raise NotImplementedError + + def get_eval_params(self, state: EvoState): + """Return reshaped parameters to evaluate.""" + if self.use_param_reshaper: + x_out = self.param_reshaper.reshape_single(state.mean) + else: + x_out = state.mean + return x_out diff --git a/evosax/utils/__init__.py b/evosax/utils/__init__.py index e3af896..3c09aab 100755 --- a/evosax/utils/__init__.py +++ b/evosax/utils/__init__.py @@ -2,7 +2,7 @@ from .es_logger import ESLog # Import additional utilities for reshaping flat parameters into net dict -from .reshape_params import ParameterReshaper +from .reshape_params import ParameterReshaper, ravel_pytree # Import additional utilities for reshaping fitness from .reshape_fitness import FitnessShaper @@ -11,13 +11,23 @@ from .helpers import get_best_fitness_member # Import Gradient Based Optimizer step functions -from .optimizer import SGD, Adam, RMSProp, ClipUp, OptState, OptParams +from .optimizer import ( + SGD, + Adam, + RMSProp, + ClipUp, + Adan, + OptState, + OptParams, + exp_decay, +) GradientOptimizer = { "sgd": SGD, "adam": Adam, "rmsprop": RMSProp, "clipup": ClipUp, + "adan": Adan, } @@ -25,12 +35,15 @@ "get_best_fitness_member", "ESLog", "ParameterReshaper", + "ravel_pytree", "FitnessShaper", "GradientOptimizer", "SGD", "Adam", "RMSProp", "ClipUp", + "Adan", "OptState", "OptParams", + "exp_decay", ] diff --git a/evosax/utils/eigen_decomp.py b/evosax/utils/eigen_decomp.py index 25e42f9..dc26301 100644 --- a/evosax/utils/eigen_decomp.py +++ b/evosax/utils/eigen_decomp.py @@ -4,11 +4,12 @@ def full_eigen_decomp( - C: chex.Array, B: chex.Array, D: chex.Array + C: chex.Array, B: chex.Array, D: chex.Array, gen_counter: int ) -> Tuple[chex.Array, chex.Array, chex.Array]: """Perform eigendecomposition of covariance matrix.""" if B is not None and D is not None: return C, B, D + C = C + 1e-10 * (gen_counter == 0) C = (C + C.T) / 2 # Make sure matrix is symmetric D2, B = jnp.linalg.eigh(C) D = jnp.sqrt(jnp.where(D2 < 0, 1e-20, D2)) diff --git a/evosax/utils/evojax_wrapper.py b/evosax/utils/evojax_wrapper.py new file mode 100644 index 0000000..1c941cb --- /dev/null +++ b/evosax/utils/evojax_wrapper.py @@ -0,0 +1,53 @@ +import chex +import jax +import jax.numpy as jnp +from evojax.algo.base import NEAlgorithm +from evosax import Strategy + + +class Evosax2JAX_Wrapper(NEAlgorithm): + """Wrapper for evosax-style ES for EvoJAX deployment.""" + + def __init__( + self, + evosax_strategy: Strategy, + param_size: int, + pop_size: int, + es_config: dict = {}, + es_params: dict = {}, + opt_params: dict = {}, + seed: int = 42, + ): + self.es = evosax_strategy( + popsize=pop_size, num_dims=param_size, **es_config, **opt_params + ) + self.es_params = self.es.default_params.replace(**es_params) + self.pop_size = pop_size + self.param_size = param_size + self.rand_key = jax.random.PRNGKey(seed=seed) + self.rand_key, init_key = jax.random.split(self.rand_key) + self.es_state = self.es.initialize(init_key, self.es_params) + + def ask(self) -> chex.Array: + """Ask strategy for next set of solution candidates to evaluate.""" + self.rand_key, ask_key = jax.random.split(self.rand_key) + self.params, self.es_state = self.es.ask( + ask_key, self.es_state, self.es_params + ) + return self.params + + def tell(self, fitness: chex.Array) -> None: + """Tell strategy about most recent fitness evaluations.""" + self.es_state = self.es.tell( + self.params, fitness, self.es_state, self.es_params + ) + + @property + def best_params(self) -> chex.Array: + """Return set of mean/best parameters.""" + return jnp.array(self.es_state.mean, copy=True) + + @best_params.setter + def best_params(self, params: chex.Array) -> None: + """Update the best parameters stored internally.""" + self.es_state = self.es_state.replace(mean=jnp.array(params, copy=True)) diff --git a/evosax/utils/helpers.py b/evosax/utils/helpers.py index 02c3890..21dd54a 100644 --- a/evosax/utils/helpers.py +++ b/evosax/utils/helpers.py @@ -5,18 +5,25 @@ def get_best_fitness_member( - x: chex.Array, fitness: chex.Array, state + x: chex.Array, fitness: chex.Array, state, maximize: bool = False ) -> Tuple[chex.Array, float]: - best_in_gen = jnp.argmin(fitness) + """Check if fitness improved & replace in ES state.""" + fitness_min = jax.lax.select(maximize, -1 * fitness, fitness) + max_and_later = maximize and state.gen_counter > 0 + best_fit_min = jax.lax.select( + max_and_later, -1 * state.best_fitness, state.best_fitness + ) + best_in_gen = jnp.argmin(fitness_min) best_in_gen_fitness, best_in_gen_member = ( - fitness[best_in_gen], + fitness_min[best_in_gen], x[best_in_gen], ) - replace_best = best_in_gen_fitness < state.best_fitness + replace_best = best_in_gen_fitness < best_fit_min best_fitness = jax.lax.select( - replace_best, best_in_gen_fitness, state.best_fitness + replace_best, best_in_gen_fitness, best_fit_min ) best_member = jax.lax.select( replace_best, best_in_gen_member, state.best_member ) + best_fitness = jax.lax.select(maximize, -1 * best_fitness, best_fitness) return best_member, best_fitness diff --git a/evosax/utils/optimizer.py b/evosax/utils/optimizer.py index 02e46de..e686c96 100644 --- a/evosax/utils/optimizer.py +++ b/evosax/utils/optimizer.py @@ -11,11 +11,22 @@ # "clip_value": 5, +def exp_decay( + param: chex.Array, param_decay: chex.Array, param_limit: chex.Array +) -> chex.Array: + """Exponentially decay parameter & clip by minimal value.""" + param = param * param_decay + param = jnp.maximum(param, param_limit) + return param + + @struct.dataclass class OptState: lrate: float m: chex.Array v: Optional[chex.Array] = None + n: Optional[chex.Array] = None + last_grads: Optional[chex.Array] = None gen_counter: int = 0 @@ -27,6 +38,7 @@ class OptParams: momentum: Optional[float] = None beta_1: Optional[float] = None beta_2: Optional[float] = None + beta_3: Optional[float] = None eps: Optional[float] = None max_speed: Optional[float] = None @@ -57,8 +69,7 @@ def step( def update(self, state: OptState, params: OptParams) -> OptState: """Exponentially decay the learning rate if desired.""" - lrate = state.lrate * params.lrate_decay - lrate = jnp.maximum(lrate, params.lrate_limit) + lrate = exp_decay(state.lrate, params.lrate_decay, params.lrate_limit) return state.replace(lrate=lrate) @property @@ -91,7 +102,7 @@ def __init__(self, num_dims: int): def params_opt(self) -> Dict[str, float]: """Return default SGD+Momentum parameters.""" return { - "momentum": 0.9, + "momentum": 0.0, } def initialize_opt(self, params: OptParams) -> OptState: @@ -241,3 +252,56 @@ def clip(velocity: chex.Array, max_speed: float): m = clip(velocity, params.max_speed) mean_new = mean - state.lrate * m return mean_new, state.replace(m=m) + + +class Adan(Optimizer): + def __init__(self, num_dims: int): + """JAX-Compatible Adan Optimizer (Xi et al., 2022) + Reference: https://arxiv.org/pdf/2208.06677.pdf""" + super().__init__(num_dims) + self.opt_name = "adan" + + @property + def params_opt(self) -> Dict[str, float]: + """Return default Adam parameters.""" + return { + "beta_1": 0.98, + "beta_2": 0.92, + "beta_3": 0.99, + "eps": 1e-8, + } + + def initialize_opt(self, params: OptParams) -> OptState: + """Initialize the m, v, n trace of the optimizer.""" + return OptState( + m=jnp.zeros(self.num_dims), + v=jnp.zeros(self.num_dims), + n=jnp.zeros(self.num_dims), + last_grads=jnp.zeros(self.num_dims), + lrate=params.lrate_init, + ) + + def step_opt( + self, + mean: chex.Array, + grads: chex.Array, + state: OptState, + params: OptParams, + ) -> Tuple[chex.Array, OptState]: + """Perform a simple Adan GD step.""" + m = (1 - params.beta_1) * grads + params.beta_1 * state.m + grad_diff = grads - state.last_grads + v = (1 - params.beta_2) * grad_diff + params.beta_2 * state.v + n = (1 - params.beta_3) * ( + grads + params.beta_2 * grad_diff + ) ** 2 + params.beta_3 * state.n + + mhat = m / (1 - params.beta_1 ** (state.gen_counter + 1)) + vhat = v / (1 - params.beta_2 ** (state.gen_counter + 1)) + nhat = n / (1 - params.beta_3 ** (state.gen_counter + 1)) + mean_new = mean - state.lrate * (mhat + params.beta_2 * vhat) / ( + jnp.sqrt(nhat) + params.eps + ) + return mean_new, state.replace( + m=m, v=v, n=n, last_grads=grads, gen_counter=state.gen_counter + 1 + ) diff --git a/evosax/utils/reshape_fitness.py b/evosax/utils/reshape_fitness.py index 4d2205d..800b3c8 100755 --- a/evosax/utils/reshape_fitness.py +++ b/evosax/utils/reshape_fitness.py @@ -2,41 +2,57 @@ import jax.numpy as jnp import chex from functools import partial +from typing import Union class FitnessShaper(object): def __init__( self, - centered_rank: bool = False, - z_score: bool = False, + centered_rank: Union[bool, int] = False, + z_score: Union[bool, int] = False, + norm_range: Union[bool, int] = False, w_decay: float = 0.0, - maximize: bool = False, + maximize: Union[bool, int] = False, ): """JAX-compatible fitness shaping tool.""" self.w_decay = w_decay self.centered_rank = bool(centered_rank) self.z_score = bool(z_score) + self.norm_range = bool(norm_range) self.maximize = bool(maximize) + # Check that only single fitness shaping transformation is used + num_options_on = self.centered_rank + self.z_score + self.norm_range + assert ( + num_options_on < 2 + ), "Only use one fitness shaping transformation." + @partial(jax.jit, static_argnums=(0,)) def apply(self, x: chex.Array, fitness: chex.Array) -> chex.Array: """Max objective trafo, rank shaping, z scoring & add weight decay.""" - fitness = jax.lax.select(self.maximize, -1 * fitness, fitness) - fitness = jax.lax.select( - self.centered_rank, compute_centered_ranks(fitness), fitness - ) - fitness = jax.lax.select( - self.z_score, z_score_fitness(fitness), fitness - ) + if self.maximize: + fitness = -1 * fitness + + if self.centered_rank: + fitness = centered_rank_trafo(fitness) + + if self.z_score: + fitness = z_score_trafo(fitness) + + if self.norm_range: + fitness = range_norm_trafo(fitness, -1.0, 1.0) + + # Apply wdecay after normalization - makes easier to tune # "Reduce" fitness based on L2 norm of parameters - l2_fit_red = self.w_decay * compute_weight_norm(x) - l2_fit_red = jax.lax.select(self.maximize, -1 * l2_fit_red, l2_fit_red) - return fitness + l2_fit_red + if self.w_decay > 0.0: + l2_fit_red = self.w_decay * compute_l2_norm(x) + fitness += l2_fit_red + return fitness -def z_score_fitness(fitness: chex.Array) -> chex.Array: +def z_score_trafo(arr: chex.Array) -> chex.Array: """Make fitness 'Gaussian' by substracting mean and dividing by std.""" - return (fitness - jnp.mean(fitness)) / jnp.std(1e-05 + fitness) + return (arr - jnp.mean(arr)) / (jnp.std(arr) + 1e-10) def compute_ranks(fitness: chex.Array) -> chex.Array: @@ -46,13 +62,28 @@ def compute_ranks(fitness: chex.Array) -> chex.Array: return ranks -def compute_centered_ranks(fitness: chex.Array) -> chex.Array: +def centered_rank_trafo(fitness: chex.Array) -> chex.Array: """Return ~ -0.5 to 0.5 centered ranks (best to worst - min!).""" y = compute_ranks(fitness) y /= fitness.size - 1 return y - 0.5 -def compute_weight_norm(x: chex.Array) -> chex.Array: +def compute_l2_norm(x: chex.Array) -> chex.Array: """Compute L2-norm of x_i. Assumes x to have shape (popsize, num_dims).""" return jnp.mean(x * x, axis=1) + + +def range_norm_trafo( + arr: chex.Array, min_val: float = -1.0, max_val: float = 1.0 +) -> chex.Array: + """Map scores into a min/max range.""" + arr = jnp.clip(arr, -1e10, 1e10) + normalized_arr = ( + 2 + * max_val + * (arr - jnp.nanmin(arr)) + / (jnp.nanmax(arr) - jnp.nanmin(arr) + 1e-10) + - min_val + ) + return normalized_arr diff --git a/evosax/utils/reshape_params.py b/evosax/utils/reshape_params.py index aa6fb0c..bab0dfc 100755 --- a/evosax/utils/reshape_params.py +++ b/evosax/utils/reshape_params.py @@ -1,17 +1,29 @@ import jax import jax.numpy as jnp import chex -from typing import Union, List, Optional -from flax.core.frozen_dict import FrozenDict, unfreeze -from flax.traverse_util import flatten_dict, unflatten_dict -from jax.tree_util import tree_flatten, tree_unflatten +from typing import Union, Optional +from jax import vjp, flatten_util +from jax.tree_util import tree_flatten + + +def ravel_pytree(pytree): + leaves, _ = tree_flatten(pytree) + flat, _ = vjp(ravel_list, *leaves) + return flat + + +def ravel_list(*lst): + return ( + jnp.concatenate([jnp.ravel(elt) for elt in lst]) + if lst + else jnp.array([]) + ) class ParameterReshaper(object): def __init__( self, placeholder_params: Union[chex.ArrayTree, chex.Array], - identity: bool = False, n_devices: Optional[int] = None, verbose: bool = True, ): @@ -19,19 +31,12 @@ def __init__( # Get network shape to reshape self.placeholder_params = placeholder_params - leafs, treedef = jax.tree_util.tree_flatten(placeholder_params) - self._treedef = treedef - self.network_shape = jax.tree_map(jnp.shape, leafs) - self.total_params = get_total_params(self.network_shape) - self.l_id = get_layer_ids(self.network_shape) - - # Special case for no identity mapping (no pytree reshaping) - if identity: - self.reshape = jax.jit(self.reshape_identity) - self.reshape_single = jax.jit(self.reshape_single_flat) - else: - self.reshape = jax.jit(self.reshape_network) - self.reshape_single = jax.jit(self.reshape_single_net) + # Set total parameters depending on type of placeholder params + flat, self.unravel_pytree = flatten_util.ravel_pytree( + placeholder_params + ) + self.total_params = flat.shape[0] + self.reshape_single = jax.jit(self.unravel_pytree) if n_devices is None: self.n_devices = jax.local_device_count() @@ -50,13 +55,9 @@ def __init__( " for optimization." ) - def reshape_identity(self, x: chex.Array) -> chex.Array: - """Return parameters w/o reshaping for evaluation.""" - return x - - def reshape_network(self, x: chex.Array) -> chex.ArrayTree: + def reshape(self, x: chex.Array) -> chex.ArrayTree: """Perform reshaping for a 2D matrix (pop_members, params).""" - vmap_shape = jax.vmap(self.flat_to_network, in_axes=(0,)) + vmap_shape = jax.vmap(self.reshape_single) if self.n_devices > 1: x = self.split_params_for_pmap(x) map_shape = jax.pmap(vmap_shape) @@ -64,56 +65,38 @@ def reshape_network(self, x: chex.Array) -> chex.ArrayTree: map_shape = vmap_shape return map_shape(x) - def reshape_single_flat(self, x: chex.Array) -> chex.Array: - """Perform reshaping for a 1D vector (params,).""" - return x - - def reshape_single_net(self, x: chex.Array) -> chex.ArrayTree: - """Perform reshaping for a 1D vector (params,).""" - unsqueezed_re = self.flat_to_network(x) - return unsqueezed_re + def multi_reshape(self, x: chex.Array) -> chex.ArrayTree: + """Reshape parameters lying already on different devices.""" + # No reshaping required! + vmap_shape = jax.vmap(self.reshape_single) + return jax.pmap(vmap_shape)(x) - @property - def vmap_dict(self) -> chex.ArrayTree: - """Get a dictionary specifying axes to vmap over.""" - vmap_dict = jax.tree_map(lambda x: 0, self.placeholder_params) - return vmap_dict + def flatten(self, x: chex.ArrayTree) -> chex.Array: + """Reshaping pytree parameters into flat array.""" + vmap_flat = jax.vmap(ravel_pytree) + if self.n_devices > 1: + # Flattening of pmap paramater trees to apply vmap flattening + def map_flat(x): + x_re = jax.tree_map(lambda x: x.reshape(-1, *x.shape[2:]), x) + return vmap_flat(x_re) - def flat_to_network(self, flat_params: chex.Array) -> chex.ArrayTree: - """Fill a FrozenDict with new proposed vector of params.""" - new_nn = list() + else: + map_flat = vmap_flat + flat = map_flat(x) + return flat - # Loop over layers in network - for i, shape in enumerate(self.network_shape): - # Select params from flat to vector to be reshaped - p_flat = jax.lax.dynamic_slice( - flat_params, (self.l_id[i],), (self.l_id[i + 1] - self.l_id[i],) - ) - # Reshape parameters into matrix/kernel/etc. shape - p_reshaped = p_flat.reshape(shape) - # Place reshaped params into dict and increase counter - new_nn.append(p_reshaped) - return tree_unflatten(self._treedef, new_nn) + def multi_flatten(self, x: chex.Array) -> chex.ArrayTree: + """Flatten parameters lying remaining on different devices.""" + # No reshaping required! + vmap_flat = jax.vmap(ravel_pytree) + return jax.pmap(vmap_flat)(x) def split_params_for_pmap(self, param: chex.Array) -> chex.Array: """Helper reshapes param (bs, #params) into (#dev, bs/#dev, #params).""" return jnp.stack(jnp.split(param, self.n_devices)) - -def get_total_params(params: List[chex.Array]) -> int: - """Get total number of params in net. Loop over layer modules + params.""" - total_params = 0 - layer_keys = params - # Loop over layers - for l_k in layer_keys: - total_params += jnp.prod(jnp.array(l_k)) - return total_params - - -def get_layer_ids(network_shape_list: List[chex.Array]) -> List[int]: - """Get indices to target when reshaping single flat net into dict.""" - l_id = [0] - for shape in network_shape_list: - add_pcount = jnp.prod(jnp.array(shape)) - l_id.append(int(l_id[-1] + add_pcount)) - return l_id + @property + def vmap_dict(self) -> chex.ArrayTree: + """Get a dictionary specifying axes to vmap over.""" + vmap_dict = jax.tree_map(lambda x: 0, self.placeholder_params) + return vmap_dict diff --git a/evosax/utils/visualizer_2d.py b/evosax/utils/visualizer_2d.py new file mode 100644 index 0000000..8dcba14 --- /dev/null +++ b/evosax/utils/visualizer_2d.py @@ -0,0 +1,247 @@ +"""Fitness landscape visualizer and evaluation animator.""" +import chex +import jax +import jax.numpy as jnp +import numpy as np +import matplotlib.cm as cm +import matplotlib.pyplot as plt +import matplotlib.animation as animation + +cmap = cm.colors.LinearSegmentedColormap.from_list( + "Custom", [(0, "#2f9599"), (0.45, "#eee"), (1, "#8800ff")], N=256 +) + + +class BBOBVisualizer(object): + """Fitness landscape visualizer and evaluation animator.""" + + def __init__( + self, + X: chex.Array, + fitness: chex.Array, + fn_name: str = "Rastrigin", + title: str = "", + use_3d: bool = False, + plot_log_fn: bool = False, + seed_id: int = 1, + ): + from evosax.problems.bbob import BBOB_fns, get_rotation + + self.X = X + self.fitness = fitness + self.title = title + self.fn_name = fn_name + self.use_3d = use_3d + if not self.use_3d: + self.fig, self.ax = plt.subplots(figsize=(6, 5)) + else: + self.fig = plt.figure(figsize=(6, 5)) + self.ax = self.fig.add_subplot(1, 1, 1, projection="3d") + self.fn_name = fn_name + self.fn = BBOB_fns[self.fn_name] + + rng = jax.random.PRNGKey(seed_id) + rng_q, rng_r = jax.random.split(rng) + self.R = get_rotation(rng_r, 2) + self.Q = get_rotation(rng_q, 2) + self.global_minima = [] + self.plot_log_fn = plot_log_fn + + # Set boundaries for evaluation range of black-box functions + self.x1_lower_bound, self.x1_upper_bound = -5, 5 + self.x2_lower_bound, self.x2_upper_bound = -5, 5 + + # Set meta-data for rotation/azimuth + self.interval = 50 # Delay between frames in milliseconds. + try: + self.num_frames = X.shape[0] + self.static_frames = int(0.2 * self.num_frames) + self.azimuths = jnp.linspace( + 0, 90, self.num_frames - self.static_frames + ) + self.angles = jnp.linspace( + 0, 90, self.num_frames - self.static_frames + ) + except Exception: + pass + + def animate(self, save_fname: str): + """Run animation for provided data.""" + ani = animation.FuncAnimation( + self.fig, + self.update, + frames=self.num_frames, + init_func=self.init, + blit=False, + interval=self.interval, + ) + ani.save(save_fname) + + def init(self): + """Initialize the first frame for the animation.""" + if self.use_3d: + self.plot_contour_3d() + (self.scat,) = self.ax.plot( + self.X[0, :, 0], + self.X[0, :, 1], + self.fitness[0, :], + marker="o", + c="r", + linestyle="", + markersize=3, + alpha=0.5, + ) + + else: + self.plot_contour_2d() + (self.scat,) = self.ax.plot( + self.X[0, :, 0], + self.X[0, :, 1], + marker="o", + c="r", + linestyle="", + markersize=3, + alpha=0.5, + ) + + return (self.scat,) + + def update(self, frame): + """Update the frame with the solutions evaluated in generation.""" + # Plot sample points + self.scat.set_data(self.X[frame, :, 0], self.X[frame, :, 1]) + if self.use_3d: + self.scat.set_3d_properties(self.fitness[frame, :]) + if frame < self.num_frames - self.static_frames: + self.ax.view_init(self.azimuths[frame], self.angles[frame]) + self.ax.set_title( + f"{self.fn_name}: {self.title} - Generation {frame + 1}", + fontsize=15, + ) + self.fig.tight_layout() + return (self.scat,) + + def contour_function(self, x1, x2): + """Evaluate vmapped fitness landscape.""" + + def fn_val(x1, x2): + x = jnp.stack([x1, x2]) + return self.fn(x, self.R, self.Q) + + return jax.vmap(jax.vmap(fn_val, in_axes=(0, None)), in_axes=(None, 0))( + x1, x2 + ) + + def plot_contour_2d(self, save: bool = False): + """Plot 2d landscape contour.""" + + if save: + self.fig, self.ax = plt.subplots(figsize=(6, 5)) + self.ax.set_xlim(self.x1_lower_bound, self.x1_upper_bound) + self.ax.set_ylim(self.x2_lower_bound, self.x2_upper_bound) + self.ax.set_xlim(self.x1_lower_bound, self.x1_upper_bound) + self.ax.set_ylim(self.x2_lower_bound, self.x2_upper_bound) + + # Plot local minimum value + for m in self.global_minima: + self.ax.plot(m[0], m[1], "y*", ms=10) + self.ax.plot(m[0], m[1], "y*", ms=10) + + x1 = jnp.arange(self.x1_lower_bound, self.x1_upper_bound, 0.01) + x2 = jnp.arange(self.x2_lower_bound, self.x2_upper_bound, 0.01) + X, Y = np.meshgrid(x1, x2) + contour = self.contour_function(x1, x2) + if self.plot_log_fn: + contour = jnp.log(contour) + self.ax.contour(X, Y, contour, levels=30, linewidths=0.5, colors="#999") + im = self.ax.contourf(X, Y, contour, levels=30, cmap=cmap, alpha=0.7) + self.ax.set_title(f"{self.fn_name} Function", fontsize=15) + self.ax.set_xlabel(r"$x_1$") + self.ax.set_ylabel(r"$x_2$") + self.fig.colorbar(im, ax=self.ax) + self.fig.tight_layout() + + if save: + plt.savefig(f"{self.fn_name}_2d.png", dpi=300) + + def plot_contour_3d(self, save: bool = False): + """Plot 3d landscape contour.""" + if save: + self.fig = plt.figure(figsize=(6, 5)) + self.ax = self.fig.add_subplot(1, 1, 1, projection="3d") + x1 = jnp.arange(self.x1_lower_bound, self.x1_upper_bound, 0.01) + x2 = jnp.arange(self.x2_lower_bound, self.x2_upper_bound, 0.01) + contour = self.contour_function(x1, x2) + if self.plot_log_fn: + contour = jnp.log(contour) + + X, Y = np.meshgrid(x1, x2) + self.ax.contour( + X, + Y, + contour, + zdir="z", + offset=np.min(contour), + levels=30, + cmap=cmap, + alpha=0.5, + ) + self.ax.plot_surface( + X, + Y, + contour, + cmap=cmap, + linewidth=0, + antialiased=True, + alpha=0.7, + ) + + # Rmove fills and set labels + self.ax.xaxis.pane.fill = False + self.ax.yaxis.pane.fill = False + self.ax.zaxis.pane.fill = False + + self.ax.xaxis.set_tick_params(labelsize=8) + self.ax.yaxis.set_tick_params(labelsize=8) + self.ax.zaxis.set_tick_params(labelsize=8) + + self.ax.set_xlabel(r"$x_1$") + self.ax.set_ylabel(r"$x_2$") + self.ax.set_zlabel(r"$f(x)$") + self.ax.set_title(f"{self.fn_name} Function", fontsize=15) + self.fig.tight_layout() + if save: + plt.savefig(f"{self.fn_name}_3d.png", dpi=300) + + +if __name__ == "__main__": + import jax + from jax.config import config + + config.update("jax_enable_x64", True) + + rng = jax.random.PRNGKey(42) + + # for fn_name in [ + # "BuecheRastrigin", + # ]: # BBOB_fns.keys(): + # print(f"Start 2d/3d - {fn_name}") + # visualizer = BBOBVisualizer(None, None, fn_name, "") + # visualizer.plot_contour_2d(save=True) + # visualizer.plot_contour_3d(save=True) + + # Test animations + # All solutions from single run (10 gens, 16 pmembers, 2 dims) + X = jax.random.normal(rng, shape=(50, 16, 2)) + + def sphere(x): + return jnp.sum(x ** 2) + + fitness = jax.vmap(jax.vmap(sphere))(X) + print(fitness.shape) + visualizer = BBOBVisualizer( + X, fitness, "Sphere", "Test Strategy", use_3d=True + ) + visualizer.animate("Sphere_3d.gif") + # visualizer = BBOBVisualizer(X, None, "Sphere", "Test Strategy", use_3d=False) + # visualizer.animate("Sphere_2d.gif") diff --git a/examples/00_getting_started.ipynb b/examples/00_getting_started.ipynb index 9a4acaa..35187d4 100644 --- a/examples/00_getting_started.ipynb +++ b/examples/00_getting_started.ipynb @@ -35,15 +35,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "9dee8f4d-9ce8-4f5b-8d9a-ccde409dd873", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "EvoParams(mu_eff=DeviceArray(3.1672993, dtype=float32), c_1=DeviceArray(0.14227484, dtype=float32), c_mu=DeviceArray(0.1547454, dtype=float32), c_sigma=DeviceArray(0.50822735, dtype=float32), d_sigma=DeviceArray(1.5082273, dtype=float32), c_c=DeviceArray(0.60908335, dtype=float32), chi_n=DeviceArray(1.2542727, dtype=float32, weak_type=True), c_m=1.0, sigma_init=1.0, init_min=-3, init_max=3, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import jax\n", "import jax.numpy as jnp\n", "from evosax import CMA_ES\n", - "from evosax.problems import ClassicFitness\n", + "from evosax.problems import BBOBFitness\n", "\n", "# Instantiate the evolution strategy instance\n", "strategy = CMA_ES(num_dims=2, popsize=10)\n", @@ -70,13 +81,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "969621cb", "metadata": {}, "outputs": [], "source": [ "# Instantiate helper class for classic evolution strategies benchmarks\n", - "evaluator = ClassicFitness(\"rosenbrock\", num_dims=2)" + "evaluator = BBOBFitness(\"RosenbrockOriginal\", num_dims=2)" ] }, { @@ -89,10 +100,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "e8982ce5-91b0-4ccc-ba46-2c69d4f11287", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "EvoState(p_sigma=DeviceArray([1.4210665 , 0.17011967], dtype=float32), p_c=DeviceArray([1.5021834, 0.1798304], dtype=float32), C=DeviceArray([[1.6521862, 0.2126719],\n", + " [0.2126719, 0.6064514]], dtype=float32), D=None, B=None, mean=DeviceArray([-0.7851852, 1.9345263], dtype=float32), sigma=DeviceArray(1.0486844, dtype=float32), weights=DeviceArray([ 0.45627266, 0.27075312, 0.16223112, 0.08523354,\n", + " 0.02550957, -0.09313666, -0.25813875, -0.4010702 ,\n", + " -0.5271447 , -0.639922 ], dtype=float32), weights_truncated=DeviceArray([0.45627266, 0.27075312, 0.16223112, 0.08523354, 0.02550957,\n", + " 0. , 0. , 0. , 0. , 0. ], dtype=float32), best_member=DeviceArray([0.7087054, 2.3278952], dtype=float32), best_fitness=DeviceArray(17.166704, dtype=float32), gen_counter=DeviceArray(1, dtype=int32, weak_type=True))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Ask for a set of candidate solutions to evaluate\n", "x, state = strategy.ask(rng, state, es_params)\n", @@ -113,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "57b3e83c-e81f-48bd-aa8c-e90b7542502a", "metadata": {}, "outputs": [], @@ -127,10 +153,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "791a2d65-7404-40a2-9c4b-4a1e1ac12dfa", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(
,\n", + " )" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ "state = strategy.initialize(rng, es_params)\n", "for i in range(num_gens):\n", @@ -159,7 +209,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "c8bb5d1f-2b3b-4c3a-91fd-58c4be4b10ce", "metadata": {}, "outputs": [], @@ -190,10 +240,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "772c9449-6fab-4650-bd82-d0c589eb45b1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ParameterReshaper: 4610 parameters detected for optimization.\n" + ] + }, + { + "data": { + "text/plain": [ + "4610" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from evosax.utils import ParameterReshaper\n", "\n", @@ -212,10 +280,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "40ff50bd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(100, 4610)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from evosax import DE\n", "strategy = DE(popsize=100, num_dims=param_reshaper.total_params)\n", @@ -234,34 +313,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "607fab0a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(frozen_dict_keys(['params']), (100, 4, 64))" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "net_params = param_reshaper.reshape(x)\n", "net_params.keys(), net_params['params']['Dense_0']['kernel'].shape" ] }, - { - "cell_type": "markdown", - "id": "38cb1644", - "metadata": {}, - "source": [ - "If you now want to map over the population member axis, you can do so with the of the `vmap_dict` (more about this later):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4b1ac3ea", - "metadata": {}, - "outputs": [], - "source": [ - "# Get dictionary to vectorize/parallelize rollouts with\n", - "param_reshaper.vmap_dict" - ] - }, { "cell_type": "markdown", "id": "4d1192e9", @@ -274,10 +345,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "9c488bab", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([ 0.49, -0.04, -0.59], dtype=float32)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from evosax import FitnessShaper\n", "fit_shaper = FitnessShaper(centered_rank=True, w_decay=0.01, maximize=True)\n", @@ -294,28 +376,39 @@ "source": [ "## ARS on CartPole Task\n", "\n", - "`evosax` also comes with a simple fitness evaluation helper for a JAX-based version of Cartpole. You will have to make use of the `vmap_dict` in order to vectorize the rollouts along the population axis:" + "`evosax` also comes with a simple fitness evaluation helper for all [`gymnax`](https://github.com/RobertTLange/gymnax) environments (e.g. CartPole, MinAtar, etc.). We will vectorize the rollouts of the different population members:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "92cbdc9d-fe18-4582-91c8-4713568c1199", "metadata": {}, "outputs": [], "source": [ - "from evosax.problems import GymFitness\n", + "from evosax.problems import GymnaxFitness\n", "\n", - "evaluator = GymFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16)\n", - "evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)" + "evaluator = GymnaxFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16)\n", + "evaluator.set_apply_fn(network.apply)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "5186a497", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=1.0, lrate_limit=0.001, momentum=0.0, beta_1=None, beta_2=None, beta_3=None, eps=None, max_speed=None), sigma_init=0.03, sigma_decay=1.0, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from evosax import ARS\n", "\n", @@ -330,15 +423,66 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "c99235f5-6cb1-4e3b-b00b-1b5789d7898e", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/flax/core/scope.py:740: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.\n", + " abs_value_flat = jax.tree_leaves(abs_value)\n", + "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/flax/core/scope.py:741: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.\n", + " value_flat = jax.tree_leaves(value)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generation: 0 Performance: 22.875\n", + "Generation: 5 Performance: 26.25\n", + "Generation: 10 Performance: 27.8125\n", + "Generation: 15 Performance: 31.3125\n", + "Generation: 20 Performance: 53.0\n", + "Generation: 25 Performance: 99.0625\n", + "Generation: 30 Performance: 115.8125\n", + "Generation: 35 Performance: 130.125\n", + "Generation: 40 Performance: 192.9375\n", + "Generation: 45 Performance: 200.0\n" + ] + }, + { + "data": { + "text/plain": [ + "(
,\n", + " )" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], "source": [ - "num_generations = 250\n", + "num_generations = 50\n", "num_rollouts = 20\n", - "print_every_k_gens = 20\n", + "print_every_k_gens = 5\n", "\n", + "rng = jax.random.PRNGKey(0)\n", "es_logging = ESLog(param_reshaper.total_params,\n", " num_generations,\n", " top_k=5,\n", @@ -364,11 +508,117 @@ "es_logging.plot(log, \"CartPole Augmented Random Search\")" ] }, + { + "cell_type": "markdown", + "id": "50ba4ab2", + "metadata": {}, + "source": [ + "# More Minimalism (no `es_params`, `fit_shaper` or `param_reshaper`)\n", + "\n", + "We also provide utilities that abstract away all the details if you are only interested in a default implementation or want to avoid 10 additional lines of boilerplate code :)\n", + "\n", + "This means that you can directly provide the placeholder parameters and fitness shaping arguments at the time of strategy instantiation. Furthermore, if you don't explicitly provide `es_params` at the time of `initialize`, `ask`, `tell` the strategy will use a set of default parameters:" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "e1fab05b", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ParameterReshaper: 4610 parameters detected for optimization.\n" + ] + } + ], + "source": [ + "strategy = ARS(popsize=100,\n", + " pholder_params=policy_params,\n", + " elite_ratio=0.1,\n", + " opt_name=\"sgd\",\n", + " maximize=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "07ab2a62", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generation: 0 Performance: 22.875\n", + "Generation: 5 Performance: 26.25\n", + "Generation: 10 Performance: 27.8125\n", + "Generation: 15 Performance: 31.3125\n", + "Generation: 20 Performance: 53.0\n", + "Generation: 25 Performance: 99.0625\n", + "Generation: 30 Performance: 115.8125\n", + "Generation: 35 Performance: 130.125\n", + "Generation: 40 Performance: 192.9375\n", + "Generation: 45 Performance: 200.0\n" + ] + }, + { + "data": { + "text/plain": [ + "(
,\n", + " )" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAADgCAYAAADsbXoVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAB3DElEQVR4nO2dd3hb5fmw70fT8t524kxnEggJJAGSsDeFsmdZgbIplFHo/BUKbT+g0FJaSoECYYe9QimUVQgrJIwMsryS2In3trX1fn+cI0W2JVtO7NgJ731duqyz3vMeST7PebYopdBoNBqNBsAy1BPQaDQazfBBCwWNRqPRRNBCQaPRaDQRtFDQaDQaTQQtFDQajUYTQQsFjUaj0UTQQkGz3YjIoSJSOdTz2NURESUiE4fBPBaIyJKhnsfOQkRuFZGnhnoeww0tFHYBRORHIrJMRNpFZKuIvCUiB+7AeF1uQubNPWSO3yYi60TkooGZfZ9zudWcz/4743wDzWDfSEXkQxHxmN9NvYi8LCIjBut8OwsR+bGIrDV/bzUi8m8RSRvqeWm0UBj2iMgNwL3AH4ECYAzwD+Ck7RjL1svmLUqpVCAd+DnwsIhM6/eE+zcfAS4AGs2/mtj8xPxuJgKpwN1DPJ8dQkQOwfg9n6OUSgP2AJ4bhPOIiOh7XD/RH9gwRkQygNuAq5VSLyulOpRSfqXUG0qpm8x99hORz0Sk2dQi/i4ijqgxlIhcLSIbgA0i8pG56Vvz6fOs6HMqg1eBJmCaiDhF5F4R2WK+7hURZ5z5jhSRl0SkTkTKReTaPi7xIGAEcC1wdrd5d1HtRWSceS02c3m8iHxkPmm+KyL3h/eP2vciEdksIk0icoWIzBGRFeZn9fduc79YRNaY+74tImO7fYZXiMgG89j7zRvOHsA/gbnmZ9ls7u8UkbtFZJP5FPxPEXFFjXeT+V1tEZGL+/iMIiilmoFXgZlRY11kzrtNRMpE5PKobYeKSKWI3CgiteY5L4raniMir4tIq4gsBSZ0+0zmiciXItJi/p0Xte1DEfm9iHxqXvsb5nhPm+N9KSLj4lzKHOAzpdTX5nU1KqUeV0q19fX5iUiWiCw2f2NN5vtR3eb1BxH5BOgEikVkTxH5r4g0muP9KmouDhF5wvz8VovI7ES/j90WpZR+DdMXcCwQAGy97DMLOACwAeOANcB1UdsV8F8gG3BFrZsYtc+hQKX53gKcAviBKRhC6XMgH8gDPgVuj3PccuC3gAMoBsqAY3qZ+yPA84AdaABOi9p2K/BU1PI4c942c/kzjCdmB3Ag0BreP2rffwJJwNGAB+OGmg8UAbXAIeb+JwElGE+sNuA3wKfdPsPFQCaGplYHHGtuWwAs6XZdfwFeNz/zNOAN4P9Ffac1wF5ACvBM9++j21gfApeY73OAd4HXorYfj3EzF+AQjBvhvlHfT8D8Du3AD8ztWeb2Rebnn2LOpyp8Lebcm4Dzzc/kHHM5J2peJea5M4DvgPXAkeb+TwCPxbmmgwA38DtgPuDsx+eXA5wGJJvbXgBe7fZ5bQL2NOeRBmwFbjR/C2nA/lG/MY/5uViB/wd8PtT/90P9GvIJ6FcvXw6cC1T385jrgFeilhVweLd9YgmFENCMYcr5Bjjb3FYK/CBq32OAiqjjwkJhf2BTt/P8spcbQzLGjfxkc/lBut7sbiWOUMC4MQeA5KjtT9FTKBRFbW8AzopafglTeAJvAT+O2mbBuHmOjfq8Doza/jzwC/P9AqKEAsbNuQOYELVuLlBuvn8UuCNq2+Tu30e3z+lDcy4t5n7fAGN6+f5fBX4a9f24iXqowBCGB5g3QT8wNWrbH9kmFM4HlnYb+zNgQdS8fh217R7grajlHwLf9DLP4zBu9s1AO/Bnc069fn4xxpkJNHX7vG6LWj4H+DrOsbcC70YtTwPc2/v/uru8erMxa4aeBiBXRGxKqUCsHURkMsY/1GyMG60N44k9ms0JnGuLUmpUjPUjgY1RyxvNdd0ZC4wMm1BMrMDHcc53CsaN/d/m8tPAuyKSp5Sq62OuI4FGpVRn1LrNwOhu+9VEvXfHWE6NmvtfReSeqO2CoVGEr706altn1LHdycP4HpaLSPRY1qi5R38/0Z9tPK5VSv1LRKZjaCyjMJ6GEZHjgFswhIvFPPfKqGMbuv12wnPPw/itRP82oufS/XsPby+KWk708+2BUuot4C0xbP6HYTzxrwNeoZfPT0SSMTSJY4Esc3uaiFiVUkFzOfqaRmM82MSj+/ea1Nv/2/cB7VMY3nwGeIGTe9nnAWAtMEkplQ78CuOfKJodKYW7BeOmGWaMua47mzGe5jKjXmlKqR/EGfdCjJvGJhGpxrgp2IEfmds7MG4OYQqj3m8Fss0bRJjuAqE/bAYu7zZ3l1Lq0wSO7f7Z1mPcEPeMGitDGY7i8Nyj5zom0UkqpVYCvwfCPg0nhsZzN1CglMrEELLdv/9Y1GEI5Xhz6f69h7dXJTrfRFBKhZRS7wHvY5iw+vr8bsQwa+5v/t4PNtdHX3P0d7IZw5SpSRAtFIYxSqkWDBv9/SJysogki4hdRI4TkbvM3dIwzDDtIjIVuDKBoWtI/B/lWeA3IpInIrnmfGLFdi8F2kTk5yLiEhGriOwlInO67ygiRcARwAkY6v9MYAZwJ9uikL4BDhaRMWI43H8ZPl4ptRFYBtwqIg4RmYthrthe/gn8UkT2NOeXISJnJHhsDTBKTCe5UioEPAz8RUTyzfGKROQYc//ngQUiMs0Uarf0c66PY0ShnYjhT3Fi3uBNreHoRAYxn6pfxvgMk8WINLswapd/A5PFCIe2iRGQMA1DU9khROQkETnbdBqLiOyH4Q/5PIHPLw1DaDSLSDZ9f36LgREicp3pwE6TXTT8eWehhcIwRyl1D3ADhvOzDuPJ5ycYtmOAn2E8Xbdh/DMlEtp3K/C4GJE0Z/ax7+8xbsArMMwSX5nrus8zyLabfDnGE9+/MJyQ3Tkfw978jlKqOvwC7gP2FpG9lFL/Na9lBYa5pfvN6FwMW3ODOZ/nMLSqfqOUegVDIC0SkVZgFYbNOxHeB1YD1SJSb677OYYT9nNzvHcxnm7DZpN7zeNKzL/9masP+Cvwf8qI1rkWQ9A0YfwOXu/HcD/B0NaqgYXAY1HnacD4Pm/E+IxvBk5QStX3HKbfNAGXAhswAwSAPymlnja3x/38MD47F8bv63PgP72dyPyMjsJ4aKg2z3nYAFzDbouYDhaNZpdGRJ4D1iql+vvkrdFootCagmaXRIycgwkiYhGRYzHCSl8d4mlpNLs8OvpIs6tSiGETzwEqgSuVmQyl0Wi2H20+0mg0Gk0EbT7SaDQaTQQtFDQajUYTYZf2KRx77LHqP//pNSJNo9FoND2Jm+C4S2sK9fUDETKt0Wg0mjC7tFDQaDQazcCihYJGo9FoIgyaUBCR0SLygYh8Zzav+Km5PttseLHB/JtlrhcRuU9ESsRohLLvYM1No9FoNLEZTEdzALhRKfWVGL1Xl4vIfzHqz7+nlLpDRH4B/AKj1slxwCTztT9G9c9+F67y+/1UVlbi8XgG6DI0vZGUlMSoUaOw2+1DPRWNRjMADJpQUEptxSgTjFKqTUTWYNRiPwmj+QcYFR8/xBAKJwFPKCOb7nMRyRSREeY4CVNZWUlaWhrjxo0jqh67ZhBQStHQ0EBlZSXjx48f6uloNAPOhtot3P3WrznANwknXR98BLApH46QG0fIY77cfGWrZo2tNeZ4EwLJHOLL3qE5tUmA15Nq2afgEBacfNsOjRWLnRKSKkav1n2ALzDqvodv9NUYZYDBEBjRzTEqzXVdhIKIXAZcBjBmTM9S9B6PRwuEnYSIkJOTQ11dXz1xNJpdj7LaZm585VTKkzo4tektjul097p/EAsenNwyOod2i5Aa6lotot0ifGtr4vzaku2aj1fglXQbz2bYcbUrMpq/265x+mLQhYKIpLKt9WFr9M1aKaVEpF91NpRSDwEPAcyePTvmsVog7Dz0Z63ZHfluSyu3vXQO5ekdAJQddC1MPqvnjjYn2JPBkYLV6kD5O6h7di4/3fenXDL9ki67/vPbf3L/N/eT/X9rcFgdCc8lpEK8WfYm9319H9Ud1Rwy8mCu/vlnZJ0+OG7XQY0+EhE7hkB4Win1srm6RkRGmNtHYPSMBaOjU3QXqFEMcJennUFDQwMzZ85k5syZFBYWUlRUFFn2+Xz9Hm/t2rXMnTsXp9PJ3XffPQgz1mg00Swtb+R3z1zH6vRKziSHotQiynxNkDOh5ytjFCRnG8JBhPKWcgDGZ/Q0p+a4cgBo9DQmPJcvq7/k7MVn86slvyLLmcUjRz/Cn/f4JXS6cU6cODAX3I1B0xTEeIR8BFijlPpz1KbXMTo83WH+fS1q/U9EZBGGg7mlv/6E4UBOTg7ffPMNALfeeiupqan87Gc/2+7xsrOzue+++3j11VcHZoIajSYu762p4a8v3kfVqKXMCTr45blvcO3HP6espSyh48P7FWf0bGyYk2QIhQZ3A4UphT22d6eqvYpL37mUvOQ8/njgHzm++HgsYqHtvfcASJoypY8Rto/B1BTmY3TYOlxEvjFfP8AQBkeJyAbgSHMZjPZ/ZRgdlx4GrhrEue1U3nvvPfbZZx+mT5/OxRdfjNdrNAgbN24cN998M9OnT2e//fajpKSnrTE/P585c+bo6B6NZpB5aXkltz3zAu0jXyVXCfec/AI2ZxrFGcVUtFQQDAX7HKO0pRSbxcbotJ4tw8OaQoOnIaH5bGzZSFAFueOgO/jhhB9iEeN27Vm3DkR2PU1BKbWE+PU1joixvwKuHsg5/O6N1Xy3JXYUwPYybWQ6t/xwz4T393g8LFiwgPfee4/JkydzwQUX8MADD3DdddcBkJGRwcqVK3niiSe47rrrWLx4h1vgajTfe4558lpqG3KwdcxP+BhfezVTx/+TMovw1CF/ISt7AgATMifgC/moaq9iTHrP4JZoypvLGZc+Dpul5601WlNIhHqPUcYnz5XXZb133XrsY0ZjSUlJaJz+sksXxNsVCAaDjB8/nsmTJwNw4YUXcv/990eEwjnnnBP5e/311w/VNDWa3YZWj58q31KS00dyxJhTe99ZKRwhN6n+Bura7uUtp3D3tMuYUnxkZJewf6CspaxPoVDWUsbU7Kkxt/VXU6h3G0Ih15XbZb133TqSJg+O6Qh2c6HQnyf6oSI6ekdH8mg0O85XG5uwWDuxU8Yfg3/p+X/l74T2WuiohfY6CLh5N9nF9QV5XFp4EMfMuabL7sWZhn+gtLmUQ0cfGve83qCXyvZKflD8g5jbXTYXybbkxDUFd71xjD05si7U2Ylv40bSTzghoTG2h91aKAwHrFYrFRUVlJSUMHHiRJ588kkOOeSQyPbnnnuOX/ziFzz33HPMnTt3CGeq0ewefFG2GSWKVoLUVX9NvuomFOwuSMkzoodS8iAlj8+avyKtZR1XH/W3HuOlO9LJc+X16WyuaKkgpEIxncxhclw5/dIUemgJJSWgFM4pkxMaY3vQQmGQSUpK4rHHHuOMM84gEAgwZ84crrjiisj2pqYm9t57b5xOJ88++2yP46urq5k9ezatra1YLBbuvfdevvvuO9LT03fmZWg0uwyV5UvBNLeXnHo/+UXz+jxm7b/PZXLOHlgt1pjbizOKI+Gm8Qhv71UoJOXQ6E4sJLXB3dBDKHjWrQMGL/IItFAYVG699dbI+6+/jt1T/qabbuLOO++MO0ZhYSGVlZUDPTWNZrfEHwwhjd9GhMKG5g3M60MohFSIDU0bOGXiKXH3Kc4s5vXS11FKxTXzlraUYhEL4zLGxR0nx5XDxtaNfV4HGJrChMwJXdZ5163HkpyMfdSohMbYHnTpbI1Gs9uweksrI6mILK9vWt/nMZvbNuMOuOM6iMF4+u/wd1DTWRN3n7LmMopSi3BanXH3yU7K7pdPIZaT2Tl5MmIZvFu3FgpDSEVFBbm5uX3vqNFoEmJZRSOFti0AZDozKWnuu87QukbDJDM5O76dPmwS6s2vUNZSxoSMCXG3hzo6GNnppNnbTCAU6HVO3qCXVl9rF6GglMKzfj3OQTQdgRYKGo1mN+LbsmrSbcaT+D75+1DaXNpn0tnaxrVYxcrEzNjJYN6yskgEUllzbKEQCAWoaK1gfGbsasG+TZsoP/U0Zv/+dRSKJk9Tr3MKaxPRQiFQU0OopWVQncyghYJGo9lNUErRvulrOi2GzX9WwaxImGhvrG9az7j0cTHNPu5vv6XsB8eTvGYz6Y70uJpCZVslgVAgpqbgXrGCirPPwbdxI46GNqDvXIVYOQreneBkBi0UNBrNbsLGhk7GetbSatrb98nfB4ANTRt6PW5d0zqmZMe+0bpXrwbAt3EjEzInxBUKpS2lQM/Io7b3P2DjBRdiSU4m49RTEZ8fu1/16VcIC4VwwhuAZ53hH3FO1pqCRqPR9MmXFY3MsJTS4kwhxZ7CpKxJCNKrUGjxtlDdUR1XKPhKDSEQqK2lOKM4rvkoVnXUpmefpfInP8E5cSLjFj2La8YMAFI9/dAUkrpqCvaRI7GmpfV67I6ihcIAM9ClswEOPfRQpkyZEhmntra2xz4LFy5ERHj33Xcj61599VVEhBdffHG7r0ej2VVYvrGJfaxltKfmku5Ix2VzMTptNBua4wuFcHTSlKzYQsFbZmgAgdpaxmeMp8nbFNMfUNpcSkFyAamOVFQoRO09f6b6d7eRevDBjH3icWy5uVgzMgBIdfdd/yi8Pdu1rUubd/26QXcyg85TGHAGunR2mKeffprZs2f3us/06dNZtGgRRx5p1G159tlnmWE+nWg0uzvflW9iHFtpTRpPmsMoDTEpa1KvmsLaxrUA8TWFkrBQqGFC5qGAEWU0K2lWl/3KWsoipqPaO++i8fHHyTz7LAp/8xvEZtxmrZmGUMj22RMyH2U5s7BbjOrIIZ8Pb1k5qUf0qCU64GhNYSewI6Wz+8NBBx3E0qVL8fv9tLe3U1JSwsyZMyPbly9fziGHHMKsWbM45phj2LrVaFfx8MMPM2fOHGbMmMFpp51GZ2cnAAsWLODaa69l3rx5FBcXa41DM2xp7PCR1rgKgDabkz0qBd+mTUzMnMimtk14g96Yx61rXEd2UnaPfACAYFsbAbPVrN80H4GhFUQTUiHKW8ojEUqt/32H1MMPp/CWWyICAcBqViEoDKb22Win3l3fxZ/gKy2FYHDQncywu2sKb/0CqlcO7JiF0+G4O/rez2SgSmdfdNFFWK1WTjvtNH7zm9/EzKoUEY488kjefvttWlpaOPHEEykvN2ydfr+fa665htdee428vDyee+45fv3rX/Poo49y6qmncumllwLwm9/8hkceeYRrrjGKgm3dupUlS5awdu1aTjzxRE4//fT+fFoazU5h+cYmZohh728VxQ+fLqWu4u9MuuoIQipEWXMZe+Ts0eO49U3r4yat+UqNm781K4tATS1jUwpx2Vw9yl1Ud1TjDrgpzihGBYMEampx/vDEHv+jYfNRXiCZsr58Cp76LiWzw+Utdob5SGsKg0ys0tkfffRRZHt06ezPPvss5hhPP/00K1eu5OOPP+bjjz/mySefjHu+s88+m0WLFrFo0aLI2ADr1q1j1apVHHXUUcycOZPf//73kfIZq1at4qCDDmL69Ok8/fTTrDYjLgBOPvlkLBYL06ZNo6YmfjanRjOULNvYyExrGaGsYtoCnbg6AgTq6piUNQkgpl/BH/JT0lwS359gOpmTD9ifQF0dogxHcndNIbrbWqC+AYJB7CN6dlazZGQCkOt39m0+6qzvFo66HnE6cYzpvXT3QLB7awr9eKIfKrqXzg4Gg8yaZdgrTzzxRG677TaKiooASEtL40c/+hFLly7lggsuiDnefvvtx8qVK0lOTo4IIjBiuPfcc8+YgmfBggW8+uqrzJgxg4ULF/Lhhx9Gtjmdzi5jaDTDkWUVTVxmK8My6nDa3CuwewIE6xsYkzYGh8VBSVNP02x5Szn+kD9uJrO3tBRxOEjeZ1/a3voPwcZGijOK+bL6yy77hYXEhMwJBNZtAsBWUNBjPEtKMthsZPpsvUYfKaV6lLjwrluHc+LELuaowUJrCoNMdOlsIGbp7PDfuXPnYrVa+eabb/jmm2+47bbbCAQC1Ncb4Wl+v5/Fixez11579XrOO+64gz/+8Y9d1k2ZMoW6urqIUPD7/RGNoK2tjREjRuD3+3n66acH5sI1mp2Exx+kurKCnFADgREzCZk+sUBjIzaLjeLMYtY396yBFC5vEU9T8JWW4hg/Hpv51O+vqWFC5gRqOmto97VH9itvKSfLmUVWUhb+akObto8Y0WM8EcGakUG6R2jyNBFSoZjnbfO34Qv5uuYo7ITyFmF2b01hGLCjpbO9Xi/HHHMMfr+fYDDIkUceGbH/x+O4447rsc7hcPDiiy9y7bXX0tLSQiAQ4LrrrmPPPffk9ttvZ//99ycvL4/999+ftra2Hb9wjWYnsbKqhT2U8dDVlj+FZNOnHGxqQgWDTMyc2OPpHgx/gt1ij1vV1FtWhmv6XtjNp/5AbS3jJxh5COUt5UzPmw4Y5qNwfkKg2gjeiKUpgOFXSHErgipIs7eZ7KTsHvt0z2YO1NcTrK8naZDLW4TRQmEQGYjS2SkpKSxfvrzPcy1YsIAFCxb0WL9w4cLI+5kzZ3bxZ4S58sorufLKK3s9FqC9vb3HPhrNUPNlRSN7W0pRYqUtawyucKBRKESwuZlJWZNYXLaYFm8LGc6MyHHrGtcxMXNiJOwzmpDHg7+ykoyTTsKWnw9AoLaOCfvuBxiCYHredJRSlDaXcsy4YwDwV9cgSUlYMzNjztWakUFSp/F/1OBuiCkUutc98q43M5l3kqagzUcajWaXZnlFE3OdFUj+NFqVf5tQAAINDUzKNJzN0RVTlVK9lrfwlZcbHc4mTsCWkwMiBGpqGJU2CrvFHnEuN3gaaPW1RsJV/dVbsRcUxO25YE1Px9Hpjxwbi+6aws4qbxFGC4UhRJfO1mh2jFBIsayikWmUQdE+tPpaSfZuC4gINjREIpCinc317noaPY19Rh45iosRux1rbg6BulpsFhtj08dGyl1077YWqK7BFsOfEMaamYG1zQ3Ez2ruLhS869Zhy8vDlt1TqxgMtFDQaDS7LCV17WR4q0gJtkLRLEMoRFWTCTQ0UpBcQJo9rUtY6rom08kcT1MoKwWLBce4cQDY8/LxmyHZxRnFEU0hHHkUTlzzV1dHfBCxsGRkIG3bzEexqHfXY7fYSXcYyW6enVTeIjLHnXYmjUajGWCWVTQxU8y8gZH70uZr62I+CjbUIyJMzJrYpdxFpLFOVrxw1DIco0djcTgAsOXnE6g1spuLM4upbK/EG/RS1lJGij2FguQCI3GttjYSrRQLa0YGqr0Dp7L2aj7KdeUiIqhAAN+GkkHvoRCNFgoajWaXZVlFI/s5K1C2JMjfg1ZvayT6CAxNAWBS5iQ2NG+I5Nqsa1rHiJQRXRzP0XhLS3BM2NYbwVZQQMAsRFmcUUxIhahoqaCs2ah5JCIE6uuNxLXC3oRCJgBFZPaqKYRNR76KCpTfv1PKW4TR0UcajWbY88WmMpKkp039i/JGrnRWILl7g9VOm6+NFJ/h5LXm5RJoNG68k7Im8fz656nprKEwpZB1jevi+hNUIIBv4ybSDjs8ss6Wn0ewsZGQzxfxH5S3lFPWUsbckXMBCFRXG/v2KhQMITQymNGrpjAydSSwc8tbhNFCYRCoqanh+uuv5/PPPycrKwuHw8HNN9/MKaecMqDnWbt2LRdddBFfffUVf/jDHwakGqtGM9z48at3sLTlaTpKbyDky++yzUqQ8cklUHQRAK2+VrIDDiwpFmy5eQTrjRtvuNVmSXMJmc5MKlorOHLskTHP59u0Gfx+HBO2NcwJh6UG6+oYVzgOi1j4tu5b6tx12yKPthpCoXdNIVwUL4XVvWgKe+ftDRjlLbDZcI6P3eZzMNBCYYBRSnHyySdz4YUX8swzzwCwceNGXn/99QE/V3Z2Nvfddx+vvvrqgI+t0QwH/rzkZZa2GFn21x+XzZ5ZXcvHp7esx/aWB4r2BaDN18ZYvw1LWgq2nBwCjab5KFwDqWkDWc4sQioUP5PZ7KHgjDIfhZ3H/tpakouKGJU6inc3Gb1LJmQa+wVqEhEK4aJ4Lho9G3tsD4QCNHmaukQeOYuLEdO3sTPQPoUB5v3338fhcHTJWh47dmyk6mgwGOSmm25izpw57L333jz44IMAfPjhhxx66KGcfvrpTJ06lXPPPbfPWkP5+fnMmTMHu71n8o1Gs6vz3w3f8OiGP2ILGdVCR+WGOHxqQZfXbLtZsXTkNqGQ5rdiSU3BlpNN0CwRk+HMIN+VT0lzSZ+RR16zh4JjfE9NIVCzza9Q3VEdeQ+GpiBJSVgyYvspYJtQyPY7afQ09vgfb/I0oVCRjms7s7xFmN1aU7hz6Z2RJhoDxdTsqfx8v5/H3b569Wr23XffuNsfeeQRMjIy+PLLL/F6vcyfP5+jjz4aMLKeV69ezciRI5k/fz6ffPIJBx544IDOX6PZFahorOVnH12HiJN/HfMgC949NbYNvmo5JGVAttnLwAxJtaakYs02NAWlFCISabiTak+NdGWLhbesFNuIEVhTUyLrbFGlLgDGZ47nw8oPcVgcFKUaBSv9NdXYCwvjJq4BEYGR6bXiD/lp9bV2cXZH5ygEW1oIbN2608pbROa4U8/2PeTqq69mxowZzJkzB4B33nmHJ554gpkzZ7L//vvT0NDAhg1GqNx+++3HqFGjsFgszJw5k4qKiiGcuUYzNHT6vZz92lUELS3cMucuZhVNItmWHDtap+orGLkPWIxbWZuvjWQPWFJTseXmoDwelFkgb1LWJEqbS/mu4TsmZ03GIrFvf77SMpzFxV3WWTMzwW4nUGvkKkzIMExG4zLGYbVYAQhsre7VyQzbGu2kewzB0V3QRYRCci6+TUbFVUe3uQw2g6YpiMijwAlArVJqL3PdrcClQJ2526+UUv82t/0S+DEQBK5VSr29o3Po7Yl+sNhzzz156aWXIsv3338/9fX1kVaaSin+9re/ccwxx3Q57sMPP+xSptpqtRIIBHbOpDWaYcQ5L/6KDss6Tim6kdP3PACWPUaOz0PDykWw7NWuO9ethfnXRRZbfa04vSEsaWlYs40qo4GGBhwpKUzMnIgv5OPbum85Y/IZMc+tQiG85eVknn5al/Uigj0vD39UWGr0XzCqqKbsv3+v1yZWK5b0dJLdhtmowd3QZYxoTSFQZ5S3sOXl9xxoEBlMTWEhcGyM9X9RSs00X2GBMA04G9jTPOYfImIdxLkNGocffjgej4cHHnggsi7c3hLgmGOO4YEHHsDvN+qfrF+/no6Ojp0+T41mOHLz2w9R5nuHqa4TuL14PDx4MCy+jhwsNDqSIGdC19e0k2HmjwDjgavV14rDEzB8CrnbhAJsczYrVFx/QmDrVlRnJ87iCT22GbkK2xLYHBZHZJxI4lph/GzmMNb0dJLi1D8KC4WcpBwC9ca5bHk7txTOoGkKSqmPRGRcgrufBCxSSnmBchEpAfYDYrciG8aICK+++irXX389d911F3l5eaSkpEQqoV5yySVUVFSw7777opQiLy+vz+ih3/72t8yePZsTTzyxy/rq6mpmz55Na2srFouFe++9l++++450U0XVaHYlXlj1Cf/e+g/yQhN5NlAOT5wImWPgjMfJrvmATW2b4KSn4h7vDrgJhALY3WL6FIy8hqApFIozirGIhZAKxc9kLjPKVzgnxhAK+fmRiqUp9hSe/+HzEX/CtsS1+HWPwlgzMpB2I8Ouu0ms3l1Pmj2NJFsSbaaTfGfVPAozFI7mn4jIBcAy4EalVBNQBHwetU+luW6XZMSIESxatCjmNovFwh//+MceTXAOPfRQDj300Mjy3//+98j72267LeZYhYWFkZaaGs2uzqPLXsJKiFerlmCzOuCIW+CAq8CeRE7rCr6ujV1+Pkybrw0JKayeAJa0NKO6KduympNsSYxJG8PG1o29lLcwI48mxBYKHUuWRJbDoahgaBhAYppCRga0d2ARS0yhEG6uE6yvx5qZuVPDUWHnO5ofACYAM4GtwD39HUBELhORZSKyrK6uru8DNBrNsKe2zUNm6zKygwEy9j4LrvkKDroB7EkA5LhyaPY2EwjF97O1+lpxmcXwLKkpkSfsYOO2G+/03OlMzppMsj055hi+0jKsWVnYsrJ6bLMX5BPq6CDY3tPc21vHte5YMzMItbSQ5cyi0dPYZVt0iYtAXf1ONx3BTtYUlFKRzu8i8jCw2FysAqLjw0aZ62KN8RDwEMDs2bN102CNZjfg9a82k2Wtpd2WCif9vcf2nKQcFIpmb3OX3sXRRBfDs6amIg4HlvR0AvXbhMKv9v8VvpAv5vFgaArRmczRbGu2U4s1tWuGsd/suNZbhdQwlowMgi0t5LhG9NAUGjwN7JG9h3Ge+nqsQ1BaP2FNQURii9Z+ICLRYvQUYJX5/nXgbBFxish4YBKwdEfPp9Fohj9KKdYtfQefNUhmSuxIm3CHsnhF5MDMUTCFgiU1DcDMat52TKojNWa3s/A8fKWlMZ3MALb8rrkK0QSqaxCXq9fEtTDWsFBwZMd0NEe34bTl5vU53kDTp1AQkXki8h2w1lyeISL/SOC4ZzEcxVNEpFJEfgzcJSIrRWQFcBhwPYBSajXwPPAd8B/gaqVUcHsvSqPR7DqsqmplZst7tFitZGSMiblP2M7em1Bo87VFeilYUlMBsOZkR+of9UWwsZFgS0tMJzNEawo1PbaF+yj0lrgWxpqRCaEQhZLR5Xo6/Z10+DvIdeWilDKFwvA0H/0FOAbjaR6l1LcicnBfBymlzomx+pFe9v8D8IcE5qPRaHYjXlpWzk+tS3nCWcCkOE/xOUmmUIhTWRRMn4LZdS2cjWzLzok4j/siUt4irqawzXzUnUB1da99FKIJJ7AVBFJo8DREMq7D15bryiXU0Ylyu4dEKCRkPlJKbe62Sj/FazSaHcYbCFLzzTtkSRutFiHdGTucOqwpdHfMRtPFfJRmmo9ycyIhqX2xrRBebJ+CNTUFS0pKJIEtGkNTSFAoZBomptxgEt6gl86AkccU1hpyXbkEhyhHARITCptFZB6gRMQuIj8D1gzyvHZpampq+NGPfkRxcTGzZs1i7ty5vPLKKwN+noULFyIivPvuu5F1r776KiLCiy++OODn02gGmvfW1HJ44GO89jTag95IC8rupNpTsVvsvfsUvK1kBIzwzYj5KDuHYHMzykwW7Q1vaRmW5OReS1XY8vMjRfHCqECAQF1d4ppCuCiez5hr+Jrq3IYgyHXlGnkPMGw1hSuAqzHyBqowwkmvHsQ57dKES2cffPDBlJWVsXz5chYtWjRo+QTTp0/vkhPx7LPPMmPGjEE5l0Yz0Ly6rIzjbMvo2OM4gLid0ESEHFdOr+ajNl8bWUGjVIwlxRAKkazmpqY+5+IrK8UxYUKvfgGjLWdXoRBJXEtUU4gqigfbTGKRbGZXTkQoDLvoI7PUxF+VUucqpQqUUvlKqfOUUonpY99DdmbpbICDDjqIpUuX4vf7aW9vp6SkhJkzZ0a2L1++nEMOOYRZs2ZxzDHHsNVMsnn44YeZM2cOM2bM4LTTTouU4liwYAHXXnst8+bNo7i4WGscmkGjttWDlLxHKp20TzK6nMXTFMDwK/TlU8jw20AES4oRLNk9q7k3vCWlPQrhdcdWkE+gpqujOdxxzZ6gphCOUEr1GMthTaHeXY9FLGQ5swjUmZpC3s6PPurV0ayUCorIWBFxKKXiB/cOU6r/+Ee8awa2dLZzj6kU/upXcbfv7NLZIsKRRx7J22+/TUtLCyeeeCLl5UaNeb/fzzXXXMNrr71GXl4ezz33HL/+9a959NFHOfXUU7n00ksB+M1vfsMjjzwSEVxbt25lyZIlrF27lhNPPJHTTz+9X5+RRpMIr3xdxfGWTwkmZdOSPxXoQyi4cqjrjJ+wavRSsGFJTY087XfPao5HsK2NQG1tzEzmaOz5+bTW1UWcw2D4E6D3NpzRhDWF5M4QpG8TCg3uBrKTsrFarIamYLNF9t2ZJBJ9VAZ8IiKvA5FUPqXUnwdtVrsRV199NUuWLMHhcPDll1/yzjvvsGLFisgTeEtLCxs2bMDhcERKZwOR0tmJ9FM4++yzue+++2hpaeGee+6JlNBYt24dq1at4qijjgIMLWWEmXG5atUqfvOb39Dc3Ex7e3uXqq0nn3wyFouFadOmUVPTM/xOo9lRlFIsXlbCC9avsO51Lq2mszWeoxmMXIW1DfEf8oz+zJaIPwG2CYXorOZY+MI1j+I4mSPj5eeD30+wuTmS9RwWCr11XIvG4nQiSUk4OnxIunQxH23LUajDlpODWHZ+d4NEhEKp+bIAaYM7nYGltyf6wWIoSmfvt99+rFy5kuTkZCZP3lbTRSnFnnvuyWef9awruGDBAl599VVmzJjBwoUL+fDDDyPboueRiAlLo+kvKypbGNvwMUkOL+x1Gq2+VgAyHPGfjHOScmj0NBJSoZi9EFp9raR4pUtzHGtYU+gjV8FbGhYKvWsKkQS2mpqIUAhsrTYS1/pRiNKakYFqbSNzXGYX81E4ymqochQgAUezUup3SqnfYdQpuidqWRODoSqdfccdd/QosjdlyhTq6uoiQsHv97N69WoA2traGDFiBH6/n6effnqHz6/R9IcXl1dyku1zQqmFMGYuLd4WoHdNIceVQ0AFaPW2xtxulLlQXTQFi1nuoi9NwVtagtjt2E1NPR6xchX8NTV9dlzrjjUjg2BrC9lJ2V00hTyX4UMI1g1joSAie4nI18BqYLWILBeRPQd/arsm4dLZ//vf/xg/fjz77bcfF154YZfS2dOmTWPfffdlr7324vLLL+9TI/jtb3/L66+/3us+xx13HIcddliXdQ6HgxdffJGf//znzJgxg5kzZ/Lpp58CcPvtt7P//vszf/58pk6dugNXrNH0D48/yHvfbOAw6zdY9joVLNaIptCXoxli5yoEQ0Ha/e1Gg50ooSAiWHNy+tQUfKVlOMaPR2y9G0/sBT2FQmDr1oSqo0Zjzcgg1NxiRFS5GwipEA2ehi4lLqxDkKMAiZmPHgJuUEp9ACAihwIPA/MGb1q7NjurdPaCBQtYsGBBj/ULFy6MvJ85cyYfffRRj32uvPJKrrzyyl6PBWhvb495bs3uS2OHj9Agmg3fX1vLXN/n2Bx+2MvocNbqbSXJmoTDGr9MdKTUhaeBYrra/tv9xu/U4fZjTUvtss2Wnd2l/lEsvKWlJO3V97Ou1YwG8kf52vw1NaTMndvnsV3GyczAV7GRnKSprGpYRau3lUAoYJS4CIUINDQMmaaQiFBICQsEAKXUhyKS0tsBGo1m1+TCl+7iy/p36Cy/flDP86zrC1TGGKRoFmD4A3ozHUHvRfHCJiWb2x/JUQhjzc3ptf5RyOPBX1lJxkkn9Tlvi8OBNSsr0oFNBQIJd1zrMk6kUqqhKYQT13JcRrIdwSC2nOErFMpE5P+AJ83l8zAikjQazW5ERX0HS7eswJZew60nTsEqg1NZ3+lr4oAPVyB7XgOmHb7F29Kr6Qi6agrdafUbQsHS6e1iPgKz/tG69XHH9ZWXg1J9Rh5FxisoiOQqBOrrIRRKqONaNNYoodAZ6KSyzUhuzU3KjcpRGL5C4WLgd8DLgAI+NtdpNJrdiLvfWYfNbjh8jxuxmYI4xel2mPX/BRWMmI7A1BT6EAqZzkysYo2rKViDCovXh6W7+cisfxSdWxBNOPKorxyFyHj5eRGfgt9MBrX316eQnoHyeskV45rXNa0DzLpHZVvMeQ9ToWC2y7x2J8xlwIj35WsGHh2yunuwsrKFxSu2MmnKFqqBhmfPoMDXd72g7SZ3MhROjyy2+loZmTqy10MsYiErqWe3MjAij5LM9FprN03Bmp2D8vsJtbVFKpRG4y0tAYsFx7hxCU3dlp+PZ41R/i2sMdi2Q1MAyA0Y4d/rmwxNxqh7tMIYc7gKBRH5L3CGUqrZXM4CFimljun1wCEiKSmJhoYGcnJytGAYZJRSNDQ0kJSUNNRT0ewgd729lr1cjdTiASw0HP4ryJwyeCcs2CtiOgJDKOzh2KPPw7KTsmNrCtEVUrv5FGw5hsYTaGiIKRR8JaU4xozBkmAvZHt+AcH6BpTfH5W41k9NwayUmmUWxVvftJ4kaxIp9hQa68J1j3Z+iQtIzHyUGxYIYGgOIhK7PdIwYNSoUVRWVqL7N+8ckpKSIlnYml2TJRvq+XhDPS8Xv8uFZgZtQ24xTDy+z2OVUvy7/N8cNvqwuH2PE6HF29Knoxni1z9q87VFlc3upilEspobYfz47ofiLSvDEaexTixs+fmgFIGGBiNxLTm5X4lrsE1TSPcan/em1k0UpRYhIgTq641kuJQdbna5XSQiFEIiMkYptQlARMZi+BaGJXa7nfExvniNRtOTUEhx53/Wsk9GB5n1/4Ei43mvtxLV0ZS3lvOLj3/Bj/f6MdfNum675uAP+XEH3H36FMBwNm9q29RjfauvlRSfBQj2MB/ZeslqVn4/vo0bSTviiITnawvnKtTUGIlrCXZciyYsFFLdxq1Uobq14cwdMktHIoU1fg0sEZEnReQp4CPgl4M7LY1GszP496qtrKxq4U8jPqDBuu120Fs10mhqOgyb+gvrX6DT39nH3rEJh5MmJBSSjBDO7r6sNl8bOSEXQM/oo17qH/k2bYJAIG4LzliEs5r9tbUEtm5NuDpqNGGhIG0dpDmM6kFd6h4NkT8BEitz8R9gX+A54FlgllLq7cGemEajGVz8wRB3v72OuXl+Jmx+idpio/iiRSwJawrh+PpWXyuLyxZv1zwidY/i9FKIJtuVjSfoiXQri4zhbSUraPi2LKldS7RZwzWKYmgKfbXgjIU9qtSFv6YGW4J9FKIJl88ONrdEMrXDIbfBIax7BL0IBbNkdgaAUqoeo0Lq0cAFIpKYR0aj0QxbnvtyMxUNndwx8n9IyE/92AMAKM4oTlhTCJeynpg5kSe/e5KQCvV7HpG6RwlqCtDTvNXqbyXLH+661jW3Vmw2rJmZMbOavaUlADiLEzc5W3NywGolsHUrgdra7dIULCkpYLVGchUgSlOoqx+yHAXoXVN4HkgBEJGZwAvAJmAG8I9Bn5lGoxk0On0B/vreBo4YI4wpWwTTz6BWQrhsLsakjUlYU6h315NsS+aS6ZdQ0VrBkqol/Z5LpO5RIo7mOAlsbb420gOGi7S7TwHMrOYYPRV8pWXYR47Ekpy4U1csFmx5ebhXr4ZQaLs0BRGJFMULC7pcVy7K5yPY3DwkHdfC9OZodimltpjvzwMeVUrdIyIW4JtBn5lGo+lCaUMt937yBuOT+u6x0Rfratqoa/Py+z2WILVuOOhG6tY8Qq4rlxxXDt/UfZPQOHXuOvKS8zh67NH8edmfeeq7pzh41MH9mksixfDCRIriubve4Fu9raT5bWC1Ii5Xj+Ns2TkEYnRf85aW9ivyKDJefj6elauAxDuudSec1ZydZAiAXFcugUbjuobSfNSbUIh2fR+O6VxWSoV0/L9Gs/P51TsL+c73JG+W2lD+Hb9pXLRvJiPWPgnTToK8KdQtryPPlUeOK4cmTxOBUACbpfcAxbrOOnJduditds7Z4xz++tVf2dC0gUlZkxKeR9jRnJBPIVz/KIamkOJL79J1LRpbbg6e79Z0WaeCQXzl5aQccEDCcw1jL8jHs8JMMtsOTQHMSqktLeS4jHyQXFcuga1miYshylGA3oXC+yLyPLAVyALeBxCREcAu15pTo9mVqW3zsKK6Els2/OuS0Rw+5vAdH/TDO+C7Njj4JsB46p+aPZWcpBwUimZvc8TOHY96dz3TcqYBcMbkM3jw2wd5as1T/G5e4i1XWnyGTyEchdMb2a6eRfGUUkbymi8da0rsWp3W7JzIU3gYf1UVyuvtV+RRGFvetlStHdEUAnV1jE0fi81iY0TKiG3ZzMPUp3AdRr2jCuBApVQ4570QI0xVo9HsJB7/tAJlaQOgtLl0xwf0tMLn/4Apx0PhXoDx1B/WFCCxXIU6d11EcGQ4MzhxwoksLl2csE8CDE0h2ZaM3WLvc1+7xU6GM6OLpuANevGH/CR5FJa02ILFlpNNqLWVkG/b86y31Iw8SrDmUZfxzAgkSU6Oe86+sGSkE2xp4eixR/PGyW8YTYTqw5rCMDQfKSMQuEdTAKXU14M6I41G04V2b4AnP9vIiPFu6oHSkregYweV9S1fg6cFDv4ZAB3+DjoDneQl58WN8OlOh78Dd8BNXvI2U8e5087l+fXP88L6F7hixhUJTSWRstnRhNtyRh8P4PQEsaTGNkFFZzVbzF7KPlMo9NWCMxa2AqOsRX87rnWZU0YmwZYWrBYro9KMqgBBUyiE5zsUDE5tXI1GM2A89+VmWj0BJssm6hWU1q+GFe/t+MBTT4CifYFtoaVdNIU+wlKjjwlTnFHMgUUHsmjtIi7e6+Jem+aEafW19tqbuTvhHgRh2nyGBmX3+LFkxzYfRWc1202h4C0pxZaXF7MeUl/Y8o1r7m/No2isGRmE2ttRgUCk41ugrh5rRkbCdZgGAy0UNJphjD8Y4tEl5Rw21k6VvxlsVsqTUwneXI7VYt2xwaNs+OEktP5oCuFjuvsdzp92Ppf/93LeKn+Lkyb23bim1ds/TSE7KZu1jWu3HW9qCrZOP9bUeOajnlnN3rKy7TIdAdhNTaG/1VGjCWc1B9vasEUS7IauDWeYRMpcRBCRLBHZe7Amo9FouvLmiq1UNbv5Vd4nNFqETHsa3qCPLYEOSMrYsZdl279/+Kk/35VPij0Fp9XZp6ZQ7zZMHdGaAsDcEXMjyWyJlFZPpJdCNOFSF2HCmoKl09OjxEWYsDkmYOYqKKXwlZZul+kItvkUwlrH9hCulBpsbo6sM+oeDV3kESQgFETkQxFJF5Fs4CvgYRH58+BPTaP5fqOU4sGPytgjz0nhxmfxWizMGWmET5Y0lwzouaI1BRHpceONeUzntmOiERHO2+M81jWtY1nNsj7P3ertp1Bw5dDub8cbNMqihjOipdPdI5s5jC3biFoKNhiCLFBTQ6ijA0eC3da6Y01PZ+Q9d5N51lnbdTxs0xRCLS2RdYEhLnEBiWkKGUqpVuBU4Aml1P7AkYM7LY1Gs6SknjVbW/nd+O9oNG/QswtmA1DaMgARSFHUddaRZE0i1W48aee4ciKaQNxj3HU4LI6YN/Tji48nxZ7C2xV9l0lr9bVGchRUqO8yGd0T2Np8bViDCrw+rHEigSwpKYjLFdEUwjWPnBMm9nm+eGQcfzz2gu3vIhAxH3UXCkPoZIbEhILNzE04E9i+ilcajabfPPi/MvJTHcze8hSNecbNa1z6OPKT8wcmLDWKWndtREuA+H0LoglnM8eKvkmyJTEqdRTVHdW9juENevEEPaQ70ml6/nlKjjiSUEdHr8d0T2DrrcFONLbsbAKmpuArCwuF7dMUBoJwD4awUAh1dKA6O4c0RwESEwq3AW8DJUqpL0WkGNjQ10Ei8qiI1IrIqqh12SLyXxHZYP7NMteLiNwnIiUiskJE9t3eC9JodgdWVbWwpKSeW/bYiqV+HQ1TjwOM5K2JmRMHXCjUu+u7+Aa6R/jEPKazvtfktsKUQmo6a3odI7pstmfVagJbt9L8yqu9HtM9j6LN1xZVITW+UIiuf+QtKcWakTGkoZ/WzEzAqJQKRHIUhrLuESRWOvsFpdTeSqmrzOUypdRpfR0HLASO7bbuF8B7SqlJwHvmMsBxwCTzdRnwQGLT12h2Tx7+uIwUh5VjWl+AtBE05k0GjCf44oxiylvKt6siaTzqOuu6+Aayk7Jp8jYRDAXjH+Ou6+FkjqYguaBPTSG6GF6gthaAxieeQAXjnzcsFMK5Cq2+VnJDRkG7eD4FMOsfmVnNRs2jiUPasjds6gprCtsS14a/o/ku09FsF5H3RKRORM7r6zil1EdA97KEJwGPm+8fB06OWv+EMvgcyDRNVhrN947Kpk4Wr9jKdXu5sW38CPa/ggazFERmUiYTMyfiCXqoaq8asHPWdtb20BRCKkSztznuMdHZzLEoTCmk2duMJ+CJu0+kl4Ijg0BtLZKcjH/TJto//DDuMd3NR22+NnKCRhG8eD4FAGtONsH6eiPyqKQEZ/HQmY7AKOltSUsj2Gp8BgGzN/NQm48SyVM4Wil1s4icglHy4lSM7mtPbcf5CpRSW8331UA486MI2By1X6W5bivdEJHLMLQJxowZsx1T0Gh2HqtrNnPe4kvw1h8JHdMTOiYQVAhwbugNcKTCrAU0fPt3Mp2Z2C12JmQaYZRlzWWMThu9w3OMzmYOE53AFn4fjSfgoc3X1iPyKJqCFOPfu7azljHpsf9XI70UnOn462pJP/poOpZ+QePCx+O2yHTZXCTbkiPmo1ZfK+MDTqAPn0JOLoGmJoINDQRbWrar5tFAY1RKbQYYFiUuIDGhEN7neOAFpVTLQKhcSiklIv3u9ayUegh4CGD27NnDtle0RuPx+7jozZ8SsG1hzJi1HJR+fMLHzs7qJPnd12C/y8CVSaOnMfKEXJxpPOGWNJdwyOhDdniesTKTuySwZfU8Jl6OQjQFyYZQqO6ojisUwppCmiWFzoZG7CNHkH3e+dTedRfu1atx7blnzOOifR5tvjYyArEb7ERjy8mGQIDOr74C+tdtbbAIl88Gow0nVmvE1zBUJCIUFovIWsANXCkieUB8fbB3akRkhFJqq2keqjXXVwHRjzyjzHUazS7Lgldvx23dQI6jiFa1hp8fN7nPUtQR3vkNKAX7G/WDGtzbntjTHenku/IpaykbkHlG5yiE6avURVgo9GU+Anp1NocdzantATpDIWz5+aQffzz1f/87jY8/TtFdd8U8Ljo6qs3XRro/E+jLfGRcU+fSLwGGiaaQTsh0NAcbGrBlZyPWHcxU30EScTT/ApgHzDYrpXZi+AC2h9eBC833FwKvRa2/wIxCOgBoiTIzaTS7HH/77DVWd77KGPvh/HLu9bT521hVv6rvA8GoYLr8cdjzZMgaC9BFUwBDWxioBLY+NYVYx8QQJN3JTzZi+HtzNoc1haRmo+eyLT8fa1oaGaedRuu/38JfUxvzuBzXtqJ4RoMd40baW/SRLSIUlmJJTsa2A9nIA4UlWlOoG/oSF5CApiAiycBVwBgMW/5IYAp95CyIyLPAoUCuiFQCtwB3AM+LyI+BjRi5DwD/Bn4AlGAInYu241o0mmHB8qpSHlrzB+xqFM+cfht8+DsswKdv38BMawLJTh214G2FuT+JrGpwN0Ru1GD0RH5pw0uEVAiL9KtaTQ9i3eDTHenYLfa4mkJYkPSmKbhsLjKdmb1qCi3eFtLsaYTqjfOE+xRkX3A+TU89RdMzz5B//XU9jstOyuarmq8IqRDt/naSfQI2G+J0xj2X1cxq9q5fT9L06UMaeRSmq/lo6LOZITHz0WPAcgxtAQyzzgv0IRSUUufE2dTDe2SW6b46gbloNMOaDq+Xy9++HkTx9yP+QsaXD8CXj7DXmPF8Emrkqg5/34MAzLkkUsHUF/TR5m/r4vAtzizGHXCzpX1LpOzy9hLOZk6zbzO9iEivuQr17nqsYu2ivcSir7DUcNnscDhquKaQY/Ro0o48guZFi8i94nIs3Vps5rhyaPY20+xtRqFI9iqscbquhYm+4W5vzaOBxpqRSbC1FaUUgfp6nJMnD/WUEhIKE5RSZ4nIOQBKqU4ZDiJWoxmGnPfK/+G1lnP++P9jnsMHH/0Jpp/J3OIZPLzyYVouWpxQ28lowmaS6BvwxEwjw7mspWzHhYIZWtr937q3rOY6dx05STl9ail9JbCFi+EFqmrBYjGcwSbZF15I23/fpeW118g6++wec1MoNrVuAsDpDfVqOgKzrITFAqHQdtc8GmisGRkQDBJqayPQ0DAsNIVE9E6fiLgABSAiEwDvoM5Ko9kFueuj5ynxvsVE53HcPP8UePUqcGXDcXcyv2g+IRViafXSfo8bvjFHm4+KM4yb2kBkNte56yL2/87ly6l/+GHjfK6cSH2hWMfkJvd9AytILqCmo3fzUbozHX9tLbacnEhfAQDXrFkk7bUXjY8/0aMmUlhrKm8pB8INdnoXCmK1RkxIO1LzaCAJ1z/ybdoMfv8uIxRuAf4DjBaRpzEykW8e1FlpNLsYn2xcw5Olf8IZHMdTp9wOH98DNSvhh/dCcjZ75e5Fqj2VT6o+6ffYYRNOuD8xGK0v81x5A+JsruvcloTW/Pzz1N37V1Qg0Gul1LrOOvJdfftHClIKaPI2xU1gi2gKtbUR01EYESH7wgvxlZfT/tFHXbaFBWRFawUANncAax9CAbZVSx3KmkfRhMtn+0qN73GoE9cgseij/2IkrC0AnsWIQvpwcKel0ew6VLc1cfW714Cy8OAx95LStN40G50BU43cBLvFzn6F+/HZls8S6jEQTdh8FK0pgOFXKGve8bDUaE3Bt7kSgkH81dWRCJ9Y5TTq3fUJaQrhsNTazthRROGy2YHauh5CASD92GOwFRTQ+PjjXdaHTWkVLRUA2Dq9fWoKYGQ1i8OBfdSOmdwGinDXt3DV1l1FUwBIApqAVmCaiBw8eFPSaHYdfIEAZ7x0DQFrHTfO+AOzCsfAq1eaZqOuMfbzi+azpWMLG1s39uscEU2hm1N3YuZESltKd6gGUqe/kw5/R0RT8FdWRv7mJOUQUIFILkEYf8hPo6ex18S1MNEJbN1RSnVxNMcSCmK3k3XuuXR+9jm+iorI+oj5qNUwH/XWYCca18yZpBx44JDnAoSxmOYjr9kveqiL4UFiIal3AmcBq4Hwr09hlLrQaL7XLHj1NprlW47Ov4IFs46ED++E6pVw9jOQ3PUmPnfkXAA+2fIJ4zLGJXyOBk+DUdrBntxlfXGGEYFU3VHNyNSR2zX/cDhqfnI+IY8nEgXkr6wkZ9a2BLbMpMxt8zGFVG/hqGF6S2BzB9z4Q34yJZVgY2Ok73F3Ug8+iLo//xnPmjU4xo0z1tlTcVgcbG4zq+N0urGk9S0U8n/60z732ZlYMzIB8EbMR0NbDA8S0xROBqYopY5XSv3QfJ04yPPSaIY9f/jwGVZ2vMIY+2HcfcyVhjD46K4uZqNoRqeNZnTaaD7b8lm/ztM9cS1MOAJpR/wKYbNOrisX/5YtkfU+U1OAnglsiZS4CNNbAls4cS3bbUQ9xdIUABxjx4II3rJtprJwyGwgFMAqVlRbe0I+heGGNcMwH/k3VyJJSVhS4pfp2FkkIhTKAPtgT0Sj2ZV4Y82XPFt+N67gBJ4/7U9Ygt64ZqNo5o2cx9LqpfiDCeYr0LXERTTRhfG2l/ANPj85H//mbTUp/ZVVcUtdxGvDGQuXzUWGMyOmphCpkNpqGCDscYSCxeXCPnIkvrLyLuvDgjLLkory+3sthjdcsSQlIUlJEAphy+0ZFjwkc0pgn07gGxF50GyEc5+I3DfYE9Nohisb6rfy609vxKJSePKH/yBl8xL4x1xDUzjhLz3MRtHMGzkPd8DNN3XfJHy+Rk9jDyczGBFIua7cHWrNGa0p+Ex/gmPChIhPAXpqCmGTUyLmI4DC5MKYYalhX0VaiyEg42kKAI7iYrzlXYVfWGjlhoyn60TMR8ORcFjqcHAyQ2JC4XXgduBTjMzm5UDf3bg1mt2QDq+Xc1+/mpCljf8389dM+d//wVOngljggtdgjxN6PX6/wv2wiY1Pt3ya8Dkb3A1xM4cnZEzYoVyFenc9TquTdEd6xIThmjEDX1Ul6c50bGLroSnUu+sRJKb2EouClAKqO3uaj1rMHhHJLUa4am9CwVlcjK+8oku+QlhohRvs7IrmI9gWgWTNHdrezGESEQqZSqnHo1/ELKar0ez+nP/Kb3FbN3B1yjyOf/syWPM6HPILuPJTKD60z+NTHansnbd3wkIhpEI0eZvi3oCLM4spbS7td5hrmNrO2kg2s7+qEvuoIhyjRxGsqwevj+yk7JiaQlZSFnZLYlblvjQFZ1MH2GxYs+LfVhzFxSi3m0D1NuES/kyyE2jFOZzZFTWFC2OsWzDA89Bohj1/++w1Nnj/zQ877Vz53VMwYm9DGBz2S7AnJTzOvJHzWNOwJpJ/0BvN3mZCKhRXU5iYOZHOQGefbS/jUe+u75Kj4CgaFYnh91cZfoUemkIfvZm7Ey+BLexTsDW2YcvLQyzxb0fO4vEAeEu3mZAiPoVwg53U+GWzhzOWzLBQGPrII+hFKIjIOSLyBjBeRF6Pen1AzzabGs1uzZraSh5a8/8Y6xN+21wHJ/8TLnwDcif1e6x5I+ehUHy+5fM+9w0/pcfVFMLlLrbTrxBuw6mUwr95M/bRo7EXmUKhspJsV2xNIZHIozDxEthafa0IAvVNccNRwzjM1pm+KL9C2HyUHjA0lt4a7AxndiVN4VPgHmCt+Tf8uhE4ZvCnptEMDwLBIBe/eQM2cXNfbRVJJ/0dZp4D2xkpMi1nGhnOjIRMSPGymcOEw1K3169Q564jLzmPYHMzoY4O7KOKsI8qAraFpfaIPuqjN3N34iWwheseBetq40YehbFmZ2PJyOgSlhppOuQ30q12WZ+CmaswHEpcQC/Ja0qpjRg9D+buvOloNMOPKxf/mXbLGm6pb6J42hkwbXt7TBlYLVYOGHFApORFb2GIEU0hjlDITMokOyl7u4RCOJs5z5UXyWR2jB5tmHKcTiMsdapR/yg8z5AK0eBuiJiclFK0f/AhKfvvFzfGPiwUuoelhuse+WvrSJ6zX69zFRGc48d3CUsNfyapfuPZ1tJL17XhzC6jKYjIEvNvm4i0Rr3aRKQ13nEaze7Eiys/4fOmpzi4U3GqJQOOu3NAxp03ch617to+E89ilc3uTrjcRX+Jbq4TFgr2UaMQEexFRZGwVH/IH7H/N3oaCapgRFPwrt9A5VVXUXXTzT0qmYYpSIkvFLIllVBLS6+RR2EcE7qGpY5IHUGaPY28oBmSuotqCrb8fBDBVjhiqKcC9G4+OhdAKZWmlEqPeqUppdJ30vw0miFjS2sjty/9NZlBK3+sq8Jy8j8hqX+9EOIxb6TRs6ovE1KDpwGb2Eh3xv+XK84wCuP1NwIpug2nb7MpFEx/gn1UEb6qyh4JbJFsZjNxzbt2DQDt779Pw0MPxTxPOIGtu/mozdvGCLfhJE5EKDiLiwnW1RNsNQRUij2FD8/6kHG2AsRux+JwJHjlw4v043/AuGefwV6QQFe+nUBvQuGV8BsReWknzEWjGTaEQiHOf+3nhKyN/K12Mxn7XwXjDxqw8QtTChmXPo5l1b2n/IRLXPTWzGZi5kTa/e29NrOJRURTMM1H1uxsrKaz1jFqlGE+6pbA1r2fs2f9esRuJ/0HP6Dur/fR/vHHMc8VKyy1xddCgdu4kSekKYw3nc1RfgWH1UGoo32XNR0BWBwOXDNnDvU0IvRWEC/a0Dk8io9rNIPIpuY6lm5ez8raEr6q/Yba0Odc1upjZvoEOPz/Bvx8e+TswTe13/S6T4O7oUsfhVgUZ25ruBOO9EmE6HIVLZWbu5STtheNItTaSrYZ7tldU4iYj9atxzFxIiP+8Hu8ZWVU/ewmxr/4Ao7Ro7ucK1YCW6u3lZx2Y759RR9BVFhqWXmXm2iorX2XNR0NR3oTCirOe41mt+HKN/7Ml3Uf4qUGrJ2R9UoJc7zJXNm8FS59oV95CIkyOWsyb5W/FXG4xqLB3RDXyRwmHJZa1lLG/KL5CZ+/zl2Hw+Ig3ZFOXWUVrr32imwLC4iMBm9kHuFjIMp8tG4dKfPmYXG5GPW3+yg/7XQqr7mWcc8+06WvcmFyISvrVkaWw2WzM9uNW0tf0UeROdntXcJSAULt7btsOOpwpDfz0YywYxnYWzuaNbsbL6/+jCWNj6EIMDZpLgdmX8SPJ93OP2b/jWWjzuCxrWuxHfYrI0ltEJicZTRp39C0Ie4+8SqkRpOdlE2GM4Oylv4VxguHoxIK4d+ypaumYIalJtW2YBFLF/NRmiMNp9VJoKmJQF0dzilTACNyqejuP+Fdt46tt9zSxcfRPYGtM9BJUAVJbw0gDkekr0BviM2GY+wYvN0K44Xa27Huoolrw5HeQlKHRxcKjWaQuHvpX0Gl8O8zn6XAUwtrF8Pqv0GVaeefeCTMH7z6+1OyjJvpusZ1zCqY1WO7UooGT+wKqdGISMTZ3B/qOo0ktEB1NQQC2EdvEwoOU0AEq7aSlZIViYKqd9dH/AnedesBcE6eHDku9eCDybv2Gur+eh+uvWeQfd65wLaw1NrOWsakj6HFa9Q9SmnxYsvPT7g6qLN4At4NXYVosL0d+8jt6yeh6UmfTXY0mt2RJ79+nzbLak6z7EvBwqOgfp2xYeQ+hv9g6gmQN2W7E9QSIT85nwxnBuub1sfc3hnoxBv09qkpgGFCen/T+/06f21nLZOyJkUijxxRmoI1IwNLWpoRljozp4v5KCIU1hvzTpoyucu4OZdfjnvlKmruuIOkadNI3nefLs12xqSPiYS4OpvdCTmZwziKx9P2/vsovx+xG5nM2nw0sCTajlOj2W0IhUL87eu/YQ+k8Iuy18CZCsf9Ca5fDZd9CAf/DPKnDphAiBcqKiJMzpoc13wUq8RFvLHGZ4ynydtEk6cp4XmF6x75q8xw1G7OYfuoUfiqKsl15XZxNId7M3vWr8OaldWjhaRYLIy88w6sGRk0PfUU0DOrOVwMz97Y1i+h4CwuhkAA36ZNkXXafDSwaKGg+d7x4LL/4LaWcLU3SFJKgVHyev/LIGPgm7kH6uooOfgQWl5/Peb2KVlT2NC8gWAo2GNb9xIXbe+9x4a58/DX1PbYN9rZnAid/k7a/e1GH4XNm8FqxV7YNXLJMaooEpYazmoOm5zAMB85p0yJafqxpqWRPGsW7hUrgJ4JbGFNwdrQklDkUWROZlhquNyFUopgu44+Gki0UNB8rwiFQvxr1T9ICSRzQe16OOp34By8p8y6++4jUFdHy5tvxtw+OWsy7oCbyvbKHtvCmkLYfNTx+RcEm5tpeuaZHvuGw1ITFQpdO65VYh8xArF1tSbbi0YZlVKTsmnwNNDqa8UX8pHrykUFg3hLSnqYjqJxzZiBv7KSQENDjwS2Fm8LTp+Cjs6EIo/COMYbYanhchfK64VAQAuFAUQLBc33ins+eRmfdSM3tjVhHzUHpp/Zr+P9W7fS+u9/J5Q97FmzhuYXX8KSkkLnF0sJeb099glHIMXyK4RNNmHzUdjB2rxoESG3u8u+I1JG4LK5EnY2d+nNXFnZJfIojH3UKJTHQ6EnCW/QS0VrBbCtdadyu7s4mbvjmmFEbbm/NbWF5IJIAlurr5WsdmO//piPrKkp2AoKIglsoXZjEOsu2nVtOKKFguZ7gy8Q4Jn1D5Hjd3JKc41Rx6iXGv7d8dfWsvH8C6i64UaaX3yx132VUtTccSfWjAwKf/c7lMdD55c9s5cnZE7AIpZehUJWktF8xrt+Pc5JEwm2tNDy2mtd9rWIhXHp4yhvKe8xTiwimoIrH19lJY7RsYSCEZaabwQKsa7RcMbnunLxrA9HHk2Je46kPfcEqxX3im8BI4s72nyU22589v0RCgDOCcV4y43rDAsFrSkMHFooaL433PHxIgK2rfy8qRrbzPOgqGcYaDyCLS1svuRSAo2NJE2fTs3v/xC5Mcai/f336fziC3Kv+QlpRxyOOBx0fPxRj/2SbEmMTR8bueFG0+BuIMOZgd1iJ9DQQLCxkczTTydpr71oXPh4jwJ04zPGJ2w+CmsKOaQSbGjAPmp0j33C0UiZjT4A1jauBYwSF95160EE58QJcc9hcblwTpmM+1tDKBQkF3RxNI/0GAmB/RUKjvHF+MqMWk/BNlMopGihMFBooaD5XtDp9/JS2aOM9tk4OiBw5C0JHxvq7GTzFVfiKy9n9N//xuh/3I8lNZWqG27oYcYBUD4fNXfdhWPCBLLOOguLy0XyfvvR/vGSmONPzpocU1OITlzzrt+WE5B94YX4Kipo/6irkCnOKGZrx1Y6/Z09xupOvbseh8WBq9aMAjK1gmjsRca6tHpjvLDgykvOw7t+PY6xY7tkLcfCNWMGnhUrUcEghSmFNHmb8Aa9tPhaKOxH3aNoHMXjCbW3E6itI9ShzUcDjRYKmu8Fv3v/CUK2Om5u3Ir1kJshNbEbkfL5qPzpdbi//ZaRd99Nyrx52PLyKPrTXfhKy6j+wx96HNP49DP4N26i4Oc3R5y3qQcdiK+sDF9lVY/9p2RNoaq9inZfe5f10SUuIkJh0iTSjz0GW0EBjQsf77J/2Nlc3tq3CanWXWuUzK4y5tO9VhEYT/rW3Fyctc2A4fdw2Vyk2FPwrF/Xqz8hjGvvGYQ6OvCVlW3rq9BRQ6u3ldwOK5KcHLcPQzycUV3YtPlo4BkSoSAiFSKyUkS+EZFl5rpsEfmviGww/8bv4q3R9EEoFOKbrRXc++mrLHjlD7xV9RhTvYqDk4tgv8sTGkOFQmz55a/o+PhjCm+9hfRjjo5sS5k3j5zLLqPlxZdoeWNxZH2gqYn6f/yDlAMPJPXgg7ftf5DxvmNJzyqikXIXzV3zFaI1Bc+GDVizs7Hl5iJ2O1nnnUvn55/jWbs2sn8kLDUBZ3NdZ53pZN4MENPRDOAoKkK21iMInqCHPFceoc5O/Js24+wl8iiMa8YMANwrVnRJYDMczQp7Xl7C2cyRORVvC0uNmI+0UBgwhlJTOEwpNVMpNdtc/gXwnlJqEvCeuazRJExteytnv/BbDnjsDGYsnMv57/yQRzb8H8taniMzFOSW+hosx94Btr7r7iulqPn9H2h9803ybryBrDN7RinlXfMTXPvuS/Utt+CrqACg/m9/J9TZScHPb+6yr2P8OOxFRbR/FF8orG/sakKKLnHhXb8B56Rt/aCzzjwTcblofPyJyLoxaWOwirVPZ3MwFGRNwxomZk7EV1mJJTkZa1bsZzD7qFEEqqoizu5cVy7ekhJQiqQp8Z3MkeseNxZLejrub1d0SWBr8baQ0Rrst+kIDHOTJTkZX1m51hQGgeFkPjoJCOvDjwMnD91UNLsiP3njflZ3vkJAuRmTdABH5V/Fb2bez0d73sRHtZXsNfYwmHRU3ONVMIhnzRoan36ayquupumZZ8i++GJyLrkk5v5is1F0z92I3U7lDTfg+e47mp57jqyzzuxyAwcjeznl4IPo+PxzlM/XZVthSiFpjrQufgV/0E+br42cpBxUKIS3pKSLucaakUHmKafQungxgTqjcqndamd02ug+nc0lzSW0+duYVTDLyFEwu63Fwj5qFP6tW8m1G0IhLzkPzzrDt5CI+UgsFlzTp+P+9tsuCWytvlZSW/3bJRREBEex4WyO+BT6aYLSxGeoah8p4B0RUcCDSqmHgAKl1FZzezVQEOtAEbkMuAxgzJgxO2Ouml2AbzY3s6plCbmp4/jowjeMlY1l8OaNUPo+jNwXjr+nx3GetWtpe/c93F99hfvbbwl1dADG02jOpZeSd8P1vZo37CNGMOL//ZHKq65m43nnY0lOJveaa2Lum3rQQTQ/u4jOr74i5YADIuvD5S7WNW2LQAqHo2a7svFXVaE6O3FOmthlvOwLzqfp2WdpevZZ8q69FjC7sPUhFJbVGKGxswtm4658GHsv/0f2UUUQDDLWm8p6zMij9RuQ5OS4JqfuuGbMoP6f/8TpDZHhzGBr+1bavK24mvvvZA7jnFBMx9Ivce4xFXE6kV2069pwZKg0hQOVUvsCxwFXi8jB0RuVkRkUMztIKfWQUmq2Ump2Xl7i6fGa3ZdQSPHr1z/C6trEmdOOh4APPr4H/jEXNn9p1DW65F3I7OpM7Vy+nIozz6L+/vsJNDSQfuIPGfmnu5jw7rtM/N+H5N94Q0L27rTDDyf7wgsIdXaSe+WV2OKYYlL23x/s9rgmpA1NGwgpI8w0kriWlBNJWkvq9mTuGDeO1MMOo+nZRYQ8Rknq4sxiNrduxh/yx53v8prlFKUWUZhSaOQoxIg8ipzDvPGPbjOa7eS6cvGuW4dz0kQkwRwP14y9IRTCvWo1BckFlDSXkORVWL2B7RYKjvHFBLZuJVBbp01HA8yQaApKqSrzb62IvALsB9SIyAil1FYRGQH0LPCi0cTg+WWbWd/+GUmpcEJyITx4MNStgT1ONBLU0nuWVfaWlbP5qquxFxUx9sknsHUr6tZf8n/2M1IOPIiUeXPj7mNJSSF51iw6Pv4Ybr6py7YpWVPoDHRS1V7F6LTRNLqNukfZSdl4138OgGPipB5jZi+4kE3vv0/L66+TdeaZFGcUE1ABNrdujkQjRaOUYnnNcg4sOpBgYyPK7Y6ZoxAmrA0UtgrkQZ4pFNKOPjruMd1J2tvMbF7xLYXFhXxd8zXZkWzm7Xuwc5hd2DwrV2LVQmFA2emagoikiEha+D1wNLAKeB240NztQuC12CNoNNto7vRx53/WkpW3hslJeYxbdCH42uGcRXDWkzEFQqC+ns2XXYbYbIx+6MEdFggAYreTetCBiLX3NiSpBx2Ed8MG/Fu3dlnfvdxFdIkL7/r12IuKIv2To0meMwfntD1ofPwJlFJ9FsYrby2n0dNo+hPCkUfxNQV7YSFYLOQ0GQX78jptBFtaEvInhLFlZWEfO8bwKyQX0OZvI6sfHddiEQlLrajQmsIAMxTmowJgiYh8CywF3lRK/Qe4AzhKRDYAR5rLGk2v3P3OOtoCDbitpRxdXQrFh8FVn8OU42LuH+rsZPOVVxGor2f0Px+IGZ8/mKQcdCAA7Uu6JrJNyJyAIJEIpOgKqd4NG+LehEWEnAsvxFdaSucXXzA+w3iCjicUltcsB2BWwaxtfRR6+QzEbsdeWEhmo1G3KW+L4XNJJBw1GteMGYZQcBlCIKvNWG/bThOwfcwYMAWwFgoDy04XCkqpMqXUDPO1p1LqD+b6BqXUEUqpSUqpI5VSjTt7bppdi1VVLTz9xSaOnGIUXDvaXgBnPm70R4iBCgapuvFneFavpujP9+CaPn1nThcwks9shYV0dPMrJNuTGZs+dpum4DYqi7qUDW95RY9opmjSjj4aMX0VyfZkClMKexUKua5cxqSN2dZHoSi+pgCGCSmvSfG3w/9GVpVxN+/u3+gL194zCNbVM6rTyICOmI+2UyhYHI6Iv8Ois5kHlOEUkqrRJEwopPjta6uYlNyJ272YiYEQ43/0IiTF7vWrlKLmD3+g/YMPKPjNr0k7/PCdPGMDESH1oAPp+OwzlL+rM3hS1qSIUAgnrnnLKyAQ6NVcY3G5cM2eRccnnwDEbc2plGJZ9TJmFcxCRPBt3ow1L7fPUhXhXIVDRx+KZ906bAUFWDMz+3Xd4YqpBRuNshpZ7QpSUvqdzRyNY4JRd8mq6x4NKFooaHZJXvqqktWbavl7+r18bVMcPfk0yIwdWqlCIRoe/hdNzzxLziU/JvtHP9rJs+1KykEHEWpvx/3NN13WT86azOa2zXT6OyMlLrbVPIqvKQCkzp+Pd906/LW1FGcUU9FaEYlkCrOlYws1nTWRftD+yiocvTiZw9hHFRGoqyPk8ZiNdfqnJQAkTZmCOBxkbDAK4mW1g307ncxhnKazWZuPBhbdo1kzbPH6/dS2e7BZuv5MPf4Qd731HY9m/IvlgY0oyeLo6Rd22Sfk89H5xRe0vfsebe+/R7CunvQfHEfeDTfszEuIScrcuWCz0f7xEpLnzImsn5I1BYViQ/MGGjwNjEwdifebDWCz4Rw3rvcx58+Hu++h87PPGD91PO6Am+qOakambnO0L6velp8A4N+8GdesvivFhs00vo2b8JaVkWr6RfqDOBwkTZtGaE0ZjIWcdsE+KmYqUsKEu7Bp89HAooWCZliyrm4LZ79+MV6/0LnxclD2Lttvti1ifnAJ/5qwPxMcyUzInIBSira336btnXdo/99HhDo6sCQnk3LwwaQdcQTpxx6TcGz9YGJNSyN55kzaP/6Y/Buuj6yfnL0tAqnR08j03OlGD4Xx4/tMznJOmYI1J4f2JZ9QvL9RkqOspayLUFhes5wMZ4bxWfn9+KuryYjRR6E74bDUjiVLwO/HmUB5i1i4ZsygadEisk9IJ7u9ebtzFMKEw1J1SOrAooWCZtjxXU0V5y5egLLUYHUpzh/3O67vcEa2iwqS1bae+n1+xPLmT7h8slHgrvm556m+9VasOTmk/+A40o48kuQDDsDidMY71ZCRctBB1P3lL/hrayNhmSNTRpJqT2Vtw1qaPE1mjsLHuGbO7HM8sVhImTePjk8/ZfzvjLpLZc1lHFi07al+ec1y9s3fF4tY8G2thFAIe1ECQsHcp+3994HEylvEwjVjbxoff5y9m0eS2da03eGoYZyTJmPNy8UxcWLfO2sSZugfmzSaKFZsqeS8Ny7AYqnmnzXVXEgGryYFWJ6fTXbhOLILx5E1YgLMu4b3Jh+EQnH02KNRwSANjz5K0vTpTProf4y4/XZSDzlkWAoEgNSDDwKgY8knkXXhchdf1nxJUAXJU6n4t2xJ+CacMn8ewYYGUjbWkenM7BKBVNtZy6a2TVH+BDPyKAFNwZaXizgcuL/+2jBlmX2S+0u4YuohlanYgmqHNQVragqTP/6YtEMP3aFxNF3RQkEzbFi+eTM/fvM8LNYa/l7XxAEnPMBPz/uA6bnTucXSTOUP74YfLTJeR/+e/256n3Hp45iYOZG2/76Lf9Mmcn784z4TyIYDzqlTseXl0fbuu13WT8qaFKlyWlBjFM5LWCjMmwdAxyefUJxR3KVa6lc1XwHb/AmRHIUE6heJxWKErYZCOIuLt7vOkG3kSKy5uey7NmAs76BQ0AwOWihohgVfbNzINf85G2Wr597WEHPPWwx7nYbdaueug+8C4OaPbsYfNMI4G9wNfFnzJUePM8otNDzyCPYxY0g76sghu4b+ICJknHoq7R98gG/jxsj6cGYzQFaVEb7ZV+RRGHt+Ps7Jk2n/5JMerTmX1Swj2ZbMlGzDH+CvrAS7HVtBYs7esF9he01HYFyza8YMvGvWAFooDFe0UNBsF26/j++qa6mo79jh13++W8/P3zkDv62ZP/uzOPDHH8LImZFzjUobxa3zbmVl/Uru+/o+AN7f/D4hFeLosUfT+eWXeFauJOfii3YJLSFM1rk/Qmw2Gh/f1kEtfNMGSN5UjyU5GfvInqU64pEyfz7uZcuZmDSaZm9zJDN6ec1y9snfJxLJ5avcjH3kiIQ/r3ApjO0JR43GZdZBAi0Uhiva0azpFy1uH3e8/xwfbL2fgMXDiVXjKPQmJ3SsoLATwIkfh/hxEqDT6uGxokY6bCHuStqbg89fCLaefoCjxx3NWdVnsXD1QuYUzuGdincYmz7WiO1/5Aqs2dlknHzywF7sIGPPzyf9xB/S/PIr5F5zDbasLCZlbtMKbBVVWPpRjRQModD42GNM2mhoVGXNZUimUNJcwvHFxwNGuK77629Imjo14XHDZqb+ZjJ3J5zEBtufzawZXLRQ0CTElmY3T731JEuaH6U8uYNxyk8gJLw5agP31bUyyxu/VHM0IYudkMVB0Opkvd3OjRkhOizwp5Enc+hRv4eoUtXesnLEIjjMGP2b5tzE17Vf8+slv6bN18bFe12Md8MGOv73EbnXXoMlKWkwLn1QyVmwgJaXXjZ6Ilx1Fcn2ZEanjWZLWxXBknKSjzyiX+Mlz56FOJ3krqyEsUZYaouvBSDiZG5+7nkC1dVk/eH3iY+73344xo+PVDzdXpL2mg4iWNPTh20QwPcdLRR2Aeo72nH7VI8kLoBAUOENhPD4g3j9QfzuVppbN7GlvYQxSWNJsuzYjVKFgmxa8Sbf+v7NOxkKV5LiGhnNRYfdRFPhnlz67uX8xL6Few+7l/lF8xMe9/Otn3P9B9eTbE/miSP+0cVsEmxro+5vf6PpqaexJCczZuFCXHvtidPq5E+H/ImzF59NUAU5auxRNP7pMcTlIuucc3boOocK56RJpBxyME1PPU3Oj3+MxelkavZUHM2dBJtqeq15FAtLUhLJs2fj/+IbXBNclLeUU9FagdPqZM+cPQl1dlL/4IMkz5kTcUwngmv6dCa89e/+Xl4PrKkpOCdOJE67FM0wQAuFBAkEQ3gDob533JFzhBSbGjopraqmZeMKAjVroPU7nixYQ6clxMxOC7M6LEzvtOBUxhO1Az/pdJJk6eTLFMV/UpP5wpVEUASbUszweJnr8XCA28OeXl9CX7gCaq1W1jnsrHI6eTY9lVaXlRPS9uTGw+8gO8sIScwHHjv2Ma747xX85P2fcPfBd3PE2L6fbN8ofYPffvJbxmWM44EjH4g0dFdK0bp4MTV33UWwvoHM00+n45NP2HzJJYx96kmcEydSnFHM/zvw//FR1UdM8GVSungxWeecE7exza5AzkUXs2nBAlpee42sM8/kp/v+lIb2D4A7t8uxmzJ/PrV33cXeoamUtZTR7G1m77y9cVgd1D/9OMH6evLu+2tCDYQGg7wbrkf5EtMsNTsfMZqc7ZrMnj1bLVu2bFDPsbXFzZ3vPs6nzY9gDwnFLfmMb83HFdwWlmcnQLJ4cWG8ksVL0OKmw+4j128hOZTYP5+VEGMtNYySegACwKWFBXyb5OTggIulNg9tonAomB1I4sCAC6fY+MDh4wtLK34UBZYUDkmewh4p4/jOU8XX7gpKfTUoIMXiZK+kUWRaU0gSO06LHafYSBI7drFRE2im1FdLmbeW1pA7Mq9ZmdP4xUG/Y2q2YYNWoRBt77yDv6qKjFNPpTPFylXvXsWq+lXcPv92fjjhhzGvTynFwysf5m9f/439C/fnL4f9hTRHGgDeDRuovu12Or/8kqTp0yn87W9xTd8L38aNVJx3HoIw9pmnu5R5rrnzLhqfeIIJb7/da/ew4Y5SiorTTifkdlP85mLEYqFh4UJq77iTSZ9+gi07u1/jedatp/ykk/hkwUyeGFdFi6+Fy/a+jCsmXEDJkUfhmrE3Yx56aJCuRrOLEPempDWFOKwoq2LJfx/lc/+rfJ3mY0rQR7IK8XWem1W5FRzW6ebUtnbmuj1YgTaLnS+TU/kiKYkvk+yU2MB8mCc9JIwKWigKWhgVsjAuYOUAvw1rj+/FSjBjfxpG7knmmL25p+EzlpW9zt/bfsgevlzsUydTUqB417+S9yrf59POGsDoznXGuHP4QfEP2NMxDv+mzYTa2zhl7jSsGRk0eZr4ovoLPt/yOV/Xfk25rxp3wI076CYQCkTO7rA4mJQ1iaNGzWZK9hSmZk9lctZkUuxGJUulFB2ffErtn+/B+50RVlh3/z/IOvtsHjj/j1y/4jZ+veTXNHoamZw1GV/QhzfojbyW1yxncdliTig+gdvm3YbdakcFAtT99T4aHnsMS0oKhbfeSuYZp0eiYhxjxzLmkUfYdP4FbFpwEWOfeRp7QQHB1laan3+e9GOP3aUFAhihmtkXX8yWn/2M9g//R9rhh+HdsAFrTk6/BQIYIazWvFwmrXfTNKIJMPwJjY8tJNTSQt5PfzrQl6DZjfheagpry77io5WvMNE+kkxr17opPncbgTWLabSu5J7cNFotFi5yTeHSqZfjHDuLCn81L5e9zuvl/6bJ20xhcgGFyYWsalhNQAWwW+wc317Msf/rIH1zE3XzJvP1/HxWuxrZ2LqR2k6jy+j8ovncedCdZDhjl3p+o/QNfvXxL/l/X09hwtvfgcUCIcN8ZUlJwTl1Kh3j8wmkOMlpDBDYtBnfpk0Em5q6jOOcNAnXvvuSPGtfXPvui72oqIvZwB/y4wl48AQ8ZCVlxfRbALhXrqT2nj/T+fnn2EeOJO+n15I0bRr1Dz1M65tvInY7aaefyl/32MhbHUtjjiEIl0y/hGv2uQYRIdjSQtX1N9Dx6adknHIK+Tf9LO5N0L1yJZsuXICtsJCxTz1J80svUXfPnxn/8kskTZsW85hdCeX3U3L0MThGjWLsk09QfsaZWFJTGPvYY9s13paf/4LGD97lR1d5sFrsfHTsYrYceyIpBx7IqPv+OsCz1+yCxNUUvpdC4Z8Lr2HphvepyBdsKUH29nqZ4fWxt9dLfiDI73Py+SjFznzPCK53H4pjydd4Vq1CXC7SDj+c9BOOx3nA/vyv9hNe3vAyrd5W5hTMZv7WNHKe+xDvsuVYs7Jw7b037Z98AoEAKQceSNaPzsEybw5vVLzJnV/eyYiUEfz1sL8yKaurM3F1/Wou+Pf53PBxBvt+XE3WBeeTf+ONeEtK8K5Zg+e7NXi++w7PunUojwfbiEIcY8biGDMGx9ixOMaOQVwuPCtW0PnV17i//ppQu9HVxJqXi3N8MfbRo3CMHoNjzGjso0djHzUKsVoJuT0or8f463ETbG+n+YUXafvPf7BmZZF75RVknn02lqisVl9FBfUPPUzL668bdfqPO5jgeSfiGFmEw+ogyZqEw+og2Z4cMRf5KirYfMWV+KqqGHHLb8k8/fQ+v7eOpUvZfOllOIqLCdbX45w0iTGPPtLv73+40vDYQmrvvJNxzz/HxgsXkHnG6RT+6lfbNVbLG2+w5aab+fkCK+l7z+SeVTNoXLiQ4tdfMx29mu85WihEs/npx2m/3ej26U2ysjnfQklugI0FQnUm7FVp5diNGaRsMmz7SXvvTdoRR+DfuoW2t/5DsKUFa0YGacceS8YJxxNsa6P+nw/iWbECW34+OT++mMwzzsCSnIy/ppbmF18wwgBra7GNGEHWWWdSecSeXP/1b+nwd3D7/Ns5ZtwxgJGpe/brZ3Lmmy0cuLSD7AULyP/5zTGdgioYRAWDXW7QsVDBIN6SEjqXL8fz7Qp8mzbhq9xMsK4+oc9LkpPJWbCA7Isv6rUipa+yioaHH6b55ZcRIPOM08m5/HLs3bJmOz7/nMqfXoeIMOpv93UpH90X7R99xOarfwJ+P6Mf+Rep8xOPeBruBNvbKTn0MBzFxXhWrGDE729PSFjGIlBfz4YDD+Klw5MoOP1sDrr+WdKPPZaRd+outxpAC4WuhNxuvOvX41m7Du+6tXjWrcezdg2qo9PYQQTXrH1JP/po0o46CvuIEZFjlc9H+6ef0rr4Tdreew/lNhyy9qIici69lIxTT4l5k1Z+P20ffEDTs8/S+dnniMuF85QTuGvCWpYE1nDxXhdz1cyruPzty9jvia847JsAOZdeQt4NNwxalEiosxPf5kr8lZvxV1ailMKS5MLiSkLCf51JOCdP6ld0j3/LFuoffIjml15CLBYyzzqLnEsvwZ6fT9OiRVTf/nsc48cx+oHt65Hc9uGHuJd/Rd4N1w9ZBM1gUfOnP9H4yKMAjHv+uS4ZwP2l7JRTCSY7SZ44mdaXXmbCW//e6T2pNcMWLRT6QimFv6oKX3k5SWaxsr4IdXbS9sEHiAhpRx2F2O19HgPgWb+exkceoWXxmyBC+QGjuW+PTQSK8jn1pWoOW6nIueJy8n760136puerrKLhwX/S/PIriM1G8qx96fj0M1IOOZiie+7RdfBj4K+upuTIoyAQYMryZTvUrrL27rtpWPg4iJB52qmMuPXWgZuoZldHC4XhiL+qioaFj9P8wgsoj4eqXKGoXpF79dXk/uTqXVogROPbvJn6B/5Jy+uvk33eeeTf9LNdqkbRzmbrLbfiWbmS8S+/tEPjdHz2GZsuuhhxOpnwzts9zHia7zVaKAxnAk1NND35FE0vvkD2j84l94rLh3pKg0LI5+vT/6ExfEAohdh2LGI85PVSctjhZJ52Gvk3Dn0bUs2wQgsFjeb7SKijA3G5hkUbUs2wQievaTTfR3bEJ6H5fqIfHzQajUYTQQsFjUaj0UTQQkGj0Wg0EbRQ0Gg0Gk0ELRQ0Go1GE2GXDkkVkTpgYx+75QKJFfnZvdDX/f3j+3rt+rr7T71S6thYG3ZpoZAIIrJMKTV7qOexs9HX/f3j+3rt+roHFm0+0mg0Gk0ELRQ0Go1GE+H7IBS+r81o9XV///i+Xru+7gFkt/cpaDQajSZxvg+agkaj0WgSZLcWCiJyrIisE5ESEfnFUM9nsBCRR0WkVkRWRa3LFpH/isgG82/irdN2EURktIh8ICLfichqEfmpuX63vnYRSRKRpSLyrXndvzPXjxeRL8zf+3MislvWKRcRq4h8LSKLzeXd/rpFpEJEVorINyKyzFw3KL/z3VYoiIgVuB84DpgGnCMi04Z2VoPGQqB7zPEvgPeUUpOA98zl3Y0AcKNSahpwAHC1+R3v7tfuBQ5XSs0AZgLHisgBwJ3AX5RSE4Em4MdDN8VB5afAmqjl78t1H6aUmhkVhjoov/PdVigA+wElSqkypZQPWAScNMRzGhSUUh8Bjd1WnwQ8br5/HDh5Z85pZ6CU2qqU+sp834ZxoyhiN792ZdBuLtrNlwIOB1401+921w0gIqOA44F/mcvC9+C64zAov/PdWSgUAZujlivNdd8XCpRSW8331cBu3YtRRMYB+wBf8D24dtOE8g1QC/wXKAWalVIBc5fd9fd+L3AzEDKXc/h+XLcC3hGR5SJymbluUH7nusnO9wCllBKR3TbMTERSgZeA65RSrdG9rXfXa1dKBYGZIpIJvAJMHdoZDT4icgJQq5RaLiKHDvF0djYHKqWqRCQf+K+IrI3eOJC/891ZU6gCRkctjzLXfV+oEZERAObf2iGez6AgInYMgfC0Uuplc/X34toBlFLNwAfAXCBTRMIPervj730+cKKIVGCYgw8H/sruf90oparMv7UYDwH7MUi/891ZKHwJTDIjExzA2cDrQzynncnrwIXm+wuB14ZwLoOCaU9+BFijlPpz1Kbd+tpFJM/UEBARF3AUhj/lA+B0c7fd7rqVUr9USo1SSo3D+H9+Xyl1Lrv5dYtIioikhd8DRwOrGKTf+W6dvCYiP8CwQVqBR5VSfxjaGQ0OIvIscChG1cQa4BbgVeB5YAxGJdkzlVLdndG7NCJyIPAxsJJtNuZfYfgVdttrF5G9MRyLVowHu+eVUreJSDHGE3Q28DVwnlLKO3QzHTxM89HPlFIn7O7XbV7fK+aiDXhGKfUHEclhEH7nu7VQ0Gg0Gk3/2J3NRxqNRqPpJ1ooaDQajSaCFgoajUajiaCFgkaj0WgiaKGg0Wg0mghaKGh2OiKiROSeqOWficitAzT2QhE5ve89d/g8Z4jIGhH5IMa2SSKyWERKzbIEH4jIwYM9p3iIyMnRxSBF5DYROXKo5qMZ3mihoBkKvMCpIpI71BOJJiorNhF+DFyqlDqs2xhJwJvAQ0qpCUqpWcA1QPHAzbQnZlXgeJyMUSkYAKXUb5VS7w7mfDS7LlooaIaCAEYrweu7b+j+pC8i7ebfQ0XkfyLymoiUicgdInKu2VdgpYhMiBrmSBFZJiLrzXo54QJyfxKRL0VkhYhcHjXuxyLyOvBdjPmcY46/SkTuNNf9FjgQeERE/tTtkHOBz5RSkex5pdQqpdRC89gUMfpfLDV7Apxkrl8gIi+LyH/M+vh3Rc3haBH5TES+EpEXzFpP4Rr7d4rIV8AZInKpeX3fishLIpIsIvOAE4E/iVGLf0L0ZywiR5jzWGnOyxk19u/Mc64Ukanm+kPMcb4xj0vr68vW7FpooaAZKu4HzhWRjH4cMwO4AtgDOB+YrJTaD6OM8jVR+43DqA1zPPBP8+n9x0CLUmoOMAe4VETGm/vvC/xUKTU5+mQiMhKjVv/hGH0L5ojIyUqp24BlwLlKqZu6zXFP4KteruHXGOUZ9gMOw7hZp5jbZgJnAdOBs8RoIpQL/AY4Uim1r3neG6LGa1BK7auUWgS8rJSaY/ZZWAP8WCn1KUY5hJvMWvylUdeXhNGL4yyl1HSMbNkro8auN8/5APAzc93PgKuVUjOBgwB3L9eq2QXRQkEzJCilWoEngGv7cdiXZg8FL0ap6HfM9SsxBEGY55VSIaXUBqAMo4Lo0cAFYpSb/gKj5PIkc/+lSqnyGOebA3yolKozSzM/DfTLNyAir5haRrhY39HAL8x5fAgkYZQpAKNhSotSyoOhtYzFaB40DfjEPOZCc32Y56Le72VqPSsxNJY9+5jeFKBcKbXeXH682/WF57ycbZ/vJ8CfReRaIDOqZLVmN0GXztYMJfdiPFU/FrUugPmwIiIWILq1YnQ9m1DUcoiuv+XutVsUIMA1Sqm3ozeYNXQ6tmfycVhN1I1VKXWKiMwG7g6fEjhNKbWu2zz2p+v1BTGuSYD/KqXOiXO+6LkvBE5WSn0rIgsw6mHtCOH5hOeCUuoOEXkT+AGGoDpGKbU23gCaXQ+tKWiGDLN41/N0bZ9YAcwy35+I0VWsv5whIhbTz1AMrAPeBq4Uo9Q2IjI5ymwTj6XAISKSazpyzwH+18cxzwDzReTEqHXJUe/fBq4RMZo+iMg+fYz3uTneRHP/FBGZHGffNGCreY3nRq1vM7d1Zx0wLjw2hkmu1+sTkQlKqZVKqTsxKhHv9n0cvm9ooaAZau7BqO4a5mGMG/G3GD0CtucpfhPGDf0t4ArTHPMvDJPMVyKyCniQPjRls6vVLzBKM38LLFdK9VqeWCnlBk4ArjAd4p9h+AR+b+5yO4agWyEiq83l3sarAxYAz4rICuAz4t+I/w/DNPYJEP30vgi4yXQMRxzy5udyEfCCaXIKAf/sbT7AdaY5bAXgx/iMNbsRukqqRqPRaCJoTUGj0Wg0EbRQ0Gg0Gk0ELRQ0Go1GE0ELBY1Go9FE0EJBo9FoNBG0UNBoNBpNBC0UNBqNRhNBCwWNRqPRRPj/ytmJl+peabYAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "num_generations = 50\n", + "num_rollouts = 20\n", + "print_every_k_gens = 5\n", + "\n", + "rng = jax.random.PRNGKey(0)\n", + "es_logging = ESLog(param_reshaper.total_params,\n", + " num_generations,\n", + " top_k=5,\n", + " maximize=True)\n", + "\n", + "# No es_params!\n", + "state = strategy.initialize(rng)\n", + "\n", + "for gen in range(num_generations):\n", + " rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)\n", + " x, state = strategy.ask(rng_ask, state)\n", + " fitness = evaluator.rollout(rng_eval, x).mean(axis=1)\n", + " state = strategy.tell(x, fitness, state)\n", + " if gen % print_every_k_gens == 0:\n", + " print(\"Generation: \", gen, \"Performance: \", state.best_fitness)\n", + " #break\n", + " \n", + "es_logging.plot(log, \"CartPole Augmented Random Search\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc13efd2", + "metadata": {}, "outputs": [], "source": [] } diff --git a/examples/01_classic_benchmark.ipynb b/examples/01_classic_benchmark.ipynb index cd8f39c..b1dae81 100755 --- a/examples/01_classic_benchmark.ipynb +++ b/examples/01_classic_benchmark.ipynb @@ -33,52 +33,67 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 6, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - ":228: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n", - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CMA-ES - # Gen: 10|Fitness: 0.13|Params: [0.6441135 0.41466928]\n", - "CMA-ES - # Gen: 20|Fitness: 0.00|Params: [0.97413015 0.9518173 ]\n", - "CMA-ES - # Gen: 30|Fitness: 0.00|Params: [0.9981632 0.9965331]\n", - "CMA-ES - # Gen: 40|Fitness: 0.00|Params: [0.9999719 0.9999461]\n", - "CMA-ES - # Gen: 50|Fitness: 0.00|Params: [0.9999997 0.9999994]\n" - ] + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" } ], "source": [ "import jax\n", "import jax.numpy as jnp\n", "from evosax import CMA_ES\n", - "from evosax.problems import ClassicFitness\n", + "from evosax.problems import BBOBFitness\n", "\n", "# Instantiate the problem evaluator\n", - "rosenbrock = ClassicFitness(\"rosenbrock\", num_dims=2)\n", - "\n", + "rosenbrock = BBOBFitness(\"RosenbrockOriginal\", num_dims=2, seed_id=2)\n", + "rosenbrock.visualize(plot_log_fn=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CMA-ES - # Gen: 10|Fitness: 0.11798|Params: [-0.24922156 -0.45996755]\n", + "CMA-ES - # Gen: 20|Fitness: 0.06408|Params: [-0.25254983 -0.44303334]\n", + "CMA-ES - # Gen: 30|Fitness: 0.00020|Params: [-0.00756136 -0.01385564]\n", + "CMA-ES - # Gen: 40|Fitness: 0.00000|Params: [-0.00087966 -0.00171674]\n", + "CMA-ES - # Gen: 50|Fitness: 0.00000|Params: [-3.9389306e-06 -8.1345934e-06]\n" + ] + } + ], + "source": [ "# Instantiate the search strategy\n", "rng = jax.random.PRNGKey(0)\n", "strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)\n", - "state = strategy.initialize(rng)\n", + "es_params = strategy.default_params.replace(init_min=-2, init_max=2)\n", + "\n", + "state = strategy.initialize(rng, es_params)\n", "\n", "# Run ask-eval-tell loop - NOTE: By default minimization\n", "for t in range(50):\n", " rng, rng_gen, rng_eval = jax.random.split(rng, 3)\n", - " x, state = strategy.ask(rng_gen, state)\n", + " x, state = strategy.ask(rng_gen, state, es_params)\n", " fitness = rosenbrock.rollout(rng_eval, x)\n", - " state = strategy.tell(x, fitness, state)\n", + " state = strategy.tell(x, fitness, state, es_params)\n", "\n", " if (t + 1) % 10 == 0:\n", - " print(\"CMA-ES - # Gen: {}|Fitness: {:.2f}|Params: {}\".format(\n", + " print(\"CMA-ES - # Gen: {}|Fitness: {:.5f}|Params: {}\".format(\n", " t+1, state.best_fitness, state.best_member))" ] }, @@ -91,96 +106,110 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "SimpleES - # Gen: 5|Fitness: 0.44|Params: [0.41951638 0.14410228]\n", - "SimpleES - # Gen: 10|Fitness: 0.04|Params: [0.91297907 0.8160112 ]\n", - "SimpleES - # Gen: 15|Fitness: 0.01|Params: [0.98460174 0.97844493]\n", - "SimpleES - # Gen: 20|Fitness: 0.01|Params: [0.98460174 0.97844493]\n", - "SimpleES - # Gen: 25|Fitness: 0.01|Params: [0.98460174 0.97844493]\n", - "SimpleES - # Gen: 30|Fitness: 0.01|Params: [0.98460174 0.97844493]\n", + "SimpleES - # Gen: 5|Fitness: 2.41|Params: [0.5749427 1.3363407]\n", + "SimpleES - # Gen: 10|Fitness: 0.05|Params: [-0.02394086 -0.06951416]\n", + "SimpleES - # Gen: 15|Fitness: 0.02|Params: [0.052019 0.11855024]\n", + "SimpleES - # Gen: 20|Fitness: 0.02|Params: [0.052019 0.11855024]\n", + "SimpleES - # Gen: 25|Fitness: 0.02|Params: [0.052019 0.11855024]\n", + "SimpleES - # Gen: 30|Fitness: 0.02|Params: [0.052019 0.11855024]\n", "====================\n", - "SimpleGA - # Gen: 5|Fitness: 6.79|Params: [-0.012256 -0.24003565]\n", - "SimpleGA - # Gen: 10|Fitness: 0.68|Params: [0.21533592 0.02063736]\n", - "SimpleGA - # Gen: 15|Fitness: 0.39|Params: [0.4103716 0.14900509]\n", - "SimpleGA - # Gen: 20|Fitness: 0.18|Params: [0.5903524 0.33600026]\n", - "SimpleGA - # Gen: 25|Fitness: 0.17|Params: [0.6199676 0.39935672]\n", - "SimpleGA - # Gen: 30|Fitness: 0.13|Params: [0.64335036 0.40990546]\n", + "SimpleGA - # Gen: 5|Fitness: 0.05|Params: [-0.23124489 -0.40957353]\n", + "SimpleGA - # Gen: 10|Fitness: 0.03|Params: [-0.13031456 -0.23322381]\n", + "SimpleGA - # Gen: 15|Fitness: 0.02|Params: [-0.11532619 -0.20849138]\n", + "SimpleGA - # Gen: 20|Fitness: 0.00|Params: [-0.00291448 -0.00076399]\n", + "SimpleGA - # Gen: 25|Fitness: 0.00|Params: [-0.00291448 -0.00076399]\n", + "SimpleGA - # Gen: 30|Fitness: 0.00|Params: [0.0479427 0.09704389]\n", "====================\n", - "PSO - # Gen: 5|Fitness: 1.11|Params: [-0.01428866 0.02790421]\n", - "PSO - # Gen: 10|Fitness: 0.03|Params: [1.0889671 1.1718146]\n", - "PSO - # Gen: 15|Fitness: 0.01|Params: [1.109518 1.2260276]\n", - "PSO - # Gen: 20|Fitness: 0.01|Params: [1.07492 1.1620886]\n", - "PSO - # Gen: 25|Fitness: 0.01|Params: [1.07492 1.1620886]\n", - "PSO - # Gen: 30|Fitness: 0.01|Params: [1.07492 1.1620886]\n", + "PSO - # Gen: 5|Fitness: 0.32|Params: [-0.01428866 0.02790421]\n", + "PSO - # Gen: 10|Fitness: 0.19|Params: [-0.32952115 -0.5220759 ]\n", + "PSO - # Gen: 15|Fitness: 0.12|Params: [-0.33950207 -0.56108034]\n", + "PSO - # Gen: 20|Fitness: 0.09|Params: [-0.30322644 -0.5158702 ]\n", + "PSO - # Gen: 25|Fitness: 0.06|Params: [-0.23692334 -0.41880134]\n", + "PSO - # Gen: 30|Fitness: 0.06|Params: [-0.23692334 -0.41880134]\n", "====================\n", - "DE - # Gen: 5|Fitness: 0.33|Params: [0.4517086 0.18653232]\n", - "DE - # Gen: 10|Fitness: 0.06|Params: [0.7975259 0.6230468]\n", - "DE - # Gen: 15|Fitness: 0.00|Params: [0.95278853 0.90621567]\n", - "DE - # Gen: 20|Fitness: 0.00|Params: [0.9835618 0.9694447]\n", - "DE - # Gen: 25|Fitness: 0.00|Params: [1.0125908 1.0251266]\n", - "DE - # Gen: 30|Fitness: 0.00|Params: [1.0005985 1.0011392]\n", + "DE - # Gen: 5|Fitness: 0.37|Params: [-0.6068972 -0.8382896]\n", + "DE - # Gen: 10|Fitness: 0.03|Params: [-0.15622607 -0.2968421 ]\n", + "DE - # Gen: 15|Fitness: 0.00|Params: [-0.01819804 -0.03231879]\n", + "DE - # Gen: 20|Fitness: 0.00|Params: [0.0003629 0.00136572]\n", + "DE - # Gen: 25|Fitness: 0.00|Params: [0.0003629 0.00136572]\n", + "DE - # Gen: 30|Fitness: 0.00|Params: [0.00027412 0.00099771]\n", "====================\n", - "Sep_CMA_ES - # Gen: 5|Fitness: 5.25|Params: [-1.2778556 1.6572908]\n", - "Sep_CMA_ES - # Gen: 10|Fitness: 5.25|Params: [-1.2778556 1.6572908]\n", - "Sep_CMA_ES - # Gen: 15|Fitness: 5.25|Params: [-1.2778556 1.6572908]\n", - "Sep_CMA_ES - # Gen: 20|Fitness: 5.25|Params: [-1.2778556 1.6572908]\n", - "Sep_CMA_ES - # Gen: 25|Fitness: 5.25|Params: [-1.2778556 1.6572908]\n", - "Sep_CMA_ES - # Gen: 30|Fitness: 5.25|Params: [-1.2778556 1.6572908]\n", + "Sep_CMA_ES - # Gen: 5|Fitness: 4.09|Params: [-1.5836709 -0.5336474]\n", + "Sep_CMA_ES - # Gen: 10|Fitness: 4.09|Params: [-1.5836709 -0.5336474]\n", + "Sep_CMA_ES - # Gen: 15|Fitness: 4.09|Params: [-1.5836709 -0.5336474]\n", + "Sep_CMA_ES - # Gen: 20|Fitness: 4.09|Params: [-1.5836709 -0.5336474]\n", + "Sep_CMA_ES - # Gen: 25|Fitness: 4.09|Params: [-1.5836709 -0.5336474]\n", + "Sep_CMA_ES - # Gen: 30|Fitness: 4.09|Params: [-1.5836709 -0.5336474]\n", "====================\n", - "Full_iAMaLGaM - # Gen: 5|Fitness: 0.25|Params: [0.6604285 0.39947018]\n", - "Full_iAMaLGaM - # Gen: 10|Fitness: 0.13|Params: [0.69210684 0.4968851 ]\n", - "Full_iAMaLGaM - # Gen: 15|Fitness: 0.04|Params: [0.7911089 0.6271153]\n", - "Full_iAMaLGaM - # Gen: 20|Fitness: 0.01|Params: [0.8877124 0.78315926]\n", - "Full_iAMaLGaM - # Gen: 25|Fitness: 0.00|Params: [0.97602683 0.9512631 ]\n", - "Full_iAMaLGaM - # Gen: 30|Fitness: 0.00|Params: [0.99968135 0.999422 ]\n", + "Full_iAMaLGaM - # Gen: 5|Fitness: 0.01|Params: [-0.10261801 -0.19658661]\n", + "Full_iAMaLGaM - # Gen: 10|Fitness: 0.00|Params: [-0.0065736 -0.01341229]\n", + "Full_iAMaLGaM - # Gen: 15|Fitness: 0.00|Params: [9.990766e-05 1.880897e-04]\n", + "Full_iAMaLGaM - # Gen: 20|Fitness: 0.00|Params: [2.0616037e-05 4.2061394e-05]\n", + "Full_iAMaLGaM - # Gen: 25|Fitness: 0.00|Params: [1.1449511e-06 2.9405437e-06]\n", + "Full_iAMaLGaM - # Gen: 30|Fitness: 0.00|Params: [-9.1895004e-07 -1.6921636e-06]\n", "====================\n", - "Indep_iAMaLGaM - # Gen: 5|Fitness: 0.17|Params: [0.61008203 0.38497373]\n", - "Indep_iAMaLGaM - # Gen: 10|Fitness: 0.14|Params: [0.69136626 0.5000103 ]\n", - "Indep_iAMaLGaM - # Gen: 15|Fitness: 0.14|Params: [0.69136626 0.5000103 ]\n", - "Indep_iAMaLGaM - # Gen: 20|Fitness: 0.14|Params: [0.69136626 0.5000103 ]\n", - "Indep_iAMaLGaM - # Gen: 25|Fitness: 0.14|Params: [0.63670754 0.41385096]\n", - "Indep_iAMaLGaM - # Gen: 30|Fitness: 0.14|Params: [0.63670754 0.41385096]\n", + "Indep_iAMaLGaM - # Gen: 5|Fitness: 0.05|Params: [0.14939058 0.30428362]\n", + "Indep_iAMaLGaM - # Gen: 10|Fitness: 0.02|Params: [0.13289806 0.28522342]\n", + "Indep_iAMaLGaM - # Gen: 15|Fitness: 0.02|Params: [0.13289806 0.28522342]\n", + "Indep_iAMaLGaM - # Gen: 20|Fitness: 0.02|Params: [0.13289806 0.28522342]\n", + "Indep_iAMaLGaM - # Gen: 25|Fitness: 0.02|Params: [0.13289806 0.28522342]\n", + "Indep_iAMaLGaM - # Gen: 30|Fitness: 0.02|Params: [0.13289806 0.28522342]\n", "====================\n", - "MA_ES - # Gen: 5|Fitness: 840.14|Params: [ 1.6348459 -0.22510758]\n", - "MA_ES - # Gen: 10|Fitness: 839.89|Params: [ 1.6386213 -0.21230145]\n", - "MA_ES - # Gen: 15|Fitness: 839.11|Params: [ 1.6380427 -0.21285582]\n", - "MA_ES - # Gen: 20|Fitness: 839.04|Params: [ 1.6380086 -0.21284352]\n", - "MA_ES - # Gen: 25|Fitness: 839.04|Params: [ 1.6380068 -0.21284315]\n", - "MA_ES - # Gen: 30|Fitness: 839.04|Params: [ 1.6380068 -0.21284278]\n", + "MA_ES - # Gen: 5|Fitness: 0.33|Params: [-0.5359075 -0.7642154]\n", + "MA_ES - # Gen: 10|Fitness: 0.33|Params: [-0.5359075 -0.7642154]\n", + "MA_ES - # Gen: 15|Fitness: 0.33|Params: [-0.5359075 -0.7642154]\n", + "MA_ES - # Gen: 20|Fitness: 0.33|Params: [-0.5359075 -0.7642154]\n", + "MA_ES - # Gen: 25|Fitness: 0.33|Params: [-0.5359075 -0.7642154]\n", + "MA_ES - # Gen: 30|Fitness: 0.33|Params: [-0.5359075 -0.7642154]\n", "====================\n", - "LM_MA_ES - # Gen: 5|Fitness: 6.08|Params: [-1.30967 1.6290693]\n", - "LM_MA_ES - # Gen: 10|Fitness: 6.08|Params: [-1.30967 1.6290693]\n", - "LM_MA_ES - # Gen: 15|Fitness: 6.08|Params: [-1.30967 1.6290693]\n", - "LM_MA_ES - # Gen: 20|Fitness: 6.08|Params: [-1.30967 1.6290693]\n", - "LM_MA_ES - # Gen: 25|Fitness: 6.08|Params: [-1.30967 1.6290693]\n", - "LM_MA_ES - # Gen: 30|Fitness: 6.08|Params: [-1.30967 1.6290693]\n", + "LM_MA_ES - # Gen: 5|Fitness: 7.78|Params: [-2.7078676 1.8501627]\n", + "LM_MA_ES - # Gen: 10|Fitness: 7.45|Params: [-2.7272193 1.9920657]\n", + "LM_MA_ES - # Gen: 15|Fitness: 7.42|Params: [-2.7236936 1.9769124]\n", + "LM_MA_ES - # Gen: 20|Fitness: 7.42|Params: [-2.7237751 1.9702088]\n", + "LM_MA_ES - # Gen: 25|Fitness: 7.41|Params: [-2.721651 1.9702324]\n", + "LM_MA_ES - # Gen: 30|Fitness: 7.41|Params: [-2.7209196 1.9683607]\n", "====================\n", - "RmES - # Gen: 5|Fitness: 1.81|Params: [ 0.17560971 -0.0752801 ]\n", - "RmES - # Gen: 10|Fitness: 0.12|Params: [0.7972345 0.60836744]\n", - "RmES - # Gen: 15|Fitness: 0.09|Params: [0.9898771 1.0095383]\n", - "RmES - # Gen: 20|Fitness: 0.01|Params: [0.9717485 0.9363907]\n", - "RmES - # Gen: 25|Fitness: 0.01|Params: [0.94771284 0.9043704 ]\n", - "RmES - # Gen: 30|Fitness: 0.00|Params: [1.0041738 1.006372 ]\n", + "RmES - # Gen: 5|Fitness: 0.37|Params: [-0.59044695 -0.84810144]\n", + "RmES - # Gen: 10|Fitness: 0.37|Params: [-0.59044695 -0.84810144]\n", + "RmES - # Gen: 15|Fitness: 0.11|Params: [-0.33058548 -0.5489675 ]\n", + "RmES - # Gen: 20|Fitness: 0.11|Params: [-0.33058548 -0.5489675 ]\n", + "RmES - # Gen: 25|Fitness: 0.11|Params: [-0.33058548 -0.5489675 ]\n", + "RmES - # Gen: 30|Fitness: 0.09|Params: [-0.28109062 -0.47365618]\n", "====================\n", - "GLD - # Gen: 5|Fitness: 2.20|Params: [-0.1850586 0.12321383]\n", - "GLD - # Gen: 10|Fitness: 1.01|Params: [-0.00243703 0.00842408]\n", - "GLD - # Gen: 15|Fitness: 0.40|Params: [0.364061 0.1313454]\n", - "GLD - # Gen: 20|Fitness: 0.27|Params: [0.5125 0.2448907]\n", - "GLD - # Gen: 25|Fitness: 0.16|Params: [0.6050571 0.37277922]\n", - "GLD - # Gen: 30|Fitness: 0.10|Params: [0.68510354 0.4659995 ]\n", + "GLD - # Gen: 5|Fitness: 0.01|Params: [-0.1030587 -0.19583313]\n", + "GLD - # Gen: 10|Fitness: 0.01|Params: [-0.07785733 -0.14456296]\n", + "GLD - # Gen: 15|Fitness: 0.00|Params: [-0.03990307 -0.08063483]\n", + "GLD - # Gen: 20|Fitness: 0.00|Params: [-0.0303201 -0.0579391]\n", + "GLD - # Gen: 25|Fitness: 0.00|Params: [-0.03103314 -0.06153299]\n", + "GLD - # Gen: 30|Fitness: 0.00|Params: [-0.0139228 -0.02847508]\n", "====================\n", - "SimAnneal - # Gen: 5|Fitness: 19.24|Params: [-1.131186 0.8961963]\n", - "SimAnneal - # Gen: 10|Fitness: 3.70|Params: [-0.8958996 0.83507967]\n", - "SimAnneal - # Gen: 15|Fitness: 3.08|Params: [-0.7558189 0.5737612]\n", - "SimAnneal - # Gen: 20|Fitness: 2.34|Params: [-0.52904546 0.27448243]\n", - "SimAnneal - # Gen: 25|Fitness: 1.89|Params: [-0.36856776 0.14812674]\n", - "SimAnneal - # Gen: 30|Fitness: 1.23|Params: [-0.07835681 0.03237222]\n", + "SimAnneal - # Gen: 5|Fitness: 114.55|Params: [-1.7434927 0.60876817]\n", + "SimAnneal - # Gen: 10|Fitness: 29.11|Params: [-1.990768 0.48308632]\n", + "SimAnneal - # Gen: 15|Fitness: 4.59|Params: [-2.1422086 0.30636734]\n", + "SimAnneal - # Gen: 20|Fitness: 4.36|Params: [-2.0813289 0.18672767]\n", + "SimAnneal - # Gen: 25|Fitness: 4.25|Params: [-2.0612166 0.12882923]\n", + "SimAnneal - # Gen: 30|Fitness: 3.98|Params: [-1.9781697 -0.0172411]\n", + "====================\n", + "GESMR_GA - # Gen: 5|Fitness: 9.00|Params: [-1.181883 -1.2426325]\n", + "GESMR_GA - # Gen: 10|Fitness: 1.43|Params: [-1.1773776 -0.99034745]\n", + "GESMR_GA - # Gen: 15|Fitness: 0.43|Params: [-0.57461786 -0.78734714]\n", + "GESMR_GA - # Gen: 20|Fitness: 0.27|Params: [-0.5132652 -0.7691257]\n", + "GESMR_GA - # Gen: 25|Fitness: 0.24|Params: [-0.4847008 -0.7374786]\n", + "GESMR_GA - # Gen: 30|Fitness: 0.21|Params: [-0.46073768 -0.7048114 ]\n", + "====================\n", + "SAMR_GA - # Gen: 5|Fitness: 0.59|Params: [0.68800676 1.8831477 ]\n", + "SAMR_GA - # Gen: 10|Fitness: 0.57|Params: [0.64244306 1.7375212 ]\n", + "SAMR_GA - # Gen: 15|Fitness: 0.57|Params: [0.64244306 1.7375212 ]\n", + "SAMR_GA - # Gen: 20|Fitness: 0.57|Params: [0.64244306 1.7375212 ]\n", + "SAMR_GA - # Gen: 25|Fitness: 0.52|Params: [0.6906921 1.8796452]\n", + "SAMR_GA - # Gen: 30|Fitness: 0.52|Params: [0.6906921 1.8796452]\n", "====================\n" ] } @@ -191,7 +220,7 @@ "\n", "for s_name in [\"SimpleES\", \"SimpleGA\", \"PSO\", \"DE\", \"Sep_CMA_ES\",\n", " \"Full_iAMaLGaM\", \"Indep_iAMaLGaM\", \"MA_ES\", \"LM_MA_ES\",\n", - " \"RmES\", \"GLD\", \"SimAnneal\"]:\n", + " \"RmES\", \"GLD\", \"SimAnneal\", \"GESMR_GA\", \"SAMR_GA\"]:\n", " strategy = Strategies[s_name](popsize=20, num_dims=2)\n", " es_params = strategy.default_params\n", " es_params = es_params.replace(init_min=-2, init_max=2)\n", @@ -213,69 +242,28 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# xNES on Sinusoidal Task" + "# Try out one of the many `evosax` algorithms!" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "xNES - # Gen: 500|Fitness: -0.00000|Params: [ 9991.45 -9987.809]\n", - "xNES - # Gen: 1000|Fitness: -0.00000|Params: [ 9951.659 -9911.333]\n", - "xNES - # Gen: 1500|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]\n", - "xNES - # Gen: 2000|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]\n", - "xNES - # Gen: 2500|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]\n", - "xNES - # Gen: 3000|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]\n", - "xNES - # Gen: 3500|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]\n", - "xNES - # Gen: 4000|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]\n", - "xNES - # Gen: 4500|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]\n", - "xNES - # Gen: 5000|Fitness: -1.00000|Params: [ 8.5644424e-05 -5.0786706e-03]\n" - ] + "data": { + "text/plain": [ + "dict_keys(['SimpleGA', 'SimpleES', 'CMA_ES', 'DE', 'PSO', 'OpenES', 'PGPE', 'PBT', 'PersistentES', 'ARS', 'Sep_CMA_ES', 'BIPOP_CMA_ES', 'IPOP_CMA_ES', 'Full_iAMaLGaM', 'Indep_iAMaLGaM', 'MA_ES', 'LM_MA_ES', 'RmES', 'GLD', 'SimAnneal', 'SNES', 'xNES', 'ESMC', 'DES', 'SAMR_GA', 'GESMR_GA', 'GuidedES', 'ASEBO', 'CR_FM_NES', 'MR15_GA'])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from evosax.strategies import xNES\n", - "\n", - "def f(x):\n", - " \"\"\"Taken from https://github.com/chanshing/xnes\"\"\" \n", - " r = jnp.sum(x ** 2)\n", - " return -jnp.sin(r) / r\n", - "\n", - "batch_func = jax.vmap(f, in_axes=0)\n", - "\n", - "rng = jax.random.PRNGKey(0)\n", - "strategy = xNES(popsize=50, num_dims=2)\n", - "es_params = strategy.default_params\n", - "es_params = es_params.replace(use_adaptive_sampling=True, \n", - " use_fitness_shaping=True,\n", - " eta_bmat=0.01,\n", - " eta_sigma_init=0.1)\n", - "\n", - "state = strategy.initialize(rng, es_params)\n", - "# Set mean to a bad initial guess\n", - "state = state.replace(mean = jnp.array([9999.0, -9999.0]))\n", - "num_iters = 5000\n", - "for t in range(num_iters):\n", - " rng, rng_iter = jax.random.split(rng)\n", - " y, state = strategy.ask(rng_iter, state, es_params)\n", - " fitness = batch_func(y)\n", - " state = strategy.tell(y, fitness, state, es_params)\n", - " if (t + 1) % 500 == 0:\n", - " print(\"xNES - # Gen: {}|Fitness: {:.5f}|Params: {}\".format(\n", - " t+1, state.best_fitness, state.best_member))\n" + "Strategies.keys()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -295,6 +283,11 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } } }, "nbformat": 4, diff --git a/examples/02_mlp_control.ipynb b/examples/02_mlp_control.ipynb index 061912e..f689e92 100755 --- a/examples/02_mlp_control.ipynb +++ b/examples/02_mlp_control.ipynb @@ -35,14 +35,6 @@ "execution_count": 1, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":228: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n", - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -57,7 +49,7 @@ "\n", "from evosax import OpenES, ParameterReshaper, FitnessShaper, NetworkMapper\n", "from evosax.utils import ESLog\n", - "from evosax.problems import GymFitness\n", + "from evosax.problems import GymnaxFitness\n", "\n", "rng = jax.random.PRNGKey(0)\n", "network = NetworkMapper[\"MLP\"](\n", @@ -83,8 +75,8 @@ "metadata": {}, "outputs": [], "source": [ - "evaluator = GymFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16)\n", - "evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)" + "evaluator = GymnaxFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16)\n", + "evaluator.set_apply_fn(network.apply)" ] }, { @@ -95,7 +87,7 @@ { "data": { "text/plain": [ - "EvoParams(opt_params=OptParams(lrate_init=0.01, lrate_decay=0.999, lrate_limit=0.001, momentum=0.9, beta_1=None, beta_2=None, eps=None, max_speed=None), sigma_init=0.04, sigma_decay=0.999, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)" + "EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=1.0, lrate_limit=0.001, momentum=0.0, beta_1=None, beta_2=None, beta_3=None, eps=None, max_speed=None), sigma_init=0.03, sigma_decay=1.0, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)" ] }, "execution_count": 3, @@ -119,29 +111,26 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n", - " warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n" + "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/flax/core/scope.py:740: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.\n", + " abs_value_flat = jax.tree_leaves(abs_value)\n", + "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/flax/core/scope.py:741: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.\n", + " value_flat = jax.tree_leaves(value)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Generation: 0 Generation: 21.875\n", - "Generation: 20 Generation: 80.25\n", - "Generation: 40 Generation: 82.75\n", - "Generation: 60 Generation: 166.1875\n", - "Generation: 80 Generation: 200.0\n", - "Generation: 100 Generation: 200.0\n", - "Generation: 120 Generation: 200.0\n", - "Generation: 140 Generation: 200.0\n", - "Generation: 160 Generation: 200.0\n", - "Generation: 180 Generation: 200.0\n" + "Generation: 0 Generation: 22.875\n", + "Generation: 20 Generation: 81.75\n", + "Generation: 40 Generation: 200.0\n", + "Generation: 60 Generation: 200.0\n", + "Generation: 80 Generation: 200.0\n" ] } ], "source": [ - "num_generations = 200\n", + "num_generations = 100\n", "print_every_k_gens = 20\n", "\n", "es_logging = ESLog(param_reshaper.total_params,\n", @@ -151,7 +140,7 @@ "log = es_logging.initialize()\n", "\n", "fit_shaper = FitnessShaper(centered_rank=True,\n", - " z_score=True,\n", + " z_score=False,\n", " w_decay=0.1,\n", " maximize=True)\n", "\n", @@ -188,7 +177,7 @@ }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -248,22 +237,22 @@ "metadata": {}, "outputs": [], "source": [ - "evaluator = GymFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16)\n", - "evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply, network.initialize_carry)" + "evaluator = GymnaxFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16)\n", + "evaluator.set_apply_fn(network.apply, network.initialize_carry)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=0.999, lrate_limit=0.01, momentum=None, beta_1=0.99, beta_2=0.999, eps=1e-08, max_speed=None), sigma_init=0.05, sigma_decay=0.999, sigma_limit=0.01, sigma_lrate=0.2, sigma_max_change=0.2, init_min=-0.1, init_max=0.1, clip_min=-10, clip_max=10)" + "EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=1.0, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, beta_3=None, eps=1e-08, max_speed=None), sigma_init=0.03, sigma_decay=1.0, sigma_limit=0.01, sigma_lrate=0.2, sigma_max_change=0.2, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -272,68 +261,33 @@ "from evosax import PGPE\n", "\n", "popsize = 100\n", - "strategy = PGPE(param_reshaper.total_params, popsize,\n", + "strategy = PGPE(popsize, param_reshaper.total_params,\n", " elite_ratio=0.1, opt_name=\"adam\")\n", "\n", "# Update basic parameters of PGPE strategy\n", - "es_params = strategy.default_params.replace(\n", - " sigma_init=0.05, # Initial scale of isotropic Gaussian noise\n", - " sigma_decay=0.999, # Multiplicative decay factor\n", - " sigma_limit=0.01, # Smallest possible scale\n", - " sigma_lrate=0.2, # Learning rate for scale\n", - " sigma_max_change=0.2, # clips adaptive sigma to 20%\n", - " init_min=-0.1, # Range of parameter mean initialization - Min\n", - " init_max=0.1, # Range of parameter mean initialization - Max\n", - " clip_min=-10, # Range of parameter proposals - Min\n", - " clip_max=10 # Range of parameter proposals - Max\n", - ")\n", - "\n", - "# Update optimizer-specific parameters of Adam\n", - "es_params = es_params.replace(opt_params=es_params.opt_params.replace(\n", - " lrate_init=0.05, # Initial learning rate\n", - " lrate_decay=0.999, # Multiplicative decay factor\n", - " lrate_limit=0.01, # Smallest possible lrate\n", - " beta_1=0.99, # Adam - beta_1\n", - " beta_2=0.999, # Adam - beta_2\n", - " eps=1e-8, # eps constant,\n", - " )\n", - ")\n", - "\n", + "es_params = strategy.default_params\n", "es_params " ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n", - " warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Generation: 0 Performance: 22.3125\n", - "Generation: 20 Performance: 40.8125\n", - "Generation: 40 Performance: 83.0625\n", - "Generation: 60 Performance: 188.6875\n", - "Generation: 80 Performance: 198.5625\n", - "Generation: 100 Performance: 200.0\n", - "Generation: 120 Performance: 200.0\n", - "Generation: 140 Performance: 200.0\n", - "Generation: 160 Performance: 200.0\n", - "Generation: 180 Performance: 200.0\n" + "Generation: 0 Performance: 21.875\n", + "Generation: 20 Performance: 44.625\n", + "Generation: 40 Performance: 194.4375\n", + "Generation: 60 Performance: 199.3125\n", + "Generation: 80 Performance: 200.0\n" ] } ], "source": [ - "num_generations = 200\n", + "num_generations = 100\n", "print_every_k_gens = 20\n", "\n", "es_logging = ESLog(param_reshaper.total_params,\n", @@ -362,7 +316,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -372,13 +326,13 @@ " )" ] }, - "execution_count": 12, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] diff --git a/examples/03_cnn_mnist.ipynb b/examples/03_cnn_mnist.ipynb index f0ffcdf..b7163c0 100755 --- a/examples/03_cnn_mnist.ipynb +++ b/examples/03_cnn_mnist.ipynb @@ -28,14 +28,6 @@ "execution_count": 1, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":228: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n", - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -192,8 +184,8 @@ "train_evaluator = VisionFitness(\"MNIST\", batch_size=1024, test=False)\n", "test_evaluator = VisionFitness(\"MNIST\", batch_size=10000, test=True, n_devices=1)\n", "\n", - "train_evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)\n", - "test_evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)" + "train_evaluator.set_apply_fn(network.apply)\n", + "test_evaluator.set_apply_fn(network.apply)" ] }, { @@ -205,26 +197,7 @@ "from evosax import OpenES\n", "strategy = OpenES(popsize=100, num_dims=param_reshaper.total_params, opt_name=\"adam\")\n", "# Update basic parameters of PGPE strategy\n", - "es_params = strategy.default_params.replace(\n", - " sigma_init=0.01, # Initial scale of isotropic Gaussian noise\n", - " sigma_decay=0.999, # Multiplicative decay factor\n", - " sigma_limit=0.01, # Smallest possible scale\n", - " init_min=0.0, # Range of parameter mean initialization - Min\n", - " init_max=0.0, # Range of parameter mean initialization - Max\n", - " clip_min=-10, # Range of parameter proposals - Min\n", - " clip_max=10 # Range of parameter proposals - Max\n", - ")\n", - "\n", - "# Update optimizer-specific parameters of Adam\n", - "es_params = es_params.replace(opt_params=es_params.opt_params.replace(\n", - " lrate_init=0.001, # Initial learning rate\n", - " lrate_decay=0.9999, # Multiplicative decay factor\n", - " lrate_limit=0.0001, # Smallest possible lrate\n", - " beta_1=0.99, # Adam - beta_1\n", - " beta_2=0.999, # Adam - beta_2\n", - " eps=1e-8, # eps constant,\n", - " )\n", - ")" + "es_params = strategy.default_params" ] }, { @@ -235,9 +208,9 @@ "source": [ "from evosax import FitnessShaper\n", "fit_shaper = FitnessShaper(centered_rank=True,\n", - " z_score=True,\n", + " z_score=False,\n", " w_decay=0.1,\n", - " maximize=True)" + " maximize=False)" ] }, { diff --git a/examples/04_lrate_pes.ipynb b/examples/04_lrate_pes.ipynb index 95abf5e..481196d 100755 --- a/examples/04_lrate_pes.ipynb +++ b/examples/04_lrate_pes.ipynb @@ -33,15 +33,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":228: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n" - ] - } - ], + "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", @@ -85,15 +77,18 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] + "data": { + "text/plain": [ + "EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=1.0, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, beta_3=None, eps=1e-08, max_speed=None), T=100, K=10, sigma_init=0.1, sigma_decay=1.0, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -103,7 +98,7 @@ "\n", "strategy = PersistentES(popsize=popsize, num_dims=2)\n", "es_params = strategy.default_params.replace(\n", - " T=100, K=10\n", + " T=100, K=10, sigma_init=0.1\n", ")\n", "\n", "rng = jax.random.PRNGKey(5)\n", @@ -111,7 +106,9 @@ "\n", "# Initialize inner parameters\n", "t = 0\n", - "xs = jnp.ones((popsize, 2)) * jnp.array([1.0, 1.0])" + "xs = jnp.ones((popsize, 2)) * jnp.array([1.0, 1.0])\n", + "\n", + "es_params" ] }, { @@ -123,23 +120,23 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0 [ 0.01 -0.01] 2423.4482\n", - "500 [ 0.08029711 -0.7827603 ] 2423.3083\n", - "1000 [ 0.17781787 -0.6900547 ] 2423.4324\n", - "1500 [ 1.7823172 -0.5671913] 1363.6532\n", - "2000 [ 2.6007807 -0.43554983] 585.6527\n", - "2500 [ 2.7024412 -0.4344938] 576.2598\n", - "3000 [ 2.737832 -0.47443515] 576.0917\n", - "3500 [ 2.7500708 -0.5119995] 565.1986\n", - "4000 [ 2.750747 -0.51908237] 574.359\n", - "4500 [ 2.765111 -0.5706135] 573.79443\n" + "0 [ 0.05 -0.05] 2423.374\n", + "500 [ 0.13214235 -2.474788 ] 2423.2078\n", + "1000 [ 3.9050057 -4.4652762] 1183.7357\n", + "1500 [ 2.5583386 -4.036586 ] 582.6147\n", + "2000 [ 2.7078283 -3.8439238] 564.5876\n", + "2500 [ 2.744315 -2.5619094] 559.23505\n", + "3000 [ 2.7431633 -3.8979192] 566.58826\n", + "3500 [ 2.7665381 -4.55985 ] 558.7182\n", + "4000 [ 2.7644894 -3.5615964] 556.5793\n", + "4500 [ 2.7446108 -4.953268 ] 559.6667\n" ] } ], @@ -168,13 +165,6 @@ " )\n", " print(i, state.mean, L)\n" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/examples/05_quadratic_pbt.ipynb b/examples/05_quadratic_pbt.ipynb index bf34df8..8e08111 100755 --- a/examples/05_quadratic_pbt.ipynb +++ b/examples/05_quadratic_pbt.ipynb @@ -33,15 +33,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":228: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n" - ] - } - ], + "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", @@ -70,15 +62,7 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - } - ], + "outputs": [], "source": [ "from evosax.strategies import PBT\n", "\n", @@ -123,7 +107,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 4, diff --git a/examples/06_restart_es.ipynb b/examples/06_restart_es.ipynb index 00024c1..39c3308 100644 --- a/examples/06_restart_es.ipynb +++ b/examples/06_restart_es.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -49,7 +49,7 @@ "\n", "from evosax import OpenES, ParameterReshaper, FitnessShaper, NetworkMapper\n", "from evosax.utils import ESLog\n", - "from evosax.problems import GymFitness\n", + "from evosax.problems import GymnaxFitness\n", "\n", "rng = jax.random.PRNGKey(0)\n", "network = NetworkMapper[\"MLP\"](\n", @@ -71,26 +71,26 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "evaluator = GymFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16)\n", - "evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)" + "evaluator = GymnaxFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16)\n", + "evaluator.set_apply_fn(network.apply)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "WrapperParams(strategy_params=EvoParams(opt_params=OptParams(lrate_init=0.01, lrate_decay=0.999, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, eps=1e-08, max_speed=None), sigma_init=0.04, sigma_decay=0.999, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38), restart_params=RestartParams(min_num_gens=50, min_fitness_spread=1e-12, popsize_multiplier=2, copy_mean=False))" + "WrapperParams(strategy_params=EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=1.0, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, beta_3=None, eps=1e-08, max_speed=None), sigma_init=0.03, sigma_decay=1.0, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38), restart_params=RestartParams(min_num_gens=50, min_fitness_spread=1e-12, popsize_multiplier=2, copy_mean=False))" ] }, - "execution_count": 12, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -113,16 +113,16 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "WrapperParams(strategy_params=EvoParams(opt_params=OptParams(lrate_init=0.01, lrate_decay=0.999, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, eps=1e-08, max_speed=None), sigma_init=0.04, sigma_decay=0.999, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38), restart_params=RestartParams(min_num_gens=50, min_fitness_spread=1e-12, popsize_multiplier=2, copy_mean=True))" + "WrapperParams(strategy_params=EvoParams(opt_params=OptParams(lrate_init=0.05, lrate_decay=1.0, lrate_limit=0.001, momentum=None, beta_1=0.99, beta_2=0.999, beta_3=None, eps=1e-08, max_speed=None), sigma_init=0.03, sigma_decay=1.0, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38), restart_params=RestartParams(min_num_gens=50, min_fitness_spread=1e-12, popsize_multiplier=2, copy_mean=True))" ] }, - "execution_count": 13, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -135,29 +135,39 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 5, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/flax/core/scope.py:740: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.\n", + " abs_value_flat = jax.tree_leaves(abs_value)\n", + "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/flax/core/scope.py:741: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.\n", + " value_flat = jax.tree_leaves(value)\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Generation: 0 Perf (Best): 25.3125 Perf (Mean): 22.210625 0.89572144\n", - "Generation: 20 Perf (Best): 28.0625 Perf (Mean): 19.783125 0.8995422\n", - "Generation: 40 Perf (Best): 28.0625 Perf (Mean): 21.81625 0.46944278\n", - "--> Restarted Strategy: Gen 50\n", + "Generation: 0 Perf (Best): 23.25 Perf (Mean): 21.77375 0.7692213\n", + "Generation: 20 Perf (Best): 68.0 Perf (Mean): 59.039997 3.990374\n", + "Generation: 40 Perf (Best): 197.25 Perf (Mean): 185.15125 9.462757\n", + "--> Restarted Strategy: Gen 51\n", "--> New Popsize: 200\n", - "Generation: 60 Perf (Best): 30.875 Perf (Mean): 18.859375 0.8144148\n", - "Generation: 80 Perf (Best): 31.875 Perf (Mean): 22.13125 0.83103853\n", - "Generation: 100 Perf (Best): 41.875 Perf (Mean): 24.884687 1.4744185\n", + "Generation: 60 Perf (Best): 200.0 Perf (Mean): 197.47156 3.7782266\n", + "Generation: 80 Perf (Best): 200.0 Perf (Mean): 199.58093 1.214557\n", + "Generation: 100 Perf (Best): 200.0 Perf (Mean): 200.0 0.0\n", "--> Restarted Strategy: Gen 101\n", "--> New Popsize: 400\n", - "Generation: 120 Perf (Best): 61.5 Perf (Mean): 36.045624 5.693629\n", - "Generation: 140 Perf (Best): 108.25 Perf (Mean): 79.3975 9.712329\n", - "Generation: 160 Perf (Best): 176.25 Perf (Mean): 139.6253 18.515297\n", - "Generation: 180 Perf (Best): 200.0 Perf (Mean): 173.19374 9.615619\n", - "--> Restarted Strategy: Gen 189\n", - "--> New Popsize: 800\n" + "Generation: 120 Perf (Best): 200.0 Perf (Mean): 200.0 0.0\n", + "Generation: 140 Perf (Best): 200.0 Perf (Mean): 199.98969 0.20599204\n", + "--> Restarted Strategy: Gen 151\n", + "--> New Popsize: 800\n", + "Generation: 160 Perf (Best): 200.0 Perf (Mean): 200.0 0.0\n", + "Generation: 180 Perf (Best): 200.0 Perf (Mean): 200.0 0.0\n" ] } ], @@ -203,12 +213,12 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -234,16 +244,16 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "WrapperParams(strategy_params=EvoParams(mu_eff=DeviceArray(26.966648, dtype=float32), c_1=DeviceArray(1.2144131e-06, dtype=float32), c_mu=DeviceArray(3.0331763e-05, dtype=float32), c_sigma=DeviceArray(0.02204519, dtype=float32), d_sigma=DeviceArray(1.0220451, dtype=float32), c_c=DeviceArray(0.00312667, dtype=float32), chi_n=DeviceArray(35.798042, dtype=float32), c_m=1.0, sigma_init=0.065, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38), restart_params=RestartParams(min_num_gens=50, min_fitness_spread=1, popsize_multiplier=2, tol_x=1e-12, tol_x_up=10000.0, tol_condition_C=100000000000000.0))" + "WrapperParams(strategy_params=EvoParams(mu_eff=DeviceArray(26.966648, dtype=float32), c_1=DeviceArray(1.2144133e-06, dtype=float32), c_mu=DeviceArray(3.0331763e-05, dtype=float32), c_sigma=DeviceArray(0.02204519, dtype=float32), d_sigma=DeviceArray(1.0220451, dtype=float32), c_c=DeviceArray(0.00312667, dtype=float32), chi_n=DeviceArray(35.798046, dtype=float32, weak_type=True), c_m=1.0, sigma_init=1.0, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38), restart_params=RestartParams(min_num_gens=50, min_fitness_spread=1, popsize_multiplier=2, tol_x=1e-12, tol_x_up=10000.0, tol_condition_C=100000000000000.0, copy_mean=True))" ] }, - "execution_count": 31, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -265,12 +275,20 @@ "cell_type": "code", "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ParameterReshaper: 1282 parameters detected for optimization.\n" + ] + } + ], "source": [ "# Use single device due to odd/even popsizes -> makes it hard to make sure even division\n", "param_reshaper = ParameterReshaper(params, n_devices=1)\n", - "evaluator = GymFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16, n_devices=1)\n", - "evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)" + "evaluator = GymnaxFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16, n_devices=1)\n", + "evaluator.set_apply_fn(network.apply)" ] }, { @@ -278,44 +296,53 @@ "execution_count": 9, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/flax/core/scope.py:740: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.\n", + " abs_value_flat = jax.tree_leaves(abs_value)\n", + "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/flax/core/scope.py:741: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.\n", + " value_flat = jax.tree_leaves(value)\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Generation: 0 Perf (Best): 24.1875 Perf (Mean): 21.789375\n", - "Generation: 20 Perf (Best): 29.5625 Perf (Mean): 20.15625\n", - "Generation: 40 Perf (Best): 29.5625 Perf (Mean): 21.512499\n", - "Generation: 60 Perf (Best): 31.5625 Perf (Mean): 17.129375\n", - "Generation: 80 Perf (Best): 33.75 Perf (Mean): 24.57625\n", - "Generation: 100 Perf (Best): 53.4375 Perf (Mean): 31.818748\n", - "Generation: 120 Perf (Best): 130.8125 Perf (Mean): 76.581245\n", - "Generation: 140 Perf (Best): 200.0 Perf (Mean): 171.45375\n", - "Generation: 160 Perf (Best): 200.0 Perf (Mean): 199.22499\n", - "--> Restarted Strategy: Gen 180\n", - "--> New Popsize: 200\n", - "Generation: 180 Perf (Best): 200.0 Perf (Mean): 19.42125\n", - "Generation: 200 Perf (Best): 200.0 Perf (Mean): 20.525936\n", - "Generation: 220 Perf (Best): 200.0 Perf (Mean): 20.189062\n", - "Generation: 240 Perf (Best): 200.0 Perf (Mean): 23.519999\n", - "Generation: 260 Perf (Best): 200.0 Perf (Mean): 21.384375\n", - "Generation: 280 Perf (Best): 200.0 Perf (Mean): 21.339375\n", - "Generation: 300 Perf (Best): 200.0 Perf (Mean): 25.099375\n", - "Generation: 320 Perf (Best): 200.0 Perf (Mean): 29.201874\n", - "Generation: 340 Perf (Best): 200.0 Perf (Mean): 30.66\n", - "Generation: 360 Perf (Best): 200.0 Perf (Mean): 29.360624\n", - "Generation: 380 Perf (Best): 200.0 Perf (Mean): 42.391872\n", - "Generation: 400 Perf (Best): 200.0 Perf (Mean): 59.089687\n", - "Generation: 420 Perf (Best): 200.0 Perf (Mean): 70.78906\n", - "Generation: 440 Perf (Best): 200.0 Perf (Mean): 82.59062\n", - "Generation: 460 Perf (Best): 200.0 Perf (Mean): 117.78406\n", - "Generation: 480 Perf (Best): 200.0 Perf (Mean): 171.02657\n", - "Generation: 500 Perf (Best): 200.0 Perf (Mean): 173.53874\n", - "Generation: 520 Perf (Best): 200.0 Perf (Mean): 194.27187\n", - "Generation: 540 Perf (Best): 200.0 Perf (Mean): 197.54688\n", - "Generation: 560 Perf (Best): 200.0 Perf (Mean): 199.68718\n", - "--> Restarted Strategy: Gen 571\n", - "--> New Popsize: 283\n", - "Generation: 580 Perf (Best): 200.0 Perf (Mean): 19.240063\n" + "Generation: 0 Perf (Best): 149.4375 Perf (Mean): 16.563124\n", + "Generation: 20 Perf (Best): 200.0 Perf (Mean): 49.830624\n", + "Generation: 40 Perf (Best): 200.0 Perf (Mean): 178.595\n", + "Generation: 60 Perf (Best): 200.0 Perf (Mean): 195.30812\n", + "Generation: 80 Perf (Best): 200.0 Perf (Mean): 194.295\n", + "Generation: 100 Perf (Best): 200.0 Perf (Mean): 196.60187\n", + "Generation: 120 Perf (Best): 200.0 Perf (Mean): 186.63875\n", + "Generation: 140 Perf (Best): 200.0 Perf (Mean): 187.41937\n", + "Generation: 160 Perf (Best): 200.0 Perf (Mean): 196.69\n", + "Generation: 180 Perf (Best): 200.0 Perf (Mean): 182.97624\n", + "Generation: 200 Perf (Best): 200.0 Perf (Mean): 175.15187\n", + "Generation: 220 Perf (Best): 200.0 Perf (Mean): 178.8075\n", + "Generation: 240 Perf (Best): 200.0 Perf (Mean): 182.48563\n", + "Generation: 260 Perf (Best): 200.0 Perf (Mean): 190.94063\n", + "Generation: 280 Perf (Best): 200.0 Perf (Mean): 181.2175\n", + "Generation: 300 Perf (Best): 200.0 Perf (Mean): 155.97063\n", + "Generation: 320 Perf (Best): 200.0 Perf (Mean): 150.36\n", + "Generation: 340 Perf (Best): 200.0 Perf (Mean): 116.10562\n", + "Generation: 360 Perf (Best): 200.0 Perf (Mean): 64.515\n", + "Generation: 380 Perf (Best): 200.0 Perf (Mean): 25.566874\n", + "Generation: 400 Perf (Best): 200.0 Perf (Mean): 23.734999\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/folders/y4/1lwbxdz55wzg_83326j5cjk40000gn/T/ipykernel_41931/3292922320.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mfitness\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrollout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng_eval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreshaped_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mfit_re\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfit_shaper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfitness\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfit_re\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mes_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0mlog\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mes_logging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfitness\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/flax/struct.py\u001b[0m in \u001b[0;36mclz_from_iterable\u001b[0;34m(meta, data)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmeta\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 120\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0mclz_from_iterable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmeta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 121\u001b[0m \u001b[0mmeta_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmeta_fields\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmeta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0mdata_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_fields\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], @@ -359,7 +386,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { diff --git a/examples/07_brax_control.ipynb b/examples/07_brax_control.ipynb index 5ac4e46..1902e31 100644 --- a/examples/07_brax_control.ipynb +++ b/examples/07_brax_control.ipynb @@ -20,7 +20,7 @@ "%config InlineBackend.figure_format = 'retina'\n", "\n", "!pip install -q git+https://github.com/RobertTLange/evosax.git@main\n", - "!pip install -q brax" + "!pip install -q brax evojax" ] }, { @@ -32,162 +32,172 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ParameterReshaper: 6248 parameters detected for optimization.\n" - ] - } - ], + "outputs": [], "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "from evosax import OpenES, ParameterReshaper, FitnessShaper, NetworkMapper\n", - "from evosax.utils import ESLog\n", - "from evosax.problems import BraxFitness\n", + "import numpy as np\n", + "from evojax.obs_norm import ObsNormalizer\n", + "from evojax.sim_mgr import SimManager\n", + "from evojax.task.brax_task import BraxTask\n", + "from evojax.policy import MLPPolicy\n", "\n", - "# Instantiate brax rollout wrapper & network architecture\n", - "evaluator = BraxFitness(\"ant\", num_env_steps=1000, num_rollouts=16)\n", - "\n", - "rng = jax.random.PRNGKey(0)\n", - "network = NetworkMapper[\"MLP\"](\n", - " num_hidden_units=32,\n", - " num_hidden_layers=4,\n", - " num_output_units=evaluator.action_shape,\n", - " hidden_activation=\"tanh\",\n", - " output_activation=\"tanh\",\n", - ")\n", - "pholder = jnp.zeros((1, evaluator.input_shape[0]))\n", - "params = network.init(\n", - " rng,\n", - " x=pholder,\n", - " rng=rng,\n", - ")\n", - "\n", - "param_reshaper = ParameterReshaper(params)\n", - "\n", - "# Set mapping dictionary for parallelization\n", - "evaluator.set_apply_fn(param_reshaper.vmap_dict, network.apply)" + "from evosax import Strategies\n", + "from evosax.utils.evojax_wrapper import Evosax2JAX_Wrapper" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 2, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "EvoParams(opt_params=OptParams(lrate_init=0.01, lrate_decay=0.999, lrate_limit=0.001, momentum=0.9, beta_1=None, beta_2=None, eps=None, max_speed=None), sigma_init=0.04, sigma_decay=0.999, sigma_limit=0.01, init_min=0.0, init_max=0.0, clip_min=-3.4028235e+38, clip_max=3.4028235e+38)" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "strategy = OpenES(popsize=256,\n", - " num_dims=param_reshaper.total_params,\n", - " opt_name=\"adam\")\n", - "strategy.default_params" + "def get_brax_task(\n", + " env_name = \"ant\",\n", + " hidden_dims = [32, 32, 32, 32],\n", + "):\n", + " train_task = BraxTask(env_name, test=False)\n", + " test_task = BraxTask(env_name, test=True)\n", + " policy = MLPPolicy(\n", + " input_dim=train_task.obs_shape[0],\n", + " output_dim=train_task.act_shape[0],\n", + " hidden_dims=hidden_dims,\n", + " )\n", + " return train_task, test_task, policy" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Generation: 0 Generation: 203.22612\n", - "Generation: 20 Generation: 203.62665\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/var/folders/y4/1lwbxdz55wzg_83326j5cjk40000gn/T/ipykernel_75476/1777371750.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng_ask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mreshaped_params\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparam_reshaper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0mfitness\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrollout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng_eval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreshaped_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0mfit_re\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfit_shaper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfitness\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstrategy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfit_re\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Dropbox/core-code/develop-jax/evosax/evosax/problems/control_brax.py\u001b[0m in \u001b[0;36mrollout\u001b[0;34m(self, rng_input, policy_params)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\"\"\"Placeholder fn call for rolling out a population for multi-evals.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0mrng_pop\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrng_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_rollouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m scores, all_obs, masks = jax.jit(self.rollout_map)(\n\u001b[0m\u001b[1;32m 110\u001b[0m \u001b[0mrng_pop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 111\u001b[0m )\n", - " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", - "\u001b[0;32m~/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mcache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 471\u001b[0m \u001b[0min_type\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfer_lambda_input_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs_flat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 472\u001b[0m \u001b[0mflat_fun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mannotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 473\u001b[0;31m out_flat = xla.xla_call(\n\u001b[0m\u001b[1;32m 474\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs_flat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1763\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1764\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1765\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1766\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1767\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_bind_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1780\u001b[0m \u001b[0mfun_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mannotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0min_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1781\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1782\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_todos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_trace_todo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1783\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess_call\u001b[0;34m(self, primitive, f, tracers, params)\u001b[0m\n\u001b[1;32m 676\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 677\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 678\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 679\u001b[0m \u001b[0mprocess_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 680\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 183\u001b[0m keep_unused, *arg_specs)\n\u001b[1;32m 184\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcompiled_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjax_debug_nans\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjax_debug_infs\u001b[0m \u001b[0;31m# compiled_fun can only raise in this case\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36m_execute_compiled\u001b[0;34m(name, compiled, input_handler, output_buffer_counts, result_handlers, effects, kept_var_idx, *args)\u001b[0m\n\u001b[1;32m 613\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0meffects\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 614\u001b[0m \u001b[0minput_bufs_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoken_handler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_add_tokens\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meffects\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_bufs_flat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m \u001b[0mout_bufs_flat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompiled\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_bufs_flat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 616\u001b[0m \u001b[0mcheck_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_bufs_flat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 617\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0moutput_buffer_counts\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "START EVOLVING 6248 PARAMS.\n", + "{'num_gens': 1} {'train_perf': 984.8515625, 'test_perf': 996.6640625}\n", + "{'num_gens': 50} {'train_perf': 977.61962890625, 'test_perf': 995.9266357421875}\n", + "{'num_gens': 100} {'train_perf': 972.63818359375, 'test_perf': 998.8201904296875}\n", + "{'num_gens': 150} {'train_perf': 974.7550048828125, 'test_perf': 1005.531982421875}\n", + "{'num_gens': 200} {'train_perf': 1276.9852294921875, 'test_perf': 1715.9610595703125}\n", + "{'num_gens': 250} {'train_perf': 1773.773681640625, 'test_perf': 2281.2216796875}\n", + "{'num_gens': 300} {'train_perf': 2309.93212890625, 'test_perf': 2937.6201171875}\n", + "{'num_gens': 350} {'train_perf': 2772.38134765625, 'test_perf': 3408.61474609375}\n", + "{'num_gens': 400} {'train_perf': 3173.67919921875, 'test_perf': 3793.0986328125}\n", + "{'num_gens': 450} {'train_perf': 3442.1396484375, 'test_perf': 4159.99365234375}\n", + "{'num_gens': 500} {'train_perf': 3810.6884765625, 'test_perf': 4592.73876953125}\n", + "{'num_gens': 550} {'train_perf': 4118.63671875, 'test_perf': 4821.9951171875}\n", + "{'num_gens': 600} {'train_perf': 4364.1015625, 'test_perf': 5058.63818359375}\n", + "{'num_gens': 650} {'train_perf': 4587.93603515625, 'test_perf': 5283.1171875}\n", + "{'num_gens': 700} {'train_perf': 4855.1455078125, 'test_perf': 5531.912109375}\n", + "{'num_gens': 750} {'train_perf': 5086.2080078125, 'test_perf': 5737.22265625}\n", + "{'num_gens': 800} {'train_perf': 5173.3076171875, 'test_perf': 5803.7421875}\n", + "{'num_gens': 850} {'train_perf': 5386.2861328125, 'test_perf': 6014.3095703125}\n", + "{'num_gens': 900} {'train_perf': 5541.9794921875, 'test_perf': 6128.41064453125}\n", + "{'num_gens': 950} {'train_perf': 5708.3310546875, 'test_perf': 6317.2353515625}\n", + "{'num_gens': 1000} {'train_perf': 5864.361328125, 'test_perf': 6467.5126953125}\n" ] } ], "source": [ - "num_generations = 1000\n", - "print_every_k_gens = 20\n", - "\n", - "es_logging = ESLog(param_reshaper.total_params,\n", - " num_generations,\n", - " top_k=5,\n", - " maximize=True)\n", - "log = es_logging.initialize()\n", - "\n", - "fit_shaper = FitnessShaper(centered_rank=True,\n", - " z_score=True,\n", - " w_decay=0.1,\n", - " maximize=True)\n", - "\n", - "state = strategy.initialize(rng)\n", + "train_task, test_task, policy = get_brax_task(\"ant\")\n", + "solver = Evosax2JAX_Wrapper(\n", + " Strategies[\"OpenES\"],\n", + " param_size=policy.num_params,\n", + " pop_size=256,\n", + " es_config={\"maximize\": True,\n", + " \"centered_rank\": True,\n", + " \"lrate_init\": 0.01,\n", + " \"lrate_decay\": 0.999,\n", + " \"lrate_limit\": 0.001},\n", + " es_params={\"sigma_init\": 0.05,\n", + " \"sigma_decay\": 0.999,\n", + " \"sigma_limit\": 0.01},\n", + " seed=0,\n", + ")\n", + "obs_normalizer = ObsNormalizer(\n", + " obs_shape=train_task.obs_shape, dummy=not True\n", + ")\n", + "sim_mgr = SimManager(\n", + " policy_net=policy,\n", + " train_vec_task=train_task,\n", + " valid_vec_task=test_task,\n", + " seed=0,\n", + " obs_normalizer=obs_normalizer,\n", + " pop_size=256,\n", + " use_for_loop=False,\n", + " n_repeats=16,\n", + " test_n_repeats=1,\n", + " n_evaluations=128\n", + ")\n", "\n", - "for gen in range(num_generations):\n", - " rng, rng_init, rng_ask, rng_eval = jax.random.split(rng, 4)\n", - " x, state = strategy.ask(rng_ask, state)\n", - " reshaped_params = param_reshaper.reshape(x)\n", - " fitness = evaluator.rollout(rng_eval, reshaped_params).mean(axis=1)\n", - " fit_re = fit_shaper.apply(x, fitness)\n", - " state = strategy.tell(x, fit_re, state)\n", - " log = es_logging.update(log, x, fitness)\n", - " \n", - " if gen % print_every_k_gens == 0:\n", - " print(\"Generation: \", gen, \"Generation: \", log[\"log_top_1\"][gen])" + "print(f\"START EVOLVING {policy.num_params} PARAMS.\")\n", + "# Run ES Loop.\n", + "for gen_counter in range(1000):\n", + " params = solver.ask()\n", + " scores, _ = sim_mgr.eval_params(params=params, test=False)\n", + " solver.tell(fitness=scores)\n", + " if gen_counter == 0 or (gen_counter + 1) % 50 == 0:\n", + " test_scores, _ = sim_mgr.eval_params(\n", + " params=solver.best_params, test=True\n", + " )\n", + " print(\n", + " {\n", + " \"num_gens\": gen_counter + 1,\n", + " },\n", + " {\n", + " \"train_perf\": float(np.nanmean(scores)),\n", + " \"test_perf\": float(np.nanmean(test_scores)),\n", + " },\n", + " )" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 22, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'num_gens': 1000} {'train_perf': 5864.361328125, 'test_perf': 6430.080078125}\n" + ] + } + ], "source": [ - "# Visualize Learning Curve and Policy" + "test_scores, _ = sim_mgr.eval_params(\n", + " params=solver.best_params, test=True\n", + " )\n", + "print(\n", + " {\n", + " \"num_gens\": gen_counter + 1,\n", + " },\n", + " {\n", + " \"train_perf\": float(np.nanmean(scores)),\n", + " \"test_perf\": float(np.nanmean(test_scores)),\n", + " },\n", + ")" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "# Plot the learning curve over generations\n", - "es_logging.plot(log, \"Ant MLP OpenAI-ES\")" + "# Visualize Learning Curve and Policy" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Cumulative reward: -2459.6707\n" + "Cumulative reward: 6469.828\n" ] }, { @@ -211,11 +221,11 @@ " \n", " \n", " \n", "
\n", " \n", @@ -226,7 +236,7 @@ "" ] }, - "execution_count": 31, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -235,27 +245,34 @@ "from IPython.display import HTML\n", "from brax import envs\n", "from brax.io import html\n", + "import jax\n", "\n", - "env = envs.create(env_name=\"ant\")\n", - "jit_env_reset = jax.jit(env.reset)\n", - "jit_env_step = jax.jit(env.step)\n", - "jit_inference_fn = jax.jit(network.apply)\n", + "env = envs.create(env_name=\"ant\", legacy_spring=True)\n", + "task_reset_fn = jax.jit(env.reset)\n", + "policy_reset_fn = jax.jit(policy.reset)\n", + "step_fn = jax.jit(env.step)\n", + "act_fn = jax.jit(policy.get_actions)\n", + "obs_norm_fn = jax.jit(obs_normalizer.normalize_obs)\n", "\n", - "net_params = param_reshaper.reshape_single(state.mean)\n", + "best_params = solver.best_params\n", + "obs_params = sim_mgr.obs_params\n", "\n", + "total_reward = 0\n", "rollout = []\n", - "rng = jax.random.PRNGKey(seed=0)\n", - "env_state = jit_env_reset(rng=rng)\n", - "cum_reward = 0\n", - "for _ in range(1000):\n", - " rollout.append(env_state)\n", - " act_rng, rng = jax.random.split(rng)\n", - " norm_obs = evaluator.obs_normalizer.normalize_obs(env_state.obs, evaluator.obs_params)\n", - " act = jit_inference_fn(net_params, env_state.obs, act_rng)\n", - " env_state = jit_env_step(env_state, act)\n", - " cum_reward += env_state.reward\n", + "rng = jax.random.PRNGKey(seed=42)\n", + "task_state = task_reset_fn(rng=rng)\n", + "policy_state = policy_reset_fn(task_state)\n", + "while not task_state.done:\n", + " rollout.append(task_state)\n", + " task_state = task_state.replace(\n", + " obs=obs_norm_fn(task_state.obs[None, :], obs_params).reshape(1, 87))\n", + " act, policy_state = act_fn(task_state, best_params[None, :], policy_state)\n", + " task_state = task_state.replace(\n", + " obs=obs_norm_fn(task_state.obs[None, :], obs_params).reshape(87,))\n", + " task_state = step_fn(task_state, act[0])\n", + " total_reward = total_reward + task_state.reward\n", "\n", - "print(\"Cumulative reward:\", cum_reward)\n", + "print(\"Cumulative reward:\", total_reward)\n", "HTML(html.render(env.sys, [s.qp for s in rollout]))" ] }, @@ -269,9 +286,9 @@ ], "metadata": { "kernelspec": { - "display_name": "mle-toolbox", + "display_name": "snippets", "language": "python", - "name": "mle-toolbox" + "name": "snippets" }, "language_info": { "codemirror_mode": { @@ -283,7 +300,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.8.11" } }, "nbformat": 4, diff --git a/examples/08_encodings.ipynb b/examples/08_encodings.ipynb deleted file mode 100644 index eb82481..0000000 --- a/examples/08_encodings.ipynb +++ /dev/null @@ -1,276 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 08 - Indirect Encodings\n", - "### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/08_encodings.ipynb)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "%config InlineBackend.figure_format = 'retina'\n", - "\n", - "!pip install -q git+https://github.com/RobertTLange/evosax.git@main\n", - "!pip install -q gymnax" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Experimental (!!!) - Random Encodings" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":228: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n", - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ParameterReshaper: 4610 parameters detected for optimization.\n" - ] - }, - { - "data": { - "text/plain": [ - "DeviceArray(4610, dtype=int32)" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from evosax import NetworkMapper\n", - "from evosax.problems import GymFitness\n", - "from evosax.utils import ParameterReshaper\n", - "\n", - "rng = jax.random.PRNGKey(0)\n", - "# Run Strategy on CartPole MLP\n", - "evaluator = GymFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16)\n", - "\n", - "network = NetworkMapper[\"MLP\"](\n", - " num_hidden_units=64,\n", - " num_hidden_layers=2,\n", - " num_output_units=2,\n", - " hidden_activation=\"relu\",\n", - " output_activation=\"categorical\",\n", - ")\n", - "pholder = jnp.zeros((1, evaluator.input_shape[0]))\n", - "params = network.init(\n", - " rng,\n", - " x=pholder,\n", - " rng=rng,\n", - ")\n", - "\n", - "reshaper = ParameterReshaper(params)\n", - "reshaper.total_params" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from evosax.utils import FitnessShaper\n", - "from evosax.experimental.decodings import RandomDecoder\n", - "\n", - "# Only optimize 10 parameters!\n", - "num_encoding_dims = 6\n", - "reshaper = RandomDecoder(num_encoding_dims, params)\n", - "evaluator.set_apply_fn(reshaper.vmap_dict, network.apply)\n", - "\n", - "fit_shaper = FitnessShaper(maximize=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/rob/anaconda3/envs/mle-toolbox/lib/python3.9/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n", - " warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "20 133.08125 200.0 50.5719 -200.0\n", - "40 149.9775 200.0 45.9242 -200.0\n", - "60 157.07562 200.0 51.393032 -200.0\n", - "80 151.68312 200.0 53.497288 -200.0\n", - "100 160.70312 200.0 50.2572 -200.0\n" - ] - } - ], - "source": [ - "from evosax import DE\n", - "\n", - "strategy = DE(\n", - " num_dims=reshaper.total_params,\n", - " popsize=100,\n", - ")\n", - "state = strategy.initialize(rng)\n", - "\n", - "for t in range(100):\n", - " rng, rng_eval, rng_iter = jax.random.split(rng, 3)\n", - " x, state = strategy.ask(rng_iter, state)\n", - " x_re = reshaper.reshape(x)\n", - " fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)\n", - " fit_re = fit_shaper.apply(x, fitness)\n", - " state = strategy.tell(x, fit_re, state)\n", - "\n", - " if (t + 1) % 20 == 0:\n", - " print(\n", - " t + 1,\n", - " fitness.mean(),\n", - " fitness.max(),\n", - " fitness.std(),\n", - " state.best_fitness,\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Experimental (!!!) - Hypernetwork Encodings" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ParameterReshaper: 2306 parameters detected for optimization.\n" - ] - }, - { - "data": { - "text/plain": [ - "DeviceArray(2306, dtype=int32)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from evosax.experimental.decodings import HyperDecoder\n", - "\n", - "reshaper = HyperDecoder(\n", - " params,\n", - " hypernet_config={\n", - " \"num_latent_units\": 3, # Latent units per module kernel/bias\n", - " \"num_hidden_units\": 2, # Hidden dimensionality of a_i^j embedding\n", - " },\n", - " )\n", - "reshaper.total_params" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "20 19.089375 30.3125 6.1910863 -33.4375\n", - "40 29.787498 195.5625 32.49928 -200.0\n", - "60 31.540625 200.0 44.170444 -200.0\n", - "80 28.501875 200.0 47.56071 -200.0\n", - "100 28.136875 200.0 44.376225 -200.0\n" - ] - } - ], - "source": [ - "strategy = DE(\n", - " num_dims=reshaper.total_params,\n", - " popsize=100,\n", - ")\n", - "state = strategy.initialize(rng)\n", - "\n", - "for t in range(100):\n", - " rng, rng_eval, rng_iter = jax.random.split(rng, 3)\n", - " x, state = strategy.ask(rng_iter, state)\n", - " x_re = reshaper.reshape(x)\n", - " fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)\n", - " fit_re = fit_shaper.apply(x, fitness)\n", - " state = strategy.tell(x, fit_re, state)\n", - "\n", - " if (t + 1) % 20 == 0:\n", - " print(\n", - " t + 1,\n", - " fitness.mean(),\n", - " fitness.max(),\n", - " fitness.std(),\n", - " state.best_fitness\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mle-toolbox", - "language": "python", - "name": "mle-toolbox" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/examples/09_exp_batch_es.ipynb b/examples/09_exp_batch_es.ipynb deleted file mode 100644 index 450369f..0000000 --- a/examples/09_exp_batch_es.ipynb +++ /dev/null @@ -1,367 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 09 - Batch Strategy Rollouts\n", - "### [Last Update: June 2022][![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RobertTLange/evosax/blob/main/examples/09_exp_batch_es.ipynb)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%matplotlib inline\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "%config InlineBackend.figure_format = 'retina'\n", - "\n", - "!pip install -q git+https://github.com/RobertTLange/evosax.git@main\n", - "!pip install -q gymnax" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Experimental (!!!) - Subpopulation Batch ES Rollouts" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":228: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n", - "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ParameterReshaper: 4610 parameters detected for optimization.\n" - ] - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from evosax import NetworkMapper\n", - "from evosax.problems import GymFitness\n", - "from evosax.utils import ParameterReshaper, FitnessShaper\n", - "\n", - "rng = jax.random.PRNGKey(0)\n", - "# Run Strategy on CartPole MLP\n", - "evaluator = GymFitness(\"CartPole-v1\", num_env_steps=200, num_rollouts=16)\n", - "\n", - "network = NetworkMapper[\"MLP\"](\n", - " num_hidden_units=64,\n", - " num_hidden_layers=2,\n", - " num_output_units=2,\n", - " hidden_activation=\"relu\",\n", - " output_activation=\"categorical\",\n", - ")\n", - "pholder = jnp.zeros((1, evaluator.input_shape[0]))\n", - "params = network.init(\n", - " rng,\n", - " x=pholder,\n", - " rng=rng,\n", - ")\n", - "\n", - "reshaper = ParameterReshaper(params)\n", - "evaluator.set_apply_fn(reshaper.vmap_dict, network.apply)\n", - "\n", - "fit_shaper = FitnessShaper(maximize=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "from evosax.experimental.subpops import BatchStrategy\n", - "\n", - "strategy = BatchStrategy(\n", - " strategy_name=\"DE\",\n", - " num_dims=reshaper.total_params,\n", - " popsize=100,\n", - " num_subpops=5,\n", - " communication=\"best_subpop\",\n", - ")\n", - "params = strategy.default_params\n", - "state = strategy.initialize(rng, params)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1 22.464375 26.3125 2.0493376 [-26.3125 -26.3125 -26.3125 -26.3125 -26.3125]\n", - "2 22.575624 29.75 2.9526908 [-29.75 -29.75 -29.75 -29.75 -29.75]\n", - "3 22.914999 29.125 3.4180415 [-29.75 -29.75 -29.75 -29.75 -29.75]\n", - "4 19.238125 28.9375 2.5196369 [-29.75 -29.75 -29.75 -29.75 -29.75]\n", - "5 19.704374 33.0625 2.316076 [-33.0625 -33.0625 -33.0625 -33.0625 -33.0625]\n", - "6 23.7925 61.875 9.7088585 [-61.875 -61.875 -61.875 -61.875 -61.875]\n", - "7 35.21 118.5625 16.621597 [-118.5625 -118.5625 -118.5625 -118.5625 -118.5625]\n", - "8 38.021873 86.375 18.679571 [-118.5625 -118.5625 -118.5625 -118.5625 -118.5625]\n", - "9 45.83875 148.75 31.13269 [-148.75 -148.75 -148.75 -148.75 -148.75]\n", - "10 36.0625 125.6875 28.167828 [-148.75 -148.75 -148.75 -148.75 -148.75]\n", - "11 44.895 182.9375 38.524178 [-182.9375 -182.9375 -182.9375 -182.9375 -182.9375]\n", - "12 49.030624 170.0 36.70624 [-182.9375 -182.9375 -182.9375 -182.9375 -182.9375]\n", - "13 47.264374 170.75 32.65505 [-182.9375 -182.9375 -182.9375 -182.9375 -182.9375]\n", - "14 47.146248 174.8125 35.011383 [-182.9375 -182.9375 -182.9375 -182.9375 -182.9375]\n", - "15 57.025623 200.0 45.128643 [-200. -200. -200. -200. -200.]\n", - "16 87.83625 200.0 69.39789 [-200. -200. -200. -200. -200.]\n", - "17 73.627495 200.0 65.97414 [-200. -200. -200. -200. -200.]\n", - "18 75.694374 200.0 58.033886 [-200. -200. -200. -200. -200.]\n", - "19 73.02125 200.0 65.48465 [-200. -200. -200. -200. -200.]\n", - "20 82.159996 200.0 70.50161 [-200. -200. -200. -200. -200.]\n" - ] - } - ], - "source": [ - "for t in range(20):\n", - " rng, rng_eval, rng_iter = jax.random.split(rng, 3)\n", - " x, state = strategy.ask(rng_iter, state, params)\n", - " x_re = reshaper.reshape(x)\n", - " fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)\n", - " fit_re = fit_shaper.apply(x, fitness)\n", - " state = strategy.tell(x, fit_re, state, params)\n", - "\n", - " if t % 1 == 0:\n", - " print(\n", - " t + 1,\n", - " fitness.mean(),\n", - " fitness.max(),\n", - " fitness.std(),\n", - " state.best_fitness, # Best fitness in all subpops\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Experimental (!!!) - Subpopulation Meta-Batch ES Rollouts" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "EvoParams(mu_eff=DeviceArray(1.6496499, dtype=float32), c_1=DeviceArray(0.15949409, dtype=float32), c_mu=DeviceArray(0.02899084, dtype=float32), c_sigma=DeviceArray(0.42194194, dtype=float32), d_sigma=DeviceArray(1.421942, dtype=float32), c_c=DeviceArray(0.63072497, dtype=float32), chi_n=DeviceArray(1.2542727, dtype=float32), weights=DeviceArray([ 0.73042274, 0.2695773 , 0. , -0.726532 ,\n", - " -1.2900741 ], dtype=float32), weights_truncated=DeviceArray([0.73042274, 0.2695773 , 0. , 0. , 0. ], dtype=float32), c_m=1.0, sigma_init=0.065, init_min=DeviceArray([0.8, 0.9], dtype=float32), init_max=DeviceArray([0.8, 0.9], dtype=float32), clip_min=-3.4028235e+38, clip_max=3.4028235e+38)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from evosax.experimental.subpops import MetaStrategy\n", - "\n", - "meta_strategy = MetaStrategy(\n", - " meta_strategy_name=\"CMA_ES\",\n", - " inner_strategy_name=\"DE\",\n", - " meta_params=[\"diff_w\", \"cross_over_rate\"],\n", - " num_dims=reshaper.total_params,\n", - " popsize=100,\n", - " num_subpops=5,\n", - " meta_strategy_kwargs={\"elite_ratio\": 0.5},\n", - " )\n", - "meta_es_params = meta_strategy.default_params_meta\n", - "meta_es_params.replace(\n", - " clip_min=jnp.array([0, 0]), clip_max=jnp.array([2, 1])\n", - ")\n", - "meta_es_params" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1 20.289999 23.5 1.0141777 [-22.125 -23.5 -23.5 -23.5 -22.125]\n", - "[0.879098 0.76078224 0.7285294 0.8667383 0.9170291 ]\n", - "[0.87730575 0.89257807 0.88649285 0.8704477 0.95789057]\n", - "====================\n", - "2 23.526875 27.8125 2.8445435 [-27.75 -27.75 -27.5625 -27.8125 -27.5625]\n", - "[0.7529136 0.75194293 0.75470865 0.73406416 0.7236362 ]\n", - "[0.92120713 0.80959445 0.904763 0.9089725 0.9178807 ]\n", - "====================\n", - "3 18.4025 23.375 1.7486227 [-27.75 -27.75 -27.5625 -27.8125 -27.5625]\n", - "[0.76000106 0.7763824 0.80698913 0.7326284 0.79374725]\n", - "[0.8445878 0.90652514 0.9469857 0.9224319 0.8527581 ]\n", - "====================\n", - "4 20.2475 28.875 2.2587662 [-27.75 -27.75 -27.5625 -28.875 -27.5625]\n", - "[0.8238988 0.7451631 0.7800138 0.8030797 0.7789296]\n", - "[0.81761867 0.87493116 0.85717034 0.8252281 0.8858736 ]\n", - "====================\n", - "5 21.651875 26.75 2.3301566 [-27.75 -27.75 -27.5625 -28.875 -27.5625]\n", - "[0.7483712 0.78120756 0.7842538 0.8036731 0.8382279 ]\n", - "[0.8533484 0.85126436 0.81118304 0.87271136 0.7978533 ]\n", - "====================\n", - "6 24.300625 32.5625 3.8705084 [-30.1875 -30.75 -28.5625 -32.5625 -29.625 ]\n", - "[0.7154043 0.73717016 0.74947 0.75264627 0.7836797 ]\n", - "[0.8599149 0.8408388 0.80433404 0.9227266 0.866442 ]\n", - "====================\n", - "7 22.035 33.8125 4.2435365 [-33.8125 -30.75 -28.875 -32.5625 -30.3125]\n", - "[0.73832107 0.7330418 0.80643374 0.7836552 0.74030274]\n", - "[0.8556646 0.80183303 0.82818526 0.83224803 0.84638065]\n", - "====================\n", - "8 22.17375 49.0 7.518158 [-45.5625 -42.0625 -28.875 -49. -45.4375]\n", - "[0.7866129 0.73554915 0.7203534 0.752935 0.69754535]\n", - "[0.84873676 0.8862246 0.83316815 0.79738003 0.8420979 ]\n", - "====================\n", - "9 23.22 70.8125 12.524904 [-60.6875 -70.8125 -28.875 -66.5 -45.4375]\n", - "[0.73071593 0.7588424 0.7373127 0.78221434 0.819327 ]\n", - "[0.8632609 0.85133994 0.88077104 0.8476571 0.82447743]\n", - "====================\n", - "10 27.818125 95.25 17.123335 [-60.6875 -70.8125 -28.875 -67.0625 -95.25 ]\n", - "[0.74657243 0.73996896 0.74008286 0.77336246 0.76170486]\n", - "[0.8573377 0.8602473 0.8592791 0.8566864 0.87690324]\n", - "====================\n", - "11 33.4875 170.375 33.512127 [-147.625 -149.6875 -28.875 -89.5625 -170.375 ]\n", - "[0.74657285 0.77310294 0.7667397 0.79113615 0.7745692 ]\n", - "[0.8604234 0.8604967 0.801528 0.87977314 0.8678907 ]\n", - "====================\n", - "12 41.148125 180.875 36.242935 [-147.625 -180.875 -28.875 -89.5625 -170.375 ]\n", - "[0.784878 0.7705825 0.7628718 0.7921372 0.7565861]\n", - "[0.8717209 0.84912974 0.8673709 0.847437 0.8717182 ]\n", - "====================\n", - "13 47.984375 200.0 39.957558 [-147.625 -200. -28.875 -89.5625 -180.5 ]\n", - "[0.76560074 0.76622653 0.7893174 0.77328974 0.7705352 ]\n", - "[0.8389833 0.852461 0.86471146 0.8482863 0.86811644]\n", - "====================\n", - "14 44.32375 200.0 47.0014 [-147.625 -200. -28.875 -94.5 -200. ]\n", - "[0.7761689 0.7649889 0.78513336 0.76904625 0.775042 ]\n", - "[0.85951185 0.854263 0.87067384 0.8597369 0.86205065]\n", - "====================\n", - "15 52.1825 200.0 46.55368 [-147.625 -200. -28.875 -94.5 -200. ]\n", - "[0.76948667 0.76638454 0.7677278 0.79087144 0.7603321 ]\n", - "[0.8666052 0.86009985 0.84476924 0.8650915 0.85453224]\n", - "====================\n", - "16 52.930622 200.0 51.577286 [-149.9375 -200. -28.875 -132.625 -200. ]\n", - "[0.75860137 0.77249414 0.75932413 0.76306224 0.7624783 ]\n", - "[0.8530121 0.8678422 0.8509262 0.85580117 0.8567307 ]\n", - "====================\n", - "17 51.984375 192.9375 48.96752 [-175.875 -200. -28.875 -132.625 -200. ]\n", - "[0.7763821 0.7706568 0.7801008 0.7734966 0.77879196]\n", - "[0.8634914 0.85994005 0.87226665 0.85772806 0.87526596]\n", - "====================\n", - "18 63.015 200.0 56.433624 [-190.25 -200. -28.875 -183.125 -200. ]\n", - "[0.77782947 0.7740208 0.77715224 0.7852487 0.77923805]\n", - "[0.8703792 0.86372644 0.8684063 0.8587013 0.8679136 ]\n", - "====================\n", - "19 60.641247 200.0 54.19631 [-190.25 -200. -28.875 -183.125 -200. ]\n", - "[0.7771916 0.78043425 0.78965497 0.77828103 0.783827 ]\n", - "[0.8698033 0.8648427 0.86594695 0.8754562 0.8779571 ]\n", - "====================\n", - "20 54.751247 200.0 51.718987 [-190.25 -200. -28.875 -200. -200. ]\n", - "[0.77753127 0.78207856 0.78387165 0.7841341 0.7895221 ]\n", - "[0.87135184 0.87689304 0.8684612 0.8671708 0.8700651 ]\n", - "====================\n" - ] - } - ], - "source": [ - "# META: Initialize the meta strategy state\n", - "inner_es_params = meta_strategy.default_params\n", - "meta_state = meta_strategy.initialize_meta(rng, meta_es_params)\n", - "\n", - "# META: Get altered inner es hyperparams (placeholder for init)\n", - "inner_es_params, meta_state = meta_strategy.ask_meta(\n", - " rng, meta_state, meta_es_params, inner_es_params\n", - ")\n", - "\n", - "# INNER: Initialize the inner batch ES\n", - "state = meta_strategy.initialize(rng, inner_es_params)\n", - "\n", - "for t in range(20):\n", - " rng, rng_eval, rng_iter = jax.random.split(rng, 3)\n", - "\n", - " # META: Get altered inner es hyperparams\n", - " inner_es_params, meta_state = meta_strategy.ask_meta(\n", - " rng, meta_state, meta_es_params, inner_es_params\n", - " )\n", - "\n", - " # INNER: Ask for inner candidate params to evaluate on problem\n", - " x, state = meta_strategy.ask(rng_iter, state, inner_es_params)\n", - "\n", - " # INNER: Update using pseudo fitness\n", - " x_re = reshaper.reshape(x)\n", - " fitness = evaluator.rollout(rng_eval, x_re).mean(axis=1)\n", - " fit_re = fit_shaper.apply(x, fitness)\n", - " state = meta_strategy.tell(x, fit_re, state, inner_es_params)\n", - "\n", - " # META: Update the meta strategy\n", - " meta_state = meta_strategy.tell_meta(\n", - " inner_es_params, fit_re, meta_state, meta_es_params\n", - " )\n", - "\n", - " if t % 1 == 0:\n", - " print(\n", - " t + 1,\n", - " fitness.mean(),\n", - " fitness.max(),\n", - " fitness.std(),\n", - " state.best_fitness, # Best fitness in all subpops\n", - " )\n", - " print(inner_es_params.diff_w)\n", - " print(inner_es_params.cross_over_rate)\n", - " print(20 * \"=\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mle-toolbox", - "language": "python", - "name": "mle-toolbox" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/tests/conftest.py b/tests/conftest.py index a9235a4..e251f8e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,6 @@ def pytest_generate_tests(metafunc): "ARS", "PBT", "PersistentES", - "xNES", "Sep_CMA_ES", "Full_iAMaLGaM", "Indep_iAMaLGaM", @@ -26,34 +25,70 @@ def pytest_generate_tests(metafunc): "LM_MA_ES", "RmES", "GLD", + "xNES", + "SNES", + "ESMC", + "DES", + "SAMR_GA", + # "GESMR_GA", + "GuidedES", + "ASEBO", + "CR_FM_NES", + "MR15_GA", ], ) else: - metafunc.parametrize("strategy_name", ["Full_iAMaLGaM"]) + metafunc.parametrize("strategy_name", ["SNES"]) if "classic_name" in metafunc.fixturenames: if metafunc.config.getoption("all"): metafunc.parametrize( "classic_name", [ - "rosenbrock", - "quadratic", - "ackley", - "griewank", - "rastrigin", - "schwefel", - "himmelblau", - "six-hump", + "Sphere", + "EllipsoidalOriginal", + "RastriginOriginal", + "BuecheRastrigin", + "LinearSlope", + # Part 2: Functions with low or moderate conditions + "AttractiveSector", + "StepEllipsoidal", + "RosenbrockOriginal", + "RosenbrockRotated", + # Part 3: Functions with high conditioning and unimodal + "EllipsoidalRotated", + "Discus", + "BentCigar", + "SharpRidge", + "DifferentPowers", + # Part 4: Multi-modal functions with adequate global structure + "RastriginRotated", + "Weierstrass", + "SchaffersF7", + "SchaffersF7IllConditioned", + "GriewankRosenbrock", + # Part 5: Multi-modal functions with weak global structure + "Schwefel", + "Lunacek", + "Gallagher101Me", + "Gallagher21Hi", + # "Katsuura", + # Part 6: Additional low-d functions (not in BBOB) + "Linear", + "Ackley", + "DixonPrice", ], ) else: - metafunc.parametrize("classic_name", ["rosenbrock"]) + metafunc.parametrize("classic_name", ["Sphere"]) if "env_name" in metafunc.fixturenames: if metafunc.config.getoption("all"): metafunc.parametrize( "env_name", - ["CartPole-v1", "ant"], + [ + "CartPole-v1", + ], ) else: metafunc.parametrize("env_name", ["CartPole-v1"]) diff --git a/tests/test_fitness_rollout.py b/tests/test_fitness_rollout.py index cce595d..73d8c5b 100644 --- a/tests/test_fitness_rollout.py +++ b/tests/test_fitness_rollout.py @@ -2,9 +2,8 @@ import jax.numpy as jnp from evosax import CMA_ES, ARS, ParameterReshaper, NetworkMapper from evosax.problems import ( - ClassicFitness, - GymFitness, - BraxFitness, + BBOBFitness, + GymnaxFitness, VisionFitness, SequenceFitness, ) @@ -12,9 +11,7 @@ def test_classic_rollout(classic_name: str): rng = jax.random.PRNGKey(0) - evaluator = ClassicFitness( - classic_name, num_dims=2, num_rollouts=2, noise_std=0.1 - ) + evaluator = BBOBFitness(classic_name, num_dims=2) strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5) params = strategy.default_params state = strategy.initialize(rng, params) @@ -23,29 +20,19 @@ def test_classic_rollout(classic_name: str): rng, rng_gen, rng_eval = jax.random.split(rng, 3) x, state = strategy.ask(rng_gen, state, params) fitness = evaluator.rollout(rng_eval, x) - assert fitness.shape == (20, 2) + assert fitness.shape == (20,) def test_env_ffw_rollout(env_name: str): rng = jax.random.PRNGKey(0) - if env_name in ["CartPole-v1"]: - evaluator = GymFitness(env_name, num_env_steps=100, num_rollouts=10) - network = NetworkMapper["MLP"]( - num_hidden_units=64, - num_hidden_layers=2, - num_output_units=evaluator.action_shape, - hidden_activation="relu", - output_activation="categorical", - ) - else: - evaluator = BraxFitness(env_name, num_env_steps=100, num_rollouts=10) - network = NetworkMapper["MLP"]( - num_hidden_units=64, - num_hidden_layers=2, - num_output_units=evaluator.action_shape, - hidden_activation="tanh", - output_activation="tanh", - ) + evaluator = GymnaxFitness(env_name, num_env_steps=100, num_rollouts=10) + network = NetworkMapper["MLP"]( + num_hidden_units=64, + num_hidden_layers=2, + num_output_units=evaluator.action_shape, + hidden_activation="relu", + output_activation="categorical", + ) pholder = jnp.zeros((1, evaluator.input_shape[0])) net_params = network.init( rng, @@ -53,7 +40,7 @@ def test_env_ffw_rollout(env_name: str): rng=rng, ) reshaper = ParameterReshaper(net_params) - evaluator.set_apply_fn(reshaper.vmap_dict, network.apply) + evaluator.set_apply_fn(network.apply) strategy = ARS(popsize=20, num_dims=reshaper.total_params, elite_ratio=0.5) state = strategy.initialize(rng) @@ -69,21 +56,12 @@ def test_env_ffw_rollout(env_name: str): def test_env_rec_rollout(env_name: str): rng = jax.random.PRNGKey(0) - if env_name in ["CartPole-v1"]: - evaluator = GymFitness(env_name, num_env_steps=100, num_rollouts=10) - network = NetworkMapper["LSTM"]( - num_hidden_units=64, - num_output_units=evaluator.action_shape, - output_activation="categorical", - ) - - else: - evaluator = BraxFitness(env_name, num_env_steps=100, num_rollouts=10) - network = NetworkMapper["LSTM"]( - num_hidden_units=64, - num_output_units=evaluator.action_shape, - output_activation="tanh", - ) + evaluator = GymnaxFitness(env_name, num_env_steps=100, num_rollouts=10) + network = NetworkMapper["LSTM"]( + num_hidden_units=64, + num_output_units=evaluator.action_shape, + output_activation="categorical", + ) pholder = jnp.zeros((1, evaluator.input_shape[0])) carry_init = network.initialize_carry() @@ -94,9 +72,7 @@ def test_env_rec_rollout(env_name: str): rng=rng, ) reshaper = ParameterReshaper(net_params) - evaluator.set_apply_fn( - reshaper.vmap_dict, network.apply, network.initialize_carry - ) + evaluator.set_apply_fn(network.apply, network.initialize_carry) strategy = ARS(popsize=20, num_dims=reshaper.total_params, elite_ratio=0.5) state = strategy.initialize(rng) @@ -134,7 +110,7 @@ def test_vision_fitness(): ) reshaper = ParameterReshaper(net_params) - evaluator.set_apply_fn(reshaper.vmap_dict, network.apply) + evaluator.set_apply_fn(network.apply) strategy = ARS(popsize=4, num_dims=reshaper.total_params, elite_ratio=0.5) state = strategy.initialize(rng) @@ -162,14 +138,9 @@ def test_sequence_fitness(): rng=rng, ) param_reshaper = ParameterReshaper(params) - evaluator.set_apply_fn( - param_reshaper.vmap_dict, - network.apply, - network.initialize_carry, - ) + evaluator.set_apply_fn(network.apply, network.initialize_carry) - strategy = ARS(param_reshaper.total_params, 4) - (param_reshaper.total_params) + strategy = ARS(4, param_reshaper.total_params) es_state = strategy.initialize(rng) x, es_state = strategy.ask(rng, es_state) diff --git a/tests/test_param_reshape.py b/tests/test_param_reshape.py index 42bb987..be0fbe5 100644 --- a/tests/test_param_reshape.py +++ b/tests/test_param_reshape.py @@ -1,10 +1,21 @@ import jax import jax.numpy as jnp -from flax import linen as nn from evosax.networks import LSTM, MLP, CNN from evosax import ParameterReshaper +def test_flat_vector(): + rng = jax.random.PRNGKey(0) + vec_params = jax.random.normal(rng, (2,)) + reshaper = ParameterReshaper(vec_params) + assert reshaper.total_params == 2 + + # Test population batch matrix reshaping + test_params = jnp.zeros((100, 2)) + out = reshaper.reshape(test_params) + assert out.shape == (100, 2) + + def test_reshape_lstm(): rng = jax.random.PRNGKey(1) network = LSTM( diff --git a/tests/test_strategy_api.py b/tests/test_strategy_api.py index a00f5b4..3d7318e 100644 --- a/tests/test_strategy_api.py +++ b/tests/test_strategy_api.py @@ -1,12 +1,15 @@ import jax from evosax import Strategies -from evosax.problems import ClassicFitness +from evosax.problems import BBOBFitness def test_strategy_ask(strategy_name): # Loop over all strategies and test ask API rng = jax.random.PRNGKey(0) - popsize = 20 + if strategy_name == "ESMC": + popsize = 21 + else: + popsize = 20 strategy = Strategies[strategy_name](popsize=popsize, num_dims=2) params = strategy.default_params state = strategy.initialize(rng, params) @@ -19,12 +22,15 @@ def test_strategy_ask(strategy_name): def test_strategy_ask_tell(strategy_name): # Loop over all strategies and test ask API rng = jax.random.PRNGKey(0) - popsize = 20 + if strategy_name == "ESMC": + popsize = 21 + else: + popsize = 20 strategy = Strategies[strategy_name](popsize=popsize, num_dims=2) params = strategy.default_params state = strategy.initialize(rng, params) x, state = strategy.ask(rng, state, params) - evaluator = ClassicFitness("rosenbrock", num_dims=2) + evaluator = BBOBFitness("Sphere", num_dims=2) fitness = evaluator.rollout(rng, x) state = strategy.tell(x, fitness, state, params) return diff --git a/tests/test_strategy_run.py b/tests/test_strategy_run.py index 59b117c..bdacaf2 100644 --- a/tests/test_strategy_run.py +++ b/tests/test_strategy_run.py @@ -1,7 +1,7 @@ import jax import jax.numpy as jnp from evosax import Strategies -from evosax.problems import ClassicFitness +from evosax.problems import BBOBFitness from evosax.utils import FitnessShaper from functools import partial @@ -13,8 +13,11 @@ def test_strategy_run(strategy_name): rng = jax.random.PRNGKey(0) Strat = Strategies[strategy_name] # PBT also returns copy ID integer - treat separately - popsize = 20 - evaluator = ClassicFitness("rosenbrock", 2) + if strategy_name == "ESMC": + popsize = 21 + else: + popsize = 20 + evaluator = BBOBFitness("Sphere", 2) fitness_shaper = FitnessShaper() batch_eval = evaluator.rollout @@ -39,8 +42,11 @@ def test_strategy_scan(strategy_name): rng = jax.random.PRNGKey(0) Strat = Strategies[strategy_name] # PBT also returns copy ID integer - treat separately - popsize = 20 - evaluator = ClassicFitness("rosenbrock", 2) + if strategy_name == "ESMC": + popsize = 21 + else: + popsize = 20 + evaluator = BBOBFitness("Sphere", 2) fitness_shaper = FitnessShaper() batch_eval = evaluator.rollout