AdaSent is an approach to creating domain-specialized sentence encoders for few-shot sentence classification. It combines Domain-Adaptive Pre-training (DAPT) and Sentence Embedding Pre-Training (SEPT) in a modular fashion. First, it does DAPT on a base pre-trained language model (PLM). Separately, an adapter is trained through general-domain SEPT on the same PLM. The adapter stores the sentence-specialization abilities and can be plugged into domain-adapted PLMs from various domains to make them domain-specialized sentence encoders, on which SetFit is carried out to do downstream few-shot classification training. 📄Paper
├── data <- Path to store data for training and inference when using this project
├── models <- Path to store trained models when using this project
├── scripts <- Scripts for training and inference
├── example.sh <- Example script for AdaSent training
├── LICENSE
├── NOTICE <- Copyright information
├── README.md <- The top-level README for developers using this project
└── requirements.txt <- Requirements
- Python 3.8
-
Install the requirements:
pip install -r requirements.txt
AdaSent consists of three part of training: DAPT, SEPT and SetFit. The following provides explanations for each of them.
An example script including all three parts can be found here. The script should be run from the project root.
To train AdaSent, you need to prepare
- A
.txt
file containing unlabeled examples (one example per line) for DAPT - A
.csv
file with labeled training data and a.csv
file with evaluation data for the few-shot classification task, and fill in the paths in the script.
Take the task mteb/amazon_massive_scenario
for example, you can create these data files with the following code:
import pandas as pd
from datasets import load_dataset
from setfit import sample_dataset
# Paths to store files
unlabeled_text_file = 'data/DAPT/amazon_massive_scenario.txt'
labeled_train_file = 'data/SetFit/amazon_massive_scenario_train.csv'
eval_file = 'data/SetFit/amazon_massive_scenario_eval.csv'
# Load dataset from Huggingface
dataset = load_dataset('mteb/amazon_massive_scenario', 'en')
# Write training data text in into a text file:
with open(unlabeled_text_file, 'w') as f:
f.write('\n'.join(dataset['train']['text']))
# Sample few-shot labeled data out of the original training set
train_dataset = sample_dataset(dataset['train'],
label_column='label',
num_samples=8)
eval_dataset = dataset['test']
# Save train and evaluation data to .csv files
df = pd.DataFrame({'text': train_dataset['text'], 'label': train_dataset['label']})
df.to_csv(labeled_train_file)
df = pd.DataFrame({'text': eval_dataset['text'], 'label': eval_dataset['label']})
df.to_csv(eval_file)
First, we need to train a domain-adapted PLM. The following command trains a DistilRoBERTa
model on task-specific unlabeled data with MLM (The --model_name_or_path
can be other local or Hugging Face model path):
python scripts/DAPT/train_mlm.py \
--train_file data/DAPT/amazon_massive_scenario.txt \
--model_name_or_path distilroberta-base \
--max_seq_length 512 \
--max_train_steps 2000 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 16 \
--output_dir models/distilroberta-base_dapt_amazon_massive_scenario \
--line_by_line True
This will save the DAPT-ed model at the specified --output_dir
.
We train a SEPT adapter on the same base PLM (DistilRoBERTa
in our case) as in DAPT. After training, this adapter can be inserted into any other domain-adapted PLM.
In our work, we found that SEPT with the three datasets AllNLI
, Sentence compression
and Stackexchange duplicate question
(see the datasets' information here) can best improve downstream few-shot classification tasks.
The following command trains an SEPT adapter on an unadapted DistilRoBERTa
model and saves the adapter at --model_save_path
. The above-mentioned datasets will be automatically downloaded. More information about the parameters can be found in the script train_sept
:
python scripts/SEPT/train_sept.py \
--model_name_or_path distilroberta-base \
--use_adapter True \
--adapter_config parallel \
--adapter_name sept \
--max_seq_length 512 \
--batch_size_pairs 64 \
--batch_size_triplets 64 \
--num_epochs 1 \
--learning_rate 1e-4 \
--pooling_mode mean \
--model_save_path models/distilroberta-base_sept_adapter \
--use_amp True
After the DAPT-ed PLM and the SEPT-ed adapter are trained, we assemble them together and train a SetFit model with 8 shots per class. Here, you need to specify the paths to the train and evalutaion .csv
files in --train_dataset_path
and --eval_dataset_path
respectively
:
python scripts/SetFit/train_setfit.py \
--model_name_or_path models/distilroberta-base_dapt_amazon_massive_scenario \
--adapter_path models/distilroberta-base_sept_adapter \
--batch_size 16 \
--num_epochs 1 \
--num_samples 8 \
--num_iterations 20 \
--adapter_name sept \
--model_save_path models/adasent_setfit_amazon_massive_scenario \
--train_dataset_path data/SetFit/amazon_massive_scenario_train.csv \
--eval_dataset_path data/SetFit/amazon_massive_scenario_eval.csv \
--text_col text \
--label_col label
The --model_name_or_path
should be the path of the DAPT-ed PLM, and the --adapter_path
should be the path of the SEPT-ed adapter. The trained SetFit model will be saved at the --model_save_path
. More information about the parameters can be found in the script train_setfit
.
The unlabeled data used in DAPT can also be used for self-training in SetFit and further improve the result. First, we use SetFit model body (DAPT-Transformer + SEPT-adapter in our case) trained on few-shot labeled data to encode both the labeled and unlabeled data. Then we run self-training with the encoded data with the SelfTrainingClassifier
from scikit-learn.
Run the following command to do self-training with the unlabeled data at --unlabeled_file_path
, here we use the same file as in DAPT:
python scripts/SetFit/train_setfit_with_self_training.py \
--model_name_or_path models/distilroberta-base_dapt_amazon_massive_scenario \
--adapter_path models/distilroberta-base_sept_adapter \
--adapter_name sept \
--unlabeled_file_path data/DAPT/amazon_massive_scenario.txt \
--batch_size 16 \
--num_epochs 1 \
--num_samples 8 \
--num_iterations 20 \
--train_dataset_path data/SetFit/amazon_massive_scenario_train.csv \
--eval_dataset_path data/SetFit/amazon_massive_scenario_eval.csv \
--text_col text \
--label_col label
# Load the SetFit model saved in the last step with "model_save_path"
model = SetFitModel.from_pretrained("models/distilroberta-base_dapt_amazon_massive_scenario")
# Activate adasent adapter with the "adapter_name" used in training
model.model_body[0].auto_model.set_active_adapters("sept")
# Run inference
preds = model(["set an alarm at 7 tomorrow", "order a pizza"])
Dataset | Link |
---|---|
MTEB Benchmark | MTEB |
ADE | ade_corpus_v2 |
RCT | armanc/pubmed-rct20k |
FPB | financial_phrasebank |
TFNS | zeroshot/twitter-financial-news-sentiment |
TFNT | zeroshot/twitter-financial-news-topic |
LED | lex_glue |
The pre-trained adapter for distilroberta-base
is available here at Huggingface.
Please use the following citation:
@inproceedings{huang-etal-2023-adasent,
title = "{A}da{S}ent: Efficient Domain-Adapted Sentence Embeddings for Few-Shot Classification",
author = "Huang, Yongxin and
Wang, Kexin and
Dutta, Sourav and
Patel, Raj and
Glava{\v{s}}, Goran and
Gurevych, Iryna",
editor = "Bouamor, Houda and
Pino, Juan and
Bali, Kalika",
booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
month = dec,
year = "2023",
address = "Singapore",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2023.emnlp-main.208",
pages = "3420--3434",
}
Contact person: Yongxin Huang, [email protected]
https://www.ukp.tu-darmstadt.de/
Don't hesitate to send us an e-mail or report an issue, if something is broken or if you have further questions.
This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.