This is an official PyTorch implementation of the paper with the title Discovering the Gems in Early Layers: Accelerating Long-Context LLMs with 1000x Input Token Reduction.
We propose an algorithm that uses early layers of an LLM as filters to select and compress input tokens, significantly reducing the context length for subsequent processing.
The code depends on Huggingface transformer 4.43.3 version.
transformers==4.43.3
flash-attn==2.6.3
Check your correct PyTorch version.
git clone https://github.com/SalesforceAIResearch/GemFilter.git
cd GemFilter
conda create --name gemfilter python=3.12
conda activate gemfilter
pip install torch torchvision torchaudio
pip install -r requirements.txt
python setup.py develop
python needle_eval.py\
--model hf_model_id\
--modified gemfilter\
--topk 1024 \
--ctx_len 32000
GemFilter can be easily integrated with any transformer models. You can follow the comment marked with [GemFilter]
to construct your own models.
The detailed algorithm of GemFilter is in gem_filter_utils.py
and my_generation.py
.
Evaluate on Needle-in-a-Haystack benchmark. See more details here.
Evaluate on LongBench benchmark. See more details here.
If you feel this project is helpful, please consider cite our paper 😊
@article{smn+24,
title={Discovering the Gems in Early Layers: Accelerating Long-Context LLMs with 1000x Input Token Reduction},
author={Shi, Zhenmei and Ming, Yifei and Nguyen, Xuan-Phi and Liang, Yingyu and Joty, Shafiq},
journal={arXiv preprint arXiv:2409.17422},
year={2024}
}