This repository includes the codes of SRformer: Segmented Recurrent Transformer
This project is built on 🤗 Transformers. We modified T5 and BART models for text summarization. They are both encoder-decoder models and we use the proposed segmented recurrent attention in their decoders' cross attention blocks. Please see huggingface transformers for general instructions.
In /examples/pytorch/summarization/
, you can find run_commands.sh that lists example commands.
For example, to run T5-small on CNN-Dailymail,
$ python examples/pytorch/summarization/run_summarization_no_trainer.py --model_name_or_path t5-small --dataset_name ccdv/cnn_dailymail --dataset_config "3.0.0" --source_prefix "summarize: " --output_dir ./results/t5-small/ --per_device_train_batch_size=16 --per_device_eval_batch_size=16 --num_train_epochs=25 --gpus=0
To use a pretrained BART model, set model_name_or_path to ainize/bart-base-cnn. To run on other datasets, change dataset_name to XSUM, ccdv/arxiv-summarization, or ccdv/mediasum.
In run_summarization_no_trainer.py, you can change cache_dir
to your own directory. It is where the dataset will be downloaded to.
After training your own model, you can load the model by setting model_name_or_path to the local path. Please remember to set config_name and tokenizer_name. If you only want to evaluate the model, set --evaluation_only=True.
We modify T5 model and BART model into SRformers. Please check src/transformers/models/t5/modeling_t5.py
and src/transformers/models/t5/modeling_BART.py
. To run the regular T5 and BART models, please set split_crossattn=False
for Cross Attention layers.
Some configurations such as the segment size are set in the source file of models. We did not add those parameters to a specific configuration file so pretrained models and their configuration file can be loaded directly.
Please note that in src/transformers/generation_utils.py
, we add the timestep t as an input to the transformer model. The memory of RAF neurons are reset when t=0.
If you use this repo, please use the following citation:
@misc{long2023segmented,
title={Segmented Recurrent Transformer: An Efficient Sequence-to-Sequence Model},
author={Yinghan Long and Sayeed Shafayet Chowdhury and Kaushik Roy},
booktitle={EMNLP},
year={2023},
eprint={2305.16340},
archivePrefix={arXiv},
primaryClass={cs.CL}
}