This includes an original implementation "PROMPT WAYWARDNESS: The Curious Case of Discretized Interpretation of Continuous Prompts" by Daniel Khashabi, Xinxi Lyu, Sewon Min, Lianhui Qin, Kyle Richardson, Sameer Singh, Sean Welleck, Hannaneh Hajishirzi, Tushar Khot, Ashish Sabharwal, Yejin Choi.
This code provides commands to run the models and reproduce the numbers reported in the paper. The code is taken and modified from the Channel LM Prompting repo.
Please leave issues for any questions about the paper or the code.
If you find our code or paper useful, please cite the paper:
@inproceedings{khashabi2021waywardness,
title={{PROMPT WAYWARDNESS: The Curious Case of Discretized Interpretation of Continuous Prompts}},
author = {Khashabi, Daniel and Lyu, Xinxi and Min, Sewon and Qin, Lianhui and Richardson, Kyle and Singh, Sameer and Welleck, Sean and Hajishirzi, Hannaneh and Khot, Tushar and Sabharwal, Ashish and Choi, Yejin},
booktitle={Proceedings of NAACL},
year={2022}
}
- Installation
- Download & Preprocess Data
- Default Commands
- Reproducing Main Results (Section 4.2 of the paper)
- Reproducing Analysis (Section 4.3 of the paper)
You can run the channel model and the direct model for each of these methods. Please see Section 3 of the paper for more details about these formulations.
$ conda create -n waywardness python=3.8
$ conda activate waywardness
$ conda install pytorch=1.7.1 -c pytorch
$ pip install transformers==4.3.0
We use (and modify) the data and the preprocessing script from Gao et al. ACL 2021 (paper, code) and Zhang et al. NeurIPS 2015 (paper, data).
To download the k-shot data (already preprocessed):
Download the data (65.6MB) from this link. Pleae place data-processed.zip
under the same directory as the code and unzip it.
To download the original data and preprocess yourself:
Download the data (14MB) from this link. Pleae place data-processed.zip
under the same directory as the code and unzip it.
Then, run python3 generative_k_shot_data.py
, and you are done!
Optionally, you can specify arguments such as
--data_dir
: directory for the original data (default isdata/original
).--output_dir
: directory for the preprocessed data (default isdata
).
To check the data:
You can see the list of five datasets used in the paper by ls data/k-shot
. Each dataset consists of five different splits based on five different splits (test sets are the same).
We also used sentences sampled from The PILE, stored under the prompts
directory. Please make sure to cite their paper when you use this data.
python3 main.py \
--task {SST-2|sst-5|agnews|trec|subj} \
--prompt_group {NI|PILE} \
--split test \
--data_dir data \
--out_dir out \
--method direct \
--prompt_tune \
--do_train \
--gamma {0.01|0}
Useful notes:.
- You can adjust
--batch_size
if you run into OOM issue (default is8
). - To train with individual prompt, you can replace
--prompt_group
with--prompt_task
. - Once you train the model, you can specify
--do_check
to load the existing checkpoint without retraining the model. - Please note that GPU parallization is not implemented for inference.
- To save a log file, please specify
--log_file
.
This section is for reproducing the results of the main experiments in Section 4.2 of the paper.
Run the default commands.
This section is for reproducing the results of the analysis experiments in Section 4.3 of the paper.
Run the default commands, but fix --prompt_group NI
and vary --gamma {0|0.0001|0.0005|0.001|0.003|0.005|0.01|0.03}
.
Run the default commands, but fix --prompt_group PILE
and vary --pile_len {4|7|14|28|56}
.
Run the default commands, but fix --prompt_group PILE --gamma 0.01,0.005,0.003
and vary --gpt2 gpt2-{small|medium|large|xl}
.
Run the default commands, but fix --prompt_group TRUE
.