This repo is the official code release for the ICLR 2024 conference paper:
Unleashing the Power of Pre-trained Language Models for Offline Reinforcement Learning
|
We propose LaMo, an offline RL framework that leverages the pre-trained Language Models (LMs) for low-level Motion control. On sparse-reward tasks, LaMo achieves strong results and surpasses recent strong algorithms CQL, IQL, TD3+BC, and DT; On dense-reward tasks, LaMo significantly improves Decision Transformer and closes the gap between value-based methods and DT-based methods. Notably, in low-data scenarios, our method demonstrates powerful few-shot learning ability, which can be attributed to the inductive bias from pre-trained LMs.
We look into the relationship between the performance of various algorithms and the scale of data. As depicted in the Figure, LaMo is capable of achieving excellent performance even with relatively small datasets. For example, in Hopper, LaMo surpasses the performance of CQL and DT when the sample ratio of data is 0.5% and maintains this advantage consistently as the sample ratio increases.
Below, we visualize 8 tasks across 3 domains that we consider.
- D4RL
- MuJoCo: Hopper, Walker2d, HalfCheetah, Reacher2d
- Kitchen
- Atari: Breakout, Qbert, Pong
We can only guarantee the reproducibility with the environment configuration as below.
First, you need to download the file from this link and tar -xvf the_file_name
in the ~/.mujoco
folder. Then, run the following commands.
cd experiment-d4rl
conda create -n lamo-d4rl python=3.8.17
After that, add the following lines to your ~/.bashrc
file:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/YOUR_PATH_TO_THIS/.mujoco/mujoco210/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia
Remember to source ~/.bashrc
to make the changes take effect.
Install D4RL by following the guidance in D4RL.
Degrade the dm-control and mujoco package:
pip install mujoco==2.3.7
pip install dm-control==1.0.14
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install -r requirements.txt
To download original D4RL data,
cd data
python download_d4rl_datasets.py
As for downsampled data, if you want to reproduce our experiments, you should directly get our pre-processed data in this link.
You can also generate more downsampled data by modifing line 10 of 'data/mujoco/ratio_dataset.py' and line 10 of 'data/kitchen/ratio_dataset.py' as
suffix = [your data version name]
and then run
cd data
cd mujoco
python ratio_dataset.py
cd ..
cd kitchen
python ratio_dataset.py
cd ..
You can also try generating the data using a PPO agent trained by yourself (only support Reacher2d), as provided in ‘data/data_generation_PPO’.
First make sure you have the dependencies to install Atari.
sudo apt install cmake
sudo apt install zlib1g-dev
Then run the following commands.
cd experiment-atari
conda env create -f env.yml
The dataset will be downloaded automatically and cached locally by the package d4rl-atari once you launch an experiment. To reproduce our results of downsampled datasets, you can set the seed to be identical to ours (3 seeds, 0, 1, and 2), and our implementation of experiment-atari/buffer.py
will make sure that the downsampled dataset will also be identical to ours.
After installing the packages and data, to reproduce our results on D4RL, you only need to run scripts provided in this link.
If you meet errors in running those scripts, try
dos2unix [the-script-name].sh
If you meet errors about D4RL or MuJoCo when running, these tips 1,2 may help.
If you want to view results on Weights & Biases, you need to modify line 435, 436 of 'experiment.py' as:
entity=[your-group-name],
project=[your-project-name],
You can also design your own script as ``run.sh''.
cd experiment-d4rl
bash run.sh [env_name] [dataset_name] [sample_ratio] [description] [seed] [gpu]
An example is:
bash run.sh hopper medium 0.1 reproduce 0 0
Trying more configurations is encouraged! Important arguments are explained as below:
-w # enable wandb
--sample_ratio your_sample_ratio # determine the size of the data you are training on, like 0.1
--data_suffix your_data_version_name # you could downsample the data by yourself, default is "d1"
--mlp_embedding # use MLP as embeddings and projections
--adapt_mode # otherwise fully fine-tuning
--adapt_embed # fine-tune embeddings and projections when adapt_mode is ON
--lora # fine-tune low rank matrices of Transformer when adapt_mode is ON
--pretrained_lm language_model_name # you could try 'gpt2' and 'gpt2-medium'
--co_training # use language loss as auxiliary objective
--co_lambda # the weight of language loss, like 0.1
To reproduce our results on Breakout with one click, run the following commands
cd experiment-atari
bash run.sh
Since we use Hydra to manage the configuration of the experiments on Atari, you can overwrite hypermeters conveniently. If you want to run experiments on more environments, add the configuration for the corresponding environment under experiments-atari/cfgs/env
. Refer to the documentation of Hydra for more details. Here are a few important hyperparameters:
env # environment name (breakout, qbert, pong, or any atari environment you want to explore)
pretrained_lm # gpt2, gpt2-medium or none
seed # 0, 1, 2
sample_ratio # the ratio of dataset you train on
model.random_initialize # randomly initialize the weight of the model (overwrite the pretrained weight) or not
model.adapt_cfg.use_adapt # use adapt mode or not (relative to fully finetune)
model.adapt_cfg.adapt_embed # unfreeze embedding or not
model.lora_cfg.use_lora # use lora or not
model.lora_cfg.lora_attn_dim # the dimension of lora
model.context_len # the context length of the transformer model
train.lr # learning rate
train.weight_decay # weight decay
train.batch_size # batch size
nlp_train.co_training # use language joint training or not
nlp_train.co_lambda # the weight of language joint training loss
LaMo is based on many open-source projects, including Decision Transformer, Can Wikipedia Help Offline Reinforcement Learning, LoRA, DeFog, d4rl-atari. We thank all these authors for their nicely open sourced code and their great contributions to the community.
LaMo is licensed under the MIT license. See the LICENSE file for details.
If you find our work useful, please consider citing:
@article{Shi2024LaMo,
title={Unleashing the Power of Pre-trained Language Models for Offline Reinforcement Learning},
author={Ruizhe Shi and Yuyao Liu and Yanjie Ze and Simon S. Du and Huazhe Xu},
journal={International Conference on Learning Representations},
year={2024}
}