Synthetic Experience Replay (SynthER) is a diffusion-based approach to arbitrarily upsample an RL agent's collected experience, leading to large gains in sample efficiency and scaling benefits. We integrate SynthER into a variety of offline and online algorithms in this codebase, including SAC, TD3+BC, IQL, EDAC, and CQL. For further details, please see the paper:
Synthetic Experience Replay; Cong Lu*, Philip J. Ball*, Yee Whye Teh, Jack Parker-Holder. Published at NeurIPS, 2023.
To install, clone the repository and run the following:
git submodule update --init --recursive
pip install -r requirements.txt
The code was tested on Python 3.8 and 3.9. If you don't have MuJoCo installed, follow the instructions here: https://github.com/openai/mujoco-py#install-mujoco.
Diffusion model training (this automatically generates samples and saves them):
python3 synther/diffusion/train_diffuser.py --dataset halfcheetah-medium-replay-v2
Baseline without SynthER (e.g. on TD3+BC):
python3 synther/corl/algorithms/td3_bc.py --config synther/corl/yaml/td3_bc/halfcheetah/medium_replay_v2.yaml --checkpoints_path corl_logs/
Offline RL training with SynthER:
# Generating diffusion samples on the fly.
python3 synther/corl/algorithms/td3_bc.py --config synther/corl/yaml/td3_bc/halfcheetah/medium_replay_v2.yaml --checkpoints_path corl_logs/ --name SynthER --diffusion.path path/to/model-100000.pt
# Using saved diffusion samples.
python3 synther/corl/algorithms/td3_bc.py --config synther/corl/yaml/td3_bc/halfcheetah/medium_replay_v2.yaml --checkpoints_path corl_logs/ --name SynthER --diffusion.path path/to/samples.npz
Baselines (SAC, REDQ):
# SAC.
python3 synther/online/online_exp.py --env quadruped-walk-v0 --results_folder online_logs/ --exp_name SAC --gin_config_files 'config/online/sac.gin'
# REDQ.
python3 synther/online/online_exp.py --env quadruped-walk-v0 --results_folder online_logs/ --exp_name REDQ --gin_config_files 'config/online/redq.gin'
SynthER (SAC):
# DMC environments.
python3 synther/online/online_exp.py --env quadruped-walk-v0 --results_folder online_logs/ --exp_name SynthER --gin_config_files 'config/online/sac_synther_dmc.gin' --gin_params 'redq_sac.utd_ratio = 20' 'redq_sac.num_samples = 1000000'
# OpenAI environments (different gin config).
python3 synther/online/online_exp.py --env HalfCheetah-v2 --results_folder online_logs/ --exp_name SynthER --gin_config_files 'config/online/sac_synther_openai.gin' --gin_params 'redq_sac.utd_ratio = 20' 'redq_sac.num_samples = 1000000'
Our codebase has everything you need for diffusion with low-dimensional data along with example integrations with RL algorithms.
For a custom use-case, we recommend starting from the training script and SimpleDiffusionGenerator
class
in synther/diffusion/train_diffuser.py
. You can modify the hyperparameters specified in config/resmlp_denoiser.gin
to suit your own needs.
- Our codebase uses
wandb
for logging, you will need to set--wandb-entity
across the repository. - Our pixel-based experiments are based on a modified version of the V-D4RL repository. The latent representations are derived from the trunks of the actor and critic.
SynthER builds upon many works and open-source codebases in both diffusion modelling and reinforcement learning. We would like to particularly thank the authors of:
Please contact Cong Lu or Philip Ball for any queries. We welcome any suggestions or contributions!