Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v3] Training refactor - MultiGPU, loss logging, bf16, etc. #2449

Merged
merged 90 commits into from
Apr 25, 2024

Conversation

tomaarsen
Copy link
Collaborator

@tomaarsen tomaarsen commented Jan 25, 2024

Resolves #2436, Resolves #1446

Hello!

Pull Request overview

  • Overhaul the training backend, relying on the transformers Trainer rather than a manual training loop.
  • Ideally, this preserves full fit backwards compatibility.
    • Multi-task training is still possible.
  • Using the SentenceTransformerTrainer would become the recommended training approach. It features:
    • Integrations with various third party applications like Weights and Biases, Tensorboard for loss & metric logging.
    • MultiGPU training (needs testing).
    • bf16 training.
    • Improved callback support.
    • Automatic model card generation.

Details

This PR expands on work from @matthewfranglen & updates it to work in much more situations, e.g. with different dataset formats. Notably, training now centers around:

  1. a training Dataset or DatasetDict. This class is much more suited for sharing & efficient modifications than lists/DataLoaders of InputExample instances. A Dataset can contain multiple text columns that will be fed in order to the corresponding loss function. So, if the loss expects (anchor, positive, negative) triplets, then your dataset should also have 3 columns. The names of these columns are irrelevant at this time. If there is a "label" column, it is treated separately, and used as the labels during training.
    A DatasetDict can be used to train with multiple datasets at once, e.g.:
    DatasetDict({
        multi_nli: Dataset({
            features: ['premise', 'hypothesis', 'label'],
            num_rows: 392702
        })
        snli: Dataset({
            features: ['snli_premise', 'hypothesis', 'label'],
            num_rows: 549367
        })
        stsb: Dataset({
            features: ['sentence1', 'sentence2', 'label'],
            num_rows: 5749
        })
    })
    When a DatasetDict is used, the loss parameter to the SentenceTransformerTrainer must also be a dictionary with these dataset keys, e.g.:
    {
        'multi_nli': SoftmaxLoss(...),
        'snli': SoftmaxLoss(...),
        'stsb': CosineSimilarityLoss(...),
    }
    Just like when training with multi-task in Sentence Transformers right now, these datasets are by default considered in a round-robin fashion. I intend to update this PR to eventually have a smarter sampler that samples from each dataset in proportion to the number of samples in each dataset.
  2. A loss function, or a dictionary of loss functions like described above. These loss functions do not require changes compared to before this PR.
  3. A SentenceTransformerTrainingArguments instance, subclass of a TrainingArguments instance. This powerful class controls the specific details of the training.
  4. An optional SentenceEvaluator instance. I am considering some breaking changes here, as these evaluators currently return floats on __call__, but I would like them to return a dictionary with all values that they've computed instead.
    After this PR, these instances either return a float, or a dictionary with metric keys and values. If the latter, the class must also defined evaluator.primary_metric so e.g. the "best model" checkpointing can be based on an evaluator score.
    Models can now be evaluated both on an evaluation dataset with some loss function and/or a SentenceEvaluator instance.
  5. The new SentenceTransformersTrainer instance. This instance is provided with a SentenceTransformer model, a SentenceTransformerTrainingArguments class, a SentenceEvaluator, a training and evaluation Dataset/DatasetDict and a loss function/dict of loss functions. Most of these parameters are optional. Once provided, all you have to do is call train().

Let's have a look at some example scripts of what training could look like in v3:

Example usage

Note: The old training loop (with DataLoaders, InputExamples and model.fit) still works, you're free to keep using it. But, the new training loop has more arguments, etc.

Example 1

This example finetunes nli-distilroberta-base-v2 using in-batch negatives on SICK, while evaluating on both the SICK evaluation set via loss and via the EmbeddingSimilarityEvaluator

Training Script
import datasets
from sentence_transformers import (
    SentenceTransformerTrainingArguments,
    SentenceTransformer,
    SentenceTransformerTrainer,
    losses,
    evaluation,
)

sick_ds = datasets.load_dataset("sick").select_columns(["sentence_A", "sentence_B", "label"])

training_args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    num_train_epochs=10,
    seed=33,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    learning_rate=2e-5,
    warmup_steps=100,
    # bf16=True,
    evaluation_strategy="steps",
    eval_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="spearman_cosine",
    greater_is_better=True,
)

model = SentenceTransformer("nli-distilroberta-base-v2")
tokenizer = model.tokenizer
# loss = losses.CosineSimilarityLoss(model)
loss = losses.MultipleNegativesRankingLoss(model)
evaluator = evaluation.EmbeddingSimilarityEvaluator(
    sick_ds["validation"]["sentence_A"],
    sick_ds["validation"]["sentence_B"],
    sick_ds["validation"]["label"],
    main_similarity=evaluation.SimilarityFunction.COSINE,
)


trainer = SentenceTransformerTrainer(
    model=model,
    evaluator=evaluator,
    args=training_args,
    train_dataset=sick_ds["train"],
    eval_dataset=sick_ds["validation"],
    loss=loss,
)
trainer.train()
Training Logs
wandb: Currently logged in as: tomaarsen. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.2
wandb: Run data is saved locally in C:\code\sentence-transformers\wandb\run-20240125_162938-w24rm86v
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run giddy-fire-2366
wandb:  View project at https://wandb.ai/tomaarsen/huggingface
wandb:  View run at https://wandb.ai/tomaarsen/huggingface/runs/w24rm86v
{'eval_loss': 2.0557382106781006, 'eval_pearson_cosine': -0.3753512001269012, 'eval_spearman_cosine': -0.5005078933275353, 'eval_pearson_manhattan': -0.47722573527673684, 'eval_spearman_manhattan': -0.5167731987749575, 'eval_pearson_euclidean': -0.4704088166967001, 'eval_spearman_euclidean': -0.5098770298516158, 'eval_pearson_dot': -0.3201333973487447, 'eval_spearman_dot': -0.3861894670367086, 'eval_pearson_max': -0.3201333973487447, 'eval_spearman_max': -0.3861894670367086, 'eval_runtime': 0.8787, 'eval_samples_per_second': 563.338, 'eval_steps_per_second': 2.276, 'epoch': 0.56}
{'eval_loss': 1.817664384841919, 'eval_pearson_cosine': -0.24721726065042743, 'eval_spearman_cosine': -0.3644746386390808, 'eval_pearson_manhattan': -0.3356430032744311, 'eval_spearman_manhattan': -0.3807555281459245, 'eval_pearson_euclidean': -0.32878220335299, 'eval_spearman_euclidean': -0.3705844804499084, 'eval_pearson_dot': -0.21022034392122274, 'eval_spearman_dot': -0.2605570136060241, 'eval_pearson_max': -0.21022034392122274, 'eval_spearman_max': -0.2605570136060241, 'eval_runtime': 0.8976, 'eval_samples_per_second': 551.491, 'eval_steps_per_second': 2.228, 'epoch': 1.11}
{'eval_loss': 1.5894440412521362, 'eval_pearson_cosine': -0.18499820961635288, 'eval_spearman_cosine': -0.28155007742302574, 'eval_pearson_manhattan': -0.2395728163369895, 'eval_spearman_manhattan': -0.28307508895210926, 'eval_pearson_euclidean': -0.23533046586854228, 'eval_spearman_euclidean': -0.2785379013663104, 'eval_pearson_dot': -0.1759999454506243, 'eval_spearman_dot': -0.23249858250624952, 'eval_pearson_max': -0.1759999454506243, 'eval_spearman_max': -0.23249858250624952, 'eval_runtime': 0.8883, 'eval_samples_per_second': 557.213, 'eval_steps_per_second': 2.251, 'epoch': 1.67}
{'eval_loss': 1.482645034790039, 'eval_pearson_cosine': -0.15276100339493737, 'eval_spearman_cosine': -0.22813014802104117, 'eval_pearson_manhattan': -0.18319135744431875, 'eval_spearman_manhattan': -0.22689499999719367, 'eval_pearson_euclidean': -0.18065463740981458, 'eval_spearman_euclidean': -0.2246517859582089, 'eval_pearson_dot': -0.13941230741897542, 'eval_spearman_dot': -0.1926149671175988, 'eval_pearson_max': -0.13941230741897542, 'eval_spearman_max': -0.1926149671175988, 'eval_runtime': 0.8666, 'eval_samples_per_second': 571.185, 'eval_steps_per_second': 2.308, 'epoch': 2.22}
{'eval_loss': 1.478642225265503, 'eval_pearson_cosine': -0.13162227944894606, 'eval_spearman_cosine': -0.18599129660470928, 'eval_pearson_manhattan': -0.1466982899028075, 'eval_spearman_manhattan': -0.1873769421166824, 'eval_pearson_euclidean': -0.14433414954196194, 'eval_spearman_euclidean': -0.1839163904074744, 'eval_pearson_dot': -0.09070207777437565, 'eval_spearman_dot': -0.13684585207256467, 'eval_pearson_max': -0.09070207777437565, 'eval_spearman_max': -0.13684585207256467, 'eval_runtime': 0.8554, 'eval_samples_per_second': 578.69, 'eval_steps_per_second': 2.338, 'epoch': 2.78}
{'eval_loss': 1.3875240087509155, 'eval_pearson_cosine': -0.13270682993051464, 'eval_spearman_cosine': -0.18586929568238258, 'eval_pearson_manhattan': -0.14390943834255362, 'eval_spearman_manhattan': -0.18597259573340372, 'eval_pearson_euclidean': -0.14226529421604445, 'eval_spearman_euclidean': -0.18380819250920663, 'eval_pearson_dot': -0.08998810631488594, 'eval_spearman_dot': -0.1380128754937991, 'eval_pearson_max': -0.08998810631488594, 'eval_spearman_max': -0.1380128754937991, 'eval_runtime': 0.8821, 'eval_samples_per_second': 561.184, 'eval_steps_per_second': 2.267, 'epoch': 3.33}
{'eval_loss': 1.3624027967453003, 'eval_pearson_cosine': -0.12331743470388387, 'eval_spearman_cosine': -0.16960087342307598, 'eval_pearson_manhattan': -0.12815922903595642, 'eval_spearman_manhattan': -0.16951449320799797, 'eval_pearson_euclidean': -0.12672328147478792, 'eval_spearman_euclidean': -0.16779312253020615, 'eval_pearson_dot': -0.08434422205700165, 'eval_spearman_dot': -0.12757645352878041, 'eval_pearson_max': -0.08434422205700165, 'eval_spearman_max': -0.12757645352878041, 'eval_runtime': 0.9083, 'eval_samples_per_second': 544.995, 'eval_steps_per_second': 2.202, 'epoch': 3.89}
{'eval_loss': 1.3132266998291016, 'eval_pearson_cosine': -0.12224681574277577, 'eval_spearman_cosine': -0.16913201586391538, 'eval_pearson_manhattan': -0.12654365799009917, 'eval_spearman_manhattan': -0.16713547522262762, 'eval_pearson_euclidean': -0.1253854784944141, 'eval_spearman_euclidean': -0.1664333020309883, 'eval_pearson_dot': -0.08280671016392793, 'eval_spearman_dot': -0.12137444313794694, 'eval_pearson_max': -0.08280671016392793, 'eval_spearman_max': -0.12137444313794694, 'eval_runtime': 0.9057, 'eval_samples_per_second': 546.542, 'eval_steps_per_second': 2.208, 'epoch': 4.44}
{'eval_loss': 1.2858315706253052, 'eval_pearson_cosine': -0.11658100659135247, 'eval_spearman_cosine': -0.15828016740060913, 'eval_pearson_manhattan': -0.11478957785394783, 'eval_spearman_manhattan': -0.15680146279094884, 'eval_pearson_euclidean': -0.11401371257012077, 'eval_spearman_euclidean': -0.15521767709490505, 'eval_pearson_dot': -0.09937152050147363, 'eval_spearman_dot': -0.1394239007596869, 'eval_pearson_max': -0.09937152050147363, 'eval_spearman_max': -0.1394239007596869, 'eval_runtime': 0.8645, 'eval_samples_per_second': 572.571, 'eval_steps_per_second': 2.313, 'epoch': 5.0}
{'eval_loss': 1.2543511390686035, 'eval_pearson_cosine': -0.110348365014459, 'eval_spearman_cosine': -0.14773910960804748, 'eval_pearson_manhattan': -0.1069052682217499, 'eval_spearman_manhattan': -0.14631027398853533, 'eval_pearson_euclidean': -0.10584470148705244, 'eval_spearman_euclidean': -0.14485694913278976, 'eval_pearson_dot': -0.08572066430992958, 'eval_spearman_dot': -0.1231073905455946, 'eval_pearson_max': -0.08572066430992958, 'eval_spearman_max': -0.1231073905455946, 'eval_runtime': 0.8629, 'eval_samples_per_second': 573.677, 'eval_steps_per_second': 2.318, 'epoch': 5.56}
{'eval_loss': 1.2348754405975342, 'eval_pearson_cosine': -0.11138877545858725, 'eval_spearman_cosine': -0.14934872031684662, 'eval_pearson_manhattan': -0.10805216194950466, 'eval_spearman_manhattan': -0.14867726998520928, 'eval_pearson_euclidean': -0.10690479169242902, 'eval_spearman_euclidean': -0.14702446916887113, 'eval_pearson_dot': -0.08223201017485372, 'eval_spearman_dot': -0.12037439178194063, 'eval_pearson_max': -0.08223201017485372, 'eval_spearman_max': -0.12037439178194063, 'eval_runtime': 0.916, 'eval_samples_per_second': 540.391, 'eval_steps_per_second': 2.183, 'epoch': 6.11}
{'eval_loss': 1.2055180072784424, 'eval_pearson_cosine': -0.10806545807578179, 'eval_spearman_cosine': -0.1434103031596533, 'eval_pearson_manhattan': -0.10041631171809369, 'eval_spearman_manhattan': -0.13974003453651881, 'eval_pearson_euclidean': -0.09931433498355985, 'eval_spearman_euclidean': -0.1384042580146938, 'eval_pearson_dot': -0.09247201436069705, 'eval_spearman_dot': -0.13313328385957238, 'eval_pearson_max': -0.09247201436069705, 'eval_spearman_max': -0.13313328385957238, 'eval_runtime': 0.9172, 'eval_samples_per_second': 539.704, 'eval_steps_per_second': 2.181, 'epoch': 6.67}
{'eval_loss': 1.1998401880264282, 'eval_pearson_cosine': -0.10611831391316866, 'eval_spearman_cosine': -0.14052101854294582, 'eval_pearson_manhattan': -0.09753730834313956, 'eval_spearman_manhattan': -0.1366775442308147, 'eval_pearson_euclidean': -0.0969604381725838, 'eval_spearman_euclidean': -0.13714105868388798, 'eval_pearson_dot': -0.0886595697921101, 'eval_spearman_dot': -0.12854533676594412, 'eval_pearson_max': -0.0886595697921101, 'eval_spearman_max': -0.12854533676594412, 'eval_runtime': 0.8862, 'eval_samples_per_second': 558.587, 'eval_steps_per_second': 2.257, 'epoch': 7.22}
{'eval_loss': 1.1815522909164429, 'eval_pearson_cosine': -0.10563221719744154, 'eval_spearman_cosine': -0.1391255773364793, 'eval_pearson_manhattan': -0.0955089359917892, 'eval_spearman_manhattan': -0.13595444387366676, 'eval_pearson_euclidean': -0.09521087645713486, 'eval_spearman_euclidean': -0.13508218180491505, 'eval_pearson_dot': -0.09127287577283417, 'eval_spearman_dot': -0.13276861686911418, 'eval_pearson_max': -0.09127287577283417, 'eval_spearman_max': -0.13276861686911418, 'eval_runtime': 0.8601, 'eval_samples_per_second': 575.495, 'eval_steps_per_second': 2.325, 'epoch': 7.78}
{'eval_loss': 1.176088571548462, 'eval_pearson_cosine': -0.10316310694349418, 'eval_spearman_cosine': -0.13499802788404006, 'eval_pearson_manhattan': -0.09195560989180353, 'eval_spearman_manhattan': -0.13107796905132432, 'eval_pearson_euclidean': -0.09135332597230716, 'eval_spearman_euclidean': -0.13142170887627397, 'eval_pearson_dot': -0.09421083390919384, 'eval_spearman_dot': -0.1364905355177592, 'eval_pearson_max': -0.09135332597230716, 'eval_spearman_max': -0.13107796905132432, 'eval_runtime': 0.8593, 'eval_samples_per_second': 576.069, 'eval_steps_per_second': 2.328, 'epoch': 8.33}
{'eval_loss': 1.171457290649414, 'eval_pearson_cosine': -0.10361218999499006, 'eval_spearman_cosine': -0.1369139766751777, 'eval_pearson_manhattan': -0.09236153445551434, 'eval_spearman_manhattan': -0.13166927279165216, 'eval_pearson_euclidean': -0.09158588010184697, 'eval_spearman_euclidean': -0.131754762489049, 'eval_pearson_dot': -0.0942246469738964, 'eval_spearman_dot': -0.13543304577131443, 'eval_pearson_max': -0.09158588010184697, 'eval_spearman_max': -0.13166927279165216, 'eval_runtime': 0.9345, 'eval_samples_per_second': 529.706, 'eval_steps_per_second': 2.14, 'epoch': 8.89}
{'eval_loss': 1.172788381576538, 'eval_pearson_cosine': -0.10263800057098856, 'eval_spearman_cosine': -0.13656177693225652, 'eval_pearson_manhattan': -0.09181227752250928, 'eval_spearman_manhattan': -0.1313651610035167, 'eval_pearson_euclidean': -0.09098441573129593, 'eval_spearman_euclidean': -0.13108153112204918, 'eval_pearson_dot': -0.09341017697327209, 'eval_spearman_dot': -0.13326508047639243, 'eval_pearson_max': -0.09098441573129593, 'eval_spearman_max': -0.13108153112204918, 'eval_runtime': 0.9382, 'eval_samples_per_second': 527.628, 'eval_steps_per_second': 2.132, 'epoch': 9.44}
{'eval_loss': 1.1713230609893799, 'eval_pearson_cosine': -0.10235168311680716, 'eval_spearman_cosine': -0.13579192439684473, 'eval_pearson_manhattan': -0.09135034055161731, 'eval_spearman_manhattan': -0.13077296674550765, 'eval_pearson_euclidean': -0.09054475658492558, 'eval_spearman_euclidean': -0.13079166761681318, 'eval_pearson_dot': -0.09255427649321141, 'eval_spearman_dot': -0.13230287612183786, 'eval_pearson_max': -0.09054475658492558, 'eval_spearman_max': -0.13077296674550765, 'eval_runtime': 0.8791, 'eval_samples_per_second': 563.104, 'eval_steps_per_second': 2.275, 'epoch': 10.0}
{'train_runtime': 68.4966, 'train_samples_per_second': 648.062, 'train_steps_per_second': 2.628, 'train_loss': 1.3919714185926648, 'epoch': 10.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 180/180 [01:03<00:00,  2.82it/s] 
wandb: \ 0.044 MB of 0.044 MB uploaded (0.004 MB deduped)
wandb: Run history:
wandb:                      eval/loss █▆▄▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁
wandb:            eval/pearson_cosine ▁▄▆▇▇▇▇▇██████████
wandb:               eval/pearson_dot ▁▄▅▆████▇█████████
wandb:         eval/pearson_euclidean ▁▄▅▆▇▇▇▇██████████
wandb:         eval/pearson_manhattan ▁▄▅▆▇▇▇▇██████████
wandb:               eval/pearson_max ▁▄▅▆████▇█████████
wandb:                   eval/runtime ▃▅▄▂▁▃▅▅▂▂▆▆▄▁▁██▃
wandb:        eval/samples_per_second ▆▄▅▇█▆▃▄▇▇▃▃▅██▁▁▆
wandb:           eval/spearman_cosine ▁▄▅▆▇▇▇▇██████████
wandb:              eval/spearman_dot ▁▄▅▆████▇█████████
wandb:        eval/spearman_euclidean ▁▄▅▆▇▇▇▇██████████
wandb:        eval/spearman_manhattan ▁▃▅▆▇▇▇▇██████████
wandb:              eval/spearman_max ▁▄▅▆████▇█████████
wandb:          eval/steps_per_second ▆▄▅▇█▆▃▄▇▇▃▃▅██▁▁▆
wandb:                    train/epoch ▁▁▂▂▃▃▃▄▄▅▅▆▆▆▇▇███
wandb:              train/global_step ▁▁▂▂▃▃▃▄▄▅▅▆▆▆▇▇███
wandb:               train/total_flos ▁
wandb:               train/train_loss ▁
wandb:            train/train_runtime ▁
wandb: train/train_samples_per_second ▁
wandb:   train/train_steps_per_second ▁
wandb:
wandb: Run summary:
wandb:                      eval/loss 1.17132
wandb:            eval/pearson_cosine -0.10235
wandb:               eval/pearson_dot -0.09255
wandb:         eval/pearson_euclidean -0.09054
wandb:         eval/pearson_manhattan -0.09135
wandb:               eval/pearson_max -0.09054
wandb:                   eval/runtime 0.8791
wandb:        eval/samples_per_second 563.104
wandb:           eval/spearman_cosine -0.13579
wandb:              eval/spearman_dot -0.1323
wandb:        eval/spearman_euclidean -0.13079
wandb:        eval/spearman_manhattan -0.13077
wandb:              eval/spearman_max -0.13077
wandb:          eval/steps_per_second 2.275
wandb:                    train/epoch 10.0
wandb:              train/global_step 180
wandb:               train/total_flos 0.0
wandb:               train/train_loss 1.39197
wandb:            train/train_runtime 68.4966
wandb: train/train_samples_per_second 648.062
wandb:   train/train_steps_per_second 2.628
wandb:
wandb:  View run giddy-fire-2366 at: https://wandb.ai/tomaarsen/huggingface/runs/w24rm86v
wandb:  View job at https://wandb.ai/tomaarsen/huggingface/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzMjYxMDU1Mw==/version_details/v13
wandb: Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: .\wandb\run-20240125_162938-w24rm86v\logs
Training Notes

W&B correctly logged the Spearman & Pearson correlation coefficients, as well as the evaluation loss
image
image

Example 2

This example finetunes nli-distilroberta-base-v2 using in-batch negatives on SNLI, while evaluating on both the SNLI development set via loss and SICK evaluation set via the EmbeddingSimilarityEvaluator

Training Script
from collections import defaultdict
import datasets
from datasets import Dataset
from sentence_transformers import (
    SentenceTransformerTrainingArguments,
    SentenceTransformer,
    SentenceTransformerTrainer,
    losses,
    evaluation,
)

def to_triplets(dataset):
    premises = defaultdict(dict)
    for sample in dataset:
        premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
    queries = []
    positives = []
    negatives = []
    for premise, sentences in premises.items():
        if 0 in sentences and 2 in sentences:
            queries.append(premise)
            positives.append(sentences[0]) # <- entailment
            negatives.append(sentences[2]) # <- contradiction
    return Dataset.from_dict({
        "query": queries,
        "positive": positives,
        "negative": negatives,
    })

snli_ds = datasets.load_dataset("snli")
snli_ds = datasets.DatasetDict({
    "train": to_triplets(snli_ds["train"]),
    "validation": to_triplets(snli_ds["validation"]),
    "test": to_triplets(snli_ds["test"]),
})
sick_ds = datasets.load_dataset("sick")

training_args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    num_train_epochs=1,
    seed=33,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    learning_rate=2e-5,
    warmup_steps=100,
    bf16=True,
    logging_steps=100,
    evaluation_strategy="steps",
    eval_steps=200,
    save_steps=2000,
    save_total_limit=2,
    metric_for_best_model="spearman_cosine",
    greater_is_better=True,
)

model = SentenceTransformer("nli-distilroberta-base-v2")
loss = losses.MultipleNegativesRankingLoss(model)
evaluator = evaluation.EmbeddingSimilarityEvaluator(
    sick_ds["validation"]["sentence_A"],
    sick_ds["validation"]["sentence_B"],
    sick_ds["validation"]["label"],
    main_similarity=evaluation.SimilarityFunction.COSINE,
)

trainer = SentenceTransformerTrainer(
    model=model,
    evaluator=evaluator,
    args=training_args,
    train_dataset=snli_ds["train"],
    eval_dataset=snli_ds["validation"],
    loss=loss,
)
trainer.train()
Training Logs
wandb: Currently logged in as: tomaarsen. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.2
wandb: Run data is saved locally in C:\code\sentence-transformers\wandb\run-20240125_163910-c5qa7i3e
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run crimson-fog-2367
wandb:  View project at https://wandb.ai/tomaarsen/huggingface
wandb:  View run at https://wandb.ai/tomaarsen/huggingface/runs/c5qa7i3e
{'loss': 1.0551, 'learning_rate': 2e-05, 'epoch': 0.17}
{'loss': 1.0411, 'learning_rate': 1.58592132505176e-05, 'epoch': 0.34}
{'eval_loss': 0.9336137175559998, 'eval_pearson_cosine': -0.4100684362249894, 'eval_spearman_cosine': -0.523609703013658, 'eval_pearson_manhattan': -0.5110499009157363, 'eval_spearman_manhattan': -0.5437162564790085, 'eval_pearson_euclidean': -0.5017938611550188, 'eval_spearman_euclidean': -0.5332353086299289, 'eval_pearson_dot': -0.3409841885470256, 'eval_spearman_dot': -0.39745362518641814, 'eval_pearson_max': -0.3409841885470256, 'eval_spearman_max': -0.39745362518641814, 'eval_runtime': 2.9673, 'eval_samples_per_second': 991.817, 'eval_steps_per_second': 4.044, 'epoch': 0.34}
{'loss': 0.9867, 'learning_rate': 1.1718426501035198e-05, 'epoch': 0.51}
{'loss': 1.0002, 'learning_rate': 7.577639751552796e-06, 'epoch': 0.69}
{'eval_loss': 0.8982349038124084, 'eval_pearson_cosine': -0.46315280772669326, 'eval_spearman_cosine': -0.5587441876083802, 'eval_pearson_manhattan': -0.5491469271882455, 'eval_spearman_manhattan': -0.5764067152976315, 'eval_pearson_euclidean': -0.5415142463941801, 'eval_spearman_euclidean': -0.5675977143950363, 'eval_pearson_dot': -0.411831727580533, 'eval_spearman_dot': -0.4797610576493901, 'eval_pearson_max': -0.411831727580533, 'eval_spearman_max': -0.4797610576493901, 'eval_runtime': 2.7469, 'eval_samples_per_second': 1071.407, 'eval_steps_per_second': 4.369, 'epoch': 0.69}
{'loss': 0.9891, 'learning_rate': 3.436853002070394e-06, 'epoch': 0.86}
{'train_runtime': 158.9256, 'train_samples_per_second': 938.458, 'train_steps_per_second': 3.668, 'train_loss': 1.0075248862212558, 'epoch': 1.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 583/583 [02:33<00:00,  3.79it/s] 
wandb: - 0.040 MB of 0.040 MB uploaded (0.004 MB deduped)
wandb: Run history:
wandb:                      eval/loss █▁
wandb:            eval/pearson_cosine █▁
wandb:               eval/pearson_dot █▁
wandb:         eval/pearson_euclidean █▁
wandb:         eval/pearson_manhattan █▁
wandb:               eval/pearson_max █▁
wandb:                   eval/runtime █▁
wandb:        eval/samples_per_second ▁█
wandb:           eval/spearman_cosine █▁
wandb:              eval/spearman_dot █▁
wandb:        eval/spearman_euclidean █▁
wandb:        eval/spearman_manhattan █▁
wandb:              eval/spearman_max █▁
wandb:          eval/steps_per_second ▁█
wandb:                    train/epoch ▁▂▂▄▅▅▇█
wandb:              train/global_step ▁▂▂▄▅▅▇█
wandb:            train/learning_rate █▆▅▃▁
wandb:                     train/loss █▇▁▂▁
wandb:               train/total_flos ▁
wandb:               train/train_loss ▁
wandb:            train/train_runtime ▁
wandb: train/train_samples_per_second ▁
wandb:   train/train_steps_per_second ▁
wandb:
wandb: Run summary:
wandb:                      eval/loss 0.89823
wandb:            eval/pearson_cosine -0.46315
wandb:               eval/pearson_dot -0.41183
wandb:         eval/pearson_euclidean -0.54151
wandb:         eval/pearson_manhattan -0.54915
wandb:               eval/pearson_max -0.41183
wandb:                   eval/runtime 2.7469
wandb:        eval/samples_per_second 1071.407
wandb:           eval/spearman_cosine -0.55874
wandb:              eval/spearman_dot -0.47976
wandb:        eval/spearman_euclidean -0.5676
wandb:        eval/spearman_manhattan -0.57641
wandb:              eval/spearman_max -0.47976
wandb:          eval/steps_per_second 4.369
wandb:                    train/epoch 1.0
wandb:              train/global_step 583
wandb:            train/learning_rate 0.0
wandb:                     train/loss 0.9891
wandb:               train/total_flos 0.0
wandb:               train/train_loss 1.00752
wandb:            train/train_runtime 158.9256
wandb: train/train_samples_per_second 938.458
wandb:   train/train_steps_per_second 3.668
wandb:
wandb:  View run crimson-fog-2367 at: https://wandb.ai/tomaarsen/huggingface/runs/c5qa7i3e
wandb:  View job at https://wandb.ai/tomaarsen/huggingface/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzMjY0NzQxMA==/version_details/v5
wandb: Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: .\wandb\run-20240125_163910-c5qa7i3e\logs
Training Notes

W&B captures the training & evaluation loss lowering. It also tracked runtime, learning rates, steps, & all evaluator values regarding the Spearman/Pearson correlation coefficient.
image

Example 3

This example finetunes microsoft/mpnet-base using in-batch negatives on AllNLI, a combination of SNLI and MultiNLI, while evaluating on both the AllNLI evaluation set via loss and on STSBenchmark development set via the EmbeddingSimilarityEvaluator. It uses a cached variant of the MultiNegativeRankingLoss (in-batch negatives loss), allowing us to set the batch size really high while using much smaller mini-batches. This gives much superior performance to the final model. After training, the model performance is tested against the STSBenchmark test set.

Training Script
from collections import defaultdict
from typing import Dict
import datasets
from datasets import Dataset
from transformers import EvalPrediction
from sentence_transformers import (
    SentenceTransformerTrainingArguments,
    SentenceTransformer,
    SentenceTransformerTrainer,
    losses,
    evaluation,
)
from sentence_transformers.models import Transformer, Pooling

def to_triplets(dataset):
    premises = defaultdict(dict)
    for sample in dataset:
        premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
    queries = []
    positives = []
    negatives = []
    for premise, sentences in premises.items():
        if 0 in sentences and 2 in sentences:
            queries.append(premise)
            positives.append(sentences[0]) # <- entailment
            negatives.append(sentences[2]) # <- contradiction
    return Dataset.from_dict({
        "anchor": queries,
        "positive": positives,
        "negative": negatives,
    })

snli_ds = datasets.load_dataset("snli")
snli_ds = datasets.DatasetDict({
    "train": to_triplets(snli_ds["train"]),
    "validation": to_triplets(snli_ds["validation"]),
    "test": to_triplets(snli_ds["test"]),
})
multi_nli_ds = datasets.load_dataset("multi_nli")
multi_nli_ds = datasets.DatasetDict({
    "train": to_triplets(multi_nli_ds["train"]),
    "validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
})

all_nli_ds = datasets.DatasetDict({
    "train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]),
    "validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),
    "test": snli_ds["test"]
})

stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")

training_args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    num_train_epochs=1,
    seed=33,
    per_device_train_batch_size=2048,
    per_device_eval_batch_size=2048,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    bf16=True,
    logging_steps=1,
    evaluation_strategy="steps",
    eval_steps=10,
    save_steps=10,
    save_total_limit=2,
    metric_for_best_model="spearman_cosine",
    greater_is_better=True,
)

transformer = Transformer("microsoft/mpnet-base", max_seq_length=384)
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[transformer, pooling])

tokenizer = model.tokenizer
loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)
dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
    stsb_dev["sentence1"],
    stsb_dev["sentence2"],
    [score / 5 for score in stsb_dev["score"]],
    main_similarity=evaluation.SimilarityFunction.COSINE,
    name="sts-dev",
)

trainer = SentenceTransformerTrainer(
    model=model,
    evaluator=dev_evaluator,
    args=training_args,
    train_dataset=all_nli_ds["train"],
    eval_dataset=all_nli_ds["validation"],
    loss=loss,
)
trainer.train()

test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
    stsb_test["sentence1"],
    stsb_test["sentence2"],
    [score / 5 for score in stsb_test["similarity_score"]],
    main_similarity=evaluation.SimilarityFunction.COSINE,
    name="sts-test",
)
results = test_evaluator(model)
print(results)
Training Logs
Some weights of MPNetModel were not initialized from the model checkpoint at microsoft/mpnet-base and are newly initialized: ['mpnet.pooler.dense.weight', 'mpnet.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
wandb: Currently logged in as: tomaarsen. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.2
wandb: Run data is saved locally in C:\code\sentence-transformers\wandb\run-20240125_164316-yku359rr
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run dashing-water-2368
wandb:  View project at https://wandb.ai/tomaarsen/huggingface
wandb:  View run at https://wandb.ai/tomaarsen/huggingface/runs/yku359rr
{'loss': 7.7962, 'learning_rate': 1.4285714285714286e-06, 'epoch': 0.01}
{'loss': 7.858, 'learning_rate': 2.8571428571428573e-06, 'epoch': 0.01}
{'loss': 7.5358, 'learning_rate': 4.2857142857142855e-06, 'epoch': 0.02}
{'loss': 7.5463, 'learning_rate': 5.7142857142857145e-06, 'epoch': 0.03}
{'loss': 7.5718, 'learning_rate': 7.1428571428571436e-06, 'epoch': 0.04}
{'loss': 7.4983, 'learning_rate': 8.571428571428571e-06, 'epoch': 0.04}
{'loss': 7.3537, 'learning_rate': 1e-05, 'epoch': 0.05}
{'loss': 7.2904, 'learning_rate': 1.1428571428571429e-05, 'epoch': 0.06}
{'loss': 6.9924, 'learning_rate': 1.2857142857142859e-05, 'epoch': 0.07}
{'loss': 6.793, 'learning_rate': 1.4285714285714287e-05, 'epoch': 0.07}
{'eval_loss': 5.983082294464111, 'eval_pearson_cosine': 0.6945335472192309, 'eval_spearman_cosine': 0.7335610086696948, 'eval_pearson_manhattan': 0.7746120404670955, 'eval_spearman_manhattan': 0.7742670799050834, 'eval_pearson_euclidean': 0.729024114319253, 'eval_spearman_euclidean': 0.7356061671381438, 'eval_pearson_dot': 0.20475617702742557, 'eval_spearman_dot': 0.19217080148744858, 'eval_pearson_max': 0.7746120404670955, 'eval_spearman_max': 0.7742670799050834, 'eval_runtime': 69.87, 'eval_samples_per_second': 84.085, 'eval_steps_per_second': 0.043, 'epoch': 0.07}
{'loss': 6.3333, 'learning_rate': 1.5714285714285715e-05, 'epoch': 0.08}
{'loss': 5.9373, 'learning_rate': 1.7142857142857142e-05, 'epoch': 0.09}
{'loss': 5.4522, 'learning_rate': 1.8571428571428575e-05, 'epoch': 0.1}
{'loss': 4.9285, 'learning_rate': 2e-05, 'epoch': 0.1}
{'loss': 4.2263, 'learning_rate': 1.9836065573770492e-05, 'epoch': 0.11}
{'loss': 3.829, 'learning_rate': 1.9672131147540985e-05, 'epoch': 0.12}
{'loss': 3.5179, 'learning_rate': 1.9508196721311475e-05, 'epoch': 0.12}
{'loss': 3.36, 'learning_rate': 1.934426229508197e-05, 'epoch': 0.13}
{'loss': 3.1183, 'learning_rate': 1.918032786885246e-05, 'epoch': 0.14}
{'loss': 2.9756, 'learning_rate': 1.9016393442622952e-05, 'epoch': 0.15}
{'eval_loss': 2.2412540912628174, 'eval_pearson_cosine': 0.8264874720634756, 'eval_spearman_cosine': 0.8233842685619902, 'eval_pearson_manhattan': 0.8327103058194085, 'eval_spearman_manhattan': 0.8256104124450515, 'eval_pearson_euclidean': 0.8278859686864726, 'eval_spearman_euclidean': 0.8211117687750934, 'eval_pearson_dot': 0.8075555480779578, 'eval_spearman_dot': 0.7993778000589138, 'eval_pearson_max': 0.8327103058194085, 'eval_spearman_max': 0.8256104124450515, 'eval_runtime': 49.6763, 'eval_samples_per_second': 118.266, 'eval_steps_per_second': 0.06, 'epoch': 0.15}
{'loss': 2.8033, 'learning_rate': 1.8852459016393446e-05, 'epoch': 0.15}
{'loss': 2.7924, 'learning_rate': 1.8688524590163936e-05, 'epoch': 0.16}
{'loss': 2.6167, 'learning_rate': 1.852459016393443e-05, 'epoch': 0.17}
{'loss': 2.5549, 'learning_rate': 1.836065573770492e-05, 'epoch': 0.18}
{'loss': 2.3468, 'learning_rate': 1.8196721311475413e-05, 'epoch': 0.18}
{'loss': 2.2377, 'learning_rate': 1.8032786885245903e-05, 'epoch': 0.19}
{'loss': 2.3558, 'learning_rate': 1.7868852459016393e-05, 'epoch': 0.2}
{'loss': 2.3038, 'learning_rate': 1.7704918032786887e-05, 'epoch': 0.21}
{'loss': 2.1569, 'learning_rate': 1.7540983606557377e-05, 'epoch': 0.21}
{'loss': 2.1171, 'learning_rate': 1.737704918032787e-05, 'epoch': 0.22}
{'eval_loss': 1.681796908378601, 'eval_pearson_cosine': 0.8543958204407653, 'eval_spearman_cosine': 0.8523580323035814, 'eval_pearson_manhattan': 0.8561088982693472, 'eval_spearman_manhattan': 0.8516517453399247, 'eval_pearson_euclidean': 0.854381646380631, 'eval_spearman_euclidean': 0.8500210253547871, 'eval_pearson_dot': 0.8426460308873973, 'eval_spearman_dot': 0.8345694203991545, 'eval_pearson_max': 0.8561088982693472, 'eval_spearman_max': 0.8523580323035814, 'eval_runtime': 51.3468, 'eval_samples_per_second': 114.418, 'eval_steps_per_second': 0.058, 'epoch': 0.22}
{'loss': 2.0676, 'learning_rate': 1.721311475409836e-05, 'epoch': 0.23}
{'loss': 2.0422, 'learning_rate': 1.7049180327868854e-05, 'epoch': 0.24}
{'loss': 1.9205, 'learning_rate': 1.6885245901639347e-05, 'epoch': 0.24}
{'loss': 1.9148, 'learning_rate': 1.6721311475409837e-05, 'epoch': 0.25}
{'loss': 2.021, 'learning_rate': 1.655737704918033e-05, 'epoch': 0.26}
{'loss': 1.9944, 'learning_rate': 1.639344262295082e-05, 'epoch': 0.26}
{'loss': 1.7887, 'learning_rate': 1.6229508196721314e-05, 'epoch': 0.27}
{'loss': 1.9801, 'learning_rate': 1.6065573770491805e-05, 'epoch': 0.28}
{'loss': 1.8309, 'learning_rate': 1.5901639344262295e-05, 'epoch': 0.29}
{'loss': 1.8797, 'learning_rate': 1.5737704918032788e-05, 'epoch': 0.29}
{'eval_loss': 1.4784148931503296, 'eval_pearson_cosine': 0.863070233886015, 'eval_spearman_cosine': 0.8611035668009452, 'eval_pearson_manhattan': 0.8640751466341061, 'eval_spearman_manhattan': 0.8603862566980185, 'eval_pearson_euclidean': 0.8626572599779503, 'eval_spearman_euclidean': 0.8591038394440419, 'eval_pearson_dot': 0.8492961115874287, 'eval_spearman_dot': 0.8402706129406403, 'eval_pearson_max': 0.8640751466341061, 'eval_spearman_max': 0.8611035668009452, 'eval_runtime': 50.4952, 'eval_samples_per_second': 116.348, 'eval_steps_per_second': 0.059, 'epoch': 0.29}
{'loss': 1.8443, 'learning_rate': 1.5573770491803278e-05, 'epoch': 0.3}
{'loss': 1.8049, 'learning_rate': 1.5409836065573772e-05, 'epoch': 0.31}
{'loss': 1.6912, 'learning_rate': 1.5245901639344264e-05, 'epoch': 0.32}
{'loss': 1.8741, 'learning_rate': 1.5081967213114754e-05, 'epoch': 0.32}
{'loss': 1.785, 'learning_rate': 1.4918032786885249e-05, 'epoch': 0.33}
{'loss': 1.7228, 'learning_rate': 1.4754098360655739e-05, 'epoch': 0.34}
{'loss': 1.7783, 'learning_rate': 1.459016393442623e-05, 'epoch': 0.35}
{'loss': 1.6746, 'learning_rate': 1.4426229508196722e-05, 'epoch': 0.35}
{'loss': 1.6155, 'learning_rate': 1.4262295081967214e-05, 'epoch': 0.36}
{'loss': 1.6551, 'learning_rate': 1.4098360655737706e-05, 'epoch': 0.37}
{'eval_loss': 1.323501706123352, 'eval_pearson_cosine': 0.8688162943864073, 'eval_spearman_cosine': 0.867165113921055, 'eval_pearson_manhattan': 0.8700479348388257, 'eval_spearman_manhattan': 0.8674316995135701, 'eval_pearson_euclidean': 0.8686251452530193, 'eval_spearman_euclidean': 0.8662404343400781, 'eval_pearson_dot': 0.8495416421818938, 'eval_spearman_dot': 0.8415529290543254, 'eval_pearson_max': 0.8700479348388257, 'eval_spearman_max': 0.8674316995135701, 'eval_runtime': 50.4456, 'eval_samples_per_second': 116.462, 'eval_steps_per_second': 0.059, 'epoch': 0.37}
{'loss': 1.6531, 'learning_rate': 1.3934426229508198e-05, 'epoch': 0.38}
{'loss': 1.5554, 'learning_rate': 1.377049180327869e-05, 'epoch': 0.38}
{'loss': 1.5953, 'learning_rate': 1.3606557377049181e-05, 'epoch': 0.39}
{'loss': 1.5773, 'learning_rate': 1.3442622950819673e-05, 'epoch': 0.4}
{'loss': 1.6088, 'learning_rate': 1.3278688524590165e-05, 'epoch': 0.4}
{'loss': 1.5461, 'learning_rate': 1.3114754098360655e-05, 'epoch': 0.41}
{'loss': 1.5302, 'learning_rate': 1.295081967213115e-05, 'epoch': 0.42}
{'loss': 1.5672, 'learning_rate': 1.2786885245901642e-05, 'epoch': 0.43}
{'loss': 1.6247, 'learning_rate': 1.2622950819672132e-05, 'epoch': 0.43}
{'loss': 1.5126, 'learning_rate': 1.2459016393442624e-05, 'epoch': 0.44}
{'eval_loss': 1.2490781545639038, 'eval_pearson_cosine': 0.870816510562672, 'eval_spearman_cosine': 0.8698181457609595, 'eval_pearson_manhattan': 0.8731910157842028, 'eval_spearman_manhattan': 0.8710026643709499, 'eval_pearson_euclidean': 0.871917505757735, 'eval_spearman_euclidean': 0.869882668937103, 'eval_pearson_dot': 0.8508325237400651, 'eval_spearman_dot': 0.8432895740961073, 'eval_pearson_max': 0.8731910157842028, 'eval_spearman_max': 0.8710026643709499, 'eval_runtime': 49.1914, 'eval_samples_per_second': 119.431, 'eval_steps_per_second': 0.061, 'epoch': 0.44}
{'loss': 1.5141, 'learning_rate': 1.2295081967213116e-05, 'epoch': 0.45}
{'loss': 1.5369, 'learning_rate': 1.2131147540983608e-05, 'epoch': 0.46}
{'loss': 1.54, 'learning_rate': 1.19672131147541e-05, 'epoch': 0.46}
{'loss': 1.5295, 'learning_rate': 1.1803278688524591e-05, 'epoch': 0.47}
{'loss': 1.5697, 'learning_rate': 1.1639344262295083e-05, 'epoch': 0.48}
{'loss': 1.4429, 'learning_rate': 1.1475409836065575e-05, 'epoch': 0.49}
{'loss': 1.5297, 'learning_rate': 1.1311475409836066e-05, 'epoch': 0.49}
{'loss': 1.5191, 'learning_rate': 1.1147540983606557e-05, 'epoch': 0.5}
{'loss': 1.4843, 'learning_rate': 1.0983606557377052e-05, 'epoch': 0.51}
{'loss': 1.4394, 'learning_rate': 1.0819672131147544e-05, 'epoch': 0.51}
{'eval_loss': 1.2175294160842896, 'eval_pearson_cosine': 0.8720087901613824, 'eval_spearman_cosine': 0.8710475957811642, 'eval_pearson_manhattan': 0.8744030883108762, 'eval_spearman_manhattan': 0.8718406969770099, 'eval_pearson_euclidean': 0.8734299442499909, 'eval_spearman_euclidean': 0.8710172711425829, 'eval_pearson_dot': 0.8528592631334142, 'eval_spearman_dot': 0.8451704152500751, 'eval_pearson_max': 0.8744030883108762, 'eval_spearman_max': 0.8718406969770099, 'eval_runtime': 50.7433, 'eval_samples_per_second': 115.779, 'eval_steps_per_second': 0.059, 'epoch': 0.51}
{'loss': 1.5745, 'learning_rate': 1.0655737704918034e-05, 'epoch': 0.52}
{'loss': 1.5716, 'learning_rate': 1.0491803278688525e-05, 'epoch': 0.53}
{'loss': 1.4463, 'learning_rate': 1.0327868852459017e-05, 'epoch': 0.54}
{'loss': 1.5196, 'learning_rate': 1.0163934426229509e-05, 'epoch': 0.54}
{'loss': 1.4206, 'learning_rate': 1e-05, 'epoch': 0.55}
{'loss': 1.4496, 'learning_rate': 9.836065573770493e-06, 'epoch': 0.56}
{'loss': 1.4309, 'learning_rate': 9.672131147540984e-06, 'epoch': 0.57}
{'loss': 1.422, 'learning_rate': 9.508196721311476e-06, 'epoch': 0.57}
{'loss': 1.514, 'learning_rate': 9.344262295081968e-06, 'epoch': 0.58}
{'loss': 1.4069, 'learning_rate': 9.18032786885246e-06, 'epoch': 0.59}
{'eval_loss': 1.183715581893921, 'eval_pearson_cosine': 0.8723190452224101, 'eval_spearman_cosine': 0.8718898909610935, 'eval_pearson_manhattan': 0.8751171470150978, 'eval_spearman_manhattan': 0.8725464038690924, 'eval_pearson_euclidean': 0.8742333355276266, 'eval_spearman_euclidean': 0.871931366078647, 'eval_pearson_dot': 0.8534516682096298, 'eval_spearman_dot': 0.8463407843841212, 'eval_pearson_max': 0.8751171470150978, 'eval_spearman_max': 0.8725464038690924, 'eval_runtime': 50.0226, 'eval_samples_per_second': 117.447, 'eval_steps_per_second': 0.06, 'epoch': 0.59}
{'loss': 1.5403, 'learning_rate': 9.016393442622952e-06, 'epoch': 0.6}
{'loss': 1.4356, 'learning_rate': 8.852459016393443e-06, 'epoch': 0.6}
{'loss': 1.3933, 'learning_rate': 8.688524590163935e-06, 'epoch': 0.61}
{'loss': 1.4702, 'learning_rate': 8.524590163934427e-06, 'epoch': 0.62}
{'loss': 1.4998, 'learning_rate': 8.360655737704919e-06, 'epoch': 0.62}
{'loss': 1.4293, 'learning_rate': 8.19672131147541e-06, 'epoch': 0.63}
{'loss': 1.376, 'learning_rate': 8.032786885245902e-06, 'epoch': 0.64}
{'loss': 1.4829, 'learning_rate': 7.868852459016394e-06, 'epoch': 0.65}
{'loss': 1.4763, 'learning_rate': 7.704918032786886e-06, 'epoch': 0.65}
{'loss': 1.4594, 'learning_rate': 7.540983606557377e-06, 'epoch': 0.66}
{'eval_loss': 1.1624795198440552, 'eval_pearson_cosine': 0.8730231426393052, 'eval_spearman_cosine': 0.8728905296648886, 'eval_pearson_manhattan': 0.8755292260104687, 'eval_spearman_manhattan': 0.873301672296646, 'eval_pearson_euclidean': 0.8746889783072623, 'eval_spearman_euclidean': 0.8725728283785887, 'eval_pearson_dot': 0.8544512262940392, 'eval_spearman_dot': 0.8476722382849567, 'eval_pearson_max': 0.8755292260104687, 'eval_spearman_max': 0.873301672296646, 'eval_runtime': 49.5882, 'eval_samples_per_second': 118.476, 'eval_steps_per_second': 0.06, 'epoch': 0.66}
{'loss': 1.3685, 'learning_rate': 7.3770491803278695e-06, 'epoch': 0.67}
{'loss': 1.4248, 'learning_rate': 7.213114754098361e-06, 'epoch': 0.68}
{'loss': 1.4592, 'learning_rate': 7.049180327868853e-06, 'epoch': 0.68}
{'loss': 1.4485, 'learning_rate': 6.885245901639345e-06, 'epoch': 0.69}
{'loss': 1.3478, 'learning_rate': 6.721311475409837e-06, 'epoch': 0.7}
{'loss': 1.5044, 'learning_rate': 6.5573770491803276e-06, 'epoch': 0.71}
{'loss': 1.4572, 'learning_rate': 6.393442622950821e-06, 'epoch': 0.71}
{'loss': 1.3696, 'learning_rate': 6.229508196721312e-06, 'epoch': 0.72}
{'loss': 1.4805, 'learning_rate': 6.065573770491804e-06, 'epoch': 0.73}
{'loss': 1.3772, 'learning_rate': 5.9016393442622956e-06, 'epoch': 0.74}
{'eval_loss': 1.1528023481369019, 'eval_pearson_cosine': 0.8742162991881371, 'eval_spearman_cosine': 0.873648059432648, 'eval_pearson_manhattan': 0.876316149310109, 'eval_spearman_manhattan': 0.8739778034321567, 'eval_pearson_euclidean': 0.8755322669249754, 'eval_spearman_euclidean': 0.8733355698934299, 'eval_pearson_dot': 0.8560738123407455, 'eval_spearman_dot': 0.8487698412184224, 'eval_pearson_max': 0.876316149310109, 'eval_spearman_max': 0.8739778034321567, 'eval_runtime': 49.198, 'eval_samples_per_second': 119.415, 'eval_steps_per_second': 0.061, 'epoch': 0.74} 
{'loss': 1.37, 'learning_rate': 5.737704918032787e-06, 'epoch': 0.74}
{'loss': 1.3549, 'learning_rate': 5.573770491803278e-06, 'epoch': 0.75}
{'loss': 1.3673, 'learning_rate': 5.409836065573772e-06, 'epoch': 0.76}
{'loss': 1.4036, 'learning_rate': 5.245901639344263e-06, 'epoch': 0.76}
{'loss': 1.4269, 'learning_rate': 5.0819672131147545e-06, 'epoch': 0.77}
{'loss': 1.3891, 'learning_rate': 4.918032786885246e-06, 'epoch': 0.78}
{'loss': 1.4457, 'learning_rate': 4.754098360655738e-06, 'epoch': 0.79}
{'loss': 1.3928, 'learning_rate': 4.59016393442623e-06, 'epoch': 0.79}
{'loss': 1.36, 'learning_rate': 4.426229508196722e-06, 'epoch': 0.8}
{'loss': 1.3561, 'learning_rate': 4.2622950819672135e-06, 'epoch': 0.81}
{'eval_loss': 1.1401088237762451, 'eval_pearson_cosine': 0.8743441047146987, 'eval_spearman_cosine': 0.8741483515924068, 'eval_pearson_manhattan': 0.8764166705084039, 'eval_spearman_manhattan': 0.8743234068732011, 'eval_pearson_euclidean': 0.8757018042177926, 'eval_spearman_euclidean': 0.8737513780091954, 'eval_pearson_dot': 0.8559313161192934, 'eval_spearman_dot': 0.8490041741201432, 'eval_pearson_max': 0.8764166705084039, 'eval_spearman_max': 0.8743234068732011, 'eval_runtime': 51.8942, 'eval_samples_per_second': 113.211, 'eval_steps_per_second': 0.058, 'epoch': 0.81}
{'loss': 1.3529, 'learning_rate': 4.098360655737705e-06, 'epoch': 0.82}
{'loss': 1.4062, 'learning_rate': 3.934426229508197e-06, 'epoch': 0.82}
{'loss': 1.5084, 'learning_rate': 3.7704918032786884e-06, 'epoch': 0.83}
{'loss': 1.3901, 'learning_rate': 3.6065573770491806e-06, 'epoch': 0.84}
{'loss': 1.3704, 'learning_rate': 3.4426229508196724e-06, 'epoch': 0.85}
{'loss': 1.3867, 'learning_rate': 3.2786885245901638e-06, 'epoch': 0.85}
{'loss': 1.4266, 'learning_rate': 3.114754098360656e-06, 'epoch': 0.86}
{'loss': 1.35, 'learning_rate': 2.9508196721311478e-06, 'epoch': 0.87}
{'loss': 1.3949, 'learning_rate': 2.786885245901639e-06, 'epoch': 0.88}
{'loss': 1.4045, 'learning_rate': 2.6229508196721314e-06, 'epoch': 0.88}
{'eval_loss': 1.1358762979507446, 'eval_pearson_cosine': 0.8745573437638567, 'eval_spearman_cosine': 0.8744105795311472, 'eval_pearson_manhattan': 0.8767977578954331, 'eval_spearman_manhattan': 0.874612351912206, 'eval_pearson_euclidean': 0.8761974413576089, 'eval_spearman_euclidean': 0.8741348515831022, 'eval_pearson_dot': 0.8560706610739197, 'eval_spearman_dot': 0.8491052773926727, 'eval_pearson_max': 0.8767977578954331, 'eval_spearman_max': 0.874612351912206, 'eval_runtime': 49.2326, 'eval_samples_per_second': 119.331, 'eval_steps_per_second': 0.061, 'epoch': 0.88}
{'loss': 1.4265, 'learning_rate': 2.459016393442623e-06, 'epoch': 0.89}
{'loss': 1.3779, 'learning_rate': 2.295081967213115e-06, 'epoch': 0.9}
{'loss': 1.4574, 'learning_rate': 2.1311475409836067e-06, 'epoch': 0.9}
{'loss': 1.379, 'learning_rate': 1.9672131147540985e-06, 'epoch': 0.91}
{'loss': 1.3866, 'learning_rate': 1.8032786885245903e-06, 'epoch': 0.92}
{'loss': 1.3821, 'learning_rate': 1.6393442622950819e-06, 'epoch': 0.93}
{'loss': 1.3757, 'learning_rate': 1.4754098360655739e-06, 'epoch': 0.93}
{'loss': 1.39, 'learning_rate': 1.3114754098360657e-06, 'epoch': 0.94}
{'loss': 1.3647, 'learning_rate': 1.1475409836065575e-06, 'epoch': 0.95}
{'loss': 1.4125, 'learning_rate': 9.836065573770493e-07, 'epoch': 0.96}
{'eval_loss': 1.1335006952285767, 'eval_pearson_cosine': 0.8751330413361709, 'eval_spearman_cosine': 0.874942665307125, 'eval_pearson_manhattan': 0.8773063740801834, 'eval_spearman_manhattan': 0.8750371191088792, 'eval_pearson_euclidean': 0.8767110990458997, 'eval_spearman_euclidean': 0.8746277380438068, 'eval_pearson_dot': 0.8567930004957343, 'eval_spearman_dot': 0.8498321782495405, 'eval_pearson_max': 0.8773063740801834, 'eval_spearman_max': 0.8750371191088792, 'eval_runtime': 49.3441, 'eval_samples_per_second': 119.062, 'eval_steps_per_second': 0.061, 'epoch': 0.96}
{'loss': 1.3547, 'learning_rate': 8.196721311475409e-07, 'epoch': 0.96}
{'loss': 1.4468, 'learning_rate': 6.557377049180328e-07, 'epoch': 0.97}
{'loss': 1.4018, 'learning_rate': 4.918032786885246e-07, 'epoch': 0.98}
{'loss': 1.4117, 'learning_rate': 3.278688524590164e-07, 'epoch': 0.99}
{'loss': 1.3463, 'learning_rate': 1.639344262295082e-07, 'epoch': 0.99}
{'loss': 0.9521, 'learning_rate': 0.0, 'epoch': 1.0}
{'train_runtime': 3171.9084, 'train_samples_per_second': 87.416, 'train_steps_per_second': 0.043, 'train_loss': 2.2297125625259735, 'epoch': 1.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 136/136 [52:47<00:00, 23.29s/it] 
{'pearson_cosine': 0.8461989869850812, 'spearman_cosine': 0.8456475704810019, 'pearson_manhattan': 0.8529623428547448, 'spearman_manhattan': 0.8429845406318258, 'pearson_euclidean': 0.8533027010833827, 'spearman_euclidean': 0.8437611502611355, 'pearson_dot': 0.8276465647995052, 'spearman_dot': 0.8172668293306752, 'pearson_max': 0.8533027010833827, 'spearman_max': 0.8456475704810019}
wandb: | 0.088 MB of 0.088 MB uploaded (0.004 MB deduped)
wandb: Run history:
wandb:                      eval/loss █▃▂▁▁▁▁▁▁▁▁▁▁
wandb:            eval/pearson_cosine ▁▆▇██████████
wandb:               eval/pearson_dot ▁▇███████████
wandb:         eval/pearson_euclidean ▁▆▇▇█████████
wandb:         eval/pearson_manhattan ▁▅▇▇█████████
wandb:               eval/pearson_max ▁▅▇▇█████████
wandb:                   eval/runtime █▁▂▁▁▁▂▁▁▁▂▁▁
wandb:        eval/samples_per_second ▁█▇▇▇█▇███▇██
wandb:           eval/spearman_cosine ▁▅▇▇█████████
wandb:              eval/spearman_dot ▁▇███████████
wandb:        eval/spearman_euclidean ▁▅▇▇█████████
wandb:        eval/spearman_manhattan ▁▅▆▇▇████████
wandb:              eval/spearman_max ▁▅▆▇▇████████
wandb:          eval/steps_per_second ▁█▇▇▇█▇███▇██
wandb:                    train/epoch ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
wandb:              train/global_step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:            train/learning_rate ▂▃▅▇███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁
wandb:                     train/loss ██▇▆▅▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▂▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▂▁
wandb:               train/total_flos ▁
wandb:               train/train_loss ▁
wandb:            train/train_runtime ▁
wandb: train/train_samples_per_second ▁
wandb:   train/train_steps_per_second ▁
wandb:
wandb: Run summary:
wandb:                      eval/loss 1.1335
wandb:            eval/pearson_cosine 0.87513
wandb:               eval/pearson_dot 0.85679
wandb:         eval/pearson_euclidean 0.87671
wandb:         eval/pearson_manhattan 0.87731
wandb:               eval/pearson_max 0.87731
wandb:                   eval/runtime 49.3441
wandb:        eval/samples_per_second 119.062
wandb:           eval/spearman_cosine 0.87494
wandb:              eval/spearman_dot 0.84983
wandb:        eval/spearman_euclidean 0.87463
wandb:        eval/spearman_manhattan 0.87504
wandb:              eval/spearman_max 0.87504
wandb:          eval/steps_per_second 0.061
wandb:                    train/epoch 1.0
wandb:              train/global_step 136
wandb:            train/learning_rate 0.0
wandb:                     train/loss 0.9521
wandb:               train/total_flos 0.0
wandb:               train/train_loss 2.22971
wandb:            train/train_runtime 3171.9084
wandb: train/train_samples_per_second 87.416
wandb:   train/train_steps_per_second 0.043
wandb:
wandb:  View run dashing-water-2368 at: https://wandb.ai/tomaarsen/huggingface/runs/yku359rr
wandb:  View job at https://wandb.ai/tomaarsen/huggingface/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzMzAwOTcxNg==/version_details/v7
wandb: Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: .\wandb\run-20240125_164316-yku359rr\logs
Training Notes

This Spearman Cosine of 84.56 on the test set is actually superior to all-mpnet-base-v2, which scores a 83.42. We can attribute this increase to the higher batch size from the cached MNRL. Big thanks to @kwang2049 for implementing this.

W&B captures the training & evaluation loss on every step as you'd hope:
image
Additionally, the evaluator results are also captured:
image
For the interested, this also tracks training speed, disk usage, GPU usage, memory usage, etc.

See the first comment for more examples! I ran out of characters in the PR description.

Backwards compatibility

I want to point out that the existing examples should still work, e.g. the ones that train using:

...
# Train the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=dev_evaluator,
    epochs=num_epochs,
    evaluation_steps=int(len(train_dataloader) * 0.1),
    warmup_steps=warmup_steps,
    output_path=model_save_path,
    use_amp=False,  # Set to True, if your GPU supports FP16 operations
)

Notes

At a glance, it feels like the training time & memory usage is equivalent to the old manual training loop.

TODO's

Many TODOs still remain:

  • Lots of errors/warnings for misconfigurations
  • Implement smart sampling (default would be non-smart Round Robin if fit is used)
  • Test checkpointing, pausing training and proceeding later, etc.
  • Test whether Trainer.evaluate works.
  • Add a bunch of tests to support this all.
  • Rework how model cards are created - we can't use the old code for that anymore with the new Trainer. Luckily I've written similar code for SpanMarker and SetFit.
  • Set the environment argument to update the W&B default project name to "sentence_transformers".

cc @bwanglzu @ir2718 @johneckberg @aamir-s18 as I know you're interested in my TODO list.
cc @osanseviero @LysandreJik

  • Tom Aarsen

@tomaarsen
Copy link
Collaborator Author

tomaarsen commented Jan 25, 2024

These are the remaining experiments that did not fit in the original PR description due to the character limit.

Example 4

This example finetunes microsoft/mpnet-base using in-batch negatives on SNLI as well as MultiNLI, but both as separate Dataset instances in a DatasetDict. This prevents mixing samples from one dataset with the other. Unlike the last example, this does not use a cached variant of the in-batch negative loss, so we must use a lower batch size. The script evaluates on both SNLI and MultiNLI via loss and on the STSBenchmark development set via the EmbeddingSimilarityEvaluator. After training, the model performance is tested against the STSBenchmark test set.

Training Script
from collections import defaultdict
import datasets
from datasets import Dataset
from sentence_transformers import (
    SentenceTransformerTrainingArguments,
    SentenceTransformer,
    SentenceTransformerTrainer,
    losses,
    evaluation,
)
from sentence_transformers.models import Transformer, Pooling

def to_triplets(dataset):
    premises = defaultdict(dict)
    for sample in dataset:
        premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
    queries = []
    positives = []
    negatives = []
    for premise, sentences in premises.items():
        if 0 in sentences and 2 in sentences:
            queries.append(premise)
            positives.append(sentences[0]) # <- entailment
            negatives.append(sentences[2]) # <- contradiction
    return Dataset.from_dict({
        "anchor": queries,
        "positive": positives,
        "negative": negatives,
    })

snli_ds = datasets.load_dataset("snli")
snli_ds = datasets.DatasetDict({
    "train": to_triplets(snli_ds["train"]),
    "validation": to_triplets(snli_ds["validation"]),
    "test": to_triplets(snli_ds["test"]),
})
multi_nli_ds = datasets.load_dataset("multi_nli")
multi_nli_ds = datasets.DatasetDict({
    "train": to_triplets(multi_nli_ds["train"]),
    "validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
})

all_nli_train_ds = datasets.DatasetDict({
    "multi_ds": multi_nli_ds["train"],
    "snli_ds": snli_ds["train"],
})
all_nli_eval_ds = datasets.DatasetDict({
    "multi_ds": multi_nli_ds["validation_matched"],
    "snli_ds": snli_ds["validation"],
})

stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")

training_args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    num_train_epochs=1,
    seed=33,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    bf16=True,
    logging_steps=100,
    evaluation_strategy="steps",
    eval_steps=400,
    save_steps=400,
    save_total_limit=2,
    metric_for_best_model="spearman_cosine",
    greater_is_better=True,
)

transformer = Transformer("microsoft/mpnet-base", max_seq_length=384)
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[transformer, pooling])

loss = losses.MultipleNegativesRankingLoss(model)
dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
    stsb_dev["sentence1"],
    stsb_dev["sentence2"],
    [score / 5 for score in stsb_dev["score"]],
    main_similarity=evaluation.SimilarityFunction.COSINE,
    name="sts-dev",
)

trainer = SentenceTransformerTrainer(
    model=model,
    evaluator=dev_evaluator,
    args=training_args,
    train_dataset=all_nli_train_ds,
    eval_dataset=all_nli_eval_ds,
    loss=loss,
)
trainer.train()

test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
    stsb_test["sentence1"],
    stsb_test["sentence2"],
    [score / 5 for score in stsb_test["score"]],
    main_similarity=evaluation.SimilarityFunction.COSINE,
    name="sts-test",
)
results = test_evaluator(model)
print(results)
Training Logs
Some weights of MPNetModel were not initialized from the model checkpoint at microsoft/mpnet-base and are newly initialized: ['mpnet.pooler.dense.weight', 'mpnet.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
wandb: Currently logged in as: tomaarsen. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.2
wandb: Run data is saved locally in C:\code\sentence-transformers\wandb\run-20240125_175042-qtcy12jb
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run hopeful-leaf-2371
wandb:  View project at https://wandb.ai/tomaarsen/huggingface
wandb:  View run at https://wandb.ai/tomaarsen/huggingface/runs/qtcy12jb
{'loss': 3.3544, 'learning_rate': 4.987531172069826e-06, 'epoch': 0.02}
{'loss': 1.0246, 'learning_rate': 9.975062344139652e-06, 'epoch': 0.05}
{'loss': 0.7226, 'learning_rate': 1.4962593516209478e-05, 'epoch': 0.07}
{'loss': 0.6364, 'learning_rate': 1.9950124688279304e-05, 'epoch': 0.1}
{'eval_multi_ds_loss': 0.28911375999450684, 'eval_pearson_cosine': 0.8710415103287994, 'eval_spearman_cosine': 0.8741214973686887, 'eval_pearson_manhattan': 0.8743257652783766, 'eval_spearman_manhattan': 0.8720788092407661, 'eval_pearson_euclidean': 0.8742792501949121, 'eval_spearman_euclidean': 0.8719767095262688, 'eval_pearson_dot': 0.8331795528410947, 'eval_spearman_dot': 0.8244226660748295, 'eval_pearson_max': 0.8743257652783766, 'eval_spearman_max': 0.8741214973686887, 'eval_multi_ds_runtime': 18.2833, 'eval_multi_ds_samples_per_second': 160.365, 'eval_multi_ds_steps_per_second': 2.516, 'epoch': 0.1}
{'eval_snli_ds_loss': 0.7055805325508118, 'eval_snli_ds_runtime': 9.9822, 'eval_snli_ds_samples_per_second': 294.824, 'eval_snli_ds_steps_per_second': 4.608, 'epoch': 0.1}
{'loss': 0.5915, 'learning_rate': 1.9450762829403606e-05, 'epoch': 0.12}
{'loss': 0.5297, 'learning_rate': 1.889597780859917e-05, 'epoch': 0.15}
{'loss': 0.5025, 'learning_rate': 1.834119278779473e-05, 'epoch': 0.17}
{'loss': 0.5053, 'learning_rate': 1.778640776699029e-05, 'epoch': 0.2}
{'eval_multi_ds_loss': 0.2232046127319336, 'eval_pearson_cosine': 0.8725622603378349, 'eval_spearman_cosine': 0.8776622340170253, 'eval_pearson_manhattan': 0.8814790692050586, 'eval_spearman_manhattan': 0.8791403026500869, 'eval_pearson_euclidean': 0.8825901120918931, 'eval_spearman_euclidean': 0.8807073072874266, 'eval_pearson_dot': 0.8361209553185773, 'eval_spearman_dot': 0.8295279144077214, 'eval_pearson_max': 0.8825901120918931, 'eval_spearman_max': 0.8807073072874266, 'eval_multi_ds_runtime': 17.2035, 'eval_multi_ds_samples_per_second': 170.43, 'eval_multi_ds_steps_per_second': 2.674, 'epoch': 0.2}
{'eval_snli_ds_loss': 0.5936715006828308, 'eval_snli_ds_runtime': 9.2765, 'eval_snli_ds_samples_per_second': 317.255, 'eval_snli_ds_steps_per_second': 4.959, 'epoch': 0.2}
{'loss': 0.4816, 'learning_rate': 1.7231622746185855e-05, 'epoch': 0.22}
{'loss': 0.4576, 'learning_rate': 1.6676837725381415e-05, 'epoch': 0.25}
{'loss': 0.4728, 'learning_rate': 1.612205270457698e-05, 'epoch': 0.27}
{'loss': 0.4413, 'learning_rate': 1.556726768377254e-05, 'epoch': 0.3}
{'eval_multi_ds_loss': 0.20757625997066498, 'eval_pearson_cosine': 0.8780560238284023, 'eval_spearman_cosine': 0.882285711293379, 'eval_pearson_manhattan': 0.8842982368539201, 'eval_spearman_manhattan': 0.883945593220324, 'eval_pearson_euclidean': 0.8845710974590998, 'eval_spearman_euclidean': 0.885126124283082, 'eval_pearson_dot': 0.8472135106132146, 'eval_spearman_dot': 0.8433860261198792, 'eval_pearson_max': 0.8845710974590998, 'eval_spearman_max': 0.885126124283082, 'eval_multi_ds_runtime': 16.6207, 'eval_multi_ds_samples_per_second': 176.406, 'eval_multi_ds_steps_per_second': 2.768, 'epoch': 0.3}
{'eval_snli_ds_loss': 0.5469626188278198, 'eval_snli_ds_runtime': 9.9305, 'eval_snli_ds_samples_per_second': 296.361, 'eval_snli_ds_steps_per_second': 4.632, 'epoch': 0.3}
{'loss': 0.4297, 'learning_rate': 1.5012482662968102e-05, 'epoch': 0.32}
{'loss': 0.4362, 'learning_rate': 1.4457697642163662e-05, 'epoch': 0.35}
{'loss': 0.4212, 'learning_rate': 1.3902912621359224e-05, 'epoch': 0.37}
{'loss': 0.4412, 'learning_rate': 1.3348127600554785e-05, 'epoch': 0.4}
{'eval_multi_ds_loss': 0.18478551506996155, 'eval_pearson_cosine': 0.8725622397515482, 'eval_spearman_cosine': 0.8777594634024195, 'eval_pearson_manhattan': 0.8794692932545162, 'eval_spearman_manhattan': 0.880215097766623, 'eval_pearson_euclidean': 0.8812674507768679, 'eval_spearman_euclidean': 0.8825907186210344, 'eval_pearson_dot': 0.8379367185748164, 'eval_spearman_dot': 0.8353757679299957, 'eval_pearson_max': 0.8812674507768679, 'eval_spearman_max': 0.8825907186210344, 'eval_multi_ds_runtime': 16.7454, 'eval_multi_ds_samples_per_second': 175.093, 'eval_multi_ds_steps_per_second': 2.747, 'epoch': 0.4}
{'eval_snli_ds_loss': 0.5001863837242126, 'eval_snli_ds_runtime': 10.49, 'eval_snli_ds_samples_per_second': 280.553, 'eval_snli_ds_steps_per_second': 4.385, 'epoch': 0.4}
{'loss': 0.4108, 'learning_rate': 1.2793342579750347e-05, 'epoch': 0.42}
{'loss': 0.4025, 'learning_rate': 1.223855755894591e-05, 'epoch': 0.45}
{'loss': 0.3935, 'learning_rate': 1.1683772538141471e-05, 'epoch': 0.47}
{'loss': 0.4108, 'learning_rate': 1.1128987517337034e-05, 'epoch': 0.5}
{'eval_multi_ds_loss': 0.17254070937633514, 'eval_pearson_cosine': 0.8766792473934603, 'eval_spearman_cosine': 0.8807392653876771, 'eval_pearson_manhattan': 0.882369467310926, 'eval_spearman_manhattan': 0.8818773480226023, 'eval_pearson_euclidean': 0.8839126635694313, 'eval_spearman_euclidean': 0.8838642372923656, 'eval_pearson_dot': 0.8461220917109644, 'eval_spearman_dot': 0.8416885961243507, 'eval_pearson_max': 0.8839126635694313, 'eval_spearman_max': 0.8838642372923656, 'eval_multi_ds_runtime': 12.8856, 'eval_multi_ds_samples_per_second': 227.541, 'eval_multi_ds_steps_per_second': 3.57, 'epoch': 0.5}
{'eval_snli_ds_loss': 0.4814876914024353, 'eval_snli_ds_runtime': 7.1963, 'eval_snli_ds_samples_per_second': 408.96, 'eval_snli_ds_steps_per_second': 6.392, 'epoch': 0.5}
{'loss': 0.3911, 'learning_rate': 1.0574202496532596e-05, 'epoch': 0.52}
{'loss': 0.3802, 'learning_rate': 1.0019417475728156e-05, 'epoch': 0.55}
{'loss': 0.3871, 'learning_rate': 9.464632454923718e-06, 'epoch': 0.57}
{'loss': 0.392, 'learning_rate': 8.90984743411928e-06, 'epoch': 0.6}
{'eval_multi_ds_loss': 0.16512537002563477, 'eval_pearson_cosine': 0.8739355275303382, 'eval_spearman_cosine': 0.8796963250593532, 'eval_pearson_manhattan': 0.8810262639648156, 'eval_spearman_manhattan': 0.8806344468100958, 'eval_pearson_euclidean': 0.8831772126774371, 'eval_spearman_euclidean': 0.8835790039996912, 'eval_pearson_dot': 0.8433156354844191, 'eval_spearman_dot': 0.8406846932972077, 'eval_pearson_max': 0.8831772126774371, 'eval_spearman_max': 0.8835790039996912, 'eval_multi_ds_runtime': 13.1991, 'eval_multi_ds_samples_per_second': 222.137, 'eval_multi_ds_steps_per_second': 3.485, 'epoch': 0.6}
{'eval_snli_ds_loss': 0.46691277623176575, 'eval_snli_ds_runtime': 7.5443, 'eval_snli_ds_samples_per_second': 390.094, 'eval_snli_ds_steps_per_second': 6.097, 'epoch': 0.6}
{'loss': 0.3845, 'learning_rate': 8.355062413314841e-06, 'epoch': 0.62}
{'loss': 0.392, 'learning_rate': 7.800277392510403e-06, 'epoch': 0.65}
{'loss': 0.3811, 'learning_rate': 7.2454923717059646e-06, 'epoch': 0.67}
{'loss': 0.3857, 'learning_rate': 6.690707350901527e-06, 'epoch': 0.7}
{'eval_multi_ds_loss': 0.15821050107479095, 'eval_pearson_cosine': 0.8745454429671125, 'eval_spearman_cosine': 0.8796248189978303, 'eval_pearson_manhattan': 0.8808940344105197, 'eval_spearman_manhattan': 0.8807140731994956, 'eval_pearson_euclidean': 0.882800182996055, 'eval_spearman_euclidean': 0.8830748089170558, 'eval_pearson_dot': 0.8401340698704884, 'eval_spearman_dot': 0.8367086399910203, 'eval_pearson_max': 0.882800182996055, 'eval_spearman_max': 0.8830748089170558, 'eval_multi_ds_runtime': 13.1346, 'eval_multi_ds_samples_per_second': 223.228, 'eval_multi_ds_steps_per_second': 3.502, 'epoch': 0.7}
{'eval_snli_ds_loss': 0.4553277790546417, 'eval_snli_ds_runtime': 7.5063, 'eval_snli_ds_samples_per_second': 392.071, 'eval_snli_ds_steps_per_second': 6.128, 'epoch': 0.7}
{'loss': 0.3709, 'learning_rate': 6.135922330097088e-06, 'epoch': 0.72}
{'loss': 0.3809, 'learning_rate': 5.581137309292649e-06, 'epoch': 0.75}
{'loss': 0.3575, 'learning_rate': 5.026352288488211e-06, 'epoch': 0.77}
{'loss': 0.3511, 'learning_rate': 4.471567267683773e-06, 'epoch': 0.8}
{'eval_multi_ds_loss': 0.15566690266132355, 'eval_pearson_cosine': 0.8754714842839091, 'eval_spearman_cosine': 0.8797573910454676, 'eval_pearson_manhattan': 0.8804203622285345, 'eval_spearman_manhattan': 0.8804234057393749, 'eval_pearson_euclidean': 0.8819507122458277, 'eval_spearman_euclidean': 0.8826128752199354, 'eval_pearson_dot': 0.8465130943161607, 'eval_spearman_dot': 0.8427793815736526, 'eval_pearson_max': 0.8819507122458277, 'eval_spearman_max': 0.8826128752199354, 'eval_multi_ds_runtime': 13.3088, 'eval_multi_ds_samples_per_second': 220.306, 'eval_multi_ds_steps_per_second': 3.456, 'epoch': 0.8}
{'eval_snli_ds_loss': 0.4491104781627655, 'eval_snli_ds_runtime': 7.4502, 'eval_snli_ds_samples_per_second': 395.023, 'eval_snli_ds_steps_per_second': 6.174, 'epoch': 0.8}
{'loss': 0.341, 'learning_rate': 3.916782246879335e-06, 'epoch': 0.82}
{'loss': 0.3568, 'learning_rate': 3.3619972260748964e-06, 'epoch': 0.85}
{'loss': 0.3656, 'learning_rate': 2.807212205270458e-06, 'epoch': 0.87}
{'loss': 0.3649, 'learning_rate': 2.2524271844660195e-06, 'epoch': 0.9}
{'eval_multi_ds_loss': 0.1532323807477951, 'eval_pearson_cosine': 0.8747244603573544, 'eval_spearman_cosine': 0.8788013083189345, 'eval_pearson_manhattan': 0.8801747983324684, 'eval_spearman_manhattan': 0.8799745139709142, 'eval_pearson_euclidean': 0.8816418607710814, 'eval_spearman_euclidean': 0.882319636583665, 'eval_pearson_dot': 0.8439543676297461, 'eval_spearman_dot': 0.8406034939529871, 'eval_pearson_max': 0.8816418607710814, 'eval_spearman_max': 0.882319636583665, 'eval_multi_ds_runtime': 13.708, 'eval_multi_ds_samples_per_second': 213.889, 'eval_multi_ds_steps_per_second': 3.356, 'epoch': 0.9}
{'eval_snli_ds_loss': 0.4429143965244293, 'eval_snli_ds_runtime': 7.4127, 'eval_snli_ds_samples_per_second': 397.022, 'eval_snli_ds_steps_per_second': 6.206, 'epoch': 0.9}
{'loss': 0.3533, 'learning_rate': 1.6976421636615814e-06, 'epoch': 0.92}
{'loss': 0.3601, 'learning_rate': 1.142857142857143e-06, 'epoch': 0.95}
{'loss': 0.3582, 'learning_rate': 5.880721220527046e-07, 'epoch': 0.97}
{'loss': 0.3543, 'learning_rate': 3.3287101248266304e-08, 'epoch': 1.0}
{'eval_multi_ds_loss': 0.15337558090686798, 'eval_pearson_cosine': 0.8734623607752061, 'eval_spearman_cosine': 0.8778499591258418, 'eval_pearson_manhattan': 0.8795537148237653, 'eval_spearman_manhattan': 0.8796291329499518, 'eval_pearson_euclidean': 0.8810244766032308, 'eval_spearman_euclidean': 0.8818175010062657, 'eval_pearson_dot': 0.8435036634636086, 'eval_spearman_dot': 0.8402251129092807, 'eval_pearson_max': 0.8810244766032308, 'eval_spearman_max': 0.8818175010062657, 'eval_multi_ds_runtime': 13.8355, 'eval_multi_ds_samples_per_second': 211.919, 'eval_multi_ds_steps_per_second': 3.325, 'epoch': 1.0}
{'eval_snli_ds_loss': 0.43842339515686035, 'eval_snli_ds_runtime': 7.8348, 'eval_snli_ds_samples_per_second': 375.633, 'eval_snli_ds_steps_per_second': 5.871, 'epoch': 1.0}
{'train_runtime': 1496.5601, 'train_samples_per_second': 185.276, 'train_steps_per_second': 2.677, 'train_loss': 0.5116543066364732, 'epoch': 1.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4006/4006 [24:51<00:00,  2.69it/s]
{'pearson_cosine': 0.8634115009036718, 'spearman_cosine': 0.8706690831355366, 'pearson_manhattan': 0.8668466048928992, 'spearman_manhattan': 0.8636961613051664, 'pearson_euclidean': 0.8675179947674445, 'spearman_euclidean': 0.8650155767596514, 'pearson_dot': 0.8399016778679229, 'spearman_dot': 0.8369974602704098, 'pearson_max': 0.8675179947674445, 'spearman_max': 0.8706690831355366}
wandb: \ 0.059 MB of 0.059 MB uploaded (0.004 MB deduped)
wandb: Run history:
wandb:               eval/multi_ds_loss █▅▄▃▂▂▁▁▁▁
wandb:            eval/multi_ds_runtime █▇▆▆▁▁▁▂▂▂
wandb: eval/multi_ds_samples_per_second ▁▂▃▃█▇█▇▇▆
wandb:   eval/multi_ds_steps_per_second ▁▂▃▃█▇█▇▇▆
wandb:              eval/pearson_cosine ▁▃█▃▇▄▄▅▅▃
wandb:                 eval/pearson_dot ▁▂█▃▇▆▄█▆▆
wandb:           eval/pearson_euclidean ▁▇█▆█▇▇▆▆▆
wandb:           eval/pearson_manhattan ▁▆█▅▇▆▆▅▅▅
wandb:                 eval/pearson_max ▁▇█▆█▇▇▆▆▆
wandb:                eval/snli_ds_loss █▅▄▃▂▂▁▁▁▁
wandb:             eval/snli_ds_runtime ▇▅▇█▁▂▂▂▁▂
wandb:  eval/snli_ds_samples_per_second ▂▃▂▁█▇▇▇▇▆
wandb:    eval/snli_ds_steps_per_second ▂▃▂▁█▇▇▇▇▆
wandb:             eval/spearman_cosine ▁▄█▄▇▆▆▆▅▄
wandb:                eval/spearman_dot ▁▃█▅▇▇▆█▇▇
wandb:          eval/spearman_euclidean ▁▆█▇▇▇▇▇▇▆
wandb:          eval/spearman_manhattan ▁▅█▆▇▆▆▆▆▅
wandb:                eval/spearman_max ▁▅█▆▇▇▇▆▆▆
wandb:                      train/epoch ▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
wandb:                train/global_step ▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
wandb:              train/learning_rate ▃▄▆███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁
wandb:                       train/loss █▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:                 train/total_flos ▁
wandb:                 train/train_loss ▁
wandb:              train/train_runtime ▁
wandb:   train/train_samples_per_second ▁
wandb:     train/train_steps_per_second ▁
wandb:
wandb: Run summary:
wandb:               eval/multi_ds_loss 0.15338
wandb:            eval/multi_ds_runtime 13.8355
wandb: eval/multi_ds_samples_per_second 211.919
wandb:   eval/multi_ds_steps_per_second 3.325
wandb:              eval/pearson_cosine 0.87346
wandb:                 eval/pearson_dot 0.8435
wandb:           eval/pearson_euclidean 0.88102
wandb:           eval/pearson_manhattan 0.87955
wandb:                 eval/pearson_max 0.88102
wandb:                eval/snli_ds_loss 0.43842
wandb:             eval/snli_ds_runtime 7.8348
wandb:  eval/snli_ds_samples_per_second 375.633
wandb:    eval/snli_ds_steps_per_second 5.871
wandb:             eval/spearman_cosine 0.87785
wandb:                eval/spearman_dot 0.84023
wandb:          eval/spearman_euclidean 0.88182
wandb:          eval/spearman_manhattan 0.87963
wandb:                eval/spearman_max 0.88182
wandb:                      train/epoch 1.0
wandb:                train/global_step 4006
wandb:              train/learning_rate 0.0
wandb:                       train/loss 0.3543
wandb:                 train/total_flos 0.0
wandb:                 train/train_loss 0.51165
wandb:              train/train_runtime 1496.5601
wandb:   train/train_samples_per_second 185.276
wandb:     train/train_steps_per_second 2.677
wandb:
wandb:  View run hopeful-leaf-2371 at: https://wandb.ai/tomaarsen/huggingface/runs/qtcy12jb
wandb:  View job at https://wandb.ai/tomaarsen/huggingface/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzMzI2OTQzOQ==/version_details/v7
wandb: Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)
wandb: Find logs at: .\wandb\run-20240125_175042-qtcy12jb\logs
Training Notes

As the logs show, we now run two bits of evaluation, producing a eval_multi_ds_loss and a eval_snli_ds_loss. We only run the evaluator once, as expected. Also, the Spearman Cosine is unexpectedly really high - I didn't anticipate that: 87.06 is quite impressive for just finetuning from straight mpnet-base for 25 minutes. The best leaderboard models reach 89, and 87.06 would place 17th in MTEB on STSBenchmark out of 126 🎉

The loss goes as we had hoped, and the Spearman/Pearson performances are much jumpier, perhaps due to the much smaller batch size resulting in less "normalization"?
image
image

Interestingly, if we set the logging_steps back to 1 & we use the cached MNRL again, then we can clearly see the differences in performance for each of the two datasets:
image

The training loss keeps jumping back and forth, but the evaluation performance is nice and smooth. If we set the logging_steps a bit higher, the training loss would also start looking much smoother. This shows that unlike in experiment 3, the dataset samples are kept fully separate here, and it shows that SNLI is m uch harder than MultiNLI. Fascinating!

Example 5

This example finetunes microsoft/mpnet-base using SoftmaxLoss on SNLI as well as MultiNLI, and using CosineSimilarityLoss on the STSBenchmark training set. No samples from one dataset are mixed with another, and only the corresponding loss functions are used.

The script evaluates on SNLI and MultiNLI via the Softmax loss, on STSBenchmark dev via the CosineSimilarityLoss, and on the STSBenchmark development set via the EmbeddingSimilarityEvaluator. After training, the model performance is tested against the STSBenchmark test set.

Training Script
from collections import defaultdict
from typing import Dict
import datasets
from datasets import Dataset
from transformers import EvalPrediction
from sentence_transformers import (
    SentenceTransformerTrainingArguments,
    SentenceTransformer,
    SentenceTransformerTrainer,
    losses,
    evaluation,
)
from sentence_transformers.models import Transformer, Pooling

snli_ds = datasets.load_dataset("snli")
snli_ds = datasets.DatasetDict({
    "train": snli_ds["train"],
    "validation": snli_ds["validation"],
    "test": snli_ds["test"],
})
multi_nli_ds = datasets.load_dataset("multi_nli")
multi_nli_ds = datasets.DatasetDict({
    "train": multi_nli_ds["train"].remove_columns(set(multi_nli_ds["train"].column_names) - {"premise", "hypothesis", "label"}),
    "validation_matched": multi_nli_ds["validation_matched"].remove_columns(set(multi_nli_ds["validation_matched"].column_names) - {"premise", "hypothesis", "label"}),
})

def normalize_label(sample):
    sample["label"] = sample["label"] / 5
    return sample

stsb_train = datasets.load_dataset("mteb/stsbenchmark-sts", split="train").select_columns(["sentence1", "sentence2", "score"]).rename_column("score", "label").map(normalize_label)
stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation").select_columns(["sentence1", "sentence2", "score"]).rename_column("score", "label").map(normalize_label)
stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test").select_columns(["sentence1", "sentence2", "score"]).rename_column("score", "label").map(normalize_label)

train_dataset = datasets.DatasetDict({
    "multi_nli": multi_nli_ds["train"],
    "snli": snli_ds["train"].rename_column("premise", "snli_premise").filter(lambda x: x["label"] != -1),
    "stsb": stsb_train,
})
eval_dataset = datasets.DatasetDict({
    "multi_nli": multi_nli_ds["validation_matched"].select(range(100)),
    "snli": snli_ds["validation"].rename_column("premise", "snli_premise").filter(lambda x: x["label"] != -1),
    "stsb": stsb_dev,
})

training_args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    num_train_epochs=1,
    seed=33,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    bf16=True,
    logging_steps=1,
    evaluation_strategy="steps",
    eval_steps=10,
    save_steps=10,
    save_total_limit=2,
    metric_for_best_model="spearman_cosine",
    greater_is_better=True,
)

transformer = Transformer("microsoft/mpnet-base", max_seq_length=384)
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[transformer, pooling])

nli_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=3)
cosine_loss = losses.CosineSimilarityLoss(model)
dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
    stsb_dev["sentence1"],
    stsb_dev["sentence2"],
    stsb_dev["label"],
    main_similarity=evaluation.SimilarityFunction.COSINE,
    name="sts-dev",
)

trainer = SentenceTransformerTrainer(
    model=model,
    evaluator=dev_evaluator,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss={
        "multi_nli": nli_loss,
        "snli": nli_loss,
        "stsb": cosine_loss,
    },
)
trainer.train()

test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
    stsb_test["sentence1"],
    stsb_test["sentence2"],
    stsb_test["label"],
    main_similarity=evaluation.SimilarityFunction.COSINE,
    name="sts-test",
)
results = test_evaluator(model)
print(results)
Training Logs
Some weights of MPNetModel were not initialized from the model checkpoint at microsoft/mpnet-base and are newly initialized: ['mpnet.pooler.dense.bias', 'mpnet.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
wandb: Currently logged in as: tomaarsen. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.2
wandb: Run data is saved locally in C:\code\sentence-transformers\wandb\run-20240125_185336-e6dg4mqp
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run fresh-lion-2382
wandb:  View project at https://wandb.ai/tomaarsen/huggingface
wandb:  View run at https://wandb.ai/tomaarsen/huggingface/runs/e6dg4mqp
{'loss': 1.095, 'learning_rate': 1.4285714285714286e-06, 'epoch': 0.01}
{'loss': 1.0956, 'learning_rate': 2.8571428571428573e-06, 'epoch': 0.01}
{'loss': 0.2263, 'learning_rate': 4.2857142857142855e-06, 'epoch': 0.02}
{'loss': 1.0997, 'learning_rate': 5.7142857142857145e-06, 'epoch': 0.03}
{'loss': 1.1024, 'learning_rate': 7.1428571428571436e-06, 'epoch': 0.04}
{'loss': 0.2288, 'learning_rate': 8.571428571428571e-06, 'epoch': 0.04}
{'loss': 1.1045, 'learning_rate': 1e-05, 'epoch': 0.05}
{'loss': 1.0962, 'learning_rate': 1.1428571428571429e-05, 'epoch': 0.06}
{'loss': 0.1999, 'learning_rate': 1.2857142857142859e-05, 'epoch': 0.07}
{'loss': 1.0998, 'learning_rate': 1.4285714285714287e-05, 'epoch': 0.07}
{'eval_multi_nli_loss': 1.0992306470870972, 'eval_pearson_cosine': 0.5353005932385343, 'eval_spearman_cosine': 0.6372898564938384, 'eval_pearson_manhattan': 0.7090841704796322, 'eval_spearman_manhattan': 0.7152218072573169, 'eval_pearson_euclidean': 0.6227140703948133, 'eval_spearman_euclidean': 0.6534494640256371, 'eval_pearson_dot': 0.09094878346605995, 'eval_spearman_dot': 0.07617662182989246, 'eval_pearson_max': 0.7090841704796322, 'eval_spearman_max': 0.7152218072573169, 'eval_multi_nli_runtime': 3.8658, 'eval_multi_nli_samples_per_second': 25.868, 'eval_multi_nli_steps_per_second': 0.259, 'epoch': 0.07}
{'eval_snli_loss': 1.0988188982009888, 'eval_snli_runtime': 12.5746, 'eval_snli_samples_per_second': 782.689, 'eval_snli_steps_per_second': 6.123, 'epoch': 0.07}
{'eval_stsb_loss': 0.2768329977989197, 'eval_stsb_runtime': 2.0719, 'eval_stsb_samples_per_second': 723.983, 'eval_stsb_steps_per_second': 5.792, 'epoch': 0.07}
{'loss': 1.1021, 'learning_rate': 1.5714285714285715e-05, 'epoch': 0.08}
{'loss': 0.2063, 'learning_rate': 1.7142857142857142e-05, 'epoch': 0.09}
{'loss': 1.0987, 'learning_rate': 1.8571428571428575e-05, 'epoch': 0.1}
{'loss': 1.0937, 'learning_rate': 2e-05, 'epoch': 0.1}
{'loss': 0.1481, 'learning_rate': 1.9834710743801656e-05, 'epoch': 0.11}
{'loss': 1.0968, 'learning_rate': 1.9669421487603307e-05, 'epoch': 0.12}
{'loss': 1.098, 'learning_rate': 1.950413223140496e-05, 'epoch': 0.13}
{'loss': 0.1005, 'learning_rate': 1.9338842975206613e-05, 'epoch': 0.13}
{'loss': 1.0982, 'learning_rate': 1.9173553719008268e-05, 'epoch': 0.14}
{'loss': 1.0942, 'learning_rate': 1.900826446280992e-05, 'epoch': 0.15}
{'eval_multi_nli_loss': 1.0952303409576416, 'eval_pearson_cosine': 0.5927309185248284, 'eval_spearman_cosine': 0.6166161812130626, 'eval_pearson_manhattan': 0.6944794569756229, 'eval_spearman_manhattan': 0.697486468267352, 'eval_pearson_euclidean': 0.6355324280880893, 'eval_spearman_euclidean': 0.6475659071706225, 'eval_pearson_dot': 0.4557444361434066, 'eval_spearman_dot': 0.45829608099219776, 'eval_pearson_max': 0.6944794569756229, 'eval_spearman_max': 0.697486468267352, 'eval_multi_nli_runtime': 4.1919, 'eval_multi_nli_samples_per_second': 23.855, 'eval_multi_nli_steps_per_second': 0.239, 'epoch': 0.15}
{'eval_snli_loss': 1.096569538116455, 'eval_snli_runtime': 13.5008, 'eval_snli_samples_per_second': 728.996, 'eval_snli_steps_per_second': 5.703, 'epoch': 0.15}
{'eval_stsb_loss': 0.12020454555749893, 'eval_stsb_runtime': 2.2282, 'eval_stsb_samples_per_second': 673.192, 'eval_stsb_steps_per_second': 5.386, 'epoch': 0.15}
{'loss': 0.0711, 'learning_rate': 1.884297520661157e-05, 'epoch': 0.16}
{'loss': 1.0798, 'learning_rate': 1.8677685950413225e-05, 'epoch': 0.16}
{'loss': 1.0977, 'learning_rate': 1.851239669421488e-05, 'epoch': 0.17}
{'loss': 0.061, 'learning_rate': 1.834710743801653e-05, 'epoch': 0.18}
{'loss': 1.1083, 'learning_rate': 1.8181818181818182e-05, 'epoch': 0.19}
{'loss': 1.1019, 'learning_rate': 1.8016528925619837e-05, 'epoch': 0.19}
{'loss': 0.0722, 'learning_rate': 1.7851239669421488e-05, 'epoch': 0.2}
{'loss': 1.096, 'learning_rate': 1.7685950413223143e-05, 'epoch': 0.21}
{'loss': 1.0926, 'learning_rate': 1.7520661157024794e-05, 'epoch': 0.21}
{'loss': 0.08, 'learning_rate': 1.735537190082645e-05, 'epoch': 0.22}
{'eval_multi_nli_loss': 1.092315435409546, 'eval_pearson_cosine': 0.6701809629091153, 'eval_spearman_cosine': 0.6743649167442864, 'eval_pearson_manhattan': 0.7164336589463204, 'eval_spearman_manhattan': 0.7186635575961594, 'eval_pearson_euclidean': 0.6784895229680746, 'eval_spearman_euclidean': 0.6847840312956299, 'eval_pearson_dot': 0.5796864418283512, 'eval_spearman_dot': 0.5919393207888627, 'eval_pearson_max': 0.7164336589463204, 'eval_spearman_max': 0.7186635575961594, 'eval_multi_nli_runtime': 4.2366, 'eval_multi_nli_samples_per_second': 23.604, 'eval_multi_nli_steps_per_second': 0.236, 'epoch': 0.22}
{'eval_snli_loss': 1.091275691986084, 'eval_snli_runtime': 13.4917, 'eval_snli_samples_per_second': 729.484, 'eval_snli_steps_per_second': 5.707, 'epoch': 0.22}
{'eval_stsb_loss': 0.08859114348888397, 'eval_stsb_runtime': 2.2082, 'eval_stsb_samples_per_second': 679.3, 'eval_stsb_steps_per_second': 5.434, 'epoch': 0.22}
{'loss': 1.0881, 'learning_rate': 1.71900826446281e-05, 'epoch': 0.23}
{'loss': 1.0955, 'learning_rate': 1.7024793388429754e-05, 'epoch': 0.24}
{'loss': 0.0672, 'learning_rate': 1.6859504132231405e-05, 'epoch': 0.24}
{'loss': 1.0936, 'learning_rate': 1.669421487603306e-05, 'epoch': 0.25}
{'loss': 1.0925, 'learning_rate': 1.652892561983471e-05, 'epoch': 0.26}
{'loss': 0.0726, 'learning_rate': 1.6363636363636366e-05, 'epoch': 0.27}
{'loss': 1.0895, 'learning_rate': 1.6198347107438017e-05, 'epoch': 0.27}
{'loss': 1.0926, 'learning_rate': 1.6033057851239672e-05, 'epoch': 0.28}
{'loss': 0.0705, 'learning_rate': 1.5867768595041323e-05, 'epoch': 0.29}
{'loss': 1.0912, 'learning_rate': 1.5702479338842978e-05, 'epoch': 0.3}
{'eval_multi_nli_loss': 1.0895724296569824, 'eval_pearson_cosine': 0.7256318330540757, 'eval_spearman_cosine': 0.7202277704180212, 'eval_pearson_manhattan': 0.7345125003534995, 'eval_spearman_manhattan': 0.7365752140693772, 'eval_pearson_euclidean': 0.7059874570626268, 'eval_spearman_euclidean': 0.7109035415716308, 'eval_pearson_dot': 0.6799191156548992, 'eval_spearman_dot': 0.6807418143778174, 'eval_pearson_max': 0.7345125003534995, 'eval_spearman_max': 0.7365752140693772, 'eval_multi_nli_runtime': 4.2095, 'eval_multi_nli_samples_per_second': 23.756, 'eval_multi_nli_steps_per_second': 0.238, 'epoch': 0.3}
{'eval_snli_loss': 1.0868031978607178, 'eval_snli_runtime': 13.4653, 'eval_snli_samples_per_second': 730.915, 'eval_snli_steps_per_second': 5.718, 'epoch': 0.3}
{'eval_stsb_loss': 0.07544665783643723, 'eval_stsb_runtime': 2.2381, 'eval_stsb_samples_per_second': 670.218, 'eval_stsb_steps_per_second': 5.362, 'epoch': 0.3}
{'loss': 1.0956, 'learning_rate': 1.553719008264463e-05, 'epoch': 0.3}
{'loss': 0.0575, 'learning_rate': 1.5371900826446283e-05, 'epoch': 0.31}
{'loss': 1.0876, 'learning_rate': 1.5206611570247936e-05, 'epoch': 0.32}
{'loss': 1.0735, 'learning_rate': 1.504132231404959e-05, 'epoch': 0.33}
{'loss': 0.0828, 'learning_rate': 1.487603305785124e-05, 'epoch': 0.33}
{'loss': 1.0998, 'learning_rate': 1.4710743801652893e-05, 'epoch': 0.34}
{'loss': 1.0937, 'learning_rate': 1.4545454545454546e-05, 'epoch': 0.35}
{'loss': 0.0659, 'learning_rate': 1.4380165289256201e-05, 'epoch': 0.36}
{'loss': 1.0897, 'learning_rate': 1.4214876033057852e-05, 'epoch': 0.36}
{'loss': 1.0796, 'learning_rate': 1.4049586776859505e-05, 'epoch': 0.37}
{'eval_multi_nli_loss': 1.0950977802276611, 'eval_pearson_cosine': 0.7602059656388824, 'eval_spearman_cosine': 0.7539993008115333, 'eval_pearson_manhattan': 0.7483021330521298, 'eval_spearman_manhattan': 0.7488072578630834, 'eval_pearson_euclidean': 0.735619996045517, 'eval_spearman_euclidean': 0.7388005269910116, 'eval_pearson_dot': 0.7344251097899126, 'eval_spearman_dot': 0.7305945017622866, 'eval_pearson_max': 0.7602059656388824, 'eval_spearman_max': 0.7539993008115333, 'eval_multi_nli_runtime': 4.2343, 'eval_multi_nli_samples_per_second': 23.616, 'eval_multi_nli_steps_per_second': 0.236, 'epoch': 0.37}
{'eval_snli_loss': 1.089049220085144, 'eval_snli_runtime': 13.3874, 'eval_snli_samples_per_second': 735.166, 'eval_snli_steps_per_second': 5.752, 'epoch': 0.37}
{'eval_stsb_loss': 0.0684237852692604, 'eval_stsb_runtime': 2.2361, 'eval_stsb_samples_per_second': 670.818, 'eval_stsb_steps_per_second': 5.367, 'epoch': 0.37}
{'loss': 0.0568, 'learning_rate': 1.3884297520661158e-05, 'epoch': 0.38}
{'loss': 1.104, 'learning_rate': 1.3719008264462813e-05, 'epoch': 0.39}
{'loss': 1.0763, 'learning_rate': 1.3553719008264464e-05, 'epoch': 0.39}
{'loss': 0.062, 'learning_rate': 1.3388429752066117e-05, 'epoch': 0.4}
{'loss': 1.0857, 'learning_rate': 1.322314049586777e-05, 'epoch': 0.41}
{'loss': 1.1024, 'learning_rate': 1.3057851239669424e-05, 'epoch': 0.41}
{'loss': 0.0571, 'learning_rate': 1.2892561983471074e-05, 'epoch': 0.42}
{'loss': 1.0973, 'learning_rate': 1.2727272727272728e-05, 'epoch': 0.43}
{'loss': 1.0936, 'learning_rate': 1.2561983471074381e-05, 'epoch': 0.44}
{'loss': 0.0532, 'learning_rate': 1.2396694214876034e-05, 'epoch': 0.44}
{'eval_multi_nli_loss': 1.0998018980026245, 'eval_pearson_cosine': 0.7812947108013641, 'eval_spearman_cosine': 0.7762512336075155, 'eval_pearson_manhattan': 0.7678945230056597, 'eval_spearman_manhattan': 0.7671076116552968, 'eval_pearson_euclidean': 0.7627800766149042, 'eval_spearman_euclidean': 0.7624585248252842, 'eval_pearson_dot': 0.7587217776613172, 'eval_spearman_dot': 0.7526352849675395, 'eval_pearson_max': 0.7812947108013641, 'eval_spearman_max': 0.7762512336075155, 'eval_multi_nli_runtime': 4.2318, 'eval_multi_nli_samples_per_second': 23.631, 'eval_multi_nli_steps_per_second': 0.236, 'epoch': 0.44}
{'eval_snli_loss': 1.0912853479385376, 'eval_snli_runtime': 13.4865, 'eval_snli_samples_per_second': 729.769, 'eval_snli_steps_per_second': 5.709, 'epoch': 0.44}
{'eval_stsb_loss': 0.06371311098337173, 'eval_stsb_runtime': 2.2262, 'eval_stsb_samples_per_second': 673.8, 'eval_stsb_steps_per_second': 5.39, 'epoch': 0.44}
{'loss': 1.0977, 'learning_rate': 1.2231404958677686e-05, 'epoch': 0.45}
{'loss': 1.0912, 'learning_rate': 1.206611570247934e-05, 'epoch': 0.46}
{'loss': 0.0575, 'learning_rate': 1.1900826446280993e-05, 'epoch': 0.47}
{'loss': 1.0989, 'learning_rate': 1.1735537190082646e-05, 'epoch': 0.47}
{'loss': 1.0761, 'learning_rate': 1.1570247933884297e-05, 'epoch': 0.48}
{'loss': 0.061, 'learning_rate': 1.1404958677685952e-05, 'epoch': 0.49}
{'loss': 1.097, 'learning_rate': 1.1239669421487605e-05, 'epoch': 0.5}
{'loss': 1.0932, 'learning_rate': 1.1074380165289258e-05, 'epoch': 0.5}
{'loss': 0.0531, 'learning_rate': 1.0909090909090909e-05, 'epoch': 0.51}
{'loss': 1.088, 'learning_rate': 1.0743801652892562e-05, 'epoch': 0.52}
{'eval_multi_nli_loss': 1.0866062641143799, 'eval_pearson_cosine': 0.7871914530962119, 'eval_spearman_cosine': 0.7803757033323617, 'eval_pearson_manhattan': 0.7756029556268689, 'eval_spearman_manhattan': 0.7754081322929453, 'eval_pearson_euclidean': 0.7711928033175355, 'eval_spearman_euclidean': 0.7717827401145304, 'eval_pearson_dot': 0.759673346341878, 'eval_spearman_dot': 0.7525053549420156, 'eval_pearson_max': 0.7871914530962119, 'eval_spearman_max': 0.7803757033323617, 'eval_multi_nli_runtime': 6.2649, 'eval_multi_nli_samples_per_second': 15.962, 'eval_multi_nli_steps_per_second': 0.16, 'epoch': 0.52}
{'eval_snli_loss': 1.0768821239471436, 'eval_snli_runtime': 15.3064, 'eval_snli_samples_per_second': 643.0, 'eval_snli_steps_per_second': 5.031, 'epoch': 0.52}
{'eval_stsb_loss': 0.0561554878950119, 'eval_stsb_runtime': 2.2457, 'eval_stsb_samples_per_second': 667.952, 'eval_stsb_steps_per_second': 5.344, 'epoch': 0.52}
{'loss': 1.0875, 'learning_rate': 1.0578512396694216e-05, 'epoch': 0.53}
{'loss': 0.0567, 'learning_rate': 1.041322314049587e-05, 'epoch': 0.53}
{'loss': 1.0928, 'learning_rate': 1.024793388429752e-05, 'epoch': 0.54}
{'loss': 1.0938, 'learning_rate': 1.0082644628099174e-05, 'epoch': 0.55}
{'loss': 0.0468, 'learning_rate': 9.917355371900828e-06, 'epoch': 0.56}
{'loss': 1.0825, 'learning_rate': 9.75206611570248e-06, 'epoch': 0.56}
{'loss': 1.0727, 'learning_rate': 9.586776859504134e-06, 'epoch': 0.57}
{'loss': 0.0475, 'learning_rate': 9.421487603305785e-06, 'epoch': 0.58}
{'loss': 1.0789, 'learning_rate': 9.25619834710744e-06, 'epoch': 0.59}
{'loss': 1.0879, 'learning_rate': 9.090909090909091e-06, 'epoch': 0.59}
{'eval_multi_nli_loss': 1.0869539976119995, 'eval_pearson_cosine': 0.7924287839974626, 'eval_spearman_cosine': 0.7878582867813361, 'eval_pearson_manhattan': 0.7820553059604191, 'eval_spearman_manhattan': 0.7814241773646012, 'eval_pearson_euclidean': 0.7769628784610649, 'eval_spearman_euclidean': 0.7771091619350827, 'eval_pearson_dot': 0.7660649225153495, 'eval_spearman_dot': 0.7602893122641841, 'eval_pearson_max': 0.7924287839974626, 'eval_spearman_max': 0.7878582867813361, 'eval_multi_nli_runtime': 4.0923, 'eval_multi_nli_samples_per_second': 24.436, 'eval_multi_nli_steps_per_second': 0.244, 'epoch': 0.59}
{'eval_snli_loss': 1.0761278867721558, 'eval_snli_runtime': 13.4861, 'eval_snli_samples_per_second': 729.789, 'eval_snli_steps_per_second': 5.71, 'epoch': 0.59}
{'eval_stsb_loss': 0.05780908837914467, 'eval_stsb_runtime': 2.2124, 'eval_stsb_samples_per_second': 678.011, 'eval_stsb_steps_per_second': 5.424, 'epoch': 0.59}
{'loss': 0.0601, 'learning_rate': 8.925619834710744e-06, 'epoch': 0.6}
{'loss': 1.0932, 'learning_rate': 8.760330578512397e-06, 'epoch': 0.61}
{'loss': 1.0758, 'learning_rate': 8.59504132231405e-06, 'epoch': 0.61}
{'loss': 0.0504, 'learning_rate': 8.429752066115703e-06, 'epoch': 0.62}
{'loss': 1.0819, 'learning_rate': 8.264462809917356e-06, 'epoch': 0.63}
{'loss': 1.0839, 'learning_rate': 8.099173553719009e-06, 'epoch': 0.64}
{'loss': 0.0543, 'learning_rate': 7.933884297520661e-06, 'epoch': 0.64}
{'loss': 1.0835, 'learning_rate': 7.768595041322314e-06, 'epoch': 0.65}
{'loss': 1.0627, 'learning_rate': 7.603305785123968e-06, 'epoch': 0.66}
{'loss': 0.043, 'learning_rate': 7.43801652892562e-06, 'epoch': 0.67}
{'eval_multi_nli_loss': 1.0874528884887695, 'eval_pearson_cosine': 0.7903836294144213, 'eval_spearman_cosine': 0.7870760513667576, 'eval_pearson_manhattan': 0.7832143180912645, 'eval_spearman_manhattan': 0.7825206176220711, 'eval_pearson_euclidean': 0.7781134269436615, 'eval_spearman_euclidean': 0.7782710986789753, 'eval_pearson_dot': 0.7625101087174764, 'eval_spearman_dot': 0.7573616892499614, 'eval_pearson_max': 0.7903836294144213, 'eval_spearman_max': 0.7870760513667576, 'eval_multi_nli_runtime': 4.2078, 'eval_multi_nli_samples_per_second': 23.765, 'eval_multi_nli_steps_per_second': 0.238, 'epoch': 0.67}
{'eval_snli_loss': 1.0738754272460938, 'eval_snli_runtime': 13.6675, 'eval_snli_samples_per_second': 720.104, 'eval_snli_steps_per_second': 5.634, 'epoch': 0.67}
{'eval_stsb_loss': 0.05803224816918373, 'eval_stsb_runtime': 2.2565, 'eval_stsb_samples_per_second': 664.735, 'eval_stsb_steps_per_second': 5.318, 'epoch': 0.67}
{'loss': 1.0916, 'learning_rate': 7.272727272727273e-06, 'epoch': 0.67}
{'loss': 1.0698, 'learning_rate': 7.107438016528926e-06, 'epoch': 0.68}
{'loss': 0.047, 'learning_rate': 6.942148760330579e-06, 'epoch': 0.69}
{'loss': 1.099, 'learning_rate': 6.776859504132232e-06, 'epoch': 0.7}
{'loss': 1.0796, 'learning_rate': 6.611570247933885e-06, 'epoch': 0.7}
{'loss': 0.0491, 'learning_rate': 6.446280991735537e-06, 'epoch': 0.71}
{'loss': 1.0955, 'learning_rate': 6.280991735537191e-06, 'epoch': 0.72}
{'loss': 1.0834, 'learning_rate': 6.115702479338843e-06, 'epoch': 0.73}
{'loss': 0.0541, 'learning_rate': 5.9504132231404965e-06, 'epoch': 0.73}
{'loss': 1.0899, 'learning_rate': 5.785123966942149e-06, 'epoch': 0.74}
{'eval_multi_nli_loss': 1.0859707593917847, 'eval_pearson_cosine': 0.7880593110224633, 'eval_spearman_cosine': 0.7852908757591515, 'eval_pearson_manhattan': 0.784064221417867, 'eval_spearman_manhattan': 0.7836140594637, 'eval_pearson_euclidean': 0.7786970630754979, 'eval_spearman_euclidean': 0.7792632986510772, 'eval_pearson_dot': 0.7567199531111137, 'eval_spearman_dot': 0.753033446722682, 'eval_pearson_max': 0.7880593110224633, 'eval_spearman_max': 0.7852908757591515, 'eval_multi_nli_runtime': 4.1854, 'eval_multi_nli_samples_per_second': 23.892, 'eval_multi_nli_steps_per_second': 0.239, 'epoch': 0.74}
{'eval_snli_loss': 1.068565011024475, 'eval_snli_runtime': 13.2992, 'eval_snli_samples_per_second': 740.043, 'eval_snli_steps_per_second': 5.79, 'epoch': 0.74}
{'eval_stsb_loss': 0.05798153579235077, 'eval_stsb_runtime': 2.1889, 'eval_stsb_samples_per_second': 685.278, 'eval_stsb_steps_per_second': 5.482, 'epoch': 0.74}
{'loss': 1.076, 'learning_rate': 5.619834710743802e-06, 'epoch': 0.75}
{'loss': 0.0446, 'learning_rate': 5.4545454545454545e-06, 'epoch': 0.76}
{'loss': 1.0778, 'learning_rate': 5.289256198347108e-06, 'epoch': 0.76}
{'loss': 1.0668, 'learning_rate': 5.12396694214876e-06, 'epoch': 0.77}
{'loss': 0.0461, 'learning_rate': 4.958677685950414e-06, 'epoch': 0.78}
{'loss': 1.0872, 'learning_rate': 4.793388429752067e-06, 'epoch': 0.79}
{'loss': 1.0685, 'learning_rate': 4.62809917355372e-06, 'epoch': 0.79}
{'loss': 0.0581, 'learning_rate': 4.462809917355372e-06, 'epoch': 0.8}
{'loss': 1.0749, 'learning_rate': 4.297520661157025e-06, 'epoch': 0.81}
{'loss': 1.0731, 'learning_rate': 4.132231404958678e-06, 'epoch': 0.81}
{'eval_multi_nli_loss': 1.084850549697876, 'eval_pearson_cosine': 0.7933557513384137, 'eval_spearman_cosine': 0.7900928830022685, 'eval_pearson_manhattan': 0.7867294337881144, 'eval_spearman_manhattan': 0.7861964959820836, 'eval_pearson_euclidean': 0.779952868149915, 'eval_spearman_euclidean': 0.7804624119690476, 'eval_pearson_dot': 0.7668674759177159, 'eval_spearman_dot': 0.7617434649371185, 'eval_pearson_max': 0.7933557513384137, 'eval_spearman_max': 0.7900928830022685, 'eval_multi_nli_runtime': 4.3507, 'eval_multi_nli_samples_per_second': 22.985, 'eval_multi_nli_steps_per_second': 0.23, 'epoch': 0.81}
{'eval_snli_loss': 1.0644443035125732, 'eval_snli_runtime': 13.3746, 'eval_snli_samples_per_second': 735.874, 'eval_snli_steps_per_second': 5.757, 'epoch': 0.81}
{'eval_stsb_loss': 0.05611236020922661, 'eval_stsb_runtime': 2.2315, 'eval_stsb_samples_per_second': 672.183, 'eval_stsb_steps_per_second': 5.377, 'epoch': 0.81}
{'loss': 0.0453, 'learning_rate': 3.966942148760331e-06, 'epoch': 0.82}
{'loss': 1.0676, 'learning_rate': 3.801652892561984e-06, 'epoch': 0.83}
{'loss': 1.074, 'learning_rate': 3.6363636363636366e-06, 'epoch': 0.84}
{'loss': 0.0452, 'learning_rate': 3.4710743801652895e-06, 'epoch': 0.84}
{'loss': 1.0786, 'learning_rate': 3.3057851239669424e-06, 'epoch': 0.85}
{'loss': 1.0742, 'learning_rate': 3.1404958677685953e-06, 'epoch': 0.86}
{'loss': 0.0433, 'learning_rate': 2.9752066115702483e-06, 'epoch': 0.87}
{'loss': 1.0887, 'learning_rate': 2.809917355371901e-06, 'epoch': 0.87}
{'loss': 1.0639, 'learning_rate': 2.644628099173554e-06, 'epoch': 0.88}
{'loss': 0.0509, 'learning_rate': 2.479338842975207e-06, 'epoch': 0.89}
{'eval_multi_nli_loss': 1.0812655687332153, 'eval_pearson_cosine': 0.7983019769536408, 'eval_spearman_cosine': 0.7935848619025935, 'eval_pearson_manhattan': 0.7890734857934782, 'eval_spearman_manhattan': 0.7889181024584033, 'eval_pearson_euclidean': 0.7823518999829207, 'eval_spearman_euclidean': 0.7829360129261643, 'eval_pearson_dot': 0.7720565705042415, 'eval_spearman_dot': 0.765449435037735, 'eval_pearson_max': 0.7983019769536408, 'eval_spearman_max': 0.7935848619025935, 'eval_multi_nli_runtime': 4.2421, 'eval_multi_nli_samples_per_second': 23.573, 'eval_multi_nli_steps_per_second': 0.236, 'epoch': 0.89}
{'eval_snli_loss': 1.0605069398880005, 'eval_snli_runtime': 13.3989, 'eval_snli_samples_per_second': 734.536, 'eval_snli_steps_per_second': 5.747, 'epoch': 0.89}
{'eval_stsb_loss': 0.05315973982214928, 'eval_stsb_runtime': 2.2073, 'eval_stsb_samples_per_second': 679.574, 'eval_stsb_steps_per_second': 5.437, 'epoch': 0.89}
{'loss': 1.0747, 'learning_rate': 2.31404958677686e-06, 'epoch': 0.9}
{'loss': 1.0715, 'learning_rate': 2.1487603305785124e-06, 'epoch': 0.9}
{'loss': 0.0451, 'learning_rate': 1.9834710743801654e-06, 'epoch': 0.91}
{'loss': 1.0878, 'learning_rate': 1.8181818181818183e-06, 'epoch': 0.92}
{'loss': 1.0655, 'learning_rate': 1.6528925619834712e-06, 'epoch': 0.93}
{'loss': 0.0566, 'learning_rate': 1.4876033057851241e-06, 'epoch': 0.93}
{'loss': 1.0869, 'learning_rate': 1.322314049586777e-06, 'epoch': 0.94}
{'loss': 1.0758, 'learning_rate': 1.15702479338843e-06, 'epoch': 0.95}
{'loss': 0.06, 'learning_rate': 9.917355371900827e-07, 'epoch': 0.96}
{'loss': 1.0749, 'learning_rate': 8.264462809917356e-07, 'epoch': 0.96}
{'eval_multi_nli_loss': 1.0818829536437988, 'eval_pearson_cosine': 0.7979863315970133, 'eval_spearman_cosine': 0.7935847106569378, 'eval_pearson_manhattan': 0.7890500102795305, 'eval_spearman_manhattan': 0.7888443319449256, 'eval_pearson_euclidean': 0.7827435065490354, 'eval_spearman_euclidean': 0.7832764562213034, 'eval_pearson_dot': 0.772029153509997, 'eval_spearman_dot': 0.7653096253328342, 'eval_pearson_max': 0.7979863315970133, 'eval_spearman_max': 0.7935847106569378, 'eval_multi_nli_runtime': 5.2211, 'eval_multi_nli_samples_per_second': 19.153, 'eval_multi_nli_steps_per_second': 0.192, 'epoch': 0.96}
{'eval_snli_loss': 1.060091257095337, 'eval_snli_runtime': 13.4275, 'eval_snli_samples_per_second': 732.973, 'eval_snli_steps_per_second': 5.735, 'epoch': 0.96}
{'eval_stsb_loss': 0.053038228303194046, 'eval_stsb_runtime': 2.2009, 'eval_stsb_samples_per_second': 681.531, 'eval_stsb_steps_per_second': 5.452, 'epoch': 0.96}
{'loss': 1.0719, 'learning_rate': 6.611570247933885e-07, 'epoch': 0.97}
{'loss': 0.0497, 'learning_rate': 4.958677685950413e-07, 'epoch': 0.98}
{'loss': 1.0725, 'learning_rate': 3.3057851239669426e-07, 'epoch': 0.99}
{'loss': 1.0752, 'learning_rate': 1.6528925619834713e-07, 'epoch': 0.99}
{'loss': 0.0461, 'learning_rate': 0.0, 'epoch': 1.0}
{'train_runtime': 382.8386, 'train_samples_per_second': 2475.764, 'train_steps_per_second': 0.353, 'train_loss': 0.7491866049943147, 'epoch': 1.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 135/135 [06:18<00:00,  2.80s/it] 
{'pearson_cosine': 0.7268712657649615, 'spearman_cosine': 0.7057749245910342, 'pearson_manhattan': 0.7313316613855259, 'spearman_manhattan': 0.7110533973222474, 'pearson_euclidean': 0.7223762962847724, 'spearman_euclidean': 0.7036022336620789, 'pearson_dot': 0.6751281616886309, 'spearman_dot': 0.6573632212002909, 'pearson_max': 0.7313316613855259, 'spearman_max': 0.7110533973222474}
wandb: | 0.049 MB of 0.049 MB uploaded
wandb: Run history:
wandb:               eval/multi_nli_loss █▆▅▄▆█▃▃▃▃▂▁▁
wandb:            eval/multi_nli_runtime ▁▂▂▂▂▂█▂▂▂▂▂▅
wandb: eval/multi_nli_samples_per_second █▇▆▇▆▆▁▇▇▇▆▆▃
wandb:   eval/multi_nli_steps_per_second █▇▆▇▆▆▁▇▇▇▆▆▃
wandb:               eval/pearson_cosine ▁▃▅▆▇████████
wandb:                  eval/pearson_dot ▁▅▆▇█████████
wandb:            eval/pearson_euclidean ▁▂▃▅▆▇▇██████
wandb:            eval/pearson_manhattan ▂▁▃▄▅▆▇▇█████
wandb:                  eval/pearson_max ▂▁▂▄▅▇▇█▇▇███
wandb:                    eval/snli_loss ██▇▆▆▇▄▄▃▃▂▁▁
wandb:                 eval/snli_runtime ▁▃▃▃▃▃█▃▄▃▃▃▃
wandb:      eval/snli_samples_per_second █▅▅▅▆▅▁▅▅▆▆▆▆
wandb:        eval/snli_steps_per_second █▅▅▅▆▅▁▅▅▆▆▆▆
wandb:              eval/spearman_cosine ▂▁▃▅▆▇▇██████
wandb:                 eval/spearman_dot ▁▅▆▇█████████
wandb:           eval/spearman_euclidean ▁▁▃▄▆▇▇██████
wandb:           eval/spearman_manhattan ▂▁▃▄▅▆▇▇█████
wandb:                 eval/spearman_max ▂▁▃▄▅▇▇██▇███
wandb:                    eval/stsb_loss █▃▂▂▁▁▁▁▁▁▁▁▁
wandb:                 eval/stsb_runtime ▁▇▆▇▇▇█▆█▅▇▆▆
wandb:      eval/stsb_samples_per_second █▂▃▂▂▂▁▃▁▃▂▃▃
wandb:        eval/stsb_steps_per_second █▂▃▂▂▂▁▃▁▃▂▃▃
wandb:                       train/epoch ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
wandb:                 train/global_step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
wandb:               train/learning_rate ▂▃▅▇███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁
wandb:                        train/loss ███▂█▁█▁████▁█▁████▁█▁█▁██▁█▁█▁████▁█▁██
wandb:                  train/total_flos ▁
wandb:                  train/train_loss ▁
wandb:               train/train_runtime ▁
wandb:    train/train_samples_per_second ▁
wandb:      train/train_steps_per_second ▁
wandb:
wandb: Run summary:
wandb:               eval/multi_nli_loss 1.08188
wandb:            eval/multi_nli_runtime 5.2211
wandb: eval/multi_nli_samples_per_second 19.153
wandb:   eval/multi_nli_steps_per_second 0.192
wandb:               eval/pearson_cosine 0.79799
wandb:                  eval/pearson_dot 0.77203
wandb:            eval/pearson_euclidean 0.78274
wandb:            eval/pearson_manhattan 0.78905
wandb:                  eval/pearson_max 0.79799
wandb:                    eval/snli_loss 1.06009
wandb:                 eval/snli_runtime 13.4275
wandb:      eval/snli_samples_per_second 732.973
wandb:        eval/snli_steps_per_second 5.735
wandb:              eval/spearman_cosine 0.79358
wandb:                 eval/spearman_dot 0.76531
wandb:           eval/spearman_euclidean 0.78328
wandb:           eval/spearman_manhattan 0.78884
wandb:                 eval/spearman_max 0.79358
wandb:                    eval/stsb_loss 0.05304
wandb:                 eval/stsb_runtime 2.2009
wandb:      eval/stsb_samples_per_second 681.531
wandb:        eval/stsb_steps_per_second 5.452
wandb:                       train/epoch 1.0
wandb:                 train/global_step 135
wandb:               train/learning_rate 0.0
wandb:                        train/loss 0.0461
wandb:                  train/total_flos 0.0
wandb:                  train/train_loss 0.74919
wandb:               train/train_runtime 382.8386
wandb:    train/train_samples_per_second 2475.764
wandb:      train/train_steps_per_second 0.353
wandb:
wandb:  View run fresh-lion-2382 at: https://wandb.ai/tomaarsen/huggingface/runs/e6dg4mqp
wandb:  View job at https://wandb.ai/tomaarsen/huggingface/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzMzMzOTAyNA==/version_details/v11
wandb: Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: .\wandb\run-20240125_185336-e6dg4mqp\logs
Training Notes

Although this model is not as good as the previous ones, that was to be expected given the less useful loss functions. That said, the experiment was primarily to train a model with multiple loss functions effectively. Because the logging_steps is just 1, we end up with a very jumpy training loss: there's no averaging after all:
image

It's hard to tell what's going on there, but luckily the evaluation losses are much clearer:
image
image
image

Also note that training concluded fairly quickly, because in the default round-robin behaviour that SentenceTransformer.fit uses, the model stops training once one dataset has exhausted its data, and the STSBenchmark training set is quite small. I want to update the default to a smarter sampling, but keep the round-robin approach if a model is trained via .fit, to prevent any changes in behaviour from v2 to v3 for users.

  • Tom Aarsen

@tomaarsen
Copy link
Collaborator Author

@b5y you expressed an interest in training Sentence Transformer models using the transformers Trainer in #2446 - this PR should work for you. If you'd like, you're free to experiment with it!

You can install it (for now) with:

pip install -U https://github.com/tomaarsen/sentence-transformers@v3/trainer_refactor
  • Tom Aarsen

@tomaarsen tomaarsen linked an issue Jan 26, 2024 that may be closed by this pull request
@b5y
Copy link

b5y commented Jan 26, 2024

Thanks a lot for mentioning!
I will take a look and come back to you soon.

@b5y
Copy link

b5y commented Jan 26, 2024

Hey @tomaarsen !

Thanks for your work again!

I looked carefully and didn't find any examples for asymmetric semantic search using MSMARCO-based models and dot-product score.

Are you planning to add this kind of example or you expect others to try out?

@tomaarsen
Copy link
Collaborator Author

Heya!

You might be able to use any of the MSMarco examples from here, e.g. this one: https://github.com/tomaarsen/sentence-transformers/blob/v3/trainer_refactor/examples/training/ms_marco/train_bi-encoder_mnrl.py
In this PR, the fit is internally replaced with a SentenceTransformerTrainer instance.

I haven't yet tested if this one works, but I would like for it to work before I could merge this PR.

  • Tom Aarsen

@johnsutor
Copy link

If it's using the transformers trainer, should that mean that streamable datasets will be supported? (this is in relation to #2232)

@tomaarsen
Copy link
Collaborator Author

Yes, it should be supported.

@zhangjoy2
Copy link

zhangjoy2 commented May 13, 2024 via email

@tomaarsen
Copy link
Collaborator Author

tomaarsen commented May 13, 2024

Using this PR, is it now doable to use multiuple GPUs for training?

Yes.
And it does result in speedups. See #2449 (comment) for some exact metrics that I gathered. tl;dr: 2.56x faster training if you use 4 GPUs with DDP instead of just 1 GPU, and you can use DDP by running your normal training scripts with torchrun ... train.py instead of python train.py.

Consider checking out the training examples in the v3.0 pre-release branch. Each of those scripts should work with multi-GPU (although you will likely need to wrap the entire script in def main() ... if __name__ == "__main__": main() as otherwise each of the processes thinks that it's the main one.

  • Tom Aarsen

@Jakobhenningjensen
Copy link

@tomaarsen Aah sorry, I missed that benchmark (thanks a bunch!).

Is the branch "v3.0-pre-release" the most "stable 3.0" branch or this there another branch which would be better to clone if I wan't to use it?

@tomaarsen
Copy link
Collaborator Author

No worries, there's a lot of comments/commits here! Even I scrolled past the benchmark 😄

v3.0-pre-release is the most stable v3 branch indeed. My intention is that the training won't change any further until the full release (apart from some bug fixes that I might find), so you should be pretty safe to use it.

In the coming days, #2632 will have more documentation on how training with v3 will work, but for now the best way is to look at the PR description for this merged PR and the updated training examples. Feel free to let me know if you have any issues/feedback.

  • Tom Aarsen

@Jakobhenningjensen
Copy link

Jakobhenningjensen commented May 13, 2024

I'm currently looking at it now and I struggle a bit on how to define the train_dataset parsed to the SentenceTransformerTrainer.
As far as I can see it is either a dict or a DataSet but I can't see where the structure is defined i.e should the keys have special names in case what (say I want to train a SentenceTransformer with the MultipleNegativesRankingLoss)?

Before (v2.7) the training script was

training_data = [InputExample(texts=[t1, t2]) for t1, t2 in zip(data["Texts"], data["MoreText"])]
model = SentenceTransformer("intfloat/multilingual-e5-small",  device="cuda")
dataloader = DataLoader(training_data)
loss = losses.MultipleNegativesRankingLoss(model)
model.fit([(dataloader,loss)])

It seems like the column/key names etc. doesn't matter using MultipleNegativesRankingLoss but I don't think I get how it would work if we are using (anchor, positive, negative) i.e how we can define which is which

@tomaarsen
Copy link
Collaborator Author

tomaarsen commented May 13, 2024

2 comments:

  1. That training script will still work. model.fit will still exist; it'll just create a SentenceTransformerTrainer instance behind the scenes. The primary difference is that model.fit won't expose the new training options that you get when you train with the new approach.

  2. The train_dataset is a Dataset instance from the datasets Python package, often initialized with load_dataset. If I have my data locally, then my favourite way to initialize it is with from_dict:

from datasets import Dataset

anchors = []
positives = []
for ...
    anchors.append(...)
    positives.append(...)

dataset = Dataset.from_dict({"anchor": anchors, "positive": positives})

There are a few rules for the dataset:

  • If a column is named "score" or "label", then it corresponds with the "Labels" from the Loss Overview.
  • Otherwise, each column (in order!) is passed to the loss function, regardless of the column name. So, the column names are ignored: the column order is used.

So, the dataset from my previous snippet corresponds with (anchor, positive) pairs and none labels in the Loss Overview, so I can use CachedMultipleNegativesRankingLoss, MultipleNegativesRankingLoss, etc., even if the column names in that snippet were actually "text" and "other_text" or something.


I have prepared a few datasets in https://huggingface.co/datasets?other=sentence-transformers that can be used directly with Sentence Transformers. You can load these e.g. with:

from datasets import load_dataset

dataset = load_dataset("sentence-transformers/all-nli", "pair")

Lastly, here is a complete training script containing essentially all training features:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator

# 1. Load a model to finetune
model = SentenceTransformer("microsoft/mpnet-base")

# 2. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"]
eval_dataset = dataset["dev"]
test_dataset = dataset["test"]

# 3. Define a loss function
loss = MultipleNegativesRankingLoss(model)

# 4. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/mpnet-base-all-nli-triplet",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="mpnet-base-all-nli-triplet",  # Will be used in W&B if `wandb` is installed
)

# 5. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    name="all-nli-dev",
)
dev_evaluator(model)

# 6. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
)
trainer.train()

# Evaluate the trained model on the test set
test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="all-nli-test",
)
test_evaluator(model)

# Save the trained model and/or push it to the Hugging Face Hub
model.save_pretrained("models/mpnet-base-all-nli-triplet")
# model.push_to_hub("mpnet-base-all-nli-triplet")
  • Tom Aarsen

@Jakobhenningjensen
Copy link

Jakobhenningjensen commented May 13, 2024

That training script will still work. model.fit will still exist; it'll just create a SentenceTransformerTrainer instance behind the scenes. The primary difference is that model.fit won't expose the new training options that you get when you train with the new approach.

Yeah, that is why I wanted to use the "new" approach.

There are a few rules for the dataset:

If a column is named "score" or "label", then it corresponds with the "Labels" from the Loss Overview.
Otherwise, each column (in order!) is passed to the loss function, regardless of the column name. So, the column names are ignored: the column order is used.

Nice, thanks!

args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir="models/mpnet-base-all-nli-triplet",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=100,
run_name="mpnet-base-all-nli-triplet", # Will be used in W&B if wandb is installed
)

I'm getting a "no keyword argument 'eval_strategy'" ?

@tomaarsen
Copy link
Collaborator Author

I'm getting a "no keyword argument 'eval_strategy'" ?

transformers renamed evaluation_strategy to eval_strategy recently. You can either upgrade your transformers or use evaluation_strategy.

  • Tom Aarsen

@zhanxlin
Copy link

Hi, how to set max_seq_length.

@tomaarsen
Copy link
Collaborator Author

@zhanxlin

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("bert-base-uncased")
model.max_seq_length = 256
print(model.max_seq_length)
# => 256

Works both for Sentence Transformer models (e.g. "all-mpnet-base-v2") and non-Sentence Transformer models (e.g. "bert-base-cased")

  • Tom Aarsen

@Jakobhenningjensen
Copy link

Hello @tomaarsen
I've tried to follow the DDP training using this script:

def train_model():
    model = SentenceTransformer("intfloat/multilingual-e5-small",  device="cuda")
    loss = losses.MultipleNegativesRankingLoss(model)
    training_args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir="./sbert_fitted/",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=32,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=False,  # Set to True if you have a GPU that supports BF16
        # Optional tracking/debugging parameters:
        eval_steps=100,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=2,
        logging_steps=100,
        run_name="sts",  # Will be used in W&B if `wandb` is installed,
        report_to="none"
    )
    data = get_main_data()
    train_dataset = Dataset.from_pandas(data[["Text1","Text2"]])


    trainer = SentenceTransformerTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        loss=loss,
    )
    print("Training ...")
    trainer.train()


if __name__=="__main__":
    train_model()

it is run in a Docker (on google cloud) using the following command

ENTRYPOINT ["torchrun", "--nproc_per_node", "2", "train.py"]

which should use 2 GPUs.

I'm getting an error though RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. [torch.cuda.LongTensor [1, 9]] is at version 3; expected version 2 instead.

Is that something you have seen before?

On another note; it seems like get_main_data is being called twice - is that expected?

@tomaarsen
Copy link
Collaborator Author

tomaarsen commented May 16, 2024

Hello!

Thanks for testing this; I have seen something similar, yes. I experienced a similar error with #2647. Let me quickly try to create a fix, then you can give that a try.
And yes, I believe that the data should be prepared for each GPU, but I'm not 100% sure on that.

Edit: FYI, your script looks good.

  • Tom Aarsen

@vkleban
Copy link

vkleban commented May 23, 2024

Shall it also work with the CrossEncoder training? It looks like cross encoder training is still using an old training backend.

@tomaarsen
Copy link
Collaborator Author

I'm afraid that CrossEncoder training will still use the old recipe, but updating CrossEncoder training is high on the TODO list after v3.

  • Tom Aarsen

@karan842
Copy link

Hello @tomaarsen,

I have used your training script code from example, wrap up training and model initialization in main() and called the main function. It works!!! Thanks

But, label is required for training this script. I am planning to work on sentence-transformer for semantic search where I am fine-tuning my model using sbert. For that this script is giving me an error that label is needed.

How to run sentence-transformer for semantic search usecase using multi-gpu with label score(float) and without label score?

@tomaarsen
Copy link
Collaborator Author

Hello!

It depends on your data and your loss functions. Some loss functions require a label, and some don't. There's a few common formats for your data, and you can read about them here: https://sbert.net/docs/training/loss_overview.html#loss-overview

So, if you for example have (anchor, positive) pairs without any labels, such as these datasets:

Then you can use all of these loss functions:
image

(A popular one for semantic search is MultipleNegativesRankingLoss with one of these input types:
image

There's a lot of examples of training datasets that fit this type here: https://huggingface.co/collections/sentence-transformers/embedding-model-datasets-6644d7a3673a511914aa7552
I already showed some (anchor, positive) examples, and this is an example of (anchor, positive, negative): https://huggingface.co/datasets/sentence-transformers/msmarco-mpnet-margin-mse-mean-v1

Also, here's a bit of related work-in-progress documentation:
image

Hope this helps a bit

  • Tom Aarsen

@karan842
Copy link

Does the below code, support multi-gpu? in V3?
Or do I need to use accelerate in the code?

Code :

train_examples = []

train_data = data['text']

for text, pos in zip(data['text'], data['positive']):
    train_examples.append(InputExample(texts=[text, pos]))

train_dataloader = DataLoader(train_examples, shuffle=True)

word_embedding_model = models.Transformer('mixedbread-ai/mxbai-embed-large-v1')

pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])


train_loss = losses.AnglELoss(model=model)


num_epochs=30

warmup_steps=8

model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          warmup_steps=warmup_steps,
          show_progress_bar=True)```


I ran this code but it is using only single gpu for now!

@tomaarsen
Copy link
Collaborator Author

I believe this should already use multi-GPU by default. But the AnglELoss requires float similarity score labels, which you don't seem to have.
image

You have 2 texts (text and a positive text), so you can use these options without any preprocessing:
image

Maybe then it'll work on Multi-GPU out of the box, too (although maybe you have to wrap your code in def main(): ... if __name__ == "__main__": main() if you don't have that already.

  • Tom Aarsen

@karan842
Copy link

Yes, got it. I am using label_score now so using AnglELoss

how to run the file?

using accelerate/torchrun or just python?

I ran with python and saw wandb project it is showing multi-gpu!

@tomaarsen
Copy link
Collaborator Author

torchrun is recommended, that gives you DDP (which is the best). If you use python, you get DP, which is a bit worse. Here's some more of the yet-to-be-released docs:
image
image

  • Tom Aarsen

tomaarsen added a commit that referenced this pull request May 28, 2024
)

* [`v3`] Training refactor - MultiGPU, loss logging, bf16, etc. (#2449)

* See #1638: Adds huggingface trainer for sentence transformers

* Fix type of tokenizer

* Get the trainer using the feature collation

* Update the docstring to reflect changes

* Initial draft for refactoring training usig the Transformers Trainer

* Separate 'fit' functionality (new and old) into a mixin

* Resolve test issues

* Reformat

* Update the imports

* Add TODO regarding custom label columns

* Remove dead code

* Don't provide the trainer to the eval sampler

* Introduce datasets as a dependency

* Introduce "accelerate" as a dependency

* Avoid use_amp on CPU tests

* Specify that SentenceTransformer is a class, not a module

* Avoid circular import

* Remove | used as an "or" operator in typing

* Use test evaluator after training, as intended

* Use tokenize function instead of tokenizer;

Add EvaluatorCallback which calls the evaluator on every epoch (for BC);
Stop saving "do_lower_case" from Transformer;

* Reformat

* Revert Transformer tokenizer changes

* Add support for the tokenizer to return more than just input_ids & attention_masks

Required for LSTM

* Use the test evaluators after training the examples

* Use pure torch for BoW tokenization

* Use dev evaluator for BiLSTM - test fails

* Add Trainer support for BoW-based models

* Pass epoch to evaluator in every-epoch callback

For fit backwards compatibility

* Run formatting

* Use steps_per_epoch to set max_steps if possible

* Ignore extracting dataloader arguments for now

* Remove dead code

* Allow both "label" and "score" columns for labels

* Reformatting

* Improve errors if datasets don't match with loss dictionary well

* Made tests more consistent; list instead of set

* Simplify trainer with DatasetDict

* Implement a proportional sampler in addition to round robin

* Add CLIP finetuning support to the Trainer

* Start updating evaluators to return dictionaries

* Reformat

* Hackishly insert the DataParallel model into the loss function

* Allow for fsdp=["full_shard", "auto_wrap"]

with fsdp_config={"transformer_layer_cls_to_wrap": "BertLayer"}

* Re-add support for DataParallel

* Use 'ParallelMode.NOT_PARALLEL'

* Prevent crash with DDP & an evaluation set

* When training with multiple datasets, add "dataset_name" column

Rather than relying on some Batch Sampler hacking (which fails with some distributed training approaches)

* Update type hints: make loss & evaluator optional

Co-authored-by: Wang Bo <[email protected]>

* Set correct superclasses for samplers

* Override 'accelerator.even_batches' as it's incompatible with multi-dataset

* Throw exception if "return_loss" or "dataset_name" columns are used

* Set min. version for accelerate

* Heavily extend model card generation

* Remove some dead code

* Fix evaluator type hints

* Ensure that 'model_card_template.md' is included in the built package

* Rephrase comments slightly

* Heavily refactor samplers; add no duplicates/group by label samplers

* Ensure that data_loader.dataset exists in FitMixin

* Adopt 8 as the default batch

* Fix logging error in example

* Remove the deprecated correct_bias

* Simplify with walrus operator

* Fix some bugs in set_widget_examples with short datasets

* Improve docstring slightly

* Add edge case in case training data has an unrecognized format

* Fix extracting dataset metadata

* Remove moot TYPE_CHECKING

* Set base model when loading a ST model also

* Add test_dataloader, add prefetch_factor to dataloaders

* Resolve predict_example fix; fix newlines in text

* Fix bug in compute_dataset_metrics examples

* Add call to action in ValueError

* Reuse original model card if no training is done

* Also collect nested losses (e.g. MatryoshkaLoss) and make losses in tags

* Remove generated tag; keep loss: prefix on tags

* Remove unused arguments

* Add support for "best model step" in model card

* Make hyperparameters code-formatted

* Fix load_best_model for Transformers models, prevent for non-Transformers

* Store base_model_revision in model_card_data

* Prevent crash when loading a local model

* Allow for bfloat16 inference

---------

Co-authored-by: Matthew Franglen <[email protected]>
Co-authored-by: Wang Bo <[email protected]>

* [`v3`] Add `similarity` and `similarity_pairwise` methods to Sentence Transformers (#2615)

* Add similarity function to model configuration

* Add more tests

* Replace util.cos_sim with model.similarity in some examples

* Reintroduce evaluation.SimilarityFunction

* Remove last references of score function in ST class

* Add similarity_fn_name to model card

* Add save_pretrained alias for save

* Introduce DOT alias for DOT_PRODUCT

* [`v3`] Fix various model card errors (#2616)

* Prevent model card save failure

* Print exceptions in more detail when they occur

* Fix edge case if dataset language is None

* [`v3`] Fix trainer `compute_loss` when evaluating/predicting if the `loss` updated the inputs in-place (#2617)

* Recompute the features if return_output

* Add SimilarityFunction to __init__, increment dev version

* Never return None in infer_datasets (#2620)

* Implement resume_from_checkpoint (#2621)

* [`v3`] Update example scripts to the new v3 training format (#2622)

* Update example scripts to the new v3 training format

* Add distillation training examples

* Add Matryoshka training examples

* Add NLI training examples

* Add STS training scripts

* Fix accidentally overriding eval set

* Update paraphrases multi-dataset training script

* Convert regular dicts to DatasetDict on Trainer init

* Update Quora duplicate training scripts

* Update "other" training scripts

* Update multilingual conversion script

* Add example scripts to Evaluators

* Add example to ST class itself

* Update docs formatting slightly

* Fix model card snippet

* Add short docstring for similarity_fn_name property

* Remove "return_outputs" as it's not strictly necessary. Avoids OOM & speeds up training (#2633)

* Fix crash from inferring the dataset_id from a local dataset (#2636)

See #2635

* Fix multilingual conversion script; extend MSELoss to multi-column (#2641)

And remove the now-unnecessary make_multilingual_sys.py

* Update evaluation scripts to use HF Datasets (#2642)

* Increment the version in setup.py (as well)

* Fix resume_from_checkpoint by also updating the loss (#2648)

I'm not very sure if updating the potential wrapped model like this will also work; it seems a bit risky, but it's equally risky to not do it.

* Fix an issue with in-place variable overriding preventing backwards passes on MSELoss (#2647)

Only when there's multiple columns

* Simplify load_from_checkpoint using load_state_dict (#2650)

Overriding the model has several downsides, e.g. regarding the model card generation

* Don't override the labels variable to avoid inplace operation (#2651)

* Resolve "one of the variables needed for gradient computation has been modified by an inplace operation." (#2654)

* [`v3`] Add hyperparameter optimization support by letting `loss` be a Callable that accepts a `model` (#2655)

* Add HPO support by letting the 'loss' be a function

* Only add "dataset_name" column if required by the loss function

* Add tag hinting at the number of training samples (#2660)

* [`v3`] For the Cached losses; ignore gradients if grad is disabled (e.g. eval) (#2668)

* For the Cached losses; ignore gradients if grad is disabled (e.g. eval)

* Warn that Matryoshka/AdaptiveLayer losses are not compatible with Cached

* [`docs`] Rewrite the https://sbert.net documentation for v3.0 (#2632)

* Start restructuring/rewriting the docs

* Update Pretrained Models section for ST

* Update & add many docstrings

* Completely overhaul "Training Overview" docs page for ST

* Update dataset overview

* Remove kwargs from paraphrase_mining signature

* Add "aka sbert"

* Remove Hugging Face docs page

* Update ST Usages

* Fix some links

* Use the training examples corresponding to that model type

* Add hyperparameter optimization example script + docs

* Add distributed training docs

* Complete rewrite for the Sentence Transformer docs portion

* Update the CE part of the docs

* Specify if __name__ == "__main__" & dataloader_drop_last with DDP

* Update the entire project to Google-style docstring

* Remove contact page

* Update README with updated links, etc.

* Update the loss examples

* Fix formatting

* Add remove_columns/select_columns tip to dataset overview

* [`v3`] Chore - include import sorting in ruff (#2672)

* Include import sorting in ruff

* Remove deprecated ignore-init-module-imports

* Remove --select I from ruff.toml again after CI issues

* [`v3`] Prevent warning with 'model.fit' with transformers >= 4.41.0 due to evaluation_strategy (#2673)

* Prevent warning with 'model.fit' with transformers >= 4.41.0 due to evaluation_strategy

* Reformat

* [`v3`] Add various useful Sphinx packages (copy code, link to code, nicer tabs) (#2674)

* No longer hide toctrees in API Reference

* Add linkcode support

It's not perfect, as it'll always link to 'master', but it'll do pretty nicely for the most part.

* Add copy button to all code blocks

* Add nicer tabs

* Reformatted

* [`v3`] Make the "primary_metric" for evaluators a bit more robust (#2675)

* Make the "primary_metric" for evaluators a bit more robust

* Also remove some other TODOs that are not very important or already done

* Set `broadcast_buffers = False` when training with DDP (#2663)

* [`v3`] Warn about using DP instead of DDP + set dataloader_drop_last with DDP (#2677)

* Warn about using DP instead of DDP + set dataloader_drop_last with DDP

* Prevent duplicate warnings

* Remove note, done automatically now

* Avoid inequality comparison to True

* [`v3`] Add warning that Evaluators only run on 1 GPU when multi-GPU training (#2678)

* Add warning that Evaluators only run on 1 GPU when multi-GPU training

* Also add a note in the distributed training docs

* [`v3`] Move training dependencies into a "train" extra (#2676)

* Move training dependencies into a "train" extra

* Install the train extra with the CI tests

* Simplify dev install: also include train deps there

* Implement is_..._available in ST instead; add is_training_available

* Update references to the API ref (#2679)

* [`v3`] Add "dataset_size:" to the tag denoting the number of training samples (#2680)

* Prepend "dataset_size:" instead. I can always change the look of this later

On the HF side

* Fix formatting of Python modules

* Docs: pairwise_cosine_similarity -> pairwise_similarity

* Link to the yet-to-be-released release notes instead

* Update phrasing on local_files_only docstring

* Link directly to the 2DMSE preprint

* Add missing subset in quora-duplicates

* Add missing docstrings arguments for Cached... losses

* Update training overview docs based on the blogpost reviews

---------

Co-authored-by: Matthew Franglen <[email protected]>
Co-authored-by: Wang Bo <[email protected]>
@statsboy83
Copy link

Hi @tomaarsen ,

I am trying to fine-tune a sentence transformer with my own dataset and use multiple GPUs for the process. In the documentation it says I should wrap my training in a main() function. The question is, do I wrap my entire code in the Main() function or just the training part (which includes declaration of loss function, training args, trainer, test_evaluator) ? The beginning of the function creates and cleans up the dataset to be used in fine tuning so not sure if that part needs to be in the main() or not for parallel run. Thank you.

@zhangjoy2
Copy link

zhangjoy2 commented Sep 16, 2024 via email

@tomaarsen
Copy link
Collaborator Author

tomaarsen commented Sep 19, 2024

Hello!

@statsboy83

I am trying to fine-tune a sentence transformer with my own dataset and use multiple GPUs for the process. In the documentation it says I should wrap my training in a main() function. The question is, do I wrap my entire code in the Main() function or just the training part (which includes declaration of loss function, training args, trainer, test_evaluator) ? The beginning of the function creates and cleans up the dataset to be used in fine tuning so not sure if that part needs to be in the main() or not for parallel run. Thank you.

Apologies for the delay.
I always wrap everything except 1) imports and 2) near-instant configuration things (if any), such as defining a logger/setting a logging level.
For example, in this example training script, I would wrap this section. Hope that helps!

  • Tom Aarsen

@statsboy83
Copy link

statsboy83 commented Sep 20, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment