From 4443bf53f50cc225c738230d551ef2d205a07634 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Wed, 8 May 2024 09:43:56 +0200 Subject: [PATCH] [`v3`] Update example scripts to the new v3 training format (#2622) * Update example scripts to the new v3 training format * Add distillation training examples * Add Matryoshka training examples * Add NLI training examples * Add STS training scripts * Fix accidentally overriding eval set * Update paraphrases multi-dataset training script * Convert regular dicts to DatasetDict on Trainer init * Update Quora duplicate training scripts * Update "other" training scripts * Update multilingual conversion script * Add example scripts to Evaluators * Add example to ST class itself * Update docs formatting slightly * Fix model card snippet * Add short docstring for similarity_fn_name property --- .../adaptive_layer/adaptive_layer_nli.py | 184 ++++----- .../adaptive_layer/adaptive_layer_sts.py | 176 ++++----- ...aining_stsbenchmark_avg_word_embeddings.py | 157 ++++---- .../training_stsbenchmark_bilstm.py | 147 +++---- .../training_stsbenchmark_bow.py | 151 ++++---- .../training_stsbenchmark_cnn.py | 161 ++++---- ...ing_stsbenchmark_tf-idf_word_embeddings.py | 156 ++++---- .../train_sts_indomain_bm25.py | 177 +++++---- .../train_sts_indomain_nlpaug.py | 178 ++++----- .../distillation/dimensionality_reduction.py | 73 ++-- .../distillation/model_distillation.py | 265 +++++++------ .../model_distillation_layer_reduction.py | 229 +++++++++++ .../training/matryoshka/2d_matryoshka_nli.py | 179 ++++----- .../training/matryoshka/2d_matryoshka_sts.py | 174 ++++----- .../training/matryoshka/matryoshka_nli.py | 179 ++++----- .../matryoshka/matryoshka_nli_reduced_dim.py | 188 ++++----- .../training/matryoshka/matryoshka_sts.py | 196 +++++----- .../multilingual/make_multilingual.py | 358 +++++++++--------- examples/training/nli/training_nli.py | 190 +++++----- examples/training/nli/training_nli_v2.py | 203 +++++----- examples/training/nli/training_nli_v3.py | 203 ++++------ .../training/other/training_multi-task.py | 219 +++++------ .../other/training_wikipedia_sections.py | 177 ++++----- examples/training/paraphrases/training.py | 208 ++++++---- .../training_MultipleNegativesRankingLoss.py | 247 ++++++------ .../training_OnlineContrastiveLoss.py | 248 ++++++------ .../training_multi-task-learning.py | 292 +++++++------- .../training/sts/training_stsbenchmark.py | 166 ++++---- ...training_stsbenchmark_continue_training.py | 158 ++++---- sentence_transformers/SentenceTransformer.py | 26 ++ .../BinaryClassificationEvaluator.py | 52 +++ .../EmbeddingSimilarityEvaluator.py | 31 +- .../InformationRetrievalEvaluator.py | 83 ++++ .../evaluation/MSEEvaluator.py | 34 +- .../evaluation/ParaphraseMiningEvaluator.py | 39 ++ .../evaluation/TranslationEvaluator.py | 30 ++ .../evaluation/TripletEvaluator.py | 35 +- sentence_transformers/model_card_template.md | 2 +- sentence_transformers/trainer.py | 5 + 39 files changed, 3209 insertions(+), 2767 deletions(-) create mode 100644 examples/training/distillation/model_distillation_layer_reduction.py diff --git a/examples/training/adaptive_layer/adaptive_layer_nli.py b/examples/training/adaptive_layer/adaptive_layer_nli.py index 4747941d0..955bc2e6c 100644 --- a/examples/training/adaptive_layer/adaptive_layer_nli.py +++ b/examples/training/adaptive_layer/adaptive_layer_nli.py @@ -1,8 +1,7 @@ """ The system trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) on the SNLI + MultiNLI (AllNLI) dataset -with AdaptiveLayerLoss using MultipleNegativesRankingLoss. This trains a model at output dimensions [768, 512, 256, 128, 64]. -Entailments are positive pairs and the contradiction on AllNLI dataset is added as a hard negative. -At every 10% training steps, the model is evaluated on the STS benchmark dataset +with AdaptiveLayerLoss using MultipleNegativesRankingLoss. Entailing texts are used as positive pairs and contradictory +texts are seen as negative pairs. At every 100 training steps, the model is evaluated on the STS benchmark dataset. Usage: python adaptive_layer_nli.py @@ -11,147 +10,112 @@ python adaptive_layer_nli.py pretrained_transformer_model_name """ -import math +import traceback from datasets import load_dataset -from sentence_transformers import models, losses, datasets -from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample +from sentence_transformers import losses +from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction import logging from datetime import datetime import sys -import os -import gzip -import csv -import random - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - -model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base" -train_batch_size = 128 # The larger you select this, the better the results (usually). But it requires more GPU memory -max_seq_length = 75 -num_epochs = 1 - -# Save path of the model -model_save_path = ( - "output/adaptive_layer_nli_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -) - - -# Here we define our SentenceTransformer model -word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean") -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -# Check if dataset exists. If not, download and extract it -nli_dataset_path = "data/AllNLI.tsv.gz" - -if not os.path.exists(nli_dataset_path): - util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) - -# Read the AllNLI.tsv.gz file and create the training dataset -logging.info("Read AllNLI train dataset") +from sentence_transformers.training_args import BatchSamplers -def add_to_samples(sent1, sent2, label): - if sent1 not in train_data: - train_data[sent1] = {"contradiction": set(), "entailment": set(), "neutral": set()} - train_data[sent1][label].add(sent2) +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) +model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base" +batch_size = 128 # The larger you select this, the better the results (usually). But it requires more GPU memory +num_train_epochs = 1 -train_data = {} -with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "train": - sent1 = row["sentence1"].strip() - sent2 = row["sentence2"].strip() - - add_to_samples(sent1, sent2, row["label"]) - add_to_samples(sent2, sent1, row["label"]) # Also add the opposite - - -train_samples = [] -for sent1, others in train_data.items(): - if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0: - train_samples.append( - InputExample( - texts=[sent1, random.choice(list(others["entailment"])), random.choice(list(others["contradiction"]))] - ) - ) - train_samples.append( - InputExample( - texts=[random.choice(list(others["entailment"])), sent1, random.choice(list(others["contradiction"]))] - ) - ) - -logging.info("Train samples: {}".format(len(train_samples))) +# Save path of the model +output_dir = f"output/adaptive_layer_nli_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) +# If we want, we can limit the maximum sequence length for the model +# model.max_seq_length = 75 +logging.info(model) -# Special data loader that avoid duplicates within a batch -train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=train_batch_size) +# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli +train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train") +eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev") +logging.info(train_dataset) +# If you wish, you can limit the number of training samples +# train_dataset = train_dataset.select(range(5000)) -# Our training loss -train_loss = losses.MultipleNegativesRankingLoss(model) -train_loss = losses.AdaptiveLayerLoss(model, train_loss) +# 3. Define our training loss +inner_train_loss = losses.MultipleNegativesRankingLoss(model) +train_loss = losses.AdaptiveLayerLoss(model, inner_train_loss) -stsb_dev = load_dataset("mteb/stsbenchmark-sts", split="validation") +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") dev_evaluator = EmbeddingSimilarityEvaluator( - stsb_dev["sentence1"], - stsb_dev["sentence2"], - [score / 5 for score in stsb_dev["score"]], + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], main_similarity=SimilarityFunction.COSINE, name="sts-dev", ) -# Configure the training -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="adaptive-layer-nli", # Will be used in W&B if `wandb` is installed +) -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=dev_evaluator, - epochs=num_epochs, - evaluation_steps=int(len(train_dataloader) * 0.1), - warmup_steps=warmup_steps, - output_path=model_save_path, - use_amp=False, # Set to True, if your GPU supports FP16 operations ) +trainer.train() - -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - - -model = SentenceTransformer(model_save_path) -stsb_test = load_dataset("mteb/stsbenchmark-sts", split="test") +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") test_evaluator = EmbeddingSimilarityEvaluator( - stsb_test["sentence1"], - stsb_test["sentence2"], - [score / 5 for score in stsb_test["score"]], + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], main_similarity=SimilarityFunction.COSINE, name="sts-test", ) -test_evaluator(model, output_path=model_save_path) +test_evaluator(model) +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) -# Optionally, save the model to the Hugging Face Hub! +# 9. (Optional) save the model to the Hugging Face Hub! # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first model_name = model_name if "/" not in model_name else model_name.split("/")[-1] try: model.push_to_hub(f"{model_name}-nli-adaptive-layer") except Exception: logging.error( - "Error uploading model to the Hugging Face Hub. To upload it manually, you can run " - f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({model_save_path!r})` " + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " f"and saving it using `model.push_to_hub('{model_name}-nli-adaptive-layer')`." ) diff --git a/examples/training/adaptive_layer/adaptive_layer_sts.py b/examples/training/adaptive_layer/adaptive_layer_sts.py index b47b1c7c7..b95ac63b2 100644 --- a/examples/training/adaptive_layer/adaptive_layer_sts.py +++ b/examples/training/adaptive_layer/adaptive_layer_sts.py @@ -1,6 +1,6 @@ """ This examples trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) for the STSbenchmark from scratch. -It uses AdaptiveLayerLoss with the powerful CoSENTLoss to train models that perform well at output dimensions [768, 512, 256, 128, 64]. +It uses AdaptiveLayerLoss with the powerful CoSENTLoss to train models that perform well even when removing some layers. It generates sentence embeddings that can be compared using cosine-similarity to measure the similarity. Usage: @@ -10,118 +10,108 @@ python adaptive_layer_sts.py pretrained_transformer_model_name """ -from torch.utils.data import DataLoader -import math -from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util -from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample +import traceback +from datasets import load_dataset +from sentence_transformers import losses +from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments +from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction import logging from datetime import datetime import sys -import os -import gzip -import csv - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - - -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) -# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base model_name = sys.argv[1] if len(sys.argv) > 1 else "distilbert-base-uncased" - -# Read the dataset -train_batch_size = 16 -num_epochs = 4 -model_save_path = ( - "output/adaptive_layer_sts_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +batch_size = 16 +num_train_epochs = 4 + +# Save path of the model +output_dir = f"output/adaptive_layer_sts_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) +# If we want, we can limit the maximum sequence length for the model +# model.max_seq_length = 75 +logging.info(model) + +# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) + +# 3. Define our training loss +# CoSENTLoss (https://sbert.net/docs/package_reference/losses.html#cosentloss) needs two text columns and one +# similarity score column (between 0 and 1) +inner_train_loss = losses.CoSENTLoss(model=model) +train_loss = losses.AdaptiveLayerLoss(model, inner_train_loss) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", ) -# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings -word_embedding_model = models.Transformer(model_name) - -# Apply mean pooling to get one fixed sized sentence vector -pooling_model = models.Pooling( - word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="adaptive-layer-sts", # Will be used in W&B if `wandb` is installed ) -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -# Convert the dataset to a DataLoader ready for training -logging.info("Read STSbenchmark train dataset") - -train_samples = [] -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) - - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) - - -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) -train_loss = losses.CoSENTLoss(model=model) -train_loss = losses.AdaptiveLayerLoss(model, train_loss) - - -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - - -# Configure the training. We skip evaluation in this example -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, ) +trainer.train() -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -test_evaluator(model, output_path=model_save_path) +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) -# Optionally, save the model to the Hugging Face Hub! +# 9. (Optional) save the model to the Hugging Face Hub! # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first model_name = model_name if "/" not in model_name else model_name.split("/")[-1] try: model.push_to_hub(f"{model_name}-sts-adaptive-layer") except Exception: logging.error( - "Error uploading model to the Hugging Face Hub. To upload it manually, you can run " - f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({model_save_path!r})` " + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " f"and saving it using `model.push_to_hub('{model_name}-sts-adaptive-layer')`." ) diff --git a/examples/training/avg_word_embeddings/training_stsbenchmark_avg_word_embeddings.py b/examples/training/avg_word_embeddings/training_stsbenchmark_avg_word_embeddings.py index bb965df98..a6bb8fe79 100644 --- a/examples/training/avg_word_embeddings/training_stsbenchmark_avg_word_embeddings.py +++ b/examples/training/avg_word_embeddings/training_stsbenchmark_avg_word_embeddings.py @@ -7,102 +7,115 @@ for available word embeddings files """ -from torch.utils.data import DataLoader -import math -from sentence_transformers import models, losses, util -from sentence_transformers import LoggingHandler, SentenceTransformer +import traceback +from datasets import load_dataset +from sentence_transformers import models, losses +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample import logging from datetime import datetime -import os -import csv -import gzip - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - -# Read the dataset -batch_size = 32 -model_save_path = "output/training_stsbenchmark_avg_word_embeddings-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - - -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments -logging.info("Read STSbenchmark train dataset") -train_samples = [] -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) +num_train_epochs = 1 +batch_size = 32 +output_dir = "output/training_stsbenchmark_avg_word_embeddings-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +# 1. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) +# 2. Define the model # Map tokens to traditional word embeddings like GloVe word_embedding_model = models.WordEmbeddings.from_text_file("glove.6B.300d.txt.gz") # Apply mean pooling to get one fixed sized sentence vector pooling_model = models.Pooling( word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, + pooling_mode="mean", ) # Add two trainable feed-forward networks (DAN) sent_embeddings_dimension = pooling_model.get_sentence_embedding_dimension() dan1 = models.Dense(in_features=sent_embeddings_dimension, out_features=sent_embeddings_dimension) dan2 = models.Dense(in_features=sent_embeddings_dimension, out_features=sent_embeddings_dimension) - model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dan1, dan2]) - -# Convert the dataset to a DataLoader ready for training -logging.info("Read STSbenchmark train dataset") -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size) +# 3. Define our training loss +# CosineSimilarityLoss (https://sbert.net/docs/package_reference/losses.html#cosentloss) needs two text columns and +# one similarity score column (between 0 and 1) train_loss = losses.CosineSimilarityLoss(model=model) -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - -# Configure the training -num_epochs = 10 -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", ) +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="glove-mean-pooling-sts", # Will be used in W&B if `wandb` is installed +) -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -model.evaluate(test_evaluator) +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = "glove-mean-pooling-sts" +try: + model.push_to_hub(model_name) +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}')`." + ) diff --git a/examples/training/avg_word_embeddings/training_stsbenchmark_bilstm.py b/examples/training/avg_word_embeddings/training_stsbenchmark_bilstm.py index 296f755c7..4df3d8567 100644 --- a/examples/training/avg_word_embeddings/training_stsbenchmark_bilstm.py +++ b/examples/training/avg_word_embeddings/training_stsbenchmark_bilstm.py @@ -5,53 +5,32 @@ Note, you can also pass BERT embeddings to the BiLSTM. """ -from torch.utils.data import DataLoader -import math -from sentence_transformers import models, losses, util -from sentence_transformers import LoggingHandler, SentenceTransformer +import traceback +from datasets import load_dataset +from sentence_transformers import models, losses +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample import logging from datetime import datetime -import os -import csv -import gzip - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - -# Read the dataset -batch_size = 32 -model_save_path = "output/training_stsbenchmark_bilstm-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - -logging.info("Read STSbenchmark train dataset") - -train_samples = [] -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) +num_train_epochs = 1 +batch_size = 32 +output_dir = "output/training_stsbenchmark_bilstm-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) +# 1. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) +# 2. Define the model # Map tokens to traditional word embeddings like GloVe word_embedding_model = models.WordEmbeddings.from_text_file("glove.6B.300d.txt.gz") @@ -60,44 +39,68 @@ # Apply mean pooling to get one fixed sized sentence vector pooling_model = models.Pooling( lstm.get_word_embedding_dimension(), - pooling_mode_mean_tokens=False, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=True, + pooling_mode="mean", ) - - model = SentenceTransformer(modules=[word_embedding_model, lstm, pooling_model]) - -# Convert the dataset to a DataLoader ready for training -logging.info("Read STSbenchmark train dataset") -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size) +# 3. Define our training loss +# CosineSimilarityLoss (https://sbert.net/docs/package_reference/losses.html#cosentloss) needs two text columns and +# one similarity score column (between 0 and 1) train_loss = losses.CosineSimilarityLoss(model=model) -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - -# Configure the training -num_epochs = 10 -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", ) +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="glove-bilstm-sts", # Will be used in W&B if `wandb` is installed +) -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -model.evaluate(evaluator) +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() + +# 7. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 8. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = "glove-bilstm-sts" +try: + model.push_to_hub(model_name) +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}')`." + ) diff --git a/examples/training/avg_word_embeddings/training_stsbenchmark_bow.py b/examples/training/avg_word_embeddings/training_stsbenchmark_bow.py index 503de464d..951e006a1 100644 --- a/examples/training/avg_word_embeddings/training_stsbenchmark_bow.py +++ b/examples/training/avg_word_embeddings/training_stsbenchmark_bow.py @@ -5,56 +5,35 @@ To make the model trainable, we add multiple dense layers to create a Deep Averaging Network (DAN). """ -from torch.utils.data import DataLoader +import traceback +from datasets import load_dataset import math from sentence_transformers import models, losses, util -from sentence_transformers import LoggingHandler, SentenceTransformer +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample from sentence_transformers.models.tokenizer.WordTokenizer import ENGLISH_STOP_WORDS import logging from datetime import datetime import os -import csv -import gzip - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - -# Read the dataset -batch_size = 32 -model_save_path = "output/training_tf-idf_word_embeddings-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - -logging.info("Read STSbenchmark train dataset") - -train_samples = [] -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) - - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) +num_train_epochs = 1 +batch_size = 32 +output_dir = "output/training_tf-idf_word_embeddings-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -##### Construction of the SentenceTransformer Model ##### +# 1. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) +# 2. Define the model # Wikipedia document frequency for words wiki_doc_freq = "wikipedia_doc_frequencies.txt" if not os.path.exists(wiki_doc_freq): @@ -83,8 +62,6 @@ if len(vocab) >= max_vocab_size: break -##### Construction of the SentenceTransformer Model ##### - # Create the BoW model. Because we set word_weights to the IDF values and cumulative_term_frequency=True, we # get tf-idf vectors. Set word_weights to an empty dict and cumulative_term_frequency=False to get a 1-hot sentence encoding bow = models.BoW(vocab=vocab, word_weights=weights, cumulative_term_frequency=True) @@ -96,36 +73,74 @@ model = SentenceTransformer(modules=[bow, dan1, dan2]) - -# Convert the dataset to a DataLoader ready for training -logging.info("Read STSbenchmark train dataset") -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size) +# 3. Define our training loss +# CosineSimilarityLoss (https://sbert.net/docs/package_reference/losses.html#cosentloss) needs two text columns and +# one similarity score column (between 0 and 1) train_loss = losses.CosineSimilarityLoss(model=model) -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - -# Configure the training -num_epochs = 10 -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", ) +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="wikipedia-tf-idf-bow", # Will be used in W&B if `wandb` is installed +) -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -model.evaluate(test_evaluator) +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = "wikipedia-tf-idf-bow" +try: + model.push_to_hub(model_name) +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}')`." + ) diff --git a/examples/training/avg_word_embeddings/training_stsbenchmark_cnn.py b/examples/training/avg_word_embeddings/training_stsbenchmark_cnn.py index a7c822f52..07c743ac3 100644 --- a/examples/training/avg_word_embeddings/training_stsbenchmark_cnn.py +++ b/examples/training/avg_word_embeddings/training_stsbenchmark_cnn.py @@ -5,56 +5,37 @@ """ -from torch.utils.data import DataLoader -import math -from sentence_transformers import models, losses, util -from sentence_transformers import LoggingHandler, SentenceTransformer +import sys +import traceback +from datasets import load_dataset +from sentence_transformers import models, losses +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample import logging from datetime import datetime -import os -import csv -import gzip - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - -# Read the dataset -batch_size = 32 -model_save_path = "output/training_stsbenchmark_cnn-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - -logging.info("Read STSbenchmark train dataset") - -train_samples = [] -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) +model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-uncased" +num_train_epochs = 1 +batch_size = 32 +output_dir = "output/training_stsbenchmark_cnn-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) +# 1. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) +# 2. Define the model # Map tokens to vectors using BERT -word_embedding_model = models.Transformer("bert-base-uncased") +word_embedding_model = models.Transformer(model_name) cnn = models.CNN( in_word_embedding_dimension=word_embedding_model.get_word_embedding_dimension(), @@ -65,44 +46,78 @@ # Apply mean pooling to get one fixed sized sentence vector pooling_model = models.Pooling( cnn.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, + pooling_mode="mean", ) - - model = SentenceTransformer(modules=[word_embedding_model, cnn, pooling_model]) - -# Convert the dataset to a DataLoader ready for training -logging.info("Read STSbenchmark train dataset") -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size) +# 3. Define our training loss +# CosineSimilarityLoss (https://sbert.net/docs/package_reference/losses.html#cosentloss) needs two text columns and +# one similarity score column (between 0 and 1) train_loss = losses.CosineSimilarityLoss(model=model) -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - -# Configure the training -num_epochs = 10 -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", ) +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="cnn", # Will be used in W&B if `wandb` is installed +) -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -model.evaluate(test_evaluator) +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-cnn") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-cnn')`." + ) diff --git a/examples/training/avg_word_embeddings/training_stsbenchmark_tf-idf_word_embeddings.py b/examples/training/avg_word_embeddings/training_stsbenchmark_tf-idf_word_embeddings.py index f45a4e7d3..f11657e8c 100644 --- a/examples/training/avg_word_embeddings/training_stsbenchmark_tf-idf_word_embeddings.py +++ b/examples/training/avg_word_embeddings/training_stsbenchmark_tf-idf_word_embeddings.py @@ -9,28 +9,34 @@ https://public.ukp.informatik.tu-darmstadt.de/reimers/embeddings/wikipedia_doc_frequencies.txt """ -from torch.utils.data import DataLoader +import traceback +from datasets import load_dataset import math from sentence_transformers import models, losses, util -from sentence_transformers import LoggingHandler, SentenceTransformer +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample import logging from datetime import datetime import os -import csv -import gzip -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) -# Read the dataset +num_train_epochs = 1 batch_size = 32 -model_save_path = "output/training_tf-idf_word_embeddings-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +output_dir = "output/training_tf-idf_word_embeddings-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + +# 1. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) +# 2. Define the model # Wikipedia document frequency for words wiki_doc_freq = "wikipedia_doc_frequencies.txt" if not os.path.exists(wiki_doc_freq): @@ -38,32 +44,6 @@ "https://public.ukp.informatik.tu-darmstadt.de/reimers/embeddings/wikipedia_doc_frequencies.txt", wiki_doc_freq ) -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - -logging.info("Read STSbenchmark train dataset") - -train_samples = [] -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) - - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) - -##### Construction of the SentenceTransformer Model ##### - # Map tokens to traditional word embeddings like GloVe word_embedding_model = models.WordEmbeddings.from_text_file("glove.6B.300d.txt.gz") @@ -84,52 +64,86 @@ # Initialize the WordWeights model. This model must be between the WordEmbeddings and the Pooling model word_weights = models.WordWeights(vocab=vocab, word_weights=word_weights, unknown_word_weight=unknown_word_weight) - # Apply mean pooling to get one fixed sized sentence vector pooling_model = models.Pooling( word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, + pooling_mode="mean", ) # Add two trainable feed-forward networks (DAN) sent_embeddings_dimension = pooling_model.get_sentence_embedding_dimension() dan1 = models.Dense(in_features=sent_embeddings_dimension, out_features=sent_embeddings_dimension) dan2 = models.Dense(in_features=sent_embeddings_dimension, out_features=sent_embeddings_dimension) - model = SentenceTransformer(modules=[word_embedding_model, word_weights, pooling_model, dan1, dan2]) - -# Convert the dataset to a DataLoader ready for training -logging.info("Read STSbenchmark train dataset") -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size) +# 3. Define our training loss +# CosineSimilarityLoss (https://sbert.net/docs/package_reference/losses.html#cosentloss) needs two text columns and +# one similarity score column (between 0 and 1) train_loss = losses.CosineSimilarityLoss(model=model) -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - -# Configure the training -num_epochs = 10 -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", ) +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="glove-wikipedia-tf-idf", # Will be used in W&B if `wandb` is installed +) -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -model.evaluate(test_evaluator) +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = "glove-wikipedia-tf-idf" +try: + model.push_to_hub(model_name) +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}')`." + ) diff --git a/examples/training/data_augmentation/train_sts_indomain_bm25.py b/examples/training/data_augmentation/train_sts_indomain_bm25.py index 1787388bb..121dbb9bc 100644 --- a/examples/training/data_augmentation/train_sts_indomain_bm25.py +++ b/examples/training/data_augmentation/train_sts_indomain_bm25.py @@ -1,6 +1,6 @@ """ The script shows how to train Augmented SBERT (In-Domain) strategy for STSb dataset with BM25 sampling. -We utlise easy and practical elasticsearch (https://www.elastic.co/) for BM25 sampling. +We utilise easy and practical elasticsearch (https://www.elastic.co/) for BM25 sampling. Installations: For this example, elasticsearch to be installed (pip install elasticsearch) @@ -26,28 +26,28 @@ """ +import traceback +from datasets import load_dataset, Dataset, concatenate_datasets from torch.utils.data import DataLoader -from sentence_transformers import models, losses, util +from sentence_transformers import losses from sentence_transformers.cross_encoder import CrossEncoder from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator -from sentence_transformers import LoggingHandler, SentenceTransformer +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator from sentence_transformers.readers import InputExample from elasticsearch import Elasticsearch from datetime import datetime import logging -import csv import sys import tqdm import math -import gzip -import os -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # suppressing INFO messages for elastic-search logger tracer = logging.getLogger("elasticsearch") @@ -62,42 +62,23 @@ num_epochs = 1 max_seq_length = 128 -###### Read Datasets ###### - -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - cross_encoder_path = ( "output/cross-encoder/stsb_indomain_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) -bi_encoder_path = ( +sentence_transformer_path = ( "output/bi-encoder/stsb_augsbert_BM25_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) -###### Cross-encoder (simpletransformers) ###### -logging.info("Loading sentence-transformers model: {}".format(model_name)) -# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for cross-encoder model +# Use a Hugging Face model (like BERT, RoBERTa, XLNet, XLM-R) for loading the CrossEncoder and SentenceTransformer cross_encoder = CrossEncoder(model_name, num_labels=1) - - -###### Bi-encoder (sentence-transformers) ###### -logging.info("Loading bi-encoder model: {}".format(model_name)) -# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings -word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) - -# Apply mean pooling to get one fixed sized sentence vector -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) - -bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model]) +sentence_transformer = SentenceTransformer(model_name) +sentence_transformer.max_seq_length = max_seq_length ##################################################################### @@ -108,31 +89,27 @@ logging.info("Step 1: Train cross-encoder: ({}) with STSbenchmark".format(model_name)) -gold_samples = [] -dev_samples = [] -test_samples = [] - -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - - if row["split"] == "dev": - dev_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - elif row["split"] == "test": - test_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - else: - # As we want to get symmetric scores, i.e. CrossEncoder(A,B) = CrossEncoder(B,A), we pass both combinations to the train set - gold_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - gold_samples.append(InputExample(texts=[row["sentence2"], row["sentence1"]], label=score)) +# Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) +gold_samples = [ + InputExample(texts=[sentence1, sentence2], label=data["score"]) + for data in train_dataset + for sentence1, sentence2 in [(data["sentence1"], data["sentence2"]), (data["sentence2"], data["sentence1"])] +] # We wrap gold_samples (which is a List[InputExample]) into a pytorch DataLoader train_dataloader = DataLoader(gold_samples, shuffle=True, batch_size=batch_size) - # We add an evaluator, which evaluates the performance during training -evaluator = CECorrelationEvaluator.from_input_examples(dev_samples, name="sts-dev") +evaluator = CECorrelationEvaluator( + sentence_pairs=[[data["sentence1"], data["sentence2"]] for data in eval_dataset], + scores=[data["score"] for data in eval_dataset], + name="sts-dev", +) # Configure the training warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up @@ -215,39 +192,81 @@ # Convert the dataset to a DataLoader ready for training logging.info("Read STSbenchmark gold and silver train dataset") -silver_samples = list( - InputExample(texts=[data[0], data[1]], label=score) for data, score in zip(silver_data, silver_scores) +silver_samples = Dataset.from_dict( + { + "sentence1": [data[0] for data in silver_data], + "sentence2": [data[1] for data in silver_data], + "score": silver_scores, + } ) +train_dataset = concatenate_datasets([train_dataset, silver_samples]) - -train_dataloader = DataLoader(gold_samples + silver_samples, shuffle=True, batch_size=batch_size) -train_loss = losses.CosineSimilarityLoss(model=bi_encoder) +train_loss = losses.CosineSimilarityLoss(model=sentence_transformer) logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") +evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) -# Configure the training. -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) +# Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=sentence_transformer_path, + # Optional training parameters: + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="augmentation-indomain-bm25-sts", # Will be used in W&B if `wandb` is installed +) -# Train the bi-encoder model -bi_encoder.fit( - train_objectives=[(train_dataloader, train_loss)], +# Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=sentence_transformer, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=evaluator, - epochs=num_epochs, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=bi_encoder_path, ) +trainer.train() -###################################################################### -# -# Evaluate Augmented SBERT performance on STS benchmark (test) dataset -# -###################################################################### -# load the stored augmented-sbert model -bi_encoder = SentenceTransformer(bi_encoder_path) -logging.info("Read STSbenchmark test dataset") -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -test_evaluator(bi_encoder, output_path=bi_encoder_path) +# Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(sentence_transformer) + +# Save the trained & evaluated model locally +final_output_dir = f"{sentence_transformer_path}/final" +sentence_transformer.save(final_output_dir) + +# (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + sentence_transformer.push_to_hub(f"{model_name}-augmentation-indomain-bm25-sts") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-augmentation-indomain-bm25-sts')`." + ) diff --git a/examples/training/data_augmentation/train_sts_indomain_nlpaug.py b/examples/training/data_augmentation/train_sts_indomain_nlpaug.py index b6da91409..735e5cf36 100644 --- a/examples/training/data_augmentation/train_sts_indomain_nlpaug.py +++ b/examples/training/data_augmentation/train_sts_indomain_nlpaug.py @@ -29,26 +29,23 @@ python train_sts_indomain_nlpaug.py """ -from torch.utils.data import DataLoader +import traceback +from datasets import load_dataset, Dataset, concatenate_datasets import torch -import math -from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util +from sentence_transformers import SentenceTransformer, losses from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample import nlpaug.augmenter.word as naw import logging from datetime import datetime import sys -import os -import gzip -import csv import tqdm -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-uncased" @@ -56,54 +53,21 @@ batch_size = 16 num_epochs = 1 -###### Read Datasets ###### - -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - - -model_save_path = ( +output_dir = ( "output/bi-encoder/stsb_indomain_eda_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) -###### Bi-encoder (sentence-transformers) ###### -logging.info("Loading SBERT model: {}".format(model_name)) # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings -word_embedding_model = models.Transformer(model_name) - -# Apply mean pooling to get one fixed sized sentence vector -pooling_model = models.Pooling( - word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, -) - -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -# Convert the dataset to a DataLoader ready for training -gold_samples = [] -dev_samples = [] -test_samples = [] +model = SentenceTransformer(model_name) -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) - - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - gold_samples.append(inp_example) +# Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) ################################################################################## # @@ -125,17 +89,24 @@ # aug = naw.SynonymAug(aug_src='wordnet') #### Synonym replacement using BERT #### -aug = naw.ContextualWordEmbsAug(model_path=model_name, action="insert", device=device) - -silver_samples = [] -progress = tqdm.tqdm(unit="docs", total=len(gold_samples)) - -for sample in gold_samples: - augmented_texts = aug.augment(sample.texts) - inp_example = InputExample(texts=augmented_texts, label=sample.label) - silver_samples.append(inp_example) +aug = naw.ContextualWordEmbsAug(model_path=model_name, action="insert") + +silver_samples = { + "sentence1": [], + "sentence2": [], + "score": [], +} +progress = tqdm.tqdm(unit="docs", total=len(test_dataset)) + +for sample in train_dataset: + augmented_texts = aug.augment([sample["sentence1"], sample["sentence2"]]) + silver_samples["sentence1"].append(augmented_texts[0]) + silver_samples["sentence2"].append(augmented_texts[1]) + silver_samples["score"].append(sample["score"]) progress.update(1) +silver_dataset = Dataset.from_dict(silver_samples) + progress.reset() progress.close() logging.info("Textual augmentation completed....") @@ -147,36 +118,73 @@ # ################################################################### -logging.info("Read STSbenchmark (gold + silver) training dataset") -train_dataloader = DataLoader(gold_samples + silver_samples, shuffle=True, batch_size=batch_size) +train_dataset = concatenate_datasets([train_dataset, silver_dataset]) train_loss = losses.CosineSimilarityLoss(model=model) - logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - - -# Configure the training. -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) +evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", +) +# Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="augmentation-indomain-nlpaug-sts", # Will be used in W&B if `wandb` is installed +) -# Train the SBERT model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=evaluator, - epochs=num_epochs, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=model_save_path, ) +trainer.train() -########################################################## -# -# Evaluate SBERT performance on STS benchmark test dataset -# -########################################################## -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -test_evaluator(model, output_path=model_save_path) +# Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) + +# Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-augmentation-indomain-nlpaug-sts") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-augmentation-indomain-nlpaug-sts')`." + ) diff --git a/examples/training/distillation/dimensionality_reduction.py b/examples/training/distillation/dimensionality_reduction.py index 82b6a2916..3ba9432a9 100644 --- a/examples/training/distillation/dimensionality_reduction.py +++ b/examples/training/distillation/dimensionality_reduction.py @@ -15,71 +15,44 @@ without further changes needed. """ +from datasets import load_dataset from sklearn.decomposition import PCA -from sentence_transformers import SentenceTransformer, LoggingHandler, util, evaluation, models, InputExample +from sentence_transformers import SentenceTransformer, models import logging -import os -import gzip -import csv import random import numpy as np import torch -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -logger = logging.getLogger(__name__) -#### /print debug information to stdout +from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # Model for which we apply dimensionality reduction -model = SentenceTransformer("all-MiniLM-L6-v2") +model_name = "all-MiniLM-L6-v2" +model = SentenceTransformer(model_name) # New size for the embeddings new_dimension = 128 - -# We use AllNLI as a source of sentences to compute PCA -nli_dataset_path = "datasets/AllNLI.tsv.gz" - -# We use the STS benchmark dataset to see how much performance we loose by using the dimensionality reduction -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" - -if not os.path.exists(nli_dataset_path): - util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - - # We measure the performance of the original model # and later we will measure the performance with the reduces dimension size -logger.info("Read STSbenchmark test dataset") -eval_examples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "test": - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - eval_examples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - -# Evaluate the original model on the STS benchmark dataset -stsb_evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(eval_examples, name="sts-benchmark-test") - -logger.info("Original model performance:") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +stsb_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + name="sts-test", +) + +logging.info("Original model performance:") stsb_evaluator(model) ######## Reduce the embedding dimensions ######## -# Read sentences from NLI dataset -nli_sentences = set() -with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - nli_sentences.add(row["sentence1"]) - nli_sentences.add(row["sentence2"]) +train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train") -nli_sentences = list(nli_sentences) +nli_sentences = train_dataset["sentence1"] + train_dataset["sentence2"] random.shuffle(nli_sentences) # To determine the PCA matrix, we need some example sentence embeddings. @@ -103,12 +76,16 @@ model.add_module("dense", dense) # Evaluate the model with the reduce embedding size -logger.info("Model with {} dimensions:".format(new_dimension)) +logging.info("Model with {} dimensions:".format(new_dimension)) stsb_evaluator(model) # If you like, you can store the model on disc by uncommenting the following line -# model.save('models/my-128dim-model') +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +model.save(f"{model_name}-128dim") # You can then load the adapted model that produces 128 dimensional embeddings like this: # model = SentenceTransformer('models/my-128dim-model') + +# Or you can push the model to the Hugging Face Hub +# model.push_to_hub(f'{model_name}-128dim') diff --git a/examples/training/distillation/model_distillation.py b/examples/training/distillation/model_distillation.py index f8e6bf333..b33f777ac 100644 --- a/examples/training/distillation/model_distillation.py +++ b/examples/training/distillation/model_distillation.py @@ -20,19 +20,21 @@ of the teacher performance, while being 2.3 times faster. """ -from torch.utils.data import DataLoader +import traceback +from datasets import load_dataset, concatenate_datasets, Dataset +import pandas as pd from sentence_transformers import models, losses, evaluation -from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample -from sentence_transformers.datasets import ParallelSentencesDataset +from sentence_transformers import LoggingHandler, SentenceTransformer import logging from datetime import datetime -import os -import gzip -import csv -import random from sklearn.decomposition import PCA import torch +from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments + #### Just some code to print debug information to stdout logging.basicConfig( @@ -45,45 +47,16 @@ teacher_model_name = "stsb-roberta-base-v2" teacher_model = SentenceTransformer(teacher_model_name) -output_path = "output/model-distillation-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - - -use_layer_reduction = True - -# There are two options to create a light and fast student model: -if use_layer_reduction: - # 1) Create a smaller student model by using only some of the teacher layers - student_model = SentenceTransformer(teacher_model_name) - - # Get the transformer model - auto_model = student_model._first_module().auto_model - - # Which layers to keep from the teacher model. We equally spread the layers to keep over the original teacher - # layers_to_keep = [5] - # layers_to_keep = [3, 7] - # layers_to_keep = [3, 7, 11] - layers_to_keep = [1, 4, 7, 10] # Keep 4 layers from the teacher - # layers_to_keep = [0, 2, 4, 6, 8, 10] - # layers_to_keep = [0, 1, 3, 4, 6, 7, 9, 10] - - logging.info("Remove layers from student. Only keep these layers: {}".format(layers_to_keep)) - new_layers = torch.nn.ModuleList( - [layer_module for i, layer_module in enumerate(auto_model.encoder.layer) if i in layers_to_keep] - ) - auto_model.encoder.layer = new_layers - auto_model.config.num_hidden_layers = len(layers_to_keep) -else: - # 2) The other option is to train a small model like TinyBERT to imitate the teacher. - # You can find some small BERT models here: https://huggingface.co/nreimers - word_embedding_model = models.Transformer("nreimers/TinyBERT_L-4_H-312_v2") - pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) - student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) +output_dir = "output/model-distillation-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +# We will train a small model like TinyBERT to imitate the teacher. +# You can find some small BERT models here: https://huggingface.co/nreimers +student_model_name = "nreimers/TinyBERT_L-4_H-312_v2" +student_model = SentenceTransformer(student_model_name) inference_batch_size = 64 train_batch_size = 64 - # We use AllNLI as a source of sentences for the distillation nli_dataset_path = "datasets/AllNLI.tsv.gz" @@ -94,71 +67,73 @@ sts_dataset_path = "datasets/stsbenchmark.tsv.gz" -# Download datasets if needed -if not os.path.exists(nli_dataset_path): - util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) +logging.info("Load the AllNLI dataset") +# Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli +nli_train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train") +nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="dev") +# Concatenate all sentences into a new column "sentence" -if not os.path.exists(wikipedia_dataset_path): - util.http_get("https://sbert.net/datasets/wikipedia-en-sentences.txt.gz", wikipedia_dataset_path) -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) +def combine_sentences(batch): + return {"sentence": batch["sentence1"] + batch["sentence2"]} -# We need sentences to train our distillation. Here, we use sentences from AllNLI and from WikiPedia -train_sentences_nli = set() -dev_sentences_nli = set() -train_sentences_wikipedia = [] -dev_sentences_wikipedia = [] +nli_train_dataset = nli_train_dataset.map( + combine_sentences, batched=True, remove_columns=nli_train_dataset.column_names +) +nli_eval_dataset = nli_eval_dataset.map(combine_sentences, batched=True, remove_columns=nli_eval_dataset.column_names) -# Read ALLNLI -with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "dev": - dev_sentences_nli.add(row["sentence1"]) - dev_sentences_nli.add(row["sentence2"]) - else: - train_sentences_nli.add(row["sentence1"]) - train_sentences_nli.add(row["sentence2"]) -train_sentences_nli = list(train_sentences_nli) -random.shuffle(train_sentences_nli) +def deduplicate(dataset): + df = pd.DataFrame(dataset) + df = df.drop_duplicates() + return Dataset.from_pandas(df, preserve_index=False) -dev_sentences_nli = list(dev_sentences_nli) -random.shuffle(dev_sentences_nli) -dev_sentences_nli = dev_sentences_nli[0:5000] # Limit dev sentences to 5k -# Read Wikipedia sentences file -with gzip.open(wikipedia_dataset_path, "rt", encoding="utf8") as fIn: - wikipeda_sentences = [line.strip() for line in fIn] +nli_train_dataset = deduplicate(nli_train_dataset) +nli_eval_dataset = deduplicate(nli_eval_dataset) +logging.info(nli_train_dataset) -dev_sentences_wikipedia = wikipeda_sentences[ - 0:5000 -] # Use the first 5k sentences from the wikipedia file for development -train_sentences_wikipedia = wikipeda_sentences[5000:] +logging.info("Load the STSB dataset") +# Load the STSB eval/test datasets: https://huggingface.co/datasets/sentence-transformers/stsb +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +stsb_test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(stsb_eval_dataset) -# We use the STS benchmark dataset to measure the performance of student model im comparison to the teacher model -logging.info("Read STSbenchmark dev dataset") -dev_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "dev": - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - dev_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) -dev_evaluator_sts = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") +logging.info("Load the Wikipedia dataset") +# Load the Wikipedia dataset: https://huggingface.co/datasets/sentence-transformers/wikipedia-en-sentences +wikipedia_train_dataset = load_dataset("sentence-transformers/wikipedia-en-sentences", split="train") +# Take 5000 random sentences from the Wikipedia dataset for evaluation +wikipedia_train_dataset_dict = wikipedia_train_dataset.train_test_split(test_size=5000) +wikipedia_train_dataset = wikipedia_train_dataset_dict["train"] +wikipedia_eval_dataset = wikipedia_train_dataset_dict["test"] +logging.info(wikipedia_train_dataset) -logging.info("Teacher Performance:") -dev_evaluator_sts(teacher_model) +# Concatenate the NLI and Wikipedia datasets for training +train_dataset: Dataset = concatenate_datasets([nli_train_dataset, wikipedia_train_dataset]) +# Create a relatively small dataset for evaluation +eval_dataset: Dataset = concatenate_datasets( + [nli_eval_dataset.select(range(5000)), wikipedia_eval_dataset.select(range(5000))] +) + +# Create an STSB evaluator +dev_evaluator_stsb = EmbeddingSimilarityEvaluator( + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", +) +logging.info("Teacher Performance") +dev_evaluator_stsb(teacher_model) # Student model has fewer dimensions. Compute PCA for the teacher to reduce the dimensions if student_model.get_sentence_embedding_dimension() < teacher_model.get_sentence_embedding_dimension(): logging.info("Student model has fewer dimensions than the teacher. Compute PCA for down projection") - pca_sentences = train_sentences_nli[0:20000] + train_sentences_wikipedia[0:20000] + pca_sentences = nli_train_dataset[:20000]["sentence"] + wikipedia_train_dataset[:20000]["sentence"] pca_embeddings = teacher_model.encode(pca_sentences, convert_to_numpy=True) pca = PCA(n_components=student_model.get_sentence_embedding_dimension()) pca.fit(pca_embeddings) @@ -174,37 +149,97 @@ teacher_model.add_module("dense", dense) logging.info("Teacher Performance with {} dimensions:".format(teacher_model.get_sentence_embedding_dimension())) - dev_evaluator_sts(teacher_model) + dev_evaluator_stsb(teacher_model) -# We train the student_model such that it creates sentence embeddings similar to the embeddings from the teacher_model -# For this, we need a large set of sentences. These sentences are embedded using the teacher model, -# and the student tries to mimic these embeddings. It is the same approach as used in: https://arxiv.org/abs/2004.09813 -train_data = ParallelSentencesDataset( - student_model=student_model, - teacher_model=teacher_model, - batch_size=inference_batch_size, - use_embedding_cache=False, -) -train_data.add_dataset([[sent] for sent in train_sentences_nli], max_sentence_length=256) -train_data.add_dataset([[sent] for sent in train_sentences_wikipedia], max_sentence_length=256) +# Use the teacher model to get the gold embeddings +def map_embeddings(batch): + return { + "label": teacher_model.encode( + batch["sentence"], batch_size=inference_batch_size, show_progress_bar=False + ).tolist() + } + + +train_dataset = train_dataset.select(range(200000)) +train_dataset = train_dataset.map(map_embeddings, batched=True, batch_size=50000) +# Optionally, save the dataset to disk to speed up future runs +train_dataset.save_to_disk("datasets/distillation_train_dataset") +# from datasets import DatasetDict, load_from_disk + +# train_dataset = load_from_disk("datasets/distillation_train_dataset") +# if isinstance(train_dataset, DatasetDict): +# train_dataset = train_dataset["train"] +eval_dataset = eval_dataset.map(map_embeddings, batched=True, batch_size=50000) -train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size) train_loss = losses.MSELoss(model=student_model) # We create an evaluator, that measure the Mean Squared Error (MSE) between the teacher and the student embeddings -dev_sentences = dev_sentences_nli + dev_sentences_wikipedia -dev_evaluator_mse = evaluation.MSEEvaluator(dev_sentences, dev_sentences, teacher_model=teacher_model) - -# Train the student model to imitate the teacher -student_model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluation.SequentialEvaluator([dev_evaluator_sts, dev_evaluator_mse]), - epochs=1, - warmup_steps=1000, - evaluation_steps=5000, - output_path=output_path, - save_best_model=True, - optimizer_params={"lr": 1e-4, "eps": 1e-6}, - use_amp=True, +eval_sentences = eval_dataset["sentence"] +dev_evaluator_mse = evaluation.MSEEvaluator(eval_sentences, eval_sentences, teacher_model=teacher_model) +dev_evaluator = evaluation.SequentialEvaluator([dev_evaluator_stsb, dev_evaluator_mse]) + +# Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=1, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + metric_for_best_model="eval_sts-dev_spearman_cosine", + load_best_model_at_end=True, + learning_rate=1e-4, + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=500, + save_strategy="steps", + save_steps=500, + save_total_limit=2, + logging_steps=100, + run_name="distillation-layer-reduction", # Will be used in W&B if `wandb` is installed +) + +# Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=student_model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, ) +trainer.train() + +# Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=stsb_test_dataset["sentence1"], + sentences2=stsb_test_dataset["sentence2"], + scores=stsb_test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(student_model) + +# Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +student_model.save(final_output_dir) + +# (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +if "/" in student_model_name: + student_model_name = student_model_name.split("/")[-1] +if "/" in teacher_model_name: + teacher_model_name = teacher_model_name.split("/")[-1] +repo_id = f"{student_model_name}-distilled-from-{teacher_model_name}" +try: + student_model.push_to_hub(repo_id) +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub({repo_id!r})`." + ) diff --git a/examples/training/distillation/model_distillation_layer_reduction.py b/examples/training/distillation/model_distillation_layer_reduction.py new file mode 100644 index 000000000..7a0d76f5d --- /dev/null +++ b/examples/training/distillation/model_distillation_layer_reduction.py @@ -0,0 +1,229 @@ +""" +This file contains an example how to make a SentenceTransformer model faster and lighter. + +This is achieved by using Knowledge Distillation: We use a well working teacher model to train +a fast and light student model. The student model learns to imitate the produced +sentence embeddings from the teacher. We train this on a diverse set of sentences we got +from SNLI + Multi+NLI + Wikipedia. + +After the distillation is finished, the student model produce nearly the same embeddings as the +teacher, however, it will be much faster. + +The script implements to options two options to initialize the student: +Option 1: Train a light transformer model like TinyBERT to imitate the teacher +Option 2: We take the teacher model and keep only certain layers, for example, only 4 layers. + +Option 2) works usually better, as we keep most of the weights from the teacher. In Option 1, we have to tune all +weights in the student from scratch. + +There is a performance - speed trade-off. However, we found that a student with 4 instead of 12 layers keeps about 99.4% +of the teacher performance, while being 2.3 times faster. +""" + +import traceback +from datasets import load_dataset, concatenate_datasets, Dataset +import pandas as pd +from sentence_transformers import losses, evaluation +from sentence_transformers import SentenceTransformer +import logging +from datetime import datetime +import torch + +from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments + + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) + + +# Teacher Model: Model we want to distill to a smaller model +teacher_model_name = "mixedbread-ai/mxbai-embed-large-v1" +teacher_model = SentenceTransformer(teacher_model_name) + +output_dir = "output/model-distillation-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + +# Create a smaller student model by using only some of the teacher layers +student_model = SentenceTransformer(teacher_model_name) + +# Get the transformer model +auto_model = student_model._first_module().auto_model + +# Which layers to keep from the teacher model. We equally spread the layers to keep over the original teacher +# layers_to_keep = [5] +# layers_to_keep = [3, 7] +# layers_to_keep = [3, 7, 11] +# layers_to_keep = [0, 2, 4, 6, 8, 10] +# layers_to_keep = [0, 1, 3, 4, 6, 7, 9, 10] +# Keep every third layer: +layers_to_keep = [0, 3, 6, 9, 12, 15, 18, 21] + +logging.info("Remove layers from student. Only keep these layers: {}".format(layers_to_keep)) +new_layers = torch.nn.ModuleList( + [layer_module for i, layer_module in enumerate(auto_model.encoder.layer) if i in layers_to_keep] +) +auto_model.encoder.layer = new_layers +auto_model.config.num_hidden_layers = len(layers_to_keep) +print( + f"Number of parameters in the Teacher model: {sum(p.numel() for p in teacher_model.parameters() if p.requires_grad)}" +) +print( + f"Number of parameters in the Student model: {sum(p.numel() for p in student_model.parameters() if p.requires_grad)}" +) + +inference_batch_size = 128 +train_batch_size = 64 + +logging.info("Load the AllNLI dataset") +# Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli +nli_train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train") +nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="dev") +# Concatenate all sentences into a new column "sentence" + + +def combine_sentences(batch): + return {"sentence": batch["sentence1"] + batch["sentence2"]} + + +nli_train_dataset = nli_train_dataset.map( + combine_sentences, batched=True, remove_columns=nli_train_dataset.column_names +) +nli_eval_dataset = nli_eval_dataset.map(combine_sentences, batched=True, remove_columns=nli_eval_dataset.column_names) + + +def deduplicate(dataset): + df = pd.DataFrame(dataset) + df = df.drop_duplicates() + return Dataset.from_pandas(df, preserve_index=False) + + +nli_train_dataset = deduplicate(nli_train_dataset) +nli_eval_dataset = deduplicate(nli_eval_dataset) +logging.info(nli_train_dataset) + +logging.info("Load the STSB dataset") +# Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +stsb_test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(stsb_eval_dataset) + +logging.info("Load the Wikipedia dataset") +# Load the Wikipedia dataset: https://huggingface.co/datasets/sentence-transformers/wikipedia-en-sentences +wikipedia_train_dataset = load_dataset("sentence-transformers/wikipedia-en-sentences", split="train") +# Take 5000 random sentences from the Wikipedia dataset for evaluation +wikipedia_train_dataset_dict = wikipedia_train_dataset.train_test_split(test_size=5000) +wikipedia_train_dataset = wikipedia_train_dataset_dict["train"] +wikipedia_eval_dataset = wikipedia_train_dataset_dict["test"] +logging.info(wikipedia_train_dataset) + +# Concatenate the NLI and Wikipedia datasets for training +train_dataset: Dataset = concatenate_datasets([nli_train_dataset, wikipedia_train_dataset]) +# Create a relatively small dataset for evaluation +eval_dataset: Dataset = concatenate_datasets( + [nli_eval_dataset.select(range(5000)), wikipedia_eval_dataset.select(range(5000))] +) + + +# Use the teacher model to get the gold embeddings +def map_embeddings(batch): + return { + "label": teacher_model.encode( + batch["sentence"], batch_size=inference_batch_size, show_progress_bar=False + ).tolist() + } + + +train_dataset = train_dataset.map(map_embeddings, batched=True, batch_size=50000) +# Optionally, save the dataset to disk to speed up future runs +train_dataset.save_to_disk("datasets/distillation_train_dataset") +# from datasets import DatasetDict, load_from_disk + +# train_dataset = load_from_disk("datasets/distillation_train_dataset") +# if isinstance(train_dataset, DatasetDict): +# train_dataset = train_dataset["train"] +eval_dataset = eval_dataset.map(map_embeddings, batched=True, batch_size=50000) + +# Prepare the training loss +train_loss = losses.MSELoss(model=student_model) + +# Create an STSB evaluator +dev_evaluator_stsb = EmbeddingSimilarityEvaluator( + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", +) +logging.info("Running STSB evaluation on the teacher model") +dev_evaluator_stsb(teacher_model) + +# We create an evaluator, that measure the Mean Squared Error (MSE) between the teacher and the student embeddings +eval_sentences = eval_dataset["sentence"] +dev_evaluator_mse = evaluation.MSEEvaluator(eval_sentences, eval_sentences, teacher_model=teacher_model) +dev_evaluator = evaluation.SequentialEvaluator([dev_evaluator_stsb, dev_evaluator_mse]) + +# Run the evaluator before training to get a baseline performance of the student model +dev_evaluator(student_model) + +# Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=1, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + metric_for_best_model="eval_sts-dev_spearman_cosine", + load_best_model_at_end=True, + learning_rate=1e-4, + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=5000, + save_strategy="steps", + save_steps=5000, + save_total_limit=2, + logging_steps=1000, + run_name="distillation-layer-reduction", # Will be used in W&B if `wandb` is installed +) + +# Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=student_model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() + +# Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=stsb_test_dataset["sentence1"], + sentences2=stsb_test_dataset["sentence2"], + scores=stsb_test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(student_model) + +# Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +student_model.save(final_output_dir) + +# (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = teacher_model_name if "/" not in teacher_model_name else teacher_model_name.split("/")[-1] +try: + student_model.push_to_hub(f"{model_name}-{len(layers_to_keep)}-layers") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-{len(layers_to_keep)}-layers')`." + ) diff --git a/examples/training/matryoshka/2d_matryoshka_nli.py b/examples/training/matryoshka/2d_matryoshka_nli.py index 504b14286..0b6acd2ec 100644 --- a/examples/training/matryoshka/2d_matryoshka_nli.py +++ b/examples/training/matryoshka/2d_matryoshka_nli.py @@ -11,147 +11,112 @@ python 2d_matryoshka_nli.py pretrained_transformer_model_name """ -import math +import traceback from datasets import load_dataset -from sentence_transformers import models, losses, datasets -from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample +from sentence_transformers import losses +from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction import logging from datetime import datetime import sys -import os -import gzip -import csv -import random - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - -model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base" -train_batch_size = 128 # The larger you select this, the better the results (usually). But it requires more GPU memory -max_seq_length = 75 -num_epochs = 1 - -# Save path of the model -model_save_path = ( - "output/2d_matryoshka_nli_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -) - - -# Here we define our SentenceTransformer model -word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean") -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -# Check if dataset exists. If not, download and extract it -nli_dataset_path = "data/AllNLI.tsv.gz" - -if not os.path.exists(nli_dataset_path): - util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) - -# Read the AllNLI.tsv.gz file and create the training dataset -logging.info("Read AllNLI train dataset") +from sentence_transformers.training_args import BatchSamplers -def add_to_samples(sent1, sent2, label): - if sent1 not in train_data: - train_data[sent1] = {"contradiction": set(), "entailment": set(), "neutral": set()} - train_data[sent1][label].add(sent2) +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) +model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base" +batch_size = 128 # The larger you select this, the better the results (usually). But it requires more GPU memory +num_train_epochs = 1 -train_data = {} -with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "train": - sent1 = row["sentence1"].strip() - sent2 = row["sentence2"].strip() - - add_to_samples(sent1, sent2, row["label"]) - add_to_samples(sent2, sent1, row["label"]) # Also add the opposite - - -train_samples = [] -for sent1, others in train_data.items(): - if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0: - train_samples.append( - InputExample( - texts=[sent1, random.choice(list(others["entailment"])), random.choice(list(others["contradiction"]))] - ) - ) - train_samples.append( - InputExample( - texts=[random.choice(list(others["entailment"])), sent1, random.choice(list(others["contradiction"]))] - ) - ) - -logging.info("Train samples: {}".format(len(train_samples))) +# Save path of the model +output_dir = f"output/2d_matryoshka_nli_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) +# If we want, we can limit the maximum sequence length for the model +# model.max_seq_length = 75 +logging.info(model) -# Special data loader that avoid duplicates within a batch -train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=train_batch_size) +# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli +train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train") +eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev") +logging.info(train_dataset) +# If you wish, you can limit the number of training samples +# train_dataset = train_dataset.select(range(5000)) -# Our training loss -train_loss = losses.MultipleNegativesRankingLoss(model) -train_loss = losses.Matryoshka2dLoss(model, train_loss, [768, 512, 256, 128, 64]) +# 3. Define our training loss +inner_train_loss = losses.MultipleNegativesRankingLoss(model) +train_loss = losses.Matryoshka2dLoss(model, inner_train_loss, [768, 512, 256, 128, 64]) -stsb_dev = load_dataset("mteb/stsbenchmark-sts", split="validation") +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") dev_evaluator = EmbeddingSimilarityEvaluator( - stsb_dev["sentence1"], - stsb_dev["sentence2"], - [score / 5 for score in stsb_dev["score"]], + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], main_similarity=SimilarityFunction.COSINE, name="sts-dev", ) -# Configure the training -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="2d-matryoshka-nli", # Will be used in W&B if `wandb` is installed +) -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=dev_evaluator, - epochs=num_epochs, - evaluation_steps=int(len(train_dataloader) * 0.1), - warmup_steps=warmup_steps, - output_path=model_save_path, - use_amp=False, # Set to True, if your GPU supports FP16 operations ) +trainer.train() - -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - - -model = SentenceTransformer(model_save_path) -stsb_test = load_dataset("mteb/stsbenchmark-sts", split="test") +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") test_evaluator = EmbeddingSimilarityEvaluator( - stsb_test["sentence1"], - stsb_test["sentence2"], - [score / 5 for score in stsb_test["score"]], + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], main_similarity=SimilarityFunction.COSINE, name="sts-test", ) -test_evaluator(model, output_path=model_save_path) +test_evaluator(model) +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) -# Optionally, save the model to the Hugging Face Hub! +# 9. (Optional) save the model to the Hugging Face Hub! # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first model_name = model_name if "/" not in model_name else model_name.split("/")[-1] try: model.push_to_hub(f"{model_name}-nli-2d-matryoshka") except Exception: logging.error( - "Error uploading model to the Hugging Face Hub. To upload it manually, you can run " - f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({model_save_path!r})` " + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " f"and saving it using `model.push_to_hub('{model_name}-nli-2d-matryoshka')`." ) diff --git a/examples/training/matryoshka/2d_matryoshka_sts.py b/examples/training/matryoshka/2d_matryoshka_sts.py index a0be05a97..3880c8c93 100644 --- a/examples/training/matryoshka/2d_matryoshka_sts.py +++ b/examples/training/matryoshka/2d_matryoshka_sts.py @@ -10,118 +10,108 @@ python 2d_matryoshka_sts.py pretrained_transformer_model_name """ -from torch.utils.data import DataLoader -import math -from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util -from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample +import traceback +from datasets import load_dataset +from sentence_transformers import losses +from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments +from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction import logging from datetime import datetime import sys -import os -import gzip -import csv - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - - -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) -# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base model_name = sys.argv[1] if len(sys.argv) > 1 else "distilbert-base-uncased" - -# Read the dataset -train_batch_size = 16 -num_epochs = 4 -model_save_path = ( - "output/2d_matryoshka_sts_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +batch_size = 16 +num_train_epochs = 4 + +# Save path of the model +output_dir = f"output/2d_matryoshka_sts_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) +# If we want, we can limit the maximum sequence length for the model +# model.max_seq_length = 75 +logging.info(model) + +# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) + +# 3. Define our training loss +# CoSENTLoss (https://sbert.net/docs/package_reference/losses.html#cosentloss) needs two text columns and one +# similarity score column (between 0 and 1) +inner_train_loss = losses.CoSENTLoss(model=model) +train_loss = losses.Matryoshka2dLoss(model, inner_train_loss, [768, 512, 256, 128, 64]) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", ) -# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings -word_embedding_model = models.Transformer(model_name) - -# Apply mean pooling to get one fixed sized sentence vector -pooling_model = models.Pooling( - word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="2d-matryoshka-sts", # Will be used in W&B if `wandb` is installed ) -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -# Convert the dataset to a DataLoader ready for training -logging.info("Read STSbenchmark train dataset") - -train_samples = [] -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) - - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) - - -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) -train_loss = losses.CoSENTLoss(model=model) -train_loss = losses.Matryoshka2dLoss(model, train_loss, [768, 512, 256, 128, 64]) - - -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - - -# Configure the training. We skip evaluation in this example -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, ) +trainer.train() -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -test_evaluator(model, output_path=model_save_path) +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) -# Optionally, save the model to the Hugging Face Hub! +# 9. (Optional) save the model to the Hugging Face Hub! # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first model_name = model_name if "/" not in model_name else model_name.split("/")[-1] try: model.push_to_hub(f"{model_name}-sts-2d-matryoshka") except Exception: logging.error( - "Error uploading model to the Hugging Face Hub. To upload it manually, you can run " - f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({model_save_path!r})` " + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " f"and saving it using `model.push_to_hub('{model_name}-sts-2d-matryoshka')`." ) diff --git a/examples/training/matryoshka/matryoshka_nli.py b/examples/training/matryoshka/matryoshka_nli.py index c4369d13a..07a1d24d5 100644 --- a/examples/training/matryoshka/matryoshka_nli.py +++ b/examples/training/matryoshka/matryoshka_nli.py @@ -11,103 +11,56 @@ python matryoshka_nli.py pretrained_transformer_model_name """ -import math +import traceback from datasets import load_dataset -from sentence_transformers import models, losses, datasets -from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample +from sentence_transformers import losses +from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SequentialEvaluator, SimilarityFunction import logging from datetime import datetime import sys -import os -import gzip -import csv -import random - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout + +from sentence_transformers.training_args import BatchSamplers + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base" -train_batch_size = 128 # The larger you select this, the better the results (usually). But it requires more GPU memory -max_seq_length = 75 -num_epochs = 1 +batch_size = 128 # The larger you select this, the better the results (usually). But it requires more GPU memory +num_train_epochs = 1 matryoshka_dims = [768, 512, 256, 128, 64] # Save path of the model -model_save_path = ( - "output/matryoshka_nli_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -) - - -# Here we define our SentenceTransformer model -word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean") -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -# Check if dataset exists. If not, download and extract it -nli_dataset_path = "data/AllNLI.tsv.gz" - -if not os.path.exists(nli_dataset_path): - util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) - -# Read the AllNLI.tsv.gz file and create the training dataset -logging.info("Read AllNLI train dataset") - - -def add_to_samples(sent1, sent2, label): - if sent1 not in train_data: - train_data[sent1] = {"contradiction": set(), "entailment": set(), "neutral": set()} - train_data[sent1][label].add(sent2) - +output_dir = f"output/matryoshka_nli_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" -train_data = {} -with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "train": - sent1 = row["sentence1"].strip() - sent2 = row["sentence2"].strip() +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) +# If we want, we can limit the maximum sequence length for the model +# model.max_seq_length = 75 +logging.info(model) - add_to_samples(sent1, sent2, row["label"]) - add_to_samples(sent2, sent1, row["label"]) # Also add the opposite +# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli +train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train") +eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev") +logging.info(train_dataset) +# If you wish, you can limit the number of training samples +# train_dataset = train_dataset.select(range(5000)) -train_samples = [] -for sent1, others in train_data.items(): - if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0: - train_samples.append( - InputExample( - texts=[sent1, random.choice(list(others["entailment"])), random.choice(list(others["contradiction"]))] - ) - ) - train_samples.append( - InputExample( - texts=[random.choice(list(others["entailment"])), sent1, random.choice(list(others["contradiction"]))] - ) - ) - -logging.info("Train samples: {}".format(len(train_samples))) - - -# Special data loader that avoid duplicates within a batch -train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=train_batch_size) - +# 3. Define our training loss +inner_train_loss = losses.MultipleNegativesRankingLoss(model) +train_loss = losses.MatryoshkaLoss(model, inner_train_loss, matryoshka_dims=matryoshka_dims) -# Our training loss -train_loss = losses.MultipleNegativesRankingLoss(model) -train_loss = losses.MatryoshkaLoss(model, train_loss, matryoshka_dims=matryoshka_dims) - -stsb_dev = load_dataset("mteb/stsbenchmark-sts", split="validation") +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") evaluators = [] for dim in matryoshka_dims: evaluators.append( EmbeddingSimilarityEvaluator( - stsb_dev["sentence1"], - stsb_dev["sentence2"], - [score / 5 for score in stsb_dev["score"]], + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], main_similarity=SimilarityFunction.COSINE, name=f"sts-dev-{dim}", truncate_dim=dim, @@ -115,56 +68,68 @@ def add_to_samples(sent1, sent2, label): ) dev_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[0]) -# Configure the training -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="matryoshka-nli", # Will be used in W&B if `wandb` is installed +) -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=dev_evaluator, - epochs=num_epochs, - evaluation_steps=int(len(train_dataloader) * 0.1), - warmup_steps=warmup_steps, - output_path=model_save_path, - use_amp=False, # Set to True, if your GPU supports FP16 operations ) +trainer.train() - -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - - -model = SentenceTransformer(model_save_path) -stsb_test = load_dataset("mteb/stsbenchmark-sts", split="test") +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") evaluators = [] for dim in matryoshka_dims: evaluators.append( EmbeddingSimilarityEvaluator( - stsb_test["sentence1"], - stsb_test["sentence2"], - [score / 5 for score in stsb_test["score"]], + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], main_similarity=SimilarityFunction.COSINE, name=f"sts-test-{dim}", truncate_dim=dim, ) ) test_evaluator = SequentialEvaluator(evaluators) -test_evaluator(model, output_path=model_save_path) +test_evaluator(model) +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) -# Optionally, save the model to the Hugging Face Hub! +# 9. (Optional) save the model to the Hugging Face Hub! # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first model_name = model_name if "/" not in model_name else model_name.split("/")[-1] try: model.push_to_hub(f"{model_name}-nli-matryoshka") except Exception: logging.error( - "Error uploading model to the Hugging Face Hub. To upload it manually, you can run " - f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({model_save_path!r})` " + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " f"and saving it using `model.push_to_hub('{model_name}-nli-matryoshka')`." ) diff --git a/examples/training/matryoshka/matryoshka_nli_reduced_dim.py b/examples/training/matryoshka/matryoshka_nli_reduced_dim.py index 6413ab593..b1e470b78 100644 --- a/examples/training/matryoshka/matryoshka_nli_reduced_dim.py +++ b/examples/training/matryoshka/matryoshka_nli_reduced_dim.py @@ -15,105 +15,63 @@ python matryoshka_nli_reduced_dim.py pretrained_transformer_model_name """ -import math +import traceback from datasets import load_dataset -from sentence_transformers import models, losses, datasets -from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample +from sentence_transformers import losses, models +from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SequentialEvaluator, SimilarityFunction import logging from datetime import datetime import sys -import os -import gzip -import csv -import random - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout + +from sentence_transformers.training_args import BatchSamplers + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base" -train_batch_size = 128 # The larger you select this, the better the results (usually). But it requires more GPU memory -max_seq_length = 75 -num_epochs = 1 +batch_size = 128 # The larger you select this, the better the results (usually). But it requires more GPU memory +num_train_epochs = 1 reduced_dim = 256 matryoshka_dims = [256, 128, 64, 32, 16] # Save path of the model -model_save_path = ( - "output/matryoshka_nli_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +output_dir = ( + f"output/matryoshka_nli_reduced_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" ) +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) +# dense = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(), out_features=reduced_dim) +model.add_module( + "reduced_dim", models.Dense(in_features=model.get_sentence_embedding_dimension(), out_features=reduced_dim) +) +# If we want, we can limit the maximum sequence length for the model +# model.max_seq_length = 75 +logging.info(model) -# Here we define our SentenceTransformer model -word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean") -dense = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(), out_features=reduced_dim) -model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense]) - -# Check if dataset exists. If not, download and extract it -nli_dataset_path = "data/AllNLI.tsv.gz" - -if not os.path.exists(nli_dataset_path): - util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) - -# Read the AllNLI.tsv.gz file and create the training dataset -logging.info("Read AllNLI train dataset") - - -def add_to_samples(sent1, sent2, label): - if sent1 not in train_data: - train_data[sent1] = {"contradiction": set(), "entailment": set(), "neutral": set()} - train_data[sent1][label].add(sent2) - - -train_data = {} -with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "train": - sent1 = row["sentence1"].strip() - sent2 = row["sentence2"].strip() - - add_to_samples(sent1, sent2, row["label"]) - add_to_samples(sent2, sent1, row["label"]) # Also add the opposite - - -train_samples = [] -for sent1, others in train_data.items(): - if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0: - train_samples.append( - InputExample( - texts=[sent1, random.choice(list(others["entailment"])), random.choice(list(others["contradiction"]))] - ) - ) - train_samples.append( - InputExample( - texts=[random.choice(list(others["entailment"])), sent1, random.choice(list(others["contradiction"]))] - ) - ) - -logging.info("Train samples: {}".format(len(train_samples))) - - -# Special data loader that avoid duplicates within a batch -train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=train_batch_size) +# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli +train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train") +eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev") +logging.info(train_dataset) +# If you wish, you can limit the number of training samples +# train_dataset = train_dataset.select(range(5000)) -# Our training loss -train_loss = losses.MultipleNegativesRankingLoss(model) -train_loss = losses.MatryoshkaLoss(model, train_loss, matryoshka_dims=matryoshka_dims) +# 3. Define our training loss +inner_train_loss = losses.MultipleNegativesRankingLoss(model) +train_loss = losses.MatryoshkaLoss(model, inner_train_loss, matryoshka_dims=matryoshka_dims) -stsb_dev = load_dataset("mteb/stsbenchmark-sts", split="validation") +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") evaluators = [] for dim in matryoshka_dims: evaluators.append( EmbeddingSimilarityEvaluator( - stsb_dev["sentence1"], - stsb_dev["sentence2"], - [score / 5 for score in stsb_dev["score"]], + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], main_similarity=SimilarityFunction.COSINE, name=f"sts-dev-{dim}", truncate_dim=dim, @@ -121,56 +79,68 @@ def add_to_samples(sent1, sent2, label): ) dev_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[0]) -# Configure the training -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="matryoshka-nli-reduced", # Will be used in W&B if `wandb` is installed +) -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=dev_evaluator, - epochs=num_epochs, - evaluation_steps=int(len(train_dataloader) * 0.1), - warmup_steps=warmup_steps, - output_path=model_save_path, - use_amp=False, # Set to True, if your GPU supports FP16 operations ) +trainer.train() - -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - - -model = SentenceTransformer(model_save_path) -stsb_test = load_dataset("mteb/stsbenchmark-sts", split="test") +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") evaluators = [] for dim in matryoshka_dims: evaluators.append( EmbeddingSimilarityEvaluator( - stsb_test["sentence1"], - stsb_test["sentence2"], - [score / 5 for score in stsb_test["score"]], + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], main_similarity=SimilarityFunction.COSINE, name=f"sts-test-{dim}", truncate_dim=dim, ) ) test_evaluator = SequentialEvaluator(evaluators) -test_evaluator(model, output_path=model_save_path) +test_evaluator(model) +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) -# Optionally, save the model to the Hugging Face Hub! +# 9. (Optional) save the model to the Hugging Face Hub! # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first model_name = model_name if "/" not in model_name else model_name.split("/")[-1] try: - model.push_to_hub(f"{model_name}-nli-matryoshka-{reduced_dim}") + model.push_to_hub(f"{model_name}-nli-matryoshka-reduced") except Exception: logging.error( - "Error uploading model to the Hugging Face Hub. To upload it manually, you can run " - f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({model_save_path!r})` " - f"and saving it using `model.push_to_hub('{model_name}-nli-matryoshka-{reduced_dim}')`." + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-nli-matryoshka-reduced')`." ) diff --git a/examples/training/matryoshka/matryoshka_sts.py b/examples/training/matryoshka/matryoshka_sts.py index 039bdac13..7cb17d715 100644 --- a/examples/training/matryoshka/matryoshka_sts.py +++ b/examples/training/matryoshka/matryoshka_sts.py @@ -10,118 +10,122 @@ python matryoshka_sts.py pretrained_transformer_model_name """ -from torch.utils.data import DataLoader -import math -from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util -from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample +import traceback +from datasets import load_dataset +from sentence_transformers import losses +from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments +from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SequentialEvaluator, SimilarityFunction import logging from datetime import datetime import sys -import os -import gzip -import csv -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - - -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - - -# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base model_name = sys.argv[1] if len(sys.argv) > 1 else "distilbert-base-uncased" - -# Read the dataset -train_batch_size = 16 -num_epochs = 4 -model_save_path = ( - "output/matryoshka_sts_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -) - -# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings -word_embedding_model = models.Transformer(model_name) - -# Apply mean pooling to get one fixed sized sentence vector -pooling_model = models.Pooling( - word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, +batch_size = 16 +num_train_epochs = 4 +matryoshka_dims = [768, 512, 256, 128, 64] + +# Save path of the model +output_dir = f"output/matryoshka_sts_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) +# If we want, we can limit the maximum sequence length for the model +# model.max_seq_length = 75 +logging.info(model) + +# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) + +# 3. Define our training loss +# CoSENTLoss (https://sbert.net/docs/package_reference/losses.html#cosentloss) needs two text columns and one +# similarity score column (between 0 and 1) +inner_train_loss = losses.CoSENTLoss(model=model) +train_loss = losses.MatryoshkaLoss(model, loss=inner_train_loss, matryoshka_dims=matryoshka_dims) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +evaluators = [] +for dim in matryoshka_dims: + evaluators.append( + EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name=f"sts-dev-{dim}", + truncate_dim=dim, + ) + ) +dev_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[0]) + +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="matryoshka-sts", # Will be used in W&B if `wandb` is installed ) -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -# Convert the dataset to a DataLoader ready for training -logging.info("Read STSbenchmark train dataset") - -train_samples = [] -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) - - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) - - -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) -train_loss = losses.CoSENTLoss(model=model) -train_loss = losses.MatryoshkaLoss(model, train_loss, [768, 512, 256, 128, 64]) - - -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - - -# Configure the training. We skip evaluation in this example -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, ) +trainer.train() + + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +evaluators = [] +for dim in matryoshka_dims: + evaluators.append( + EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name=f"sts-test-{dim}", + truncate_dim=dim, + ) + ) +test_evaluator = SequentialEvaluator(evaluators) +test_evaluator(model) +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -test_evaluator(model, output_path=model_save_path) - -# Optionally, save the model to the Hugging Face Hub! +# 9. (Optional) save the model to the Hugging Face Hub! # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first model_name = model_name if "/" not in model_name else model_name.split("/")[-1] try: model.push_to_hub(f"{model_name}-sts-matryoshka") except Exception: logging.error( - "Error uploading model to the Hugging Face Hub. To upload it manually, you can run " - f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({model_save_path!r})` " + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " f"and saving it using `model.push_to_hub('{model_name}-sts-matryoshka')`." ) diff --git a/examples/training/multilingual/make_multilingual.py b/examples/training/multilingual/make_multilingual.py index fafe454dd..f31377246 100644 --- a/examples/training/multilingual/make_multilingual.py +++ b/examples/training/multilingual/make_multilingual.py @@ -17,20 +17,22 @@ https://arxiv.org/abs/2004.09813 """ -from sentence_transformers import SentenceTransformer, LoggingHandler, models, evaluation, losses -from torch.utils.data import DataLoader -from sentence_transformers.datasets import ParallelSentencesDataset +import traceback +from sentence_transformers import SentenceTransformer, LoggingHandler from datetime import datetime +from datasets import load_dataset, DatasetDict -import os import logging -import sentence_transformers.util -import csv -import gzip -from tqdm.autonotebook import tqdm +from sentence_transformers.evaluation import ( + EmbeddingSimilarityEvaluator, + MSEEvaluator, + SequentialEvaluator, + TranslationEvaluator, +) +from sentence_transformers.losses import MSELoss +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments import numpy as np -import zipfile -import io logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] @@ -38,220 +40,204 @@ logger = logging.getLogger(__name__) -teacher_model_name = ( - "paraphrase-distilroberta-base-v2" # Our monolingual teacher model, we want to convert to multiple languages -) -student_model_name = "xlm-roberta-base" # Multilingual base model we use to imitate the teacher model +# The teacher model is monolingual, we use it for English embeddings +teacher_model_name = "paraphrase-distilroberta-base-v2" +# The student model is multilingual, we train it such that embeddings of non-English texts mimic the teacher model's English embeddings +student_model_name = "xlm-roberta-base" -max_seq_length = 128 # Student model max. lengths for inputs (number of word pieces) +student_max_seq_length = 128 # Student model max. lengths for inputs (number of word pieces) train_batch_size = 64 # Batch size for training inference_batch_size = 64 # Batch size at inference max_sentences_per_language = 500000 # Maximum number of parallel sentences for training -train_max_sentence_length = 250 # Maximum length (characters) for parallel training sentences -num_epochs = 5 # Train for x epochs -num_warmup_steps = 10000 # Warumup steps - -num_evaluation_steps = 1000 # Evaluate performance after every xxxx steps -dev_sentences = 1000 # Number of parallel sentences to be used for development +num_train_epochs = 5 # Train for x epochs +num_evaluation_steps = 5000 # Evaluate performance after every xxxx steps # Define the language codes you would like to extend the model to source_languages = set(["en"]) # Our teacher model accepts English (en) sentences -target_languages = set( - ["de", "es", "it", "fr", "ar", "tr"] -) # We want to extend the model to these new languages. For language codes, see the header of the train file +# We want to extend the model to these new languages. For language codes, see the header of the train file +target_languages = set(["de", "es", "it", "fr", "ar", "tr"]) -output_path = ( +output_dir = ( "output/make-multilingual-" + "-".join(sorted(list(source_languages)) + sorted(list(target_languages))) + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) - -# This function downloads a corpus if it does not exist -def download_corpora(filepaths): - if not isinstance(filepaths, list): - filepaths = [filepaths] - - for filepath in filepaths: - if not os.path.exists(filepath): - print(filepath, "does not exists. Try to download from server") - filename = os.path.basename(filepath) - url = "https://sbert.net/datasets/" + filename - sentence_transformers.util.http_get(url, filepath) - - -# Here we define train train and dev corpora -train_corpus = "datasets/parallel-sentences.tsv.gz" -sts_corpus = "datasets/stsbenchmark.zip" -parallel_sentences_folder = "parallel-sentences/" - -# Check if the file exists. If not, they are downloaded -download_corpora([train_corpus, sts_corpus]) - - -# Create parallel files for the selected language combinations -os.makedirs(parallel_sentences_folder, exist_ok=True) -train_files = [] -dev_files = [] -files_to_create = [] +# 1a. Here we define our SentenceTransformer teacher model. +teacher_model = SentenceTransformer(teacher_model_name) +# If we want, we can limit the maximum sequence length for the model +# teacher_model.max_seq_length = 128 +logging.info(f"Teacher model: {teacher_model}") + +# 1b. Here we define our SentenceTransformer student model. If not already a Sentence Transformer model, +# it will automatically create one with "mean" pooling. +student_model = SentenceTransformer(student_model_name) +# If we want, we can limit the maximum sequence length for the model +student_model.max_seq_length = student_max_seq_length +logging.info(f"Student model: {student_model}") + +# 2. Load the parallel sentences training dataset: https://huggingface.co/datasets?other=sentence-transformers&sort=trending&search=parallel-sentences +# NOTE: We can also use multiple datasets if we want +dataset_to_use = "sentence-transformers/parallel-sentences-talks" +# dataset_to_use = "sentence-transformers/parallel-sentences-europarl" +# dataset_to_use = "sentence-transformers/parallel-sentences-global-voices" +# dataset_to_use = "sentence-transformers/parallel-sentences-muse" +# dataset_to_use = "sentence-transformers/parallel-sentences-jw300" +# dataset_to_use = "sentence-transformers/parallel-sentences-news-commentary" +# dataset_to_use = "sentence-transformers/parallel-sentences-opensubtitles" +# dataset_to_use = "sentence-transformers/parallel-sentences-tatoeba" +# dataset_to_use = "sentence-transformers/parallel-sentences-wikimatrix" +# dataset_to_use = "sentence-transformers/parallel-sentences-wikititles" +train_dataset_dict = DatasetDict() +eval_dataset_dict = DatasetDict() for source_lang in source_languages: for target_lang in target_languages: - output_filename_train = os.path.join( - parallel_sentences_folder, "talks-{}-{}-train.tsv.gz".format(source_lang, target_lang) - ) - output_filename_dev = os.path.join( - parallel_sentences_folder, "talks-{}-{}-dev.tsv.gz".format(source_lang, target_lang) - ) - train_files.append(output_filename_train) - dev_files.append(output_filename_dev) - if not os.path.exists(output_filename_train) or not os.path.exists(output_filename_dev): - files_to_create.append( - { - "src_lang": source_lang, - "trg_lang": target_lang, - "fTrain": gzip.open(output_filename_train, "wt", encoding="utf8"), - "fDev": gzip.open(output_filename_dev, "wt", encoding="utf8"), - "devCount": 0, - } + subset = f"{source_lang}-{target_lang}" + try: + train_dataset = load_dataset(dataset_to_use, subset, split="train") + if len(train_dataset) > max_sentences_per_language: + train_dataset = train_dataset.select(range(max_sentences_per_language)) + except Exception as exc: + logging.error(f"Could not load dataset {dataset_to_use}/{source_lang}-{target_lang}: {exc}") + continue + + try: + eval_dataset = load_dataset(dataset_to_use, subset, split="dev") + if len(eval_dataset) > 1000: + eval_dataset = eval_dataset.select(range(1000)) + except Exception: + logging.info( + f"Could not load dataset {dataset_to_use}/{source_lang}-{target_lang} dev split, splitting 1k samples from train" ) + dataset = train_dataset.train_test_split(test_size=1000, shuffle=True) + train_dataset = dataset["train"] + eval_dataset = dataset["test"] -if len(files_to_create) > 0: - print( - "Parallel sentences files {} do not exist. Create these files now".format( - ", ".join(map(lambda x: x["src_lang"] + "-" + x["trg_lang"], files_to_create)) - ) - ) - with gzip.open(train_corpus, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for line in tqdm(reader, desc="Sentences"): - for outfile in files_to_create: - src_text = line[outfile["src_lang"]].strip() - trg_text = line[outfile["trg_lang"]].strip() + train_dataset_dict[subset] = train_dataset + eval_dataset_dict[subset] = eval_dataset +logging.info(train_dataset_dict) - if src_text != "" and trg_text != "": - if outfile["devCount"] < dev_sentences: - outfile["devCount"] += 1 - fOut = outfile["fDev"] - else: - fOut = outfile["fTrain"] - fOut.write("{}\t{}\n".format(src_text, trg_text)) +# We want the teacher embeddings of the *source* sentences to be very similar to the student embeddings +# of the *target* sentences. +def prepare_dataset(batch): + return { + "non_english": batch["non_english"], + "label": teacher_model.encode(batch["english"], batch_size=inference_batch_size, show_progress_bar=False), + } - for outfile in files_to_create: - outfile["fTrain"].close() - outfile["fDev"].close() - -######## Start the extension of the teacher model to multiple languages ######## -logger.info("Load teacher model") -teacher_model = SentenceTransformer(teacher_model_name) - - -logger.info("Create student model from scratch") -word_embedding_model = models.Transformer(student_model_name, max_seq_length=max_seq_length) -# Apply mean pooling to get one fixed sized sentence vector -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) -student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - - -###### Read Parallel Sentences Dataset ###### -train_data = ParallelSentencesDataset( - student_model=student_model, teacher_model=teacher_model, batch_size=inference_batch_size, use_embedding_cache=True +column_names = list(train_dataset_dict.values())[0].column_names +train_dataset_dict = train_dataset_dict.map( + prepare_dataset, batched=True, batch_size=30000, remove_columns=column_names ) -for train_file in train_files: - train_data.load_data( - train_file, max_sentences=max_sentences_per_language, max_sentence_length=train_max_sentence_length - ) +logging.info("Prepared datasets for training:", train_dataset_dict) -train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size) -train_loss = losses.MSELoss(model=student_model) +# 3. Define our training loss +# MSELoss (https://sbert.net/docs/package_reference/losses.html#mseloss) needs one text columns and one +# column with embeddings from the teacher model +train_loss = MSELoss(model=student_model) +# 4. Define evaluators for use during training. This is useful to keep track of alongside the evaluation loss. +evaluators = [] -#### Evaluate cross-lingual performance on different tasks ##### -evaluators = [] # evaluators has a list of different evaluator classes we call periodically - -for dev_file in dev_files: - logger.info("Create evaluator for " + dev_file) - src_sentences = [] - trg_sentences = [] - with gzip.open(dev_file, "rt", encoding="utf8") as fIn: - for line in fIn: - splits = line.strip().split("\t") - if splits[0] != "" and splits[1] != "": - src_sentences.append(splits[0]) - trg_sentences.append(splits[1]) +for subset, eval_dataset in eval_dataset_dict.items(): + logger.info(f"Creating evaluators for {subset}") # Mean Squared Error (MSE) measures the (euclidean) distance between teacher and student embeddings - dev_mse = evaluation.MSEEvaluator( - src_sentences, - trg_sentences, - name=os.path.basename(dev_file).split(".")[0], + dev_mse = MSEEvaluator( + eval_dataset["english"], + eval_dataset["non_english"], + name=subset, teacher_model=teacher_model, batch_size=inference_batch_size, ) evaluators.append(dev_mse) - # TranslationEvaluator computes the embeddings for all parallel sentences. It then check if the embedding of source[i] is the closest to target[i] out of all available target sentences - dev_trans_acc = evaluation.TranslationEvaluator( - src_sentences, trg_sentences, name=os.path.basename(dev_file).split(".")[0], batch_size=inference_batch_size + # TranslationEvaluator computes the embeddings for all parallel sentences. It then check if the embedding of + # source[i] is the closest to target[i] out of all available target sentences + dev_trans_acc = TranslationEvaluator( + eval_dataset["english"], + eval_dataset["non_english"], + name=subset, + batch_size=inference_batch_size, ) evaluators.append(dev_trans_acc) + # Try to load this subset from STS17 + test_dataset = None + try: + test_dataset = load_dataset("mteb/sts17-crosslingual-sts", subset, split="test") + except Exception: + try: + test_dataset = load_dataset("mteb/sts17-crosslingual-sts", f"{subset[3:]}-{subset[:2]}", split="test") + subset = f"{subset[3:]}-{subset[:2]}" + except Exception: + pass + if test_dataset: + test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=[score / 5.0 for score in test_dataset["score"]], # Convert 0-5 scores to 0-1 scores + batch_size=inference_batch_size, + name=f"sts17-{subset}-test", + show_progress_bar=False, + ) + evaluators.append(test_evaluator) + +evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: np.mean(scores)) +# Now also prepare the evaluation datasets for training +eval_dataset_dict = eval_dataset_dict.map(prepare_dataset, batched=True, batch_size=30000, remove_columns=column_names) + +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + learning_rate=2e-5, + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=num_evaluation_steps, + save_strategy="steps", + save_steps=num_evaluation_steps, + save_total_limit=2, + logging_steps=100, + run_name=f"multilingual-{'-'.join(source_languages)}-{'-'.join(target_languages)}", # Will be used in W&B if `wandb` is installed +) -##### Read cross-lingual Semantic Textual Similarity (STS) data #### -all_languages = list(set(list(source_languages) + list(target_languages))) -sts_data = {} - -# Open the ZIP File of STS2017-extended.zip and check for which language combinations we have STS data -with zipfile.ZipFile(sts_corpus) as zip: - filelist = zip.namelist() - sts_files = [] - - for i in range(len(all_languages)): - for j in range(i, len(all_languages)): - lang1 = all_languages[i] - lang2 = all_languages[j] - filepath = "STS2017-extended/STS.{}-{}.txt".format(lang1, lang2) - if filepath not in filelist: - lang1, lang2 = lang2, lang1 - filepath = "STS2017-extended/STS.{}-{}.txt".format(lang1, lang2) - - if filepath in filelist: - filename = os.path.basename(filepath) - sts_data[filename] = {"sentences1": [], "sentences2": [], "scores": []} - - fIn = zip.open(filepath) - for line in io.TextIOWrapper(fIn, "utf8"): - sent1, sent2, score = line.strip().split("\t") - score = float(score) - sts_data[filename]["sentences1"].append(sent1) - sts_data[filename]["sentences2"].append(sent2) - sts_data[filename]["scores"].append(score) - -for filename, data in sts_data.items(): - test_evaluator = evaluation.EmbeddingSimilarityEvaluator( - data["sentences1"], - data["sentences2"], - data["scores"], - batch_size=inference_batch_size, - name=filename.split(".")[0], - show_progress_bar=False, - ) - evaluators.append(test_evaluator) - - -# Train the model -student_model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: np.mean(scores)), - epochs=num_epochs, - warmup_steps=num_warmup_steps, - evaluation_steps=num_evaluation_steps, - output_path=output_path, - save_best_model=True, - optimizer_params={"lr": 2e-5, "eps": 1e-6}, +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=student_model, + args=args, + train_dataset=train_dataset_dict, + eval_dataset=eval_dataset_dict, + loss=train_loss, + evaluator=evaluator, ) +trainer.train() + +# 7. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +student_model.save(final_output_dir) + +# 8. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = student_model_name if "/" not in student_model_name else student_model_name.split("/")[-1] +try: + student_model.push_to_hub(f"{model_name}-multilingual-{'-'.join(source_languages)}-{'-'.join(target_languages)}") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-multilingual-{'-'.join(source_languages)}-{'-'.join(target_languages)}')`." + ) diff --git a/examples/training/nli/training_nli.py b/examples/training/nli/training_nli.py index 7a657f4d8..2dd26c112 100644 --- a/examples/training/nli/training_nli.py +++ b/examples/training/nli/training_nli.py @@ -10,128 +10,114 @@ python training_nli.py pretrained_transformer_model_name """ -from torch.utils.data import DataLoader -import math -from sentence_transformers import models, losses -from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample +import traceback +from datasets import load_dataset +from sentence_transformers import losses +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator import logging from datetime import datetime import sys -import os -import gzip -import csv -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - -# Check if dataset exists. If not, download and extract it -nli_dataset_path = "data/AllNLI.tsv.gz" -sts_dataset_path = "data/stsbenchmark.tsv.gz" - -if not os.path.exists(nli_dataset_path): - util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) - -# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base +# You can specify any Hugging Face pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-uncased" - -# Read the dataset train_batch_size = 16 +output_dir = "output/training_nli_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -model_save_path = ( - "output/training_nli_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -) - - -# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings -word_embedding_model = models.Transformer(model_name) - -# Apply mean pooling to get one fixed sized sentence vector -pooling_model = models.Pooling( - word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, -) - -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) -# Read the AllNLI.tsv.gz file and create the training dataset +# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli +# We'll start with 10k training samples, but you can increase this to get a stronger model logging.info("Read AllNLI train dataset") +train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train").select(range(10000)) +eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev").select(range(1000)) +logging.info(train_dataset) -label2int = {"contradiction": 0, "entailment": 1, "neutral": 2} -train_samples = [] -with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "train": - label_id = label2int[row["label"]] - train_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=label_id)) - - -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) +# 3. Define our training loss: https://sbert.net/docs/package_reference/losses.html#softmaxloss train_loss = losses.SoftmaxLoss( - model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=len(label2int) + model=model, + sentence_embedding_dimension=model.get_sentence_embedding_dimension(), + num_labels=3, ) - -# Read STSbenchmark dataset and use it as development set -logging.info("Read STSbenchmark dev dataset") -dev_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "dev": - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - dev_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - -dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples( - dev_samples, batch_size=train_batch_size, name="sts-dev" +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", +) +logging.info("Evaluation before training:") +dev_evaluator(model) + +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=1, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="nli-v1", # Will be used in W&B if `wandb` is installed ) -# Configure the training -num_epochs = 1 - -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=dev_evaluator, - epochs=num_epochs, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=model_save_path, ) - - -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "test": - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - test_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples( - test_samples, batch_size=train_batch_size, name="sts-test" +trainer.train() + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", ) -test_evaluator(model, output_path=model_save_path) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-nli-v1") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-nli-v1')`." + ) diff --git a/examples/training/nli/training_nli_v2.py b/examples/training/nli/training_nli_v2.py index e37b0ee6e..0567e0f8b 100644 --- a/examples/training/nli/training_nli_v2.py +++ b/examples/training/nli/training_nli_v2.py @@ -10,23 +10,21 @@ python training_nli_v2.py pretrained_transformer_model_name """ -import math -from sentence_transformers import models, losses, datasets -from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample +import traceback +from datasets import load_dataset +from sentence_transformers import losses +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator import logging from datetime import datetime import sys -import os -import gzip -import csv -import random - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout + +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base" train_batch_size = 128 # The larger you select this, the better the results (usually). But it requires more GPU memory @@ -34,121 +32,94 @@ num_epochs = 1 # Save path of the model -model_save_path = ( +output_dir = ( "output/training_nli_v2_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) -# Here we define our SentenceTransformer model -word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean") -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -# Check if dataset exists. If not, download and extract it -nli_dataset_path = "data/AllNLI.tsv.gz" -sts_dataset_path = "data/stsbenchmark.tsv.gz" - -if not os.path.exists(nli_dataset_path): - util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) -# Read the AllNLI.tsv.gz file and create the training dataset +# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli +# We'll start with 10k training samples, but you can increase this to get a stronger model logging.info("Read AllNLI train dataset") +train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train").select(range(10000)) +eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev").select(range(1000)) +logging.info(train_dataset) - -def add_to_samples(sent1, sent2, label): - if sent1 not in train_data: - train_data[sent1] = {"contradiction": set(), "entailment": set(), "neutral": set()} - train_data[sent1][label].add(sent2) - - -train_data = {} -with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "train": - sent1 = row["sentence1"].strip() - sent2 = row["sentence2"].strip() - - add_to_samples(sent1, sent2, row["label"]) - add_to_samples(sent2, sent1, row["label"]) # Also add the opposite - - -train_samples = [] -for sent1, others in train_data.items(): - if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0: - train_samples.append( - InputExample( - texts=[sent1, random.choice(list(others["entailment"])), random.choice(list(others["contradiction"]))] - ) - ) - train_samples.append( - InputExample( - texts=[random.choice(list(others["entailment"])), sent1, random.choice(list(others["contradiction"]))] - ) - ) - -logging.info("Train samples: {}".format(len(train_samples))) - - -# Special data loader that avoid duplicates within a batch -train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=train_batch_size) - - -# Our training loss +# 3. Define our training loss: https://sbert.net/docs/package_reference/losses.html#multiplenegativesrankingloss train_loss = losses.MultipleNegativesRankingLoss(model) -# Read STSbenchmark dataset and use it as development set -logging.info("Read STSbenchmark dev dataset") -dev_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "dev": - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - dev_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - -dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples( - dev_samples, batch_size=train_batch_size, name="sts-dev" +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", +) +logging.info("Evaluation before training:") +dev_evaluator(model) + +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=1, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=10, + save_strategy="steps", + save_steps=10, + save_total_limit=2, + logging_steps=100, + run_name="nli-v2", # Will be used in W&B if `wandb` is installed ) -# Configure the training -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=dev_evaluator, - epochs=num_epochs, - evaluation_steps=int(len(train_dataloader) * 0.1), - warmup_steps=warmup_steps, - output_path=model_save_path, - use_amp=False, # Set to True, if your GPU supports FP16 operations ) - - -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "test": - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - test_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples( - test_samples, batch_size=train_batch_size, name="sts-test" +trainer.train() + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", ) -test_evaluator(model, output_path=model_save_path) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-nli-v2") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-nli-v2')`." + ) diff --git a/examples/training/nli/training_nli_v3.py b/examples/training/nli/training_nli_v3.py index 312821730..1844a2588 100644 --- a/examples/training/nli/training_nli_v3.py +++ b/examples/training/nli/training_nli_v3.py @@ -10,23 +10,21 @@ python training_nli_v3.py pretrained_transformer_model_name """ -import math -from sentence_transformers import models, losses, datasets -from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample +import traceback +from datasets import load_dataset +from sentence_transformers import losses +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator import logging from datetime import datetime import sys -import os -import gzip -import csv -import random - -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout + +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base" train_batch_size = 128 # The larger you select this, the better the results (usually). But it requires more GPU memory @@ -34,136 +32,95 @@ num_epochs = 1 # Save path of the model -model_save_path = ( +output_dir = ( "output/training_nli_v3_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) -# Here we define our SentenceTransformer model -word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean") -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -# Check if dataset exists. If not, download and extract it -nli_dataset_path = "data/AllNLI.tsv.gz" -sts_dataset_path = "data/stsbenchmark.tsv.gz" +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) -if not os.path.exists(nli_dataset_path): - util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - - -# Read the AllNLI.tsv.gz file and create the training dataset +# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli +# We'll start with 10k training samples, but you can increase this to get a stronger model logging.info("Read AllNLI train dataset") +train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train").select(range(10000)) +eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev").select(range(1000)) +logging.info(train_dataset) - -def add_to_samples(sent1, sent2, label): - if sent1 not in train_data: - train_data[sent1] = {"contradiction": set(), "entailment": set(), "neutral": set()} - train_data[sent1][label].add(sent2) - - -train_data = {} -with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "train": - sent1 = row["sentence1"].strip() - sent2 = row["sentence2"].strip() - - add_to_samples(sent1, sent2, row["label"]) - add_to_samples(sent2, sent1, row["label"]) # Also add the opposite - - -train_samples = [] -for sent1, others in train_data.items(): - if len(others["entailment"]) > 0 and len(others["contradiction"]) > 0: - train_samples.append( - InputExample( - texts=[sent1, random.choice(list(others["entailment"])), random.choice(list(others["contradiction"]))] - ) - ) - train_samples.append( - InputExample( - texts=[random.choice(list(others["entailment"])), sent1, random.choice(list(others["contradiction"]))] - ) - ) - -logging.info("Train samples: {}".format(len(train_samples))) - - -# Special data loader that avoid duplicates within a batch -train_dataloader = datasets.NoDuplicatesDataLoader(train_samples, batch_size=train_batch_size) - - +# 3. Define our training loss: https://sbert.net/docs/package_reference/losses.html#gistembedloss # The guiding model guide_model = SentenceTransformer("all-MiniLM-L6-v2") - -# Our training loss train_loss = losses.GISTEmbedLoss(model, guide_model) - -# Read STSbenchmark dataset and use it as development set -logging.info("Read STSbenchmark dev dataset") -dev_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "dev": - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - dev_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - -dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples( - dev_samples, batch_size=train_batch_size, name="sts-dev" +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", +) +logging.info("Evaluation before training:") +dev_evaluator(model) + +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=1, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=10, + save_strategy="steps", + save_steps=10, + save_total_limit=2, + logging_steps=100, + run_name="nli-v3", # Will be used in W&B if `wandb` is installed ) -# Configure the training -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=dev_evaluator, - epochs=num_epochs, - evaluation_steps=int(len(train_dataloader) * 0.1), - warmup_steps=warmup_steps, - output_path=model_save_path, - use_amp=False, # Set to True, if your GPU supports FP16 operations ) - - -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## - -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "test": - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - test_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples( - test_samples, batch_size=train_batch_size, name="sts-test" +trainer.train() + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", ) -test_evaluator(model, output_path=model_save_path) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) -# Optionally, save the model to the Hugging Face Hub! +# 9. (Optional) save the model to the Hugging Face Hub! # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first model_name = model_name if "/" not in model_name else model_name.split("/")[-1] try: - model.push_to_hub(f"{model_name}-nli-gist") + model.push_to_hub(f"{model_name}-nli-v3") except Exception: logging.error( - "Error uploading model to the Hugging Face Hub. To upload it manually, you can run " - f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({model_save_path!r})` " - f"and saving it using `model.push_to_hub('{model_name}-nli-gist')`." + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-nli-v3')`." ) diff --git a/examples/training/other/training_multi-task.py b/examples/training/other/training_multi-task.py index d1a0d625a..b9de63c0a 100644 --- a/examples/training/other/training_multi-task.py +++ b/examples/training/other/training_multi-task.py @@ -4,126 +4,129 @@ The system trains BERT on the AllNLI and on the STSbenchmark dataset. """ -from torch.utils.data import DataLoader -import math -from sentence_transformers import models, losses -from sentence_transformers import LoggingHandler, SentenceTransformer, util +import traceback +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample +from sentence_transformers.losses import CosineSimilarityLoss, SoftmaxLoss +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import MultiDatasetBatchSamplers, SentenceTransformerTrainingArguments import logging from datetime import datetime -import gzip -import csv -import os +from datasets import load_dataset -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # Read the dataset model_name = "bert-base-uncased" +num_train_epochs = 1 batch_size = 16 -model_save_path = "output/training_multi-task_" + model_name + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - - -# Check if dataset exists. If not, download and extract it -nli_dataset_path = "datasets/AllNLI.tsv.gz" -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" - -if not os.path.exists(nli_dataset_path): - util.http_get("https://sbert.net/datasets/AllNLI.tsv.gz", nli_dataset_path) - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - - -# Use BERT for mapping tokens to embeddings -word_embedding_model = models.Transformer(model_name) - -# Apply mean pooling to get one fixed sized sentence vector -pooling_model = models.Pooling( - word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, +output_dir = "output/training_multi-task_" + model_name + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) +# If we want, we can limit the maximum sequence length for the model +# model.max_seq_length = 75 +logging.info(model) + +# 2a. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli +nli_train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train") +nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev").select(range(1000)) +logging.info(nli_train_dataset) + +# 2b. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +stsb_train_dataset = load_dataset("sentence-transformers/stsb", split="train") +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +stsb_test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(stsb_train_dataset) + +# 3. Define our training losses +# 3a. SoftmaxLoss for the NLI data (sentence_A, sentence_B, class), see also https://sbert.net/docs/training/loss_overview.html +train_loss_nli = SoftmaxLoss( + model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=3 ) - -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - - -# Convert the dataset to a DataLoader ready for training -logging.info("Read AllNLI train dataset") -label2int = {"contradiction": 0, "entailment": 1, "neutral": 2} -train_nli_samples = [] -with gzip.open(nli_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "train": - label_id = label2int[row["label"]] - train_nli_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=label_id)) - - -train_dataloader_nli = DataLoader(train_nli_samples, shuffle=True, batch_size=batch_size) -train_loss_nli = losses.SoftmaxLoss( - model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=len(label2int) +# 3b. CosineSimilarityLoss for the STSB data (sentence_A, sentence_B, similarity score between 0 and 1) +train_loss_sts = CosineSimilarityLoss(model=model) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", ) -logging.info("Read STSbenchmark train dataset") -train_sts_samples = [] -dev_sts_samples = [] -test_sts_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) - - if row["split"] == "dev": - dev_sts_samples.append(inp_example) - elif row["split"] == "test": - test_sts_samples.append(inp_example) - else: - train_sts_samples.append(inp_example) - - -train_dataloader_sts = DataLoader(train_sts_samples, shuffle=True, batch_size=batch_size) -train_loss_sts = losses.CosineSimilarityLoss(model=model) - - -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_sts_samples, name="sts-dev") - -# Configure the training -num_epochs = 4 - -warmup_steps = math.ceil(len(train_dataloader_sts) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - - -# Here we define the two train objectives: train_dataloader_nli with train_loss_nli (i.e., SoftmaxLoss for NLI data) -# and train_dataloader_sts with train_loss_sts (i.e., CosineSimilarityLoss for STSbenchmark data) -# You can pass as many (dataloader, loss) tuples as you like. They are iterated in a round-robin way. -train_objectives = [(train_dataloader_nli, train_loss_nli), (train_dataloader_sts, train_loss_sts)] - -# Train the model -model.fit( - train_objectives=train_objectives, - evaluator=evaluator, - epochs=num_epochs, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # With ROUND_ROBIN you'll sample the same amount from each dataset, until one of the multi-datasets is exhausted + # The alternative is PROPORTIONAL, which samples from each dataset in proportion to the dataset size, + # but that will lead to a lot of samples from the larger dataset (AllNLI in this case) + multi_dataset_batch_sampler=MultiDatasetBatchSamplers.ROUND_ROBIN, + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="multi-task", # Will be used in W&B if `wandb` is installed ) +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset={ + "all-nli": nli_train_dataset, + "sts": stsb_train_dataset, + }, + eval_dataset={ + "all-nli": nli_eval_dataset, + "sts": stsb_eval_dataset, + }, + loss={ + "all-nli": train_loss_nli, + "sts": train_loss_sts, + }, + evaluator=dev_evaluator, +) +trainer.train() -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_sts_samples, name="sts-test") -test_evaluator(model, output_path=model_save_path) +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=stsb_test_dataset["sentence1"], + sentences2=stsb_test_dataset["sentence2"], + scores=stsb_test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-multi-task") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-multi-task')`." + ) diff --git a/examples/training/other/training_wikipedia_sections.py b/examples/training/other/training_wikipedia_sections.py index 33ca4f549..8f9b36dfa 100644 --- a/examples/training/other/training_wikipedia_sections.py +++ b/examples/training/other/training_wikipedia_sections.py @@ -4,106 +4,109 @@ As corpus, we use the wikipedia sections dataset that was describd by Dor et al., 2018, Learning Thematic Similarity Metric Using Triplet Networks. """ -from sentence_transformers import SentenceTransformer, InputExample, LoggingHandler, losses, models, util -from torch.utils.data import DataLoader +import traceback +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import TripletEvaluator +from sentence_transformers.losses import TripletLoss +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments from datetime import datetime -from zipfile import ZipFile - -import csv +from datasets import load_dataset import logging -import os - -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -logger = logging.getLogger(__name__) +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base model_name = "distilbert-base-uncased" +batch_size = 16 +num_train_epochs = 1 -dataset_path = "datasets/wikipedia-sections" -if not os.path.exists(dataset_path): - os.makedirs(dataset_path, exist_ok=True) - filepath = os.path.join(dataset_path, "wikipedia-sections-triplets.zip") - util.http_get("https://sbert.net/datasets/wikipedia-sections-triplets.zip", filepath) - with ZipFile(filepath, "r") as zip: - zip.extractall(dataset_path) - - -### Create a torch.DataLoader that passes training batch instances to our model -train_batch_size = 16 -output_path = "output/training-wikipedia-sections-" + model_name + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -num_epochs = 1 - +output_dir = "output/training-wikipedia-sections-" + model_name + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -### Configure sentence transformers for training and train on the provided dataset -# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings -word_embedding_model = models.Transformer(model_name) +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) +# If we want, we can limit the maximum sequence length for the model +# model.max_seq_length = 75 +logging.info(model) -# Apply mean pooling to get one fixed sized sentence vector -pooling_model = models.Pooling( - word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, +# 2. Load the Wikipedia-Sections dataset: https://huggingface.co/datasets/sentence-transformers/wikipedia-sections +train_dataset = load_dataset("sentence-transformers/wikipedia-sections", "triplet", split="train").select( + range(10_000) ) - -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - - -logger.info("Read Triplet train dataset") -train_examples = [] -with open(os.path.join(dataset_path, "train.csv"), encoding="utf-8") as fIn: - reader = csv.DictReader(fIn, delimiter=",", quoting=csv.QUOTE_MINIMAL) - for row in reader: - train_examples.append(InputExample(texts=[row["Sentence1"], row["Sentence2"], row["Sentence3"]], label=0)) - - -train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size) -train_loss = losses.TripletLoss(model=model) - -logger.info("Read Wikipedia Triplet dev dataset") -dev_examples = [] -with open(os.path.join(dataset_path, "validation.csv"), encoding="utf-8") as fIn: - reader = csv.DictReader(fIn, delimiter=",", quoting=csv.QUOTE_MINIMAL) - for row in reader: - dev_examples.append(InputExample(texts=[row["Sentence1"], row["Sentence2"], row["Sentence3"]])) - - if len(dev_examples) >= 1000: - break - -evaluator = TripletEvaluator.from_input_examples(dev_examples, name="dev") - - -warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) # 10% of train data - - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=output_path, +eval_dataset = load_dataset("sentence-transformers/wikipedia-sections", "triplet", split="validation").select( + range(1000) +) +test_dataset = load_dataset("sentence-transformers/wikipedia-sections", "triplet", split="test").select(range(1000)) +logging.info(train_dataset) + +# 3. Define our training loss +# TripletLoss (https://sbert.net/docs/package_reference/losses.html#tripletloss) needs three text columns +train_loss = TripletLoss(model) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = TripletEvaluator( + anchors=eval_dataset[:1000]["anchor"], + positives=eval_dataset[:1000]["positive"], + negatives=eval_dataset[:1000]["negative"], + name="wikipedia-sections-dev", ) -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="wikipedia-sections-triplet", # Will be used in W&B if `wandb` is installed +) -logger.info("Read test examples") -test_examples = [] -with open(os.path.join(dataset_path, "test.csv"), encoding="utf-8") as fIn: - reader = csv.DictReader(fIn, delimiter=",", quoting=csv.QUOTE_MINIMAL) - for row in reader: - test_examples.append(InputExample(texts=[row["Sentence1"], row["Sentence2"], row["Sentence3"]])) +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() -model = SentenceTransformer(output_path) -test_evaluator = TripletEvaluator.from_input_examples(test_examples, name="test") -test_evaluator(model, output_path=output_path) +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = TripletEvaluator( + anchors=test_dataset["anchor"], + positives=test_dataset["positive"], + negatives=test_dataset["negative"], + name="wikipedia-sections-test", +) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-wikipedia-sections-triplet") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-wikipedia-sections-triplet')`." + ) diff --git a/examples/training/paraphrases/training.py b/examples/training/paraphrases/training.py index 9f58e9603..6ad6b5877 100644 --- a/examples/training/paraphrases/training.py +++ b/examples/training/paraphrases/training.py @@ -1,101 +1,145 @@ -from sentence_transformers import models, losses -from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample +""" +Note: This script was modified with the v3 release of Sentence Transformers. +As a result, it does not produce exactly the same behaviour as the original script. +""" + +import traceback +from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator +from sentence_transformers.losses import MultipleNegativesRankingLoss +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import ( + BatchSamplers, + MultiDatasetBatchSamplers, + SentenceTransformerTrainingArguments, +) import logging from datetime import datetime -import sys -import os -import gzip -import csv -from MultiDatasetDataLoader import MultiDatasetDataLoader +from datasets import load_dataset -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout + +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) model_name = "distilroberta-base" num_epochs = 1 -sts_dataset_path = "data-eval/stsbenchmark.tsv.gz" -batch_size_pairs = 384 -batch_size_triplets = 256 +batch_size = 128 max_seq_length = 128 -use_amp = True # Set to False, if you use a CPU or your GPU does not support FP16 operations -evaluation_steps = 500 -warmup_steps = 500 - -##### - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - # Save path of the model -model_save_path = ( +output_dir = ( "output/training_paraphrases_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) - -## SentenceTransformer model -word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) -pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -dataset_list = [] -for filepath in sys.argv[1:]: - dataset = [] - with_guid = "with-guid" in filepath # Some datasets have a guid in the first column - - with gzip.open(filepath, "rt", encoding="utf8") as fIn: - for line in fIn: - splits = line.strip().split("\t") - if with_guid: - guid = splits[0] - texts = splits[1:] - else: - guid = None - texts = splits - - dataset.append(InputExample(texts=texts, guid=guid)) - - dataset_list.append(dataset) - - -train_dataloader = MultiDatasetDataLoader( - dataset_list, batch_size_pairs=batch_size_pairs, batch_size_triplets=batch_size_triplets +# 2. Load some training dataset from: https://huggingface.co/datasets?other=sentence-transformers +# Notably, we are looking for datasets compatible with MultipleNegativesRankingLoss, which accepts +# triplets of sentences (anchor, positive, negative) and pairs of sentences (anchor, positive). +all_nli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train") +sentence_compression_train_dataset = load_dataset("sentence-transformers/sentence-compression", split="train") +simple_wiki_train_dataset = load_dataset("sentence-transformers/simple-wiki", split="train") +altlex_train_dataset = load_dataset("sentence-transformers/altlex", split="train") +quora_train_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train") +coco_train_dataset = load_dataset("sentence-transformers/coco-captions", split="train") +flickr_train_dataset = load_dataset("sentence-transformers/flickr30k-captions", split="train") +yahoo_answers_train_dataset = load_dataset( + "sentence-transformers/yahoo-answers", "title-question-answer-pair", split="train" +) +stack_exchange_train_dataset = load_dataset( + "sentence-transformers/stackexchange-duplicates", "title-title-pair", split="train" ) +train_dataset_dict = { + "all-nli": all_nli_train_dataset, + "sentence-compression": sentence_compression_train_dataset, + "simple-wiki": simple_wiki_train_dataset, + "altlex": altlex_train_dataset, + "quora-duplicates": quora_train_dataset, + "coco-captions": coco_train_dataset, + "flickr30k-captions": flickr_train_dataset, + "yahoo-answers": yahoo_answers_train_dataset, + "stack-exchange": stack_exchange_train_dataset, +} +print(train_dataset_dict) + +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) +# If we want, we can limit the maximum sequence length for the model +model.max_seq_length = max_seq_length +logging.info(model) + +# 3. Define our training loss +train_loss = MultipleNegativesRankingLoss(model) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=stsb_eval_dataset["sentence1"], + sentences2=stsb_eval_dataset["sentence2"], + scores=stsb_eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", +) -# Our training loss -train_loss = losses.MultipleNegativesRankingLoss(model) - - -# Read STSbenchmark dataset and use it as development set -logging.info("Read STSbenchmark dev dataset") -dev_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["split"] == "dev": - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - dev_samples.append(InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)) - -dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - -# Configure the training -logging.info("Warmup-steps: {}".format(warmup_steps)) +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + # We can use ROUND_ROBIN or PROPORTIONAL - to avoid focusing too much on one dataset, we will + # use round robin, which samples the same amount of batches from each dataset, until one dataset is empty + multi_dataset_batch_sampler=MultiDatasetBatchSamplers.ROUND_ROBIN, + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=1000, + save_strategy="steps", + save_steps=1000, + save_total_limit=2, + logging_steps=100, + run_name="paraphrases-multi", # Will be used in W&B if `wandb` is installed +) -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset_dict, + loss=train_loss, evaluator=dev_evaluator, - epochs=num_epochs, - evaluation_steps=evaluation_steps, - warmup_steps=warmup_steps, - output_path=model_save_path, - use_amp=use_amp, - checkpoint_path=model_save_path, - checkpoint_save_steps=1000, - checkpoint_save_total_limit=3, ) +trainer.train() + +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-paraphrases-multi") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-paraphrases-multi')`." + ) diff --git a/examples/training/quora_duplicate_questions/training_MultipleNegativesRankingLoss.py b/examples/training/quora_duplicate_questions/training_MultipleNegativesRankingLoss.py index 8b1a1b4a2..9ddea4dd2 100644 --- a/examples/training/quora_duplicate_questions/training_MultipleNegativesRankingLoss.py +++ b/examples/training/quora_duplicate_questions/training_MultipleNegativesRankingLoss.py @@ -11,66 +11,47 @@ The model we get works well for duplicate questions mining and for duplicate questions information retrieval. For question pair classification, other losses (like OnlineConstrativeLoss) work better. """ -from torch.utils.data import DataLoader -from sentence_transformers import losses, util -from sentence_transformers import LoggingHandler, SentenceTransformer, evaluation -from sentence_transformers.readers import InputExample +import traceback +from datasets import load_dataset +from sentence_transformers import SentenceTransformer +from sentence_transformers.evaluation import ( + BinaryClassificationEvaluator, + InformationRetrievalEvaluator, + ParaphraseMiningEvaluator, + SequentialEvaluator, +) +from sentence_transformers.losses import MultipleNegativesRankingLoss +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments import logging from datetime import datetime -import csv -import os -from zipfile import ZipFile import random -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -logger = logging.getLogger(__name__) -#### /print debug information to stdout - +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # As base model, we use DistilBERT-base that was pre-trained on NLI and STSb data -model = SentenceTransformer("stsb-distilbert-base") - +model_name = "stsb-distilbert-base" +model = SentenceTransformer(model_name) # Training for multiple epochs can be beneficial, as in each epoch a mini-batch is sampled differently # hence, we get different negatives for each positive -num_epochs = 10 - +num_train_epochs = 1 # Increasing the batch size improves the performance for MultipleNegativesRankingLoss. Choose it as large as possible -# I achieved the good results with a batch size of 300-350 (requires about 30 GB of GPU memory) -train_batch_size = 64 - -dataset_path = "quora-IR-dataset" -model_save_path = "output/training_MultipleNegativesRankingLoss-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - -os.makedirs(model_save_path, exist_ok=True) +# I achieved the good results with a batch size of 300-350 +batch_size = 64 -# Check if the dataset exists. If not, download and extract -if not os.path.exists(dataset_path): - logger.info("Dataset not found. Download") - zip_save_path = "quora-IR-dataset.zip" - util.http_get(url="https://sbert.net/datasets/quora-IR-dataset.zip", path=zip_save_path) - with ZipFile(zip_save_path, "r") as zip: - zip.extractall(dataset_path) +output_dir = "output/training_mnrl-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +################### Load Quora Duplicate Questions dataset ################## -######### Read train data ########## -train_samples = [] -with open(os.path.join(dataset_path, "classification/train_pairs.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["is_duplicate"] == "1": - train_samples.append(InputExample(texts=[row["question1"], row["question2"]], label=1)) - train_samples.append( - InputExample(texts=[row["question2"], row["question1"]], label=1) - ) # if A is a duplicate of B, then B is a duplicate of A - - -# After reading the train_samples, we create a DataLoader -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) -train_loss = losses.MultipleNegativesRankingLoss(model) +# https://huggingface.co/datasets/sentence-transformers/quora-duplicates +dataset = load_dataset( + "sentence-transformers/quora-duplicates", "triplet", split="train" +) # The "pair" subset also works +train_dataset = dataset.select(range(100000)) +eval_dataset = dataset.select(range(100000, 101000)) +train_loss = MultipleNegativesRankingLoss(model=model) ################### Development Evaluators ################## # We add 3 evaluators, that evaluate the model on Duplicate Questions pair classification, @@ -78,50 +59,37 @@ evaluators = [] ###### Classification ###### -# Given (quesiton1, question2), is this a duplicate or not? +# Given (question1, question2), is this a duplicate or not? # The evaluator will compute the embeddings for both questions and then compute # a cosine similarity. If the similarity is above a threshold, we have a duplicate. -dev_sentences1 = [] -dev_sentences2 = [] -dev_labels = [] -with open(os.path.join(dataset_path, "classification/dev_pairs.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - dev_sentences1.append(row["question1"]) - dev_sentences2.append(row["question2"]) - dev_labels.append(int(row["is_duplicate"])) - - -binary_acc_evaluator = evaluation.BinaryClassificationEvaluator(dev_sentences1, dev_sentences2, dev_labels) + +duplicate_classes_dataset = load_dataset("sentence-transformers/quora-duplicates", "pair-class", split="train[-1000:]") +binary_acc_evaluator = BinaryClassificationEvaluator( + sentences1=duplicate_classes_dataset["sentence1"], + sentences2=duplicate_classes_dataset["sentence2"], + labels=duplicate_classes_dataset["label"], + name="quora-duplicates", +) evaluators.append(binary_acc_evaluator) ###### Duplicate Questions Mining ###### # Given a large corpus of questions, identify all duplicates in that corpus. -# For faster processing, we limit the development corpus to only 10,000 sentences. -max_dev_samples = 10000 -dev_sentences = {} -dev_duplicates = [] -with open(os.path.join(dataset_path, "duplicate-mining/dev_corpus.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - dev_sentences[row["qid"]] = row["question"] - - if len(dev_sentences) >= max_dev_samples: - break - -with open(os.path.join(dataset_path, "duplicate-mining/dev_duplicates.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["qid1"] in dev_sentences and row["qid2"] in dev_sentences: - dev_duplicates.append([row["qid1"], row["qid2"]]) +# Load the Quora Duplicates Mining dataset +# https://huggingface.co/datasets/sentence-transformers/quora-duplicates-mining +questions_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "questions", split="dev") +duplicates_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "duplicates", split="dev") +# Create a mapping from qid to question & a list of duplicates (qid1, qid2) +qid_to_questions = dict(zip(questions_dataset["qid"], questions_dataset["question"])) +duplicates = list(zip(duplicates_dataset["qid1"], duplicates_dataset["qid2"])) # The ParaphraseMiningEvaluator computes the cosine similarity between all sentences and # extracts a list with the pairs that have the highest similarity. Given the duplicate # information in dev_duplicates, it then computes and F1 score how well our duplicate mining worked -paraphrase_mining_evaluator = evaluation.ParaphraseMiningEvaluator(dev_sentences, dev_duplicates, name="dev") +paraphrase_mining_evaluator = ParaphraseMiningEvaluator(qid_to_questions, duplicates, name="quora-duplicates-dev") + evaluators.append(paraphrase_mining_evaluator) @@ -129,64 +97,85 @@ # Given a question and a large corpus of thousands questions, find the most relevant (i.e. duplicate) question # in that corpus. -# For faster processing, we limit the development corpus to only 10,000 sentences. -max_corpus_size = 10000 - -ir_queries = {} # Our queries (qid => question) -ir_needed_qids = set() # QIDs we need in the corpus -ir_corpus = {} # Our corpus (qid => question) -ir_relevant_docs = {} # Mapping of relevant documents for a given query (qid => set([relevant_question_ids]) - -with open(os.path.join(dataset_path, "information-retrieval/dev-queries.tsv"), encoding="utf8") as fIn: - next(fIn) # Skip header - for line in fIn: - qid, query, duplicate_ids = line.strip().split("\t") - duplicate_ids = duplicate_ids.split(",") - ir_queries[qid] = query - ir_relevant_docs[qid] = set(duplicate_ids) - - for qid in duplicate_ids: - ir_needed_qids.add(qid) - -# First get all needed relevant documents (i.e., we must ensure, that the relevant questions are actually in the corpus -distraction_questions = {} -with open(os.path.join(dataset_path, "information-retrieval/corpus.tsv"), encoding="utf8") as fIn: - next(fIn) # Skip header - for line in fIn: - qid, question = line.strip().split("\t") - - if qid in ir_needed_qids: - ir_corpus[qid] = question - else: - distraction_questions[qid] = question - -# Now, also add some irrelevant questions to fill our corpus -other_qid_list = list(distraction_questions.keys()) -random.shuffle(other_qid_list) - -for qid in other_qid_list[0 : max(0, max_corpus_size - len(ir_corpus))]: - ir_corpus[qid] = distraction_questions[qid] +# https://huggingface.co/datasets/BeIR/quora +# https://huggingface.co/datasets/BeIR/quora-qrels +new_ir_corpus = load_dataset("BeIR/quora", "corpus", split="corpus") +new_ir_queries = load_dataset("BeIR/quora", "queries", split="queries") +new_ir_relevant_docs_data = load_dataset("BeIR/quora-qrels", split="validation") + +# Shrink the corpus size heavily to only the relevant documents + 10,000 random documents +required_corpus_ids = list(map(str, new_ir_relevant_docs_data["corpus-id"])) +required_corpus_ids += random.sample(new_ir_corpus["_id"], k=10_000) +new_ir_corpus = new_ir_corpus.filter(lambda x: x["_id"] in required_corpus_ids) + +# Convert the datasets to dictionaries +new_ir_corpus = dict(zip(new_ir_corpus["_id"], new_ir_corpus["text"])) # Our corpus (qid => question) +new_ir_queries = dict(zip(new_ir_queries["_id"], new_ir_queries["text"])) # Our queries (qid => question) +new_ir_relevant_docs = {} # Query ID to relevant documents (qid => set([relevant_question_ids]) +for qid, corpus_ids in zip(new_ir_relevant_docs_data["query-id"], new_ir_relevant_docs_data["corpus-id"]): + qid = str(qid) + corpus_ids = str(corpus_ids) + if qid not in new_ir_relevant_docs: + new_ir_relevant_docs[qid] = set() + new_ir_relevant_docs[qid].add(corpus_ids) # Given queries, a corpus and a mapping with relevant documents, the InformationRetrievalEvaluator computes different IR # metrices. For our use case MRR@k and Accuracy@k are relevant. -ir_evaluator = evaluation.InformationRetrievalEvaluator(ir_queries, ir_corpus, ir_relevant_docs) - +ir_evaluator = InformationRetrievalEvaluator(new_ir_queries, new_ir_corpus, new_ir_relevant_docs) evaluators.append(ir_evaluator) # Create a SequentialEvaluator. This SequentialEvaluator runs all three evaluators in a sequential order. # We optimize the model with respect to the score from the last evaluator (scores[-1]) -seq_evaluator = evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1]) - - -logger.info("Evaluate model without training") -seq_evaluator(model, epoch=0, steps=0, output_path=model_save_path) - +seq_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1]) + +logging.info("Evaluate model without training") +seq_evaluator(model, epoch=0, steps=0) + +# Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=250, + save_strategy="steps", + save_steps=250, + save_total_limit=2, + logging_steps=100, + run_name="mnrl", # Will be used in W&B if `wandb` is installed +) -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=seq_evaluator, - epochs=num_epochs, - warmup_steps=1000, - output_path=model_save_path, ) +trainer.train() + +# Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-mnrl") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-mnrl')`." + ) diff --git a/examples/training/quora_duplicate_questions/training_OnlineContrastiveLoss.py b/examples/training/quora_duplicate_questions/training_OnlineContrastiveLoss.py index a492d0df0..51fc34b0a 100644 --- a/examples/training/quora_duplicate_questions/training_OnlineContrastiveLoss.py +++ b/examples/training/quora_duplicate_questions/training_OnlineContrastiveLoss.py @@ -9,63 +9,46 @@ An issue with constrative loss is, that it might push sentences away that are already well positioned in vector space. """ -from torch.utils.data import DataLoader -from sentence_transformers import losses, util -from sentence_transformers import LoggingHandler, SentenceTransformer, evaluation -from sentence_transformers.readers import InputExample +import traceback +from datasets import load_dataset +from sentence_transformers import SentenceTransformer +from sentence_transformers.evaluation import ( + BinaryClassificationEvaluator, + InformationRetrievalEvaluator, + ParaphraseMiningEvaluator, + SequentialEvaluator, +) +from sentence_transformers.losses import OnlineContrastiveLoss +from sentence_transformers.losses.ContrastiveLoss import SiameseDistanceMetric +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments import logging from datetime import datetime -import csv -import os -from zipfile import ZipFile import random -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -logger = logging.getLogger(__name__) -#### /print debug information to stdout - +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # As base model, we use DistilBERT-base that was pre-trained on NLI and STSb data -model = SentenceTransformer("stsb-distilbert-base") -num_epochs = 10 -train_batch_size = 64 +model_name = "stsb-distilbert-base" +model = SentenceTransformer(model_name) +num_train_epochs = 1 +batch_size = 64 -# As distance metric, we use cosine distance (cosine_distance = 1-cosine_similarity) -distance_metric = losses.SiameseDistanceMetric.COSINE_DISTANCE +output_dir = "output/training_ocl-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + +################### Load Quora Duplicate Questions dataset ################## +# https://huggingface.co/datasets/sentence-transformers/quora-duplicates +dataset = load_dataset("sentence-transformers/quora-duplicates", "pair-class", split="train") +dataset = dataset.train_test_split(test_size=1000) +train_dataset = dataset["train"].select(range(100000)) +eval_dataset = dataset["test"] # Negative pairs should have a distance of at least 0.5 margin = 0.5 - -dataset_path = "quora-IR-dataset" -model_save_path = "output/training_OnlineConstrativeLoss-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - -os.makedirs(model_save_path, exist_ok=True) - -# Check if the dataset exists. If not, download and extract -if not os.path.exists(dataset_path): - logger.info("Dataset not found. Download") - zip_save_path = "quora-IR-dataset.zip" - util.http_get(url="https://sbert.net/datasets/quora-IR-dataset.zip", path=zip_save_path) - with ZipFile(zip_save_path, "r") as zip: - zip.extractall(dataset_path) - - -######### Read train data ########## -# Read train data -train_samples = [] -with open(os.path.join(dataset_path, "classification/train_pairs.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - sample = InputExample(texts=[row["question1"], row["question2"]], label=int(row["is_duplicate"])) - train_samples.append(sample) - - -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) -train_loss = losses.OnlineContrastiveLoss(model=model, distance_metric=distance_metric, margin=margin) - +# As distance metric, we use cosine distance (cosine_distance = 1-cosine_similarity) +distance_metric = SiameseDistanceMetric.COSINE_DISTANCE +train_loss = OnlineContrastiveLoss(model=model, distance_metric=distance_metric, margin=margin) ################### Development Evaluators ################## # We add 3 evaluators, that evaluate the model on Duplicate Questions pair classification, @@ -73,50 +56,36 @@ evaluators = [] ###### Classification ###### -# Given (quesiton1, question2), is this a duplicate or not? +# Given (question1, question2), is this a duplicate or not? # The evaluator will compute the embeddings for both questions and then compute # a cosine similarity. If the similarity is above a threshold, we have a duplicate. -dev_sentences1 = [] -dev_sentences2 = [] -dev_labels = [] -with open(os.path.join(dataset_path, "classification/dev_pairs.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - dev_sentences1.append(row["question1"]) - dev_sentences2.append(row["question2"]) - dev_labels.append(int(row["is_duplicate"])) - - -binary_acc_evaluator = evaluation.BinaryClassificationEvaluator(dev_sentences1, dev_sentences2, dev_labels) + +binary_acc_evaluator = BinaryClassificationEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + labels=eval_dataset["label"], + name="quora-duplicates", +) evaluators.append(binary_acc_evaluator) ###### Duplicate Questions Mining ###### # Given a large corpus of questions, identify all duplicates in that corpus. -# For faster processing, we limit the development corpus to only 10,000 sentences. -max_dev_samples = 10000 -dev_sentences = {} -dev_duplicates = [] -with open(os.path.join(dataset_path, "duplicate-mining/dev_corpus.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - dev_sentences[row["qid"]] = row["question"] - - if len(dev_sentences) >= max_dev_samples: - break - -with open(os.path.join(dataset_path, "duplicate-mining/dev_duplicates.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["qid1"] in dev_sentences and row["qid2"] in dev_sentences: - dev_duplicates.append([row["qid1"], row["qid2"]]) +# Load the Quora Duplicates Mining dataset +# https://huggingface.co/datasets/sentence-transformers/quora-duplicates-mining +questions_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "questions", split="dev") +duplicates_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "duplicates", split="dev") +# Create a mapping from qid to question & a list of duplicates (qid1, qid2) +qid_to_questions = dict(zip(questions_dataset["qid"], questions_dataset["question"])) +duplicates = list(zip(duplicates_dataset["qid1"], duplicates_dataset["qid2"])) # The ParaphraseMiningEvaluator computes the cosine similarity between all sentences and # extracts a list with the pairs that have the highest similarity. Given the duplicate # information in dev_duplicates, it then computes and F1 score how well our duplicate mining worked -paraphrase_mining_evaluator = evaluation.ParaphraseMiningEvaluator(dev_sentences, dev_duplicates, name="dev") +paraphrase_mining_evaluator = ParaphraseMiningEvaluator(qid_to_questions, duplicates, name="quora-duplicates-dev") + evaluators.append(paraphrase_mining_evaluator) @@ -124,64 +93,85 @@ # Given a question and a large corpus of thousands questions, find the most relevant (i.e. duplicate) question # in that corpus. -# For faster processing, we limit the development corpus to only 10,000 sentences. -max_corpus_size = 100000 - -ir_queries = {} # Our queries (qid => question) -ir_needed_qids = set() # QIDs we need in the corpus -ir_corpus = {} # Our corpus (qid => question) -ir_relevant_docs = {} # Mapping of relevant documents for a given query (qid => set([relevant_question_ids]) - -with open(os.path.join(dataset_path, "information-retrieval/dev-queries.tsv"), encoding="utf8") as fIn: - next(fIn) # Skip header - for line in fIn: - qid, query, duplicate_ids = line.strip().split("\t") - duplicate_ids = duplicate_ids.split(",") - ir_queries[qid] = query - ir_relevant_docs[qid] = set(duplicate_ids) - - for qid in duplicate_ids: - ir_needed_qids.add(qid) - -# First get all needed relevant documents (i.e., we must ensure, that the relevant questions are actually in the corpus -distraction_questions = {} -with open(os.path.join(dataset_path, "information-retrieval/corpus.tsv"), encoding="utf8") as fIn: - next(fIn) # Skip header - for line in fIn: - qid, question = line.strip().split("\t") - - if qid in ir_needed_qids: - ir_corpus[qid] = question - else: - distraction_questions[qid] = question - -# Now, also add some irrelevant questions to fill our corpus -other_qid_list = list(distraction_questions.keys()) -random.shuffle(other_qid_list) - -for qid in other_qid_list[0 : max(0, max_corpus_size - len(ir_corpus))]: - ir_corpus[qid] = distraction_questions[qid] +# https://huggingface.co/datasets/BeIR/quora +# https://huggingface.co/datasets/BeIR/quora-qrels +new_ir_corpus = load_dataset("BeIR/quora", "corpus", split="corpus") +new_ir_queries = load_dataset("BeIR/quora", "queries", split="queries") +new_ir_relevant_docs_data = load_dataset("BeIR/quora-qrels", split="validation") + +# Shrink the corpus size heavily to only the relevant documents + 10,000 random documents +required_corpus_ids = list(map(str, new_ir_relevant_docs_data["corpus-id"])) +required_corpus_ids += random.sample(new_ir_corpus["_id"], k=10_000) +new_ir_corpus = new_ir_corpus.filter(lambda x: x["_id"] in required_corpus_ids) + +# Convert the datasets to dictionaries +new_ir_corpus = dict(zip(new_ir_corpus["_id"], new_ir_corpus["text"])) # Our corpus (qid => question) +new_ir_queries = dict(zip(new_ir_queries["_id"], new_ir_queries["text"])) # Our queries (qid => question) +new_ir_relevant_docs = {} # Query ID to relevant documents (qid => set([relevant_question_ids]) +for qid, corpus_ids in zip(new_ir_relevant_docs_data["query-id"], new_ir_relevant_docs_data["corpus-id"]): + qid = str(qid) + corpus_ids = str(corpus_ids) + if qid not in new_ir_relevant_docs: + new_ir_relevant_docs[qid] = set() + new_ir_relevant_docs[qid].add(corpus_ids) # Given queries, a corpus and a mapping with relevant documents, the InformationRetrievalEvaluator computes different IR # metrices. For our use case MRR@k and Accuracy@k are relevant. -ir_evaluator = evaluation.InformationRetrievalEvaluator(ir_queries, ir_corpus, ir_relevant_docs) - +ir_evaluator = InformationRetrievalEvaluator(new_ir_queries, new_ir_corpus, new_ir_relevant_docs) evaluators.append(ir_evaluator) # Create a SequentialEvaluator. This SequentialEvaluator runs all three evaluators in a sequential order. # We optimize the model with respect to the score from the last evaluator (scores[-1]) -seq_evaluator = evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1]) - - -logger.info("Evaluate model without training") -seq_evaluator(model, epoch=0, steps=0, output_path=model_save_path) - +seq_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1]) + +logging.info("Evaluate model without training") +seq_evaluator(model, epoch=0, steps=0) + +# Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, # OCL benefits from no duplicate samples in a batch + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=250, + save_strategy="steps", + save_steps=250, + save_total_limit=2, + logging_steps=100, + run_name="online-contrastive-loss", # Will be used in W&B if `wandb` is installed +) -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], +# Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, evaluator=seq_evaluator, - epochs=num_epochs, - warmup_steps=1000, - output_path=model_save_path, ) +trainer.train() + +# Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-ocl") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-ocl')`." + ) diff --git a/examples/training/quora_duplicate_questions/training_multi-task-learning.py b/examples/training/quora_duplicate_questions/training_multi-task-learning.py index 7ba18cdd0..15d8c227c 100644 --- a/examples/training/quora_duplicate_questions/training_multi-task-learning.py +++ b/examples/training/quora_duplicate_questions/training_multi-task-learning.py @@ -11,85 +11,58 @@ model.fit(train_objectives=[(train_dataloader_MultipleNegativesRankingLoss, train_loss_MultipleNegativesRankingLoss), (train_dataloader_constrative_loss, train_loss_constrative_loss)] ...) """ -from torch.utils.data import DataLoader -from sentence_transformers import losses, util -from sentence_transformers import LoggingHandler, SentenceTransformer, evaluation -from sentence_transformers.readers import InputExample +import traceback +from datasets import load_dataset +from sentence_transformers import SentenceTransformer +from sentence_transformers.evaluation import ( + BinaryClassificationEvaluator, + InformationRetrievalEvaluator, + ParaphraseMiningEvaluator, + SequentialEvaluator, +) +from sentence_transformers.losses import ContrastiveLoss, MultipleNegativesRankingLoss +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import ( + BatchSamplers, + MultiDatasetBatchSamplers, + SentenceTransformerTrainingArguments, +) import logging from datetime import datetime -import csv -import os -from zipfile import ZipFile import random -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -logger = logging.getLogger(__name__) -#### /print debug information to stdout - +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # As base model, we use DistilBERT-base that was pre-trained on NLI and STSb data -model = SentenceTransformer("stsb-distilbert-base") - +model_name = "stsb-distilbert-base" +model = SentenceTransformer(model_name) # Training for multiple epochs can be beneficial, as in each epoch a mini-batch is sampled differently # hence, we get different negatives for each positive -num_epochs = 10 - +num_train_epochs = 1 # Increasing the batch size improves the performance for MultipleNegativesRankingLoss. Choose it as large as possible -# I achieved the good results with a batch size of 300-350 (requires about 30 GB of GPU memory) -train_batch_size = 64 - -# As distance metric, we use cosine distance (cosine_distance = 1-cosine_similarity) -distance_metric = losses.SiameseDistanceMetric.COSINE_DISTANCE - -# Negative pairs should have a distance of at least 0.5 -margin = 0.5 - -dataset_path = "quora-IR-dataset" -model_save_path = "output/training_multi-task-learning" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - -os.makedirs(model_save_path, exist_ok=True) - -# Check if the dataset exists. If not, download and extract -if not os.path.exists(dataset_path): - logger.info("Dataset not found. Download") - zip_save_path = "quora-IR-dataset.zip" - util.http_get(url="https://sbert.net/datasets/quora-IR-dataset.zip", path=zip_save_path) - with ZipFile(zip_save_path, "r") as zip: - zip.extractall(dataset_path) - - -######### Read train data ########## -train_samples_MultipleNegativesRankingLoss = [] -train_samples_ConstrativeLoss = [] - -with open(os.path.join(dataset_path, "classification/train_pairs.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - train_samples_ConstrativeLoss.append( - InputExample(texts=[row["question1"], row["question2"]], label=int(row["is_duplicate"])) - ) - if row["is_duplicate"] == "1": - train_samples_MultipleNegativesRankingLoss.append( - InputExample(texts=[row["question1"], row["question2"]], label=1) - ) - train_samples_MultipleNegativesRankingLoss.append( - InputExample(texts=[row["question2"], row["question1"]], label=1) - ) # if A is a duplicate of B, then B is a duplicate of A - -# Create data loader and loss for MultipleNegativesRankingLoss -train_dataloader_MultipleNegativesRankingLoss = DataLoader( - train_samples_MultipleNegativesRankingLoss, shuffle=True, batch_size=train_batch_size -) -train_loss_MultipleNegativesRankingLoss = losses.MultipleNegativesRankingLoss(model) +# I achieved the good results with a batch size of 300-350 +batch_size = 64 +output_dir = "output/training_mnrl-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -# Create data loader and loss for OnlineContrastiveLoss -train_dataloader_ConstrativeLoss = DataLoader(train_samples_ConstrativeLoss, shuffle=True, batch_size=train_batch_size) -train_loss_ConstrativeLoss = losses.OnlineContrastiveLoss(model=model, distance_metric=distance_metric, margin=margin) +################### Load Quora Duplicate Questions dataset ################## +# https://huggingface.co/datasets/sentence-transformers/quora-duplicates +mnrl_dataset = load_dataset( + "sentence-transformers/quora-duplicates", "triplet", split="train" +) # The "pair" subset also works +mnrl_train_dataset = mnrl_dataset.select(range(100000)) +mnrl_eval_dataset = mnrl_dataset.select(range(100000, 101000)) + +mnrl_train_loss = MultipleNegativesRankingLoss(model=model) + +# https://huggingface.co/datasets/sentence-transformers/quora-duplicates +cl_dataset = load_dataset("sentence-transformers/quora-duplicates", "pair-class", split="train") +cl_train_dataset = cl_dataset.select(range(100000)) +cl_eval_dataset = cl_dataset.select(range(100000, 101000)) + +cl_train_loss = ContrastiveLoss(model=model, margin=0.5) ################### Development Evaluators ################## # We add 3 evaluators, that evaluate the model on Duplicate Questions pair classification, @@ -97,50 +70,37 @@ evaluators = [] ###### Classification ###### -# Given (quesiton1, question2), is this a duplicate or not? +# Given (question1, question2), is this a duplicate or not? # The evaluator will compute the embeddings for both questions and then compute # a cosine similarity. If the similarity is above a threshold, we have a duplicate. -dev_sentences1 = [] -dev_sentences2 = [] -dev_labels = [] -with open(os.path.join(dataset_path, "classification/dev_pairs.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - dev_sentences1.append(row["question1"]) - dev_sentences2.append(row["question2"]) - dev_labels.append(int(row["is_duplicate"])) - - -binary_acc_evaluator = evaluation.BinaryClassificationEvaluator(dev_sentences1, dev_sentences2, dev_labels) + +duplicate_classes_dataset = load_dataset("sentence-transformers/quora-duplicates", "pair-class", split="train[-1000:]") +binary_acc_evaluator = BinaryClassificationEvaluator( + sentences1=duplicate_classes_dataset["sentence1"], + sentences2=duplicate_classes_dataset["sentence2"], + labels=duplicate_classes_dataset["label"], + name="quora-duplicates", +) evaluators.append(binary_acc_evaluator) ###### Duplicate Questions Mining ###### # Given a large corpus of questions, identify all duplicates in that corpus. -# For faster processing, we limit the development corpus to only 10,000 sentences. -max_dev_samples = 10000 -dev_sentences = {} -dev_duplicates = [] -with open(os.path.join(dataset_path, "duplicate-mining/dev_corpus.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - dev_sentences[row["qid"]] = row["question"] - - if len(dev_sentences) >= max_dev_samples: - break - -with open(os.path.join(dataset_path, "duplicate-mining/dev_duplicates.tsv"), encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - if row["qid1"] in dev_sentences and row["qid2"] in dev_sentences: - dev_duplicates.append([row["qid1"], row["qid2"]]) +# Load the Quora Duplicates Mining dataset +# https://huggingface.co/datasets/sentence-transformers/quora-duplicates-mining +questions_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "questions", split="dev") +duplicates_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "duplicates", split="dev") +# Create a mapping from qid to question & a list of duplicates (qid1, qid2) +qid_to_questions = dict(zip(questions_dataset["qid"], questions_dataset["question"])) +duplicates = list(zip(duplicates_dataset["qid1"], duplicates_dataset["qid2"])) # The ParaphraseMiningEvaluator computes the cosine similarity between all sentences and # extracts a list with the pairs that have the highest similarity. Given the duplicate # information in dev_duplicates, it then computes and F1 score how well our duplicate mining worked -paraphrase_mining_evaluator = evaluation.ParaphraseMiningEvaluator(dev_sentences, dev_duplicates, name="dev") +paraphrase_mining_evaluator = ParaphraseMiningEvaluator(qid_to_questions, duplicates, name="quora-duplicates-dev") + evaluators.append(paraphrase_mining_evaluator) @@ -148,67 +108,95 @@ # Given a question and a large corpus of thousands questions, find the most relevant (i.e. duplicate) question # in that corpus. -# For faster processing, we limit the development corpus to only 10,000 sentences. -max_corpus_size = 100000 - -ir_queries = {} # Our queries (qid => question) -ir_needed_qids = set() # QIDs we need in the corpus -ir_corpus = {} # Our corpus (qid => question) -ir_relevant_docs = {} # Mapping of relevant documents for a given query (qid => set([relevant_question_ids]) - -with open(os.path.join(dataset_path, "information-retrieval/dev-queries.tsv"), encoding="utf8") as fIn: - next(fIn) # Skip header - for line in fIn: - qid, query, duplicate_ids = line.strip().split("\t") - duplicate_ids = duplicate_ids.split(",") - ir_queries[qid] = query - ir_relevant_docs[qid] = set(duplicate_ids) - - for qid in duplicate_ids: - ir_needed_qids.add(qid) - -# First get all needed relevant documents (i.e., we must ensure, that the relevant questions are actually in the corpus -distraction_questions = {} -with open(os.path.join(dataset_path, "information-retrieval/corpus.tsv"), encoding="utf8") as fIn: - next(fIn) # Skip header - for line in fIn: - qid, question = line.strip().split("\t") - - if qid in ir_needed_qids: - ir_corpus[qid] = question - else: - distraction_questions[qid] = question - -# Now, also add some irrelevant questions to fill our corpus -other_qid_list = list(distraction_questions.keys()) -random.shuffle(other_qid_list) - -for qid in other_qid_list[0 : max(0, max_corpus_size - len(ir_corpus))]: - ir_corpus[qid] = distraction_questions[qid] +# https://huggingface.co/datasets/BeIR/quora +# https://huggingface.co/datasets/BeIR/quora-qrels +new_ir_corpus = load_dataset("BeIR/quora", "corpus", split="corpus") +new_ir_queries = load_dataset("BeIR/quora", "queries", split="queries") +new_ir_relevant_docs_data = load_dataset("BeIR/quora-qrels", split="validation") + +# Shrink the corpus size heavily to only the relevant documents + 10,000 random documents +required_corpus_ids = list(map(str, new_ir_relevant_docs_data["corpus-id"])) +required_corpus_ids += random.sample(new_ir_corpus["_id"], k=10_000) +new_ir_corpus = new_ir_corpus.filter(lambda x: x["_id"] in required_corpus_ids) + +# Convert the datasets to dictionaries +new_ir_corpus = dict(zip(new_ir_corpus["_id"], new_ir_corpus["text"])) # Our corpus (qid => question) +new_ir_queries = dict(zip(new_ir_queries["_id"], new_ir_queries["text"])) # Our queries (qid => question) +new_ir_relevant_docs = {} # Query ID to relevant documents (qid => set([relevant_question_ids]) +for qid, corpus_ids in zip(new_ir_relevant_docs_data["query-id"], new_ir_relevant_docs_data["corpus-id"]): + qid = str(qid) + corpus_ids = str(corpus_ids) + if qid not in new_ir_relevant_docs: + new_ir_relevant_docs[qid] = set() + new_ir_relevant_docs[qid].add(corpus_ids) # Given queries, a corpus and a mapping with relevant documents, the InformationRetrievalEvaluator computes different IR # metrices. For our use case MRR@k and Accuracy@k are relevant. -ir_evaluator = evaluation.InformationRetrievalEvaluator(ir_queries, ir_corpus, ir_relevant_docs) - +ir_evaluator = InformationRetrievalEvaluator(new_ir_queries, new_ir_corpus, new_ir_relevant_docs) evaluators.append(ir_evaluator) # Create a SequentialEvaluator. This SequentialEvaluator runs all three evaluators in a sequential order. # We optimize the model with respect to the score from the last evaluator (scores[-1]) -seq_evaluator = evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1]) - - -logger.info("Evaluate model without training") -seq_evaluator(model, epoch=0, steps=0, output_path=model_save_path) - +seq_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[-1]) + +logging.info("Evaluate model without training") +seq_evaluator(model, epoch=0, steps=0) + +# Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_train_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch + multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, # PROPORTIONAL or ROUND_ROBIN + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=250, + save_strategy="steps", + save_steps=250, + save_total_limit=2, + logging_steps=100, + run_name="mnrl-cl-multi", # Will be used in W&B if `wandb` is installed +) -# Train the model -model.fit( - train_objectives=[ - (train_dataloader_MultipleNegativesRankingLoss, train_loss_MultipleNegativesRankingLoss), - (train_dataloader_ConstrativeLoss, train_loss_ConstrativeLoss), - ], +# Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset={ + "mnrl": mnrl_train_dataset, + "cl": cl_train_dataset, + }, + eval_dataset={ + "mnrl": mnrl_eval_dataset, + "cl": cl_eval_dataset, + }, + loss={ + "mnrl": mnrl_train_loss, + "cl": cl_train_loss, + }, evaluator=seq_evaluator, - epochs=num_epochs, - warmup_steps=1000, - output_path=model_save_path, ) +trainer.train() + +# Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-mnrl-cl-multi") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-mnrl-cl-multi')`." + ) diff --git a/examples/training/sts/training_stsbenchmark.py b/examples/training/sts/training_stsbenchmark.py index 42fafd93d..ea194640e 100644 --- a/examples/training/sts/training_stsbenchmark.py +++ b/examples/training/sts/training_stsbenchmark.py @@ -9,105 +9,109 @@ python training_nli.py pretrained_transformer_model_name """ -from torch.utils.data import DataLoader -import math -from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util +import traceback +from datasets import load_dataset +from sentence_transformers import SentenceTransformer, losses from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample import logging from datetime import datetime import sys -import os -import gzip -import csv -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - - -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) - -# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base +# You can specify any Hugging Face pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base model_name = sys.argv[1] if len(sys.argv) > 1 else "distilbert-base-uncased" - -# Read the dataset train_batch_size = 16 num_epochs = 4 -model_save_path = ( +output_dir = ( "output/training_stsbenchmark_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) -# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings -word_embedding_model = models.Transformer(model_name) - -# Apply mean pooling to get one fixed sized sentence vector -pooling_model = models.Pooling( - word_embedding_model.get_word_embedding_dimension(), - pooling_mode_mean_tokens=True, - pooling_mode_cls_token=False, - pooling_mode_max_tokens=False, -) - -model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - -# Convert the dataset to a DataLoader ready for training -logging.info("Read STSbenchmark train dataset") - -train_samples = [] -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) +# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically +# create one with "mean" pooling. +model = SentenceTransformer(model_name) - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) +# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) - -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) +# 3. Define our training loss +# CosineSimilarityLoss (https://sbert.net/docs/package_reference/losses.html#cosentloss) needs two text columns and one +# similarity score column (between 0 and 1) train_loss = losses.CosineSimilarityLoss(model=model) +# train_loss = losses.CoSENTLoss(model=model) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", +) - -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - - -# Configure the training. We skip evaluation in this example -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_epochs, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="sts", # Will be used in W&B if `wandb` is installed ) +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -test_evaluator(model, output_path=model_save_path) +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-sts") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-sts')`." + ) diff --git a/examples/training/sts/training_stsbenchmark_continue_training.py b/examples/training/sts/training_stsbenchmark_continue_training.py index c06b2a62d..c902306a1 100644 --- a/examples/training/sts/training_stsbenchmark_continue_training.py +++ b/examples/training/sts/training_stsbenchmark_continue_training.py @@ -1,97 +1,113 @@ """ -This example loads the pre-trained SentenceTransformer model 'nli-distilroberta-base-v2' from the server. +This example loads the pre-trained SentenceTransformer model 'nli-distilroberta-base-v2' from Hugging Face. It then fine-tunes this model for some epochs on the STS benchmark dataset. Note: In this example, you must specify a SentenceTransformer model. If you want to fine-tune a huggingface/transformers model like bert-base-uncased, see training_nli.py and training_stsbenchmark.py """ -from torch.utils.data import DataLoader -import math -from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util, InputExample +import traceback +from datasets import load_dataset +from sentence_transformers import SentenceTransformer, losses from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator import logging from datetime import datetime -import os -import gzip -import csv +import sys -#### Just some code to print debug information to stdout -logging.basicConfig( - format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] -) -#### /print debug information to stdout - -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) +# Set the log level to INFO to get more information +logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) - -# Read the dataset -model_name = "nli-distilroberta-base-v2" +# You can specify any Sentence Transformer model here, for example all-mpnet-base-v2, all-MiniLM-L6-v2, mixedbread-ai/mxbai-embed-large-v1 +model_name = sys.argv[1] if len(sys.argv) > 1 else "sentence-transformers/all-mpnet-base-v2" train_batch_size = 16 num_epochs = 4 -model_save_path = ( - "output/training_stsbenchmark_continue_training-" + model_name + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") +output_dir = ( + "output/training_stsbenchmark_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) - -# Load a pre-trained sentence transformer model +# 1. Here we define our SentenceTransformer model. model = SentenceTransformer(model_name) -# Convert the dataset to a DataLoader ready for training -logging.info("Read STSbenchmark train dataset") +# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb +train_dataset = load_dataset("sentence-transformers/stsb", split="train") +eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") +test_dataset = load_dataset("sentence-transformers/stsb", split="test") +logging.info(train_dataset) -train_samples = [] -dev_samples = [] -test_samples = [] -with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) - - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) - - -train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) +# 3. Define our training loss +# CosineSimilarityLoss (https://sbert.net/docs/package_reference/losses.html#cosentloss) needs two text columns and one +# similarity score column (between 0 and 1) train_loss = losses.CosineSimilarityLoss(model=model) +# train_loss = losses.CoSENTLoss(model=model) + +# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss. +dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", +) - -# Development set: Measure correlation between cosine score and gold labels -logging.info("Read STSbenchmark dev dataset") -evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") - - -# Configure the training. We skip evaluation in this example -warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up -logging.info("Warmup-steps: {}".format(warmup_steps)) - - -# Train the model -model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=model_save_path, +# 5. Define the training arguments +args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=output_dir, + # Optional training parameters: + num_train_epochs=num_epochs, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + eval_strategy="steps", + eval_steps=100, + save_strategy="steps", + save_steps=100, + save_total_limit=2, + logging_steps=100, + run_name="sts", # Will be used in W&B if `wandb` is installed ) +# 6. Create the trainer & start training +trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, +) +trainer.train() -############################################################################## -# -# Load the stored model and evaluate its performance on STS benchmark dataset -# -############################################################################## -model = SentenceTransformer(model_save_path) -test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test") -test_evaluator(model, output_path=model_save_path) +# 7. Evaluate the model performance on the STS Benchmark test dataset +test_evaluator = EmbeddingSimilarityEvaluator( + sentences1=test_dataset["sentence1"], + sentences2=test_dataset["sentence2"], + scores=test_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-test", +) +test_evaluator(model) + +# 8. Save the trained & evaluated model locally +final_output_dir = f"{output_dir}/final" +model.save(final_output_dir) + +# 9. (Optional) save the model to the Hugging Face Hub! +# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first +model_name = model_name if "/" not in model_name else model_name.split("/")[-1] +try: + model.push_to_hub(f"{model_name}-sts") +except Exception: + logging.error( + f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run " + f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` " + f"and saving it using `model.push_to_hub('{model_name}-sts')`." + ) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 53f61332c..852750404 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -73,6 +73,31 @@ class SentenceTransformer(nn.Sequential, FitMixin): :param token: Hugging Face authentication token to download private models. :param truncate_dim: The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is only applicable during inference when `.encode` is called. + + Example + :: + + from sentence_transformers import SentenceTransformer + + # Load a pre-trained SentenceTransformer model + model = SentenceTransformer('all-mpnet-base-v2') + + # Encode some texts + sentences = [ + "The weather is lovely today.", + "It's so sunny outside!", + "He drove to the stadium.", + ] + embeddings = model.encode(sentences) + print(embeddings.shape) + # (3, 768) + + # Get the similarity scores between all sentences + similarities = model.similarity(embeddings, embeddings) + print(similarities) + # tensor([[1.0000, 0.6817, 0.0492], + # [0.6817, 1.0000, 0.0421], + # [0.0492, 0.0421, 1.0000]]) """ def __init__( @@ -444,6 +469,7 @@ def encode( @property def similarity_fn_name(self) -> Optional[str]: + """Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.""" return self._similarity_fn_name @similarity_fn_name.setter diff --git a/sentence_transformers/evaluation/BinaryClassificationEvaluator.py b/sentence_transformers/evaluation/BinaryClassificationEvaluator.py index a18ec9f85..e15a61d9d 100644 --- a/sentence_transformers/evaluation/BinaryClassificationEvaluator.py +++ b/sentence_transformers/evaluation/BinaryClassificationEvaluator.py @@ -36,6 +36,58 @@ class BinaryClassificationEvaluator(SentenceEvaluator): :param write_csv: Write results to a CSV file :param truncate_dim: The dimension to truncate sentence embeddings to. `None` uses the model's current truncation dimension. Defaults to None. + + Example + :: + + from sentence_transformers import SentenceTransformer + from sentence_transformers.evaluation import BinaryClassificationEvaluator + from datasets import load_dataset + + # Load a model + model = SentenceTransformer('all-mpnet-base-v2') + + # Load a dataset with two text columns and a class label column (https://huggingface.co/datasets/sentence-transformers/quora-duplicates) + eval_dataset = load_dataset("sentence-transformers/quora-duplicates", "pair-class", split="train[-1000:]") + + # Initialize the evaluator + binary_acc_evaluator = BinaryClassificationEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + labels=eval_dataset["label"], + name="quora-duplicates-dev", + ) + results = binary_acc_evaluator(model) + ''' + Binary Accuracy Evaluation of the model on the quora-duplicates-dev dataset: + Accuracy with Cosine-Similarity: 81.60 (Threshold: 0.8352) + F1 with Cosine-Similarity: 75.27 (Threshold: 0.7715) + Precision with Cosine-Similarity: 65.81 + Recall with Cosine-Similarity: 87.89 + Average Precision with Cosine-Similarity: 76.03 + + Accuracy with Dot-Product: 81.60 (Threshold: 0.8352) + F1 with Dot-Product: 75.27 (Threshold: 0.7715) + Precision with Dot-Product: 65.81 + Recall with Dot-Product: 87.89 + Average Precision with Dot-Product: 76.03 + + Accuracy with Manhattan-Distance: 81.50 (Threshold: 12.0727) + F1 with Manhattan-Distance: 74.97 (Threshold: 15.2269) + Precision with Manhattan-Distance: 63.89 + Recall with Manhattan-Distance: 90.68 + Average Precision with Manhattan-Distance: 75.66 + + Accuracy with Euclidean-Distance: 81.60 (Threshold: 0.5741) + F1 with Euclidean-Distance: 75.27 (Threshold: 0.6760) + Precision with Euclidean-Distance: 65.81 + Recall with Euclidean-Distance: 87.89 + Average Precision with Euclidean-Distance: 76.03 + ''' + print(binary_acc_evaluator.primary_metric) + # => "quora-duplicates-dev_max_ap" + print(results[binary_acc_evaluator.primary_metric]) + # => 0.760277070888393 """ def __init__( diff --git a/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py b/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py index 0f2e9ca39..5bbbaa4fe 100644 --- a/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py +++ b/sentence_transformers/evaluation/EmbeddingSimilarityEvaluator.py @@ -23,7 +23,36 @@ class EmbeddingSimilarityEvaluator(SentenceEvaluator): The metrics are the cosine similarity as well as euclidean and Manhattan distance The returned score is the Spearman correlation with a specified metric. - The results are written in a CSV. If a CSV already exists, then values are appended. + Example + :: + + from datasets import load_dataset + from sentence_transformers import SentenceTransformer + from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction + + # Load a model + model = SentenceTransformer('all-mpnet-base-v2') + + # Load the STSB dataset (https://huggingface.co/datasets/nyu-mll/glue/viewer/stsb) + eval_dataset = load_dataset("nyu-mll/glue", "stsb", split="validation") + + # Initialize the evaluator + dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=[score / 5 for score in eval_dataset["label"]], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", + ) + dev_evaluator(model) + ''' + EmbeddingSimilarityEvaluator: Evaluating the model on the sts-dev dataset: + Cosine-Similarity : Pearson: 0.7874 Spearman: 0.8004 + Manhattan-Distance: Pearson: 0.7823 Spearman: 0.7827 + Euclidean-Distance: Pearson: 0.7824 Spearman: 0.7827 + Dot-Product-Similarity: Pearson: 0.7192 Spearman: 0.7126 + ''' + # => 0.8004 """ def __init__( diff --git a/sentence_transformers/evaluation/InformationRetrievalEvaluator.py b/sentence_transformers/evaluation/InformationRetrievalEvaluator.py index 917574c3e..235bd1b75 100644 --- a/sentence_transformers/evaluation/InformationRetrievalEvaluator.py +++ b/sentence_transformers/evaluation/InformationRetrievalEvaluator.py @@ -23,6 +23,89 @@ class InformationRetrievalEvaluator(SentenceEvaluator): Given a set of queries and a large corpus set. It will retrieve for each query the top-k most similar document. It measures Mean Reciprocal Rank (MRR), Recall@k, and Normalized Discounted Cumulative Gain (NDCG) + + Example + :: + + import random + from sentence_transformers import SentenceTransformer + from sentence_transformers.evaluation import InformationRetrievalEvaluator + from datasets import load_dataset + + # Load a model + model = SentenceTransformer('all-mpnet-base-v2') + + # Load the Quora IR dataset (https://huggingface.co/datasets/BeIR/quora, https://huggingface.co/datasets/BeIR/quora-qrels) + corpus = load_dataset("BeIR/quora", "corpus", split="corpus") + queries = load_dataset("BeIR/quora", "queries", split="queries") + relevant_docs_data = load_dataset("BeIR/quora-qrels", split="validation") + + # Shrink the corpus size heavily to only the relevant documents + 10,000 random documents + required_corpus_ids = list(map(str, relevant_docs_data["corpus-id"])) + required_corpus_ids += random.sample(corpus["_id"], k=10_000) + corpus = corpus.filter(lambda x: x["_id"] in required_corpus_ids) + + # Convert the datasets to dictionaries + corpus = dict(zip(corpus["_id"], corpus["text"])) # Our corpus (qid => question) + queries = dict(zip(queries["_id"], queries["text"])) # Our queries (qid => question) + relevant_docs = {} # Query ID to relevant documents (qid => set([relevant_question_ids]) + for qid, corpus_ids in zip(relevant_docs_data["query-id"], relevant_docs_data["corpus-id"]): + qid = str(qid) + corpus_ids = str(corpus_ids) + if qid not in relevant_docs: + relevant_docs[qid] = set() + relevant_docs[qid].add(corpus_ids) + + # Given queries, a corpus and a mapping with relevant documents, the InformationRetrievalEvaluator computes different IR metrics. + ir_evaluator = InformationRetrievalEvaluator( + queries=queries, + corpus=corpus, + relevant_docs=relevant_docs, + name="BeIR-quora-dev", + ) + results = ir_evaluator(model) + ''' + Information Retrieval Evaluation of the model on the BeIR-quora-dev dataset: + Queries: 5000 + Corpus: 17476 + + Score-Function: cosine + Accuracy@1: 96.26% + Accuracy@3: 99.38% + Accuracy@5: 99.74% + Accuracy@10: 99.94% + Precision@1: 96.26% + Precision@3: 43.01% + Precision@5: 27.66% + Precision@10: 14.58% + Recall@1: 82.93% + Recall@3: 96.28% + Recall@5: 98.38% + Recall@10: 99.55% + MRR@10: 0.9782 + NDCG@10: 0.9807 + MAP@100: 0.9732 + Score-Function: dot + Accuracy@1: 96.26% + Accuracy@3: 99.38% + Accuracy@5: 99.74% + Accuracy@10: 99.94% + Precision@1: 96.26% + Precision@3: 43.01% + Precision@5: 27.66% + Precision@10: 14.58% + Recall@1: 82.93% + Recall@3: 96.28% + Recall@5: 98.38% + Recall@10: 99.55% + MRR@10: 0.9782 + NDCG@10: 0.9807 + MAP@100: 0.9732 + ''' + print(ir_evaluator.primary_metric) + # => "BeIR-quora-dev_cosine_map@100" + print(results[ir_evaluator.primary_metric]) + # => 0.9732046108457585 """ def __init__( diff --git a/sentence_transformers/evaluation/MSEEvaluator.py b/sentence_transformers/evaluation/MSEEvaluator.py index 80ae29899..1cca98e6c 100644 --- a/sentence_transformers/evaluation/MSEEvaluator.py +++ b/sentence_transformers/evaluation/MSEEvaluator.py @@ -28,6 +28,38 @@ class MSEEvaluator(SentenceEvaluator): :param write_csv: Write results to CSV file :param truncate_dim: The dimension to truncate sentence embeddings to. `None` uses the model's current truncation dimension. Defaults to None. + + Example + :: + + from sentence_transformers import SentenceTransformer + from sentence_transformers.evaluation import MSEEvaluator + from datasets import load_dataset + + # Load a model + student_model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2') + teacher_model = SentenceTransformer('all-mpnet-base-v2') + + # Load any dataset with some texts + dataset = load_dataset("sentence-transformers/stsb", split="validation") + sentences = dataset["sentence1"] + dataset["sentence2"] + + # Given queries, a corpus and a mapping with relevant documents, the InformationRetrievalEvaluator computes different IR metrics. + mse_evaluator = MSEEvaluator( + source_sentences=sentences, + target_sentences=sentences, + teacher_model=teacher_model, + name="stsb-dev", + ) + results = mse_evaluator(student_model) + ''' + MSE evaluation (lower = better) on the stsb-dev dataset: + MSE (*100): 0.805045 + ''' + print(mse_evaluator.primary_metric) + # => "stsb-dev_negative_mse" + print(results[mse_evaluator.primary_metric]) + # => -0.8050452917814255 """ def __init__( @@ -60,7 +92,7 @@ def __init__( self.write_csv = write_csv self.primary_metric = "negative_mse" - def __call__(self, model: SentenceTransformer, output_path, epoch=-1, steps=-1) -> Dict[str, float]: + def __call__(self, model: SentenceTransformer, output_path: str = None, epoch=-1, steps=-1) -> Dict[str, float]: if epoch != -1: if steps == -1: out_txt = f" after epoch {epoch}" diff --git a/sentence_transformers/evaluation/ParaphraseMiningEvaluator.py b/sentence_transformers/evaluation/ParaphraseMiningEvaluator.py index 264cf7782..46f2f4fdb 100644 --- a/sentence_transformers/evaluation/ParaphraseMiningEvaluator.py +++ b/sentence_transformers/evaluation/ParaphraseMiningEvaluator.py @@ -18,6 +18,45 @@ class ParaphraseMiningEvaluator(SentenceEvaluator): Given a large set of sentences, this evaluator performs paraphrase (duplicate) mining and identifies the pairs with the highest similarity. It compare the extracted paraphrase pairs with a set of gold labels and computes the F1 score. + + Example + :: + + from datasets import load_dataset + from sentence_transformers.SentenceTransformer import SentenceTransformer + from sentence_transformers.evaluation import ParaphraseMiningEvaluator + + # Load a model + model = SentenceTransformer('all-mpnet-base-v2') + + # Load the Quora Duplicates Mining dataset + questions_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "questions", split="dev") + duplicates_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "duplicates", split="dev") + + # Create a mapping from qid to question & a list of duplicates (qid1, qid2) + qid_to_questions = dict(zip(questions_dataset["qid"], questions_dataset["question"])) + duplicates = list(zip(duplicates_dataset["qid1"], duplicates_dataset["qid2"])) + + # Initialize the paraphrase mining evaluator + paraphrase_mining_evaluator = ParaphraseMiningEvaluator( + sentences_map=qid_to_questions, + duplicates_list=duplicates, + name="quora-duplicates-dev", + ) + results = paraphrase_mining_evaluator(model) + ''' + Paraphrase Mining Evaluation of the model on the quora-duplicates-dev dataset: + Number of candidate pairs: 250564 + Average Precision: 56.51 + Optimal threshold: 0.8325 + Precision: 52.76 + Recall: 59.19 + F1: 55.79 + ''' + print(paraphrase_mining_evaluator.primary_metric) + # => "quora-duplicates-dev_average_precision" + print(results[paraphrase_mining_evaluator.primary_metric]) + # => 0.5650940787776353 """ def __init__( diff --git a/sentence_transformers/evaluation/TranslationEvaluator.py b/sentence_transformers/evaluation/TranslationEvaluator.py index 199d4cad1..603cbe744 100644 --- a/sentence_transformers/evaluation/TranslationEvaluator.py +++ b/sentence_transformers/evaluation/TranslationEvaluator.py @@ -18,6 +18,36 @@ class TranslationEvaluator(SentenceEvaluator): Given two sets of sentences in different languages, e.g. (en_1, en_2, en_3...) and (fr_1, fr_2, fr_3, ...), and assuming that fr_i is the translation of en_i. Checks if vec(en_i) has the highest similarity to vec(fr_i). Computes the accuracy in both directions + + Example + :: + + from sentence_transformers import SentenceTransformer + from sentence_transformers.evaluation import TranslationEvaluator + from datasets import load_dataset + + # Load a model + model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2') + + # Load a parallel sentences dataset + dataset = load_dataset("sentence-transformers/parallel-sentences-news-commentary", "en-nl", split="train[:1000]") + + # Initialize the TranslationEvaluator using the same texts from two languages + translation_evaluator = TranslationEvaluator( + source_sentences=dataset["english"], + target_sentences=dataset["non_english"], + name="news-commentary-en-nl", + ) + results = translation_evaluator(model) + ''' + Evaluating translation matching Accuracy of the model on the news-commentary-en-nl dataset: + Accuracy src2trg: 90.80 + Accuracy trg2src: 90.40 + ''' + print(translation_evaluator.primary_metric) + # => "news-commentary-en-nl_mean_accuracy" + print(results[translation_evaluator.primary_metric]) + # => 0.906 """ def __init__( diff --git a/sentence_transformers/evaluation/TripletEvaluator.py b/sentence_transformers/evaluation/TripletEvaluator.py index 3b9a908c2..86ff42ba8 100644 --- a/sentence_transformers/evaluation/TripletEvaluator.py +++ b/sentence_transformers/evaluation/TripletEvaluator.py @@ -17,7 +17,40 @@ class TripletEvaluator(SentenceEvaluator): """ Evaluate a model based on a triplet: (sentence, positive_example, negative_example). - Checks if distance(sentence, positive_example) < distance(sentence, negative_example). + Checks if distance(sentence, positive_example) < distance(sentence, negative_example). + + Example + :: + + from sentence_transformers import SentenceTransformer + from sentence_transformers.evaluation import TripletEvaluator + from datasets import load_dataset + + # Load a model + model = SentenceTransformer('all-mpnet-base-v2') + + # Load a dataset with (anchor, positive, negative) triplets + dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev") + + # Initialize the TripletEvaluator using anchors, positives, and negatives + triplet_evaluator = TripletEvaluator( + anchors=dataset[:1000]["anchor"], + positives=dataset[:1000]["positive"], + negatives=dataset[:1000]["negative"], + name="all-nli-dev", + ) + results = triplet_evaluator(model) + ''' + TripletEvaluator: Evaluating the model on the all-nli-dev dataset: + Accuracy Cosine Distance: 95.60 + Accuracy Dot Product: 4.40 + Accuracy Manhattan Distance: 95.40 + Accuracy Euclidean Distance: 95.60 + ''' + print(triplet_evaluator.primary_metric) + # => "all-nli-dev_max_accuracy" + print(results[triplet_evaluator.primary_metric]) + # => 0.956 """ def __init__( diff --git a/sentence_transformers/model_card_template.md b/sentence_transformers/model_card_template.md index f503c6770..bf1ac2896 100644 --- a/sentence_transformers/model_card_template.md +++ b/sentence_transformers/model_card_template.md @@ -91,7 +91,7 @@ print(embeddings.shape) # [{{ (predict_example or ["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."]) | length}}, {{ output_dimensionality | default(1024, true) }}] # Get the similarity scores for the embeddings -similarities = model.similarity(embeddings) +similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [{{ (predict_example or ["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."]) | length}}, {{ (predict_example or ["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."]) | length}}] ``` diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index f4bec3ff8..9803a1503 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -79,6 +79,11 @@ def __init__( if data_collator is None: data_collator = SentenceTransformerDataCollator(tokenize_fn=model.tokenize) + + if isinstance(train_dataset, dict) and not isinstance(train_dataset, DatasetDict): + train_dataset = DatasetDict(train_dataset) + if isinstance(eval_dataset, dict) and not isinstance(eval_dataset, Dataset): + eval_dataset = DatasetDict(eval_dataset) super().__init__( model=model, args=args,