Fantastic Rewards and How to Tame Them: A Case Study on Reward Learning for Task-Oriented Dialogue Systems
To install the required packages, first create and activate a fantastic_reward
environment in conda.
Then execute the following command:
bash install_packages.sh
Our data-setup follows the CASPI paper.
Please download the pre-processed data from here.
Unzip the downloaded file and put the resulting folder ExpStoreddata
into the folder damd_multiwoz
.
For our variant of RewardNet+GS
bash ./run_multiple_seeds.sh --EXP_IDX ${EXP_IDX} --REWARD_SAMPLES 3 --REWARD_LOSS "listNet" --LISTMLE_TEMP 1 --LISTNET_POW 1 --POLICY_TRAIN_DATA_FRAC 1 --NEG_REW_WEIGHT 0.1 --REW_MODEL_EXP '0'
where ${EXP_IDX}
is the index of the experiment, such as "2023"
.
For our variant of RewardMLE+GS
bash ./run_multiple_seeds.sh --EXP_IDX ${EXP_IDX} --REWARD_SAMPLES 5 --REWARD_LOSS "listMLE" --LISTMLE_TEMP 1 --LISTNET_POW 0 --POLICY_TRAIN_DATA_FRAC 1 --NEG_REW_WEIGHT 1.0 --REW_MODEL_EXP '0'
where ${EXP_IDX}
is again the index of the experiment.
To facilitate reproducibility, we release a checkpoint for each of the variant
RewardNet+GS 999
of the tested five seeds (111 333 555 777 999)
.
To evaluate the checkpoints, please try the following steps.
Here Exp1
corresponds to the variant of RewardNet+GS Exp2
for RewardMLE+GS
- Download and unzip the checkpoints from here.
- Download and unzip the processed data from here. Put the resulting folders into the folder
damd_multiwoz
. - Try the following command
python train.py --model_path "experiments/Exp${EXP_IDX}/all_sd999/" \
--mode 'test' --context_window 2 --pretrained_checkpoint bart-large-cnn \
--back_bone bart --cfg seed=999 cuda_device=0 batch_size=8 early_stop_count=7 \
--caspi_returns_file="fn_Gs_10_0.0_resp_soft.json" --caspi_wt=5. \
--caspi_data_file=data_for_damd.json --caspi_val_fraction=.5 --caspi --data_folder "Exp${EXP_IDX}data/s999_K10_GAMMA0.0" \
--exp_idx ${EXP_IDX}
where ${EXP_IDX}
should be replaced by 1
or 2
.
The following table shows the standardized evaluation results of our ``RewardNet+GS'' model.
Detailed numbers are provided in Example_generation/result_standard_eval.json
.
BLEU | Inform | Success | Combined Score | Av. len. | CBE | #uniq. words | #uniq. 3-grams |
---|---|---|---|---|---|---|---|
17.6 | 87.6 | 81.5 | 102.2 | 13.22 | 1.99 | 423 | 3942 |
Examples of generated dialogues on the test-split of MultiWOZ2.0 can be found at Example_generation/gen_test_formatted.json
.
This codebase builds on the following codebases and datasets: