Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Many updates to be released in v.0.1.0 #34

Merged
merged 14 commits into from
Dec 5, 2022
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
examples/experimental
bbob.py
# Standard ROB excludes
.sync-config.cson
.vim-arsync
Expand Down
27 changes: 23 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,47 @@
### Work-in-Progress

- [ ] Make xNES work with all optimizers (currently only GD)
- Implement more strategies
- [ ] Large-scale CMA-ES variants
- [ ] [LM-CMA](https://www.researchgate.net/publication/282612269_LM-CMA_An_alternative_to_L-BFGS_for_large-scale_black_Box_optimization)
- [ ] [VkD-CMA](https://hal.inria.fr/hal-01306551v1/document), [Code](https://gist.github.com/youheiakimoto/2fb26c0ace43c22b8f19c7796e69e108)
- [ ] [sNES](https://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf) (separable version of xNES)
- [ ] [ASEBO](https://proceedings.neurips.cc/paper/2019/file/88bade49e98db8790df275fcebb37a13-Paper.pdf)
- [ ] [RBO](http://proceedings.mlr.press/v100/choromanski20a/choromanski20a.pdf)

- Encoding methods - via special reshape wrappers
- [ ] Discrete Cosine Transform
- [ ] Wavelet Based Encoding (van Steenkiste, 2016)
- [ ] Hypernetworks (Ha - start with simple MLP)
- [ ] CNN Hypernetwork (Ha - start with simple MLP)

### [v0.1.0] - [TBD]

##### Added

- Adds a `total_env_steps` counter to both `GymFitness` and `BraxFitness` for easier sample efficiency comparability with RL algorithms.
- Support for new strategies/genetic algorithms
- SAMR-GA (Clune et al., 2008)
- GESMR-GA (Kumar et al., 2022)
- SNES (Wierstra et al., 2014)
- DES (Lange et al., 2022)
- Guided ES (Maheswaranathan et al., 2018)
- ASEBO (Choromanski et al., 2019)
- CR-FM-NES (Nomura & Ono, 2022)
- MR15-GA (Rechenberg, 1978)
- Adds full set of BBOB low-dimensional functions (`BBOBFitness`)
- Adds 2D visualizer animating sampled points (`BBOBVisualizer`)
- Adds `Evosax2JAXWrapper` to wrap all evosax strategies
- Adds Adan optimizer (Xie et al., 2022)

##### Changed

- `ParameterReshaper` can now be directly applied from within the strategy. You simply have to provide a `pholder_params` pytree at strategy instantiation (and no `num_dims`).
- `FitnessShaper` can also be directly applied from within the strategy. This makes it easier to track the best performing member across generations and addresses issue #32. Simply provide the fitness shaping settings as args to the strategy (`maximize`, `centered_rank`, ...)
- Removes Brax fitness (use EvoJAX version instead)
- Add lrate and sigma schedule to strategy instantiation

##### Fixed

- Fixed reward masking in `GymFitness`. Using `jnp.sum(dones) >= 1` for cumulative return computation zeros out the final timestep, which is wrong. That's why there were problems with sparse reward gym environments (e.g. Mountain Car).
- Fixed PGPE sample indexing.
- Fixed weight decay. Falsely multiplied by -1 when maximization.

### [v0.0.9] - 15/06/2022

Expand Down
182 changes: 98 additions & 84 deletions README.md

Large diffs are not rendered by default.

49 changes: 39 additions & 10 deletions evosax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .strategy import Strategy
from .strategy import Strategy, EvoState, EvoParams
from .strategies import (
SimpleGA,
SimpleES,
Expand All @@ -9,7 +9,6 @@
PGPE,
PBT,
PersistentES,
xNES,
ARS,
Sep_CMA_ES,
BIPOP_CMA_ES,
Expand All @@ -21,6 +20,16 @@
RmES,
GLD,
SimAnneal,
SNES,
xNES,
ESMC,
DES,
SAMR_GA,
GESMR_GA,
GuidedES,
ASEBO,
CR_FM_NES,
MR15_GA,
)
from .utils import FitnessShaper, ParameterReshaper, ESLog
from .networks import NetworkMapper
Expand All @@ -37,7 +46,6 @@
"PGPE": PGPE,
"PBT": PBT,
"PersistentES": PersistentES,
"xNES": xNES,
"ARS": ARS,
"Sep_CMA_ES": Sep_CMA_ES,
"BIPOP_CMA_ES": BIPOP_CMA_ES,
Expand All @@ -49,9 +57,27 @@
"RmES": RmES,
"GLD": GLD,
"SimAnneal": SimAnneal,
"SNES": SNES,
"xNES": xNES,
"ESMC": ESMC,
"DES": DES,
"SAMR_GA": SAMR_GA,
"GESMR_GA": GESMR_GA,
"GuidedES": GuidedES,
"ASEBO": ASEBO,
"CR_FM_NES": CR_FM_NES,
"MR15_GA": MR15_GA,
}

__all__ = [
"Strategies",
"EvoState",
"EvoParams",
"FitnessShaper",
"ParameterReshaper",
"ESLog",
"NetworkMapper",
"ProblemMapper",
"Strategy",
"SimpleGA",
"SimpleES",
Expand All @@ -62,7 +88,6 @@
"PGPE",
"PBT",
"PersistentES",
"xNES",
"ARS",
"Sep_CMA_ES",
"BIPOP_CMA_ES",
Expand All @@ -74,10 +99,14 @@
"RmES",
"GLD",
"SimAnneal",
"Strategies",
"FitnessShaper",
"ParameterReshaper",
"ESLog",
"NetworkMapper",
"ProblemMapper",
"SNES",
"xNES",
"ESMC",
"DES",
"SAMR_GA",
"GESMR_GA",
"GuidedES",
"ASEBO",
"CR_FM_NES",
"MR15_GA",
]
2 changes: 0 additions & 2 deletions evosax/experimental/decodings/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ def __init__(
self,
num_encoding_dims: int,
placeholder_params: Union[chex.ArrayTree, chex.Array],
identity: bool = False,
n_devices: Optional[int] = None,
):
self.num_encoding_dims = num_encoding_dims
self.total_params = num_encoding_dims
self.placeholder_params = placeholder_params
self.identity = identity
if n_devices is None:
self.n_devices = jax.local_device_count()
else:
Expand Down
17 changes: 12 additions & 5 deletions evosax/experimental/decodings/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from flax.core import unfreeze
from typing import Union, Optional
from .decoder import Decoder
from ...utils import ParameterReshaper
from ...networks import HyperNetworkMLP
from .hyper_networks import HyperNetworkMLP
from ...utils import ParameterReshaper, ravel_pytree


class HyperDecoder(Decoder):
Expand Down Expand Up @@ -39,24 +39,31 @@ def __init__(
super().__init__(
hyper_reshaper.total_params,
placeholder_params,
identity,
n_devices,
)
self.hyper_reshaper = hyper_reshaper
self.vmap_dict = self.hyper_reshaper.vmap_dict

def reshape(self, x: chex.Array) -> chex.ArrayTree:
"""Perform reshaping for random projection case."""
"""Perform reshaping for hypernetwork case."""
# 0. Reshape genome into params for hypernetwork
x_params = self.hyper_reshaper.reshape(x)
# 1. Project parameters to raw dimensionality using hypernetwork
hyper_x = jax.jit(jax.vmap(self.hyper_network.apply))(x_params)
return hyper_x

def reshape_single(self, x: chex.Array) -> chex.ArrayTree:
"""Reshape a single flat vector using random projection matrix."""
"""Reshape a single flat vector using hypernetwork."""
# 0. Reshape genome into params for hypernetwork
x_params = self.hyper_reshaper.reshape_single(x)
# 1. Project parameters to raw dimensionality using hypernetwork
hyper_x = jax.jit(self.hyper_network.apply)(x_params)
return hyper_x

def flatten(self, x: chex.ArrayTree) -> chex.Array:
"""Reshaping pytree parameters into flat array."""
return jax.vmap(ravel_pytree)(x)

def flatten_single(self, x: chex.ArrayTree) -> chex.Array:
"""Reshaping pytree parameters into flat array."""
return ravel_pytree(x)
11 changes: 6 additions & 5 deletions evosax/experimental/decodings/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,14 @@ def __init__(
placeholder_params: Union[chex.ArrayTree, chex.Array],
rng: chex.PRNGKey = jax.random.PRNGKey(0),
rademacher: bool = False,
identity: bool = False,
n_devices: Optional[int] = None,
):
"""Random Projection Decoder (Gaussian/Rademacher random matrix)."""
super().__init__(
num_encoding_dims, placeholder_params, identity, n_devices
)
super().__init__(num_encoding_dims, placeholder_params, n_devices)
self.rademacher = rademacher
# Instantiate base reshaper class
self.base_reshaper = ParameterReshaper(
placeholder_params, identity, n_devices, verbose=False
placeholder_params, n_devices, verbose=False
)
self.vmap_dict = self.base_reshaper.vmap_dict

Expand All @@ -35,6 +32,10 @@ def __init__(
self.project_matrix = jax.random.rademacher(
rng, (self.num_encoding_dims, self.base_reshaper.total_params)
)
print(
"RandomDecoder: Encoding parameters to optimize -"
f" {num_encoding_dims}"
)

def reshape(self, x: chex.Array) -> chex.ArrayTree:
"""Perform reshaping for random projection case."""
Expand Down
2 changes: 0 additions & 2 deletions evosax/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .mlp import MLP
from .cnn import CNN, All_CNN_C
from .lstm import LSTM
from .hyper_networks import HyperNetworkMLP


# Helper that returns model based on string name
Expand All @@ -18,5 +17,4 @@
"All_CNN_C",
"LSTM",
"NetworkMapper",
"HyperNetworkMLP",
]
15 changes: 6 additions & 9 deletions evosax/problems/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
from .control_brax import BraxFitness
from .control_gym import GymFitness
from .control_gym import GymnaxFitness
from .vision import VisionFitness
from .classic import ClassicFitness
from .bbob import BBOBFitness
from .sequence import SequenceFitness

ProblemMapper = {
"Gym": GymFitness,
"Brax": BraxFitness,
"Gymnax": GymnaxFitness,
"Vision": VisionFitness,
"Classic": ClassicFitness,
"BBOB": BBOBFitness,
"Sequence": SequenceFitness,
}

__all__ = [
"BraxFitness",
"GymFitness",
"GymnaxFitness",
"VisionFitness",
"ClassicFitness",
"BBOBFitness",
"SequenceFitness",
"ProblemMapper",
]
Loading