This repository contains the code, data, and checkpoints for our paper published in ICML2023:
Semi-Offline Reinforcement Learning for Optimized Text Generation
Changyu Chen, Xiting Wang, Yiqiao Jin, Victor Ye Dong, Li Dong, Jie Cao, Yi Liu, Rui Yan
Paper: http://arxiv.org/abs/2306.09712
@article{chen2023semi,
title={Semi-Offline Reinforcement Learning for Optimized Text Generation},
author={Chen, Changyu and Wang, Xiting and Jin, Yiqiao and Dong, Victor Ye and Dong, Li and Cao, Jie and Liu, Yi and Yan, Rui},
journal={arXiv preprint arXiv:2306.09712},
year={2023}
}
Our semi-offline method is illustrated in (c2): we use static data as the starting point and do exploration by one Forward Propagation (FP).
git clone https://github.com/ChangyuChen347/semi-offline-RL.git
cd semi-offline-RL
conda create -n semi-offline-rl python=3.8
conda activate semi-offline-rl
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
python -c "import nltk; nltk.download('punkt'); nltk.download('stopwords')"
Before training, you should download the checkpoint of the base model and the data. Then, place the checkpoint in the ./model_output
directory and the data in ./static_data
directory.
You can simply run the bash scripts under the train directory.
bash train/run_cnndm.sh
or
CUDA_VISIBLE_DEVICES=0 python main.py \
--do_train \
--scene bart_cnndm_generation \
--use_logit True \
--report_to tensorboard \
--seed 2022 \
--smooth 0.1 \
--trainer rl \
--learning_rate 0.000001 \
--num_train_epochs 60 \
--max_grad_norm 1 \
--print_every 1000 \
--save_every 4000 \
--eval_steps 2000 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--per_device_eval_batch_size 16 \
--length_normalize_4_rl True \
--training_length_penalty 1 \
--train_dir static_data/cnndm/cnndm_train.tsv \
--eval_dir static_data/cnndm/cnndm_valid.tsv \
--cand_pos_remove_sp_tk True \
--recover model_output/cnndm_base_model \
--exp_name demo_cnndm \
--rewards 'rouge' \
--rouge_type 12l \
--rl_weight 20 \
--sample_num 63 \
--mask_rate 0.4 \
--kd_inputs_worst True \
--eval_metrics rouges,rouge \
--seq_decode_model bart
Training Parameters:
Basic setting:
--learning_rate
: Sets the learning rate for training.--num_train_epochs
: Specifies the number of training epochs.--max_grad_norm
: Sets the maximum gradient norm for gradient clipping.--print_every
: Prints training progress every specified number of steps.--save_every
: Saves the model every specified number of steps.--eval_steps
: Evaluates the model every specified number of steps.--per_device_train_batch_size
: Sets the training batch size per GPU.--gradient_accumulation_steps
: Accumulates gradients over the specified number of steps.--per_device_eval_batch_size
: Sets the evaluation batch size per GPU.--length_normalize_4_rl
: Applies length normalization for reinforcement learning.--cand_pos_remove_sp_tk
: Removes special tokens (pad/eos) from the candidate positions.--exp_name
: Specifies the experiment name.--eval_metrics
: Specifies the evaluation metric.
Model and Task:
--scene
: Specifies the scene or task for the model. The config file is in ./config/SceneConfigs/--train_dir
: Specifies the training dataset directory.--eval_dir
: Specifies the evaluation dataset directory.--recover
: The path of the checkpoint.
RL setting:
--rewards
: Specifies the reward metric.--rouge_type
: Sets the Rouge metric type (12l for rouge-1, rouge-2, and rouge-L).--rl_weight
: Sets the weight for reinforcement learning loss.--sample_num
: Sets the number of samples for RL.--mask_rate
: Sets the masking rate for both sft and RL.--kd_inputs_worst
: Uses worst case inputs for knowledge distillation.
- The evaluation (word tokenization and metric computation) of CNN/DM and XSum is following BRIO: The predictions are first lowercased and tokenized using the PTB tokenizer provided by Standford (download here), and then the ROUGE score is computed using the standard ROUGE Perl package from (download here).
After downloading the two files, you can set the environment variables using the following commands:
export _ROUGE_PATH=./ROUGE-RELEASE-1.5.5
export CLASSPATH=./stanford-corenlp-3.8.0.jar
To utilize the ROUGE Perl package, you may need to install XML::DOM and XML::Parser. Alternatively, you can use the "-p" flag to obtain Python results for a quick start. Please note that the Python results may have slight differences compared to the Perl results.
- The evaluation of SQuAD is following LMQG
To compute the Meteor score for SQuAD, you need to download the paraphrase-en.gz and place it in the ./lmqg/automatic_evaluation_tool/meteor/data/
directory.
You can simply run the bash scripts under the evaluation directory.
bash evaluation/eval_cnn.sh
or
exp_name=rl
output_dir_path=eval_output
model_dir_path=model_output
dataset=cnndm
export _ROUGE_PATH=./ROUGE-RELEASE-1.5.5
export CLASSPATH=./stanford-corenlp-3.8.0.jar
bash evaluation/run_test_cnndm.sh ${exp_name} ${output_dir_path} ${model_dir_path}
python extract_prediction.py --dataset ${dataset} --exp_name ${exp_name} --output_dir_path ${output_dir_path}
cat ${output_dir_path}/${dataset}/${exp_name}/pred.txt | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > ${output_dir_path}/${dataset}/${exp_name}/pred.txt.token
cat ${output_dir_path}/${dataset}/${exp_name}/ref.txt | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > ${output_dir_path}/${dataset}/${exp_name}/ref.txt.token
python cal_rouge.py --ref ${output_dir_path}/${dataset}/${exp_name}/ref.txt.token --hyp ${output_dir_path}/${dataset}/${exp_name}/pred.txt.token
The BASE models are supervised fine tuning (SFT) model trained with [mask]
token. The RL models are our trained RL checkpoints.
BASE (M-FT) | RL | |
---|---|---|
CNN/DM | cnndm_bart_base_model | cnndm_bart_rl_model |
SAMSum | samsum_bart_base_model | samsum_bart_rl_model |
SQuAD | t5_squad_base_model | t5_squad_rl_model |
XSum | xsum_pegasus_base_model | xsum_pegasus_rl_model |
The training datasets (*_train.tsv) contains the source, ground truth, and ordered candidates.
Train | validation | Test | |
---|---|---|---|
CNN/DM | cnn_train.tsv | cnn_valid.tsv | cnn_test.tsv |
SAMSum | samsum_train.tsv | samsum_valid.tsv | samsum_test.tsv |
SQuAD | squad_train.tsv | squad_valid.tsv | squad_test.tsv |
XSum | xsum_train.tsv | xsum_valid.tsv | xsum_test.tsv |