Code repository for the ACL 2023 paper "RL4F: Generating Natural Language Feedback with Reinforcement Learning". Afra Feyza Akyurek, Ekin Akyurek, Aman Madaan, Ashwin Kalyan, Peter Clark, Derry Wijaya, Niket Tandon. Check out the project page for a brief introduction.
This codebase is primarily based on the RL4LMs repository. We provide custom data classes and reward functions to implement RL4F.
git clone https://github.com/feyzaakyurek/rl4f.git
cd rl4f
pip install -e .
rl4lms/data_pools/custom_text_generation_pools.py
: This file contains custom dataset loading classes. Make sure to specify the correct data paths in respective classes.
scripts/training/task_configs/
: Yaml files containing configs are stored under this path. This is where we specify the training and evaluation arguments, model and output paths.
rl4f_scripts/
: A sample sh script for supervised critique generation.
openai_key
: If running RL4F with one of OpenAI models, you need to place your API key in a file and specify the path in config. Give a path to this key in your yaml file. Note that RL4F runs using openai API might incur significant charges.
wandb_key
: We track our runs using wandb, specify your API key here which is used in the sh script.
Download data and checkpoints from this link.
All scripts can be found under rl4f_scripts
. For example, check out the rl4f_scripts/run_alphabetize_sup.sh
script for warm-starting a pretrained T5-large for supervised critique generation for alphabetization. Alternatively, you can load the released checkpoint from the above drive link. For PPO training, specify the checkpoint at scripts/training/task_configs/alphabetize/t5large_ppo_on_supervised.yaml
and run rl4f_scripts/run_alphabetize_ppo.sh
.
If you are receiving an error about your torch installation not supporting sm_86, try uninstalling torch and reinstalling with conda using the cudatoolkit that matches your environment. E.g.
conda install pytorch==1.11.0 cudatoolkit=11.3 -c pytorch