Tired of having to handle asynchronous processes for neuroevolution? Do you want to leverage massive vectorization and high-throughput accelerators for evolution strategies (ES)? evosax
allows you to leverage JAX, XLA compilation and auto-vectorization/parallelization to scale ES to your favorite accelerators. The API is based on the classical ask
, evaluate
, tell
cycle of ES. Both ask
and tell
calls are compatible with jit
, vmap
/pmap
and lax.scan
. It includes a vast set of both classic (e.g. CMA-ES, Differential Evolution, etc.) and modern neuroevolution (e.g. OpenAI-ES, Augmented RS, etc.) strategies. You can get started here π
import jax
from evosax import CMA_ES
# Instantiate the search strategy
rng = jax.random.PRNGKey(0)
strategy = CMA_ES(popsize=20, num_dims=2, elite_ratio=0.5)
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)
# Run ask-eval-tell loop - NOTE: By default minimization!
for t in range(num_generations):
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
x, state = strategy.ask(rng_gen, state, es_params)
fitness = ... # Your population evaluation fct
state = strategy.tell(x, fitness, state, es_params)
# Get best overall population member & its fitness
state["best_member"], state["best_fitness"]
Strategy | Reference | Import | Example |
---|---|---|---|
OpenES | Salimans et al. (2017) | OpenES |
|
PGPE | Sehnke et al. (2010) | PGPE |
|
ARS | Mania et al. (2018) | ARS |
|
CMA-ES | Hansen & Ostermeier (2001) | CMA_ES |
|
Simple Gaussian | Rechenberg (1978) | SimpleES |
|
Simple Genetic | Such et al. (2017) | SimpleGA |
|
x-NES | Wierstra et al. (2014) | xNES |
|
Particle Swarm Optimization | Kennedy & Eberhart (1995) | PSO |
|
Differential Evolution | Storn & Price (1997) | DE |
|
Persistent ES | Vicol et al. (2021) | PersistentES |
|
Population-Based Training | Jaderberg et al. (2017) | PBT |
|
Sep-CMA-ES | Ros & Hansen (2008) | Sep_CMA_ES |
|
BIPOP-CMA-ES | Hansen (2009) | BIPOP_CMA_ES |
|
IPOP-CMA-ES | Auer & Hansen (2005) | IPOP_CMA_ES |
|
Full-iAMaLGaM | Bosman et al. (2013) | Full_iAMaLGaM |
|
Independent-iAMaLGaM | Bosman et al. (2013) | Indep_iAMaLGaM |
|
MA-ES | Bayer & Sendhoff (2017) | MA_ES |
|
LM-MA-ES | Loshchilov et al. (2017) | LM_MA_ES |
|
RmES | Li & Zhang (2017) | RmES |
TBC |
GLD | Golovin et al. (2019) | GLD |
TBC |
The latest evosax
release can directly be installed from PyPI:
pip install evosax
If you want to get the most recent commit, please install directly from the repository:
pip install git+https://github.com/RobertTLange/evosax.git@main
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
- π Classic ES Tasks: API introduction on Rosenbrock function (CMA-ES, Simple GA, etc.).
- π CartPole-Control: OpenES & PEPG on the
CartPole-v1
gym task (MLP/LSTM controller). - π MNIST-Classifier: OpenES on MNIST with CNN network.
- π LRateTune-PES: Persistent ES on meta-learning problem as in Vicol et al. (2021).
- π Quadratic-PBT: PBT on toy quadratic problem as in Jaderberg et al. (2017).
- π Restart-Wrappers: Custom restart wrappers as e.g. used in (B)IPOP-CMA-ES.
-
Strategy Diversity:
evosax
implements more than 10 classical and modern neuroevolution strategies. All of them follow the same simpleask
/eval
API and come with tailored tools such as the ClipUp optimizer, parameter reshaping into PyTrees and fitness shaping (see below). -
Vectorization/Parallelization of
ask
/tell
Calls: Bothask
andtell
calls can leveragejit
,vmap
/pmap
. This enables vectorized/parallel rollouts of different evolution strategies.
from evosax import ARS
# E.g. vectorize over different lrate decays
strategy = ARS(popsize=100, num_dims=20)
es_params = {
"lrate_decay": jnp.array([0.999, 0.99, 0.9]),
...
}
map_dict = {
"lrate_decay": 0,
...
}
# Vmap-composed batch initialize, ask and tell functions
batch_init = jax.vmap(strategy.init, in_axes=(None, map_dict))
batch_ask = jax.vmap(strategy.ask, in_axes=(None, 0, map_dict))
batch_tell = jax.vmap(strategy.tell, in_axes=(0, 0, 0, map_dict))
- Scan Through Evolution Rollouts: You can also
lax.scan
through entireinit
,ask
,eval
,tell
loops for fast compilation of ES loops:
@partial(jax.jit, static_argnums=(1,))
def run_es_loop(rng, num_steps):
"""Run evolution ask-eval-tell loop."""
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)
def es_step(state_input, tmp):
"""Helper es step to lax.scan through."""
rng, state = state_input
rng, rng_iter = jax.random.split(rng)
x, state = strategy.ask(rng_iter, state, es_params)
fitness = ...
state = strategy.tell(y, fitness, state, es_params)
return [rng, state], fitness[jnp.argmin(fitness)]
_, scan_out = jax.lax.scan(es_step,
[rng, state],
[jnp.zeros(num_steps)])
return jnp.min(scan_out)
- Population Parameter Reshaping: We provide a
ParamaterReshaper
wrapper to reshape flat parameter vectors into PyTrees. The wrapper is compatible with JAX neural network libraries such as Flax/Haiku and makes it easier to afterwards evaluate network populations.
from flax import linen as nn
from evosax import ParameterReshaper
class MLP(nn.Module):
num_hidden_units: int
...
@nn.compact
def __call__(self, obs):
...
return ...
network = MLP(64)
policy_params = network.init(rng, jnp.zeros(4,), rng)
# Initialize reshaper based on placeholder network shapes
param_reshaper = ParameterReshaper(policy_params["params"])
# Get population candidates & reshape into stacked pytrees
x = strategy.ask(...)
x_shaped = param_reshaper.reshape(x)
- Flexible Fitness Shaping: By default
evosax
assumes that the fitness objective is to be minimized. If you would like to maximize instead, perform rank centering, z-scoring or add weight regularization you can use theFitnessShaper
:
from evosax import FitnessShaper
# Instantiate jittable fitness shaper (e.g. for Open ES)
fit_shaper = FitnessShaper(centered_rank=True,
z_score=True,
weight_decay=0.01,
maximize=True)
# Shape the evaluated fitness scores
fit_shaped = fit_shaper.apply(x, fitness)
- Strategy Restart Wrappers: You can also choose from a set of different restart mechanisms, which will relaunch a strategy (with e.g. new population size) based on termination criteria. Note: For all restart strategies which alter the population size the ask and tell methods will have to be re-compiled at the time of change.
from evosax import CMA_ES
from evosax.restarts import BIPOP_Restarter
# Define a termination criterion (kwargs - fitness, state, params)
def std_criterion(fitness, state, params):
"""Restart strategy if fitness std across population is small."""
return fitness.std() < 0.001
# Instantiate Base CMA-ES & wrap with BIPOP restarts
# Pass strategy-specific kwargs separately (e.g. elite_ration or opt_name)
strategy = CMA(num_dims, popsize, elite_ratio)
re_strategy = BIPOP_Restarter(
strategy,
stop_criteria=[std_criterion],
strategy_kwargs={"elite_ratio": elite_ratio}
)
state = re_strategy.initialize(rng, es_params)
# ask/tell loop - restarts are automatically handled
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
x, state = re_strategy.ask(rng_gen, state, params)
fitness = ... # Your population evaluation fct
state = re_strategy.tell(x, fitness, state, params)
- πΊ Rob's MLC Research Jam Talk: Small motivation talk at the ML Collective Research Jam.
- π Rob's 02/2021 Blog: Tutorial on CMA-ES & leveraging JAX's primitives.
- π» Evojax: JAX-ES library by Google Brain with great rollout wrappers.
- π» QDax: Quality-Diversity algorithms in JAX.
If you use evosax
in your research, please cite it as follows:
@software{evosax2022github,
author = {Robert Tjarko Lange},
title = {{evosax}: JAX-based Evolution Strategies},
url = {http://github.com/RobertTLange/evosax},
year = {2022},
}
You can run the test suite via python -m pytest -vv --all
. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing π€.