Paper Link: https://arxiv.org/abs/2411.07404
First, clone the repo:
git clone [email protected]:kdu4108/context-vs-prior-finetuning.git
Create an environment and install dependencies via virtualenv/pip:
python3.11 -m venv env
source env/bin/activate
pip install -r requirements.txt
Here are the steps to regenerate the experiments in the paper. The key steps are (1) accessing or regenerating the dataset and (2) running the main entry point, main.py
.
Run all cells in preprocessing/preprocess_fakepedia_train_and_dev.ipynb
.
Run python preprocessing/generate_arithmetic.py
.
The main entry point to run a single experiment is main.py
. The most important arguments to this script are:
DATASET_NAME
(positional argument, determines which dataset to run the experiment on. Must exactly match the name of a dataset defined inpreprocessing/datasets.py
).--SUBSPLIT
(the subsplit of the dataset to use)--MODEL_ID
(the model name in huggingface)--CONTEXT_WEIGHT_FORMAT
(the context weight format to use for training examples, e.g.,float
orinstruction
.)--EVALS
(the list of evals to run the model on. Must be a List of Dicts containing at leastdataset_name
,subsplit
,k_demonstrations
, andcontext_weight_format
. Optionally also includedo_steering
.)--PROJECTION-PATH
(path to a saved projection for training)
The remaining arguments (visible via python main.py --help
) are either dataset-specific (e.g., specify the --QUERY_ID
if running an experiment with DATASET_NAME="YagoECQ"
), allow for control over other experiment details (e.g., which query types to use, the model's batch size for inference, how to sample entities, etc.), or steering-specific hyperparameters (e.g., steering values).
An example command to finetune a model and evaluate on the same dataset is:
python main.py BaseFakepedia -M meta-llama/Meta-Llama-3.1-8B-Instruct -S 3 -TS 2048 -TSS 1000 -P -CWF float -O -EV '{"dataset_name": "BaseFakepedia", "subsplit": "nodup_relpid", "k_demonstrations": 0, "context_weight_format": "float", "do_steering": False}'
Run the following command to submit all experiments for a given model (llama
, mistral
, gemma
).
python generate_run_scripts.py --model-id llama --add-default --add-steering --add-oos-datasets --add-training && bash run_scripts.sh
The scripts for the interpretability analysis are in the analysis
directory.
It is mainly based on the nnsight
and nnpatch
libraries.
Check the notebooks notebooks/analysis_llama.ipynb
, notebooks/analysis_mistral.ipynb
and notebooks/analysis_gemma.ipynb
for the analysis in section 5 and 6 of the paper. To regenerate the plots, first generate the orthogonal projections using the notebooks mentioned, run all experiments and then run the analysis/plots_das.ipynb
notebook.
You can also use the existing projections, which are hosted on huggingface.
- Meta-Llama-3.1-8B-Instruct-L16 Projection for the Meta-Llama-3.1 family of models. Layer 16. Recommended steering values:
prior=6
,context=-6
. - gemma-2-9b-it-L27 Projection for the gemma-2-9b-it family of models. Layer 27. Recommended steering values:
prior=-100
,context=150
. - Mistral-7B-Instruct-v0.3-L16 Projection for the Mistral family of models. Layer 16. Recommended steering values:
prior=5
,context=-5
.
Check the analysis/demo_steering.ipynb
notebook for a demo of how to use the steering hook.
If you want the basis vector of the subspace, use the following snippet to get it:
proj = LowRankOrthogonalProjection.from_pretrained("jkminder/CTXPRIOR-Projection-Meta-Llama-3.1-8B-Instruct-L16")
u = proj.weight
u.shape # [4096, 1]