The official Github repository for paper R
Retrieval augmented generation (RAG) has been applied in many scenarios to augment large language models (LLMs) with external documents provided by retrievers. However, a semantic gap exists between LLMs and retrievers due to differences in their training objectives and architectures. This misalignment forces LLMs to passively accept the documents provided by the retrievers, leading to incomprehension in the generation process, where the LLMs are burdened with the task of distinguishing these documents using their inherent knowledge. This paper proposes R
This guide will walk you through processing datasets and training in R
Git clone our repository and install following packages.
torch==2.1.2
transformers==4.40.2
sentence-transformers==2.6.1
trl==0.8.1
peft==0.10.0
For NQ datasets, download from the repository of lost-in-the-middle. As the LongBench version of DuReader does not include labels for relevant documents, you can download the raw version of DuReader and follow the instructions in label_dureader.ipynb to generate the labels. Alternatively, download the preprocessed .pkl file directly from here. Other datasets refer their official websites.
To directly improve RAG performance, you can fine-tune a retriever using retrieval/finetune_retriever.py. This step is optional but recommended for achieving optimal results of R
cd retrieval
python finetune_retriever.py \
--dataset_name nq_10 \
--input_path 10_total_documents/nq-open-10_total_documents_gold_at_0.jsonl.gz \
--model_name google-bert/bert-base-uncased \
--train_batch_size 32 \
--num_epoch 4
where
dataset_name
: The name of dataset to be used.input_path
: The file path of dataset.model_name
: The retriever model name or path to be fine-tuned. The finetuned retriever model will be saved in theretrieval/retrievers
folder. For HotpotQA, MuSiQue, 2Wiki datasets, please usetrain_data_path
andtest_data_path
. Detailed commands are available in scripts/retriever folder.
Before training, it is need to extract retrieval features from the retriever. You have the flexibility to select the fine-tuned or non-trained retrieval model, such as google-bert/bert-base-uncased
, BAAI/bge-reranker-large
, facebook/contriever
.
For NQ datasets:
cd retrieval
python feature_extraction.py \
--dataset_name nq_10 \
--input_path 10_total_documents/nq-open-10_total_documents_gold_at_0.jsonl.gz \
--save_path dataset/nq-open-10_total_documents_gold_at_0_bert.pkl \
--model_name google-bert/bert-base-uncased
where
save_path
: The file path to save new dataset, typically is a .pkl file.model_name
: The name or path of retriever (fine-tuned or non-finetuned).
For HotpotQA, 2Wiki and MuSiQue datasets, refer to the scripts/feature_extraction folder for specific commands.
Besides, you can get retrieval features from OpenAI Embedding using retrieval/feature_extraction_openai.py. Ensure you configure the base_url
and api_key
in the script at Lines 38-39 before execution.
cd retrieval
python feature_extraction_openai.py \
--dataset_name nq_10 \
--input_path 10_total_documents/nq-open-10_total_documents_gold_at_0.jsonl.gz \
--model_name text-embedding-3-large \
--emb_save_path dataset/nq_10_openai_embedding_large.pkl \
--dataset_save_path dataset/nq-open-10_total_documents_gold_openai_large.pkl
where
emb_save_path
: The .pkl file path to save embeddings.dataset_save_path
: The .pkl file path to save dataset.
runner.py is the central script for training, generation and evaluation. You can train R
python runner.py \
--dataset_name nq_10 \
--input_path retrieval/dataset/nq-open-10_total_documents_gold_at_0_bert.pkl \
--model_name meta-llama/Llama-2-7b-hf \
--use_training\
--save_model \
--output_dir output/rrag/Rrag-Llama-2-7b \
--freeze_llm \
--num_k 10 \
--use_rrag \
--use_evaluation \
--save_results
where
-
input_path
: The file path to the pre-processed dataset, typically a .pkl file containing retrieval features. -
model_name
: The name or path of the LLM used for training. This can be a model from Hugging Face's model hub or a local path to a model file. -
use_training
: Enables the training mode in the script. -
save_model
: If set, the trained model will be saved tooutput_dir
. -
output_dir
: The directory to save model. -
freeze_llm
: If set, this freezes the parameters of the LLM during training. -
num_k
: Specifies the number of top documents. -
use_rrag
: Enables the use of R$^2$ AG. -
use_evaluation
: Enables the evaluation mode. -
save_results
: If set, results from generation will be saved tooutput
folder.
You can customize selecting different Large Language Models (LLMs), datasets, and various other parameters. For example, you can get the standard RAG results by:
python runner.py \
--dataset_name nq_10 \
--input_path retrieval/dataset/nq-open-10_total_documents_gold_at_0_bert.pkl \
--model_name meta-llama/Llama-2-7b-hf \
--freeze_llm \
--num_k 10 \
--use_evaluation \
--save_results
For HotpotQA, 2Wiki, and MuSiQue datasets, refer to the scripts/train folder for specific commands to train R
@inproceedings{Ye2024R2AG,
title={{R}$^2${AG}: Incorporating Retrieval Information into Retrieval Augmented Generation},
author={Fuda Ye and Shuangyin Li and Yongqi Zhang and Lei Chen},
booktitle={Findings of the Association for Computational Linguistics: EMNLP 2024},
year={2024},
doi={10.18653/v1/2024.findings-emnlp.678}
}