From 39b6eae9642ac874879502b53d3ae34aba222367 Mon Sep 17 00:00:00 2001 From: JINO ROHIT Date: Mon, 2 Dec 2024 15:38:33 +0530 Subject: [PATCH] Update train_sts_seed_optimization with SentenceTransformerTrainer (#3092) * update train_sts_seed_optimization with SentenceTransformerTrainer * ruff lint * fixes * Add stopping callback, should work now --------- Co-authored-by: Tom Aarsen --- .../train_sts_seed_optimization.py | 135 +++++++++++------- 1 file changed, 80 insertions(+), 55 deletions(-) diff --git a/examples/training/data_augmentation/train_sts_seed_optimization.py b/examples/training/data_augmentation/train_sts_seed_optimization.py index 36b12c483..7f5993c4a 100644 --- a/examples/training/data_augmentation/train_sts_seed_optimization.py +++ b/examples/training/data_augmentation/train_sts_seed_optimization.py @@ -23,35 +23,27 @@ python train_sts_seed_optimization.py bert-base-uncased 10 0.3 """ -import csv -import gzip import logging import math -import os +import pprint import random import sys import numpy as np import torch -from torch.utils.data import DataLoader +from datasets import load_dataset +from transformers import TrainerCallback, TrainerControl, TrainerState -from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models, util +from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator -from sentence_transformers.readers import InputExample +from sentence_transformers.similarity_functions import SimilarityFunction +from sentence_transformers.trainer import SentenceTransformerTrainer +from sentence_transformers.training_args import SentenceTransformerTrainingArguments #### Just some code to print debug information to stdout logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] ) -#### /print debug information to stdout - - -# Check if dataset exists. If not, download and extract it -sts_dataset_path = "datasets/stsbenchmark.tsv.gz" - -if not os.path.exists(sts_dataset_path): - util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) - # You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-uncased" @@ -60,6 +52,8 @@ logging.info(f"Train and Evaluate: {seed_count} Random Seeds") +scores_per_seed = {} + for seed in range(seed_count): # Setting seed for all random initializations logging.info(f"##### Seed {seed} #####") @@ -69,7 +63,6 @@ # Read the dataset train_batch_size = 16 - num_epochs = 1 model_save_path = "output/bi-encoder/training_stsbenchmark_" + model_name + "/seed-" + str(seed) # Use Hugging Face/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings @@ -85,49 +78,81 @@ model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) - # Convert the dataset to a DataLoader ready for training - logging.info("Read STSbenchmark train dataset") - - train_samples = [] - dev_samples = [] - test_samples = [] - with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: - reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) - for row in reader: - score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 - inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) - - if row["split"] == "dev": - dev_samples.append(inp_example) - elif row["split"] == "test": - test_samples.append(inp_example) - else: - train_samples.append(inp_example) - - train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) - train_loss = losses.CosineSimilarityLoss(model=model) + # 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb + train_dataset = load_dataset("sentence-transformers/stsb", split="train") + eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") + test_dataset = load_dataset("sentence-transformers/stsb", split="test") + logging.info(train_dataset) - logging.info("Read STSbenchmark dev dataset") - evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") + train_loss = losses.CosineSimilarityLoss(model=model) - # Configure the training. We skip evaluation in this example - warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up + # 4. Define an evaluator for use during training. + dev_evaluator = EmbeddingSimilarityEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + scores=eval_dataset["score"], + main_similarity=SimilarityFunction.COSINE, + name="sts-dev", + show_progress_bar=True, + ) # Stopping and Evaluating after 30% of training data (less than 1 epoch) # We find from (Dodge et al.) that 20-30% is often ideal for convergence of random seed - steps_per_epoch = math.ceil(len(train_dataloader) * stop_after) - - logging.info(f"Warmup-steps: {warmup_steps}") - - logging.info(f"Early-stopping: {int(stop_after * 100)}% of the training-data") + num_steps_until_stop = math.ceil(len(train_dataset) / train_batch_size * stop_after) + + logging.info(f"Early-stopping: {stop_after:.2%} ({num_steps_until_stop} steps) of the training-data") + + # 5. Create a Training Callback that stops training after a certain number of steps + class SeedTestingEarlyStoppingCallback(TrainerCallback): + def __init__(self, num_steps_until_stop: int): + self.num_steps_until_stop = num_steps_until_stop + + def on_step_end( + self, args: SentenceTransformerTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs + ): + if state.global_step >= self.num_steps_until_stop: + control.should_training_stop = True + + seed_testing_early_stopping_callback = SeedTestingEarlyStoppingCallback(num_steps_until_stop) + + # 6. Define the training arguments + args = SentenceTransformerTrainingArguments( + # Required parameter: + output_dir=model_save_path, + # Optional training parameters: + num_train_epochs=1, + per_device_train_batch_size=train_batch_size, + per_device_eval_batch_size=train_batch_size, + warmup_ratio=0.1, + fp16=True, # Set to False if you get an error that your GPU can't run on FP16 + bf16=False, # Set to True if you have a GPU that supports BF16 + # Optional tracking/debugging parameters: + logging_steps=num_steps_until_stop // 10, # Log every 10% of the steps + seed=seed, + run_name=f"sts-{seed}", # Will be used in W&B if `wandb` is installed + ) - # Train the model - model.fit( - train_objectives=[(train_dataloader, train_loss)], - evaluator=evaluator, - epochs=num_epochs, - steps_per_epoch=steps_per_epoch, - evaluation_steps=1000, - warmup_steps=warmup_steps, - output_path=model_save_path, + # 7. Create the trainer & start training + trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluator, + callbacks=[seed_testing_early_stopping_callback], + ) + trainer.train() + + # 8. With the partial train, evaluate this seed on the dev set + dev_score = dev_evaluator(model) + logging.info(f"Evaluator Scores for Seed {seed} after early stopping: {dev_score}") + primary_dev_score = dev_score[dev_evaluator.primary_metric] + scores_per_seed[seed] = primary_dev_score + scores_per_seed = dict(sorted(scores_per_seed.items(), key=lambda item: item[1], reverse=True)) + logging.info( + f"Current {dev_evaluator.primary_metric} Scores per Seed:\n{pprint.pformat(scores_per_seed, sort_dicts=False)}" ) + + # 9. Save the model for this seed + model.save_pretrained(model_save_path)