Skip to content

Commit

Permalink
Update train_sts_seed_optimization with SentenceTransformerTrainer (#…
Browse files Browse the repository at this point in the history
…3092)

* update train_sts_seed_optimization with SentenceTransformerTrainer

* ruff lint

* fixes

* Add stopping callback, should work now

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
JINO-ROHIT and tomaarsen authored Dec 2, 2024
1 parent 6ce518a commit 39b6eae
Showing 1 changed file with 80 additions and 55 deletions.
135 changes: 80 additions & 55 deletions examples/training/data_augmentation/train_sts_seed_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,27 @@
python train_sts_seed_optimization.py bert-base-uncased 10 0.3
"""

import csv
import gzip
import logging
import math
import os
import pprint
import random
import sys

import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import TrainerCallback, TrainerControl, TrainerState

from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models, util
from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments

#### Just some code to print debug information to stdout
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
#### /print debug information to stdout


# Check if dataset exists. If not, download and extract it
sts_dataset_path = "datasets/stsbenchmark.tsv.gz"

if not os.path.exists(sts_dataset_path):
util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path)


# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-uncased"
Expand All @@ -60,6 +52,8 @@

logging.info(f"Train and Evaluate: {seed_count} Random Seeds")

scores_per_seed = {}

for seed in range(seed_count):
# Setting seed for all random initializations
logging.info(f"##### Seed {seed} #####")
Expand All @@ -69,7 +63,6 @@

# Read the dataset
train_batch_size = 16
num_epochs = 1
model_save_path = "output/bi-encoder/training_stsbenchmark_" + model_name + "/seed-" + str(seed)

# Use Hugging Face/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
Expand All @@ -85,49 +78,81 @@

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset")

train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn:
reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
for row in reader:
score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1
inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)

if row["split"] == "dev":
dev_samples.append(inp_example)
elif row["split"] == "test":
test_samples.append(inp_example)
else:
train_samples.append(inp_example)

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)
# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb
train_dataset = load_dataset("sentence-transformers/stsb", split="train")
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
logging.info(train_dataset)

logging.info("Read STSbenchmark dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev")
train_loss = losses.CosineSimilarityLoss(model=model)

# Configure the training. We skip evaluation in this example
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up
# 4. Define an evaluator for use during training.
dev_evaluator = EmbeddingSimilarityEvaluator(
sentences1=eval_dataset["sentence1"],
sentences2=eval_dataset["sentence2"],
scores=eval_dataset["score"],
main_similarity=SimilarityFunction.COSINE,
name="sts-dev",
show_progress_bar=True,
)

# Stopping and Evaluating after 30% of training data (less than 1 epoch)
# We find from (Dodge et al.) that 20-30% is often ideal for convergence of random seed
steps_per_epoch = math.ceil(len(train_dataloader) * stop_after)

logging.info(f"Warmup-steps: {warmup_steps}")

logging.info(f"Early-stopping: {int(stop_after * 100)}% of the training-data")
num_steps_until_stop = math.ceil(len(train_dataset) / train_batch_size * stop_after)

logging.info(f"Early-stopping: {stop_after:.2%} ({num_steps_until_stop} steps) of the training-data")

# 5. Create a Training Callback that stops training after a certain number of steps
class SeedTestingEarlyStoppingCallback(TrainerCallback):
def __init__(self, num_steps_until_stop: int):
self.num_steps_until_stop = num_steps_until_stop

def on_step_end(
self, args: SentenceTransformerTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
if state.global_step >= self.num_steps_until_stop:
control.should_training_stop = True

seed_testing_early_stopping_callback = SeedTestingEarlyStoppingCallback(num_steps_until_stop)

# 6. Define the training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=model_save_path,
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
# Optional tracking/debugging parameters:
logging_steps=num_steps_until_stop // 10, # Log every 10% of the steps
seed=seed,
run_name=f"sts-{seed}", # Will be used in W&B if `wandb` is installed
)

# Train the model
model.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
steps_per_epoch=steps_per_epoch,
evaluation_steps=1000,
warmup_steps=warmup_steps,
output_path=model_save_path,
# 7. Create the trainer & start training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
callbacks=[seed_testing_early_stopping_callback],
)
trainer.train()

# 8. With the partial train, evaluate this seed on the dev set
dev_score = dev_evaluator(model)
logging.info(f"Evaluator Scores for Seed {seed} after early stopping: {dev_score}")
primary_dev_score = dev_score[dev_evaluator.primary_metric]
scores_per_seed[seed] = primary_dev_score
scores_per_seed = dict(sorted(scores_per_seed.items(), key=lambda item: item[1], reverse=True))
logging.info(
f"Current {dev_evaluator.primary_metric} Scores per Seed:\n{pprint.pformat(scores_per_seed, sort_dicts=False)}"
)

# 9. Save the model for this seed
model.save_pretrained(model_save_path)

0 comments on commit 39b6eae

Please sign in to comment.