Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MNSRL with GradCache #2879

Merged
merged 10 commits into from
Aug 29, 2024
Merged

Add MNSRL with GradCache #2879

merged 10 commits into from
Aug 29, 2024

Conversation

madhavthaker1
Copy link
Contributor

@madhavthaker1 madhavthaker1 commented Aug 7, 2024

Add Cached Multiple Negatives Symmetric Ranking Loss (MNSRL)

This PR introduces a new loss function, Cached Multiple Negatives Symmetric Ranking Loss (Cached MNSRL)

Motivation

I've seen promising value in the symmetric properties of MNSRL even with relatively small batch sizes. Like MNRL, memory constraints limit the size of our batches. This implementation aims to address that limitation.

Implementation

  • Inspired by the existing Cached MNRL implementation
  • Leverages gradient caching to reduce memory footprint
  • Enables larger batch sizes for improved performance

Notes

  • Test cases are not included as the gradient caching functionality is already covered in existing tests

Tophatting

image

@madhavthaker1 madhavthaker1 marked this pull request as ready for review August 7, 2024 02:50
@tomaarsen
Copy link
Collaborator

tomaarsen commented Aug 7, 2024

Hello!

This looks awesome, thanks a bunch! As a heads up, the CI started having issues unrelated tot his PR, so we can ignore those failures.
Beyond that, I applied one patch via 074feba that I discovered while testing. In short: normally e is the end of the slice, set with b + mini_batch_size. Imagine this scenario:

  • 2048 batch size
  • 1568 samples in the batch (partial batch, because it's the last one)
  • mini-batch size of 64
  • b is 1536, then e is 1600
  • Because e > 1568, i.e. the number of samples in the batch, this is also a partial minibatch: 1568 (samples in batch) - 1536 (b) = 32 samples in this minibatch.
  • We slice e.g. embeddings_a[b:e] in the forward_loss, which results in 32 samples because embeddings_a only has 1568 samples in total, and in Python the upper bound can just be way larger than the collection length: len([1, 2, 3, 4, 5][2:10000]) == 3.
  • But, scores itself has shape [32, (1 + nneg) * bsz]. If you have 1 negative, then that's [32, 3136]. Slicing with scores[:, b:e] for positive_scores then results in [32, 64], but positive_scores should always be square. We only want the "first" [32, 32], the last half actually matches anchors with negatives.

I hope I was able to explain that somewhat clearly.

I used this training script to verify:

from collections import defaultdict
import datasets
from datasets import Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    losses,
    evaluation,
    SentenceTransformerTrainingArguments
)
from sentence_transformers.models import Transformer, Pooling

def to_triplets(dataset):
    premises = defaultdict(dict)
    for sample in dataset:
        premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
    queries = []
    positives = []
    negatives = []
    for premise, sentences in premises.items():
        if 0 in sentences and 2 in sentences:
            queries.append(premise)
            positives.append(sentences[0]) # <- entailment
            negatives.append(sentences[2]) # <- contradiction
    return Dataset.from_dict({
        "anchor": queries,
        "positive": positives,
        "negative": negatives,
    })

if __name__ == "__main__":
    snli_ds = datasets.load_dataset("snli")
    snli_ds = datasets.DatasetDict({
        "train": to_triplets(snli_ds["train"]),
        "validation": to_triplets(snli_ds["validation"]),
        "test": to_triplets(snli_ds["test"]),
    })
    multi_nli_ds = datasets.load_dataset("multi_nli")
    multi_nli_ds = datasets.DatasetDict({
        "train": to_triplets(multi_nli_ds["train"]),
        "validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
    })

    all_nli_ds = datasets.DatasetDict({
        "train": datasets.concatenate_datasets([snli_ds["train"].select(range(10000)), snli_ds["train"].select(range(10000))]),
        "validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),
        "test": snli_ds["test"]
    })

    stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
    stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")

    training_args = SentenceTransformerTrainingArguments(
        output_dir="checkpoints",
        report_to="none",
        num_train_epochs=1,
        seed=33,
        per_device_train_batch_size=2048,
        per_device_eval_batch_size=2048,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        bf16=True,
        logging_steps=1,
        evaluation_strategy="steps",
        eval_steps=10,
        save_steps=10,
        save_total_limit=2,
        metric_for_best_model="eval_sts-dev_spearman_cosine",
        greater_is_better=True,
    )

    transformer = Transformer("microsoft/mpnet-base", max_seq_length=384)
    pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
    model = SentenceTransformer(modules=[transformer, pooling])

    tokenizer = model.tokenizer
    # loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)
    loss = losses.CachedMultipleNegativesSymmetricRankingLoss(model, mini_batch_size=64)
    dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
        stsb_dev["sentence1"],
        stsb_dev["sentence2"],
        [score / 5 for score in stsb_dev["score"]],
        main_similarity=evaluation.SimilarityFunction.COSINE,
        name="sts-dev",
    )

    trainer = SentenceTransformerTrainer(
        model=model,
        evaluator=dev_evaluator,
        args=training_args,
        train_dataset=all_nli_ds["train"],
        eval_dataset=all_nli_ds["validation"],
        loss=loss,
    )
    trainer.train()
    # breakpoint()

    test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
        stsb_test["sentence1"],
        stsb_test["sentence2"],
        [score / 5 for score in stsb_test["score"]],
        main_similarity=evaluation.SimilarityFunction.COSINE,
        name="sts-test",
    )
    results = test_evaluator(model)
    print(results)

After this patch, these were my results for CMNRL and CMNSRL:

Baseline (CMNRL):

{'loss': 8.2448, 'grad_norm': 59.990360260009766, 'learning_rate': 2e-05, 'epoch': 0.1}
{'loss': 8.2097, 'grad_norm': 36.9515266418457, 'learning_rate': 1.7777777777777777e-05, 'epoch': 0.2}                                                           
{'loss': 7.9031, 'grad_norm': 11.334012031555176, 'learning_rate': 1.555555555555556e-05, 'epoch': 0.3}                                                          
{'loss': 7.6541, 'grad_norm': 161.5382843017578, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.4}                                                          
{'loss': 7.5152, 'grad_norm': 14.368782997131348, 'learning_rate': 1.1111111111111113e-05, 'epoch': 0.5}                                                         
{'loss': 7.3256, 'grad_norm': 17.532875061035156, 'learning_rate': 8.888888888888888e-06, 'epoch': 0.6}                                                          
{'loss': 6.9168, 'grad_norm': 19.92694091796875, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.7}                                                           
{'loss': 6.7518, 'grad_norm': 20.580337524414062, 'learning_rate': 4.444444444444444e-06, 'epoch': 0.8}                                                          
{'loss': 6.5581, 'grad_norm': 21.919233322143555, 'learning_rate': 2.222222222222222e-06, 'epoch': 0.9}                                                          
{'loss': 6.1663, 'grad_norm': 22.764673233032227, 'learning_rate': 0.0, 'epoch': 1.0}                                                                            
{'eval_loss': 5.834962844848633, 'eval_sts-dev_pearson_cosine': 0.7220457826047636, 'eval_sts-dev_spearman_cosine': 0.7537410931185097, 'eval_sts-dev_pearson_manhattan': 0.7933553010829005, 'eval_sts-dev_spearman_manhattan': 0.790214816875259, 'eval_sts-dev_pearson_euclidean': 0.7631229397066142, 'eval_sts-dev_spearman_euclidean': 0.7637853164530657, 'eval_sts-dev_pearson_dot': 0.19616294036806228, 'eval_sts-dev_spearman_dot': 0.19220155417768314, 'eval_sts-dev_pearson_max': 0.7933553010829005, 'eval_sts-dev_spearman_max': 0.790214816875259, 'eval_runtime': 14.92, 'eval_samples_per_second': 393.766, 'eval_steps_per_second': 0.201, 'epoch': 1.0}
{'train_runtime': 126.7861, 'train_samples_per_second': 157.746, 'train_steps_per_second': 0.079, 'train_loss': 7.32453179359436, 'epoch': 1.0}                  
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:06<00:00, 12.68s/it]
{'sts-test_pearson_cosine': 0.6779585327495794, 'sts-test_spearman_cosine': 0.6732490122269325, 'sts-test_pearson_manhattan': 0.744016590529333, 'sts-test_spearman_manhattan': 0.7175918780931619, 'sts-test_pearson_euclidean': 0.7121500800151421, 'sts-test_spearman_euclidean': 0.6893059151171314, 'sts-test_pearson_dot': 0.2518168867731734, 'sts-test_spearman_dot': 0.27739541010179164, 'sts-test_pearson_max': 0.744016590529333, 'sts-test_spearman_max': 0.7175918780931619}

CMNSRL:

{'loss': 5.9862, 'grad_norm': 1181.5284423828125, 'learning_rate': 2e-05, 'epoch': 0.1}
{'loss': 5.9864, 'grad_norm': 669.2378540039062, 'learning_rate': 1.7777777777777777e-05, 'epoch': 0.2}
{'loss': 5.7147, 'grad_norm': 396.2860107421875, 'learning_rate': 1.555555555555556e-05, 'epoch': 0.3}
{'loss': 5.4414, 'grad_norm': 472.13311767578125, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.4}
{'loss': 5.2008, 'grad_norm': 607.52783203125, 'learning_rate': 1.1111111111111113e-05, 'epoch': 0.5}
{'loss': 4.8169, 'grad_norm': 596.0315551757812, 'learning_rate': 8.888888888888888e-06, 'epoch': 0.6}
{'loss': 4.5048, 'grad_norm': 623.6679077148438, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.7}
{'loss': 4.3086, 'grad_norm': 623.7384033203125, 'learning_rate': 4.444444444444444e-06, 'epoch': 0.8}
{'loss': 4.1499, 'grad_norm': 630.1085815429688, 'learning_rate': 2.222222222222222e-06, 'epoch': 0.9}
{'loss': 3.8008, 'grad_norm': 461.3891296386719, 'learning_rate': 0.0, 'epoch': 1.0}
{'eval_loss': 3.3768537044525146, 'eval_sts-dev_pearson_cosine': 0.7250278238797213, 'eval_sts-dev_spearman_cosine': 0.7565831495674673, 'eval_sts-dev_pearson_manhattan': 0.7968535297781607, 'eval_sts-dev_spearman_manhattan': 0.7928590108685135, 'eval_sts-dev_pearson_euclidean': 0.7757140841096808, 'eval_sts-dev_spearman_euclidean': 0.773871004899004, 'eval_sts-dev_pearson_dot': 0.2865826010351892, 'eval_sts-dev_spearman_dot': 0.2652764301652554, 'eval_sts-dev_pearson_max': 0.7968535297781607, 'eval_sts-dev_spearman_max': 0.7928590108685135, 'eval_runtime': 14.3992, 'eval_samples_per_second': 408.009, 'eval_steps_per_second': 0.208, 'epoch': 1.0}
{'train_runtime': 120.5207, 'train_samples_per_second': 165.947, 'train_steps_per_second': 0.083, 'train_loss': 4.991061210632324, 'epoch': 1.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:00<00:00, 12.05s/it]
{'sts-test_pearson_cosine': 0.6863237365690719, 'sts-test_spearman_cosine': 0.6789686786781693, 'sts-test_pearson_manhattan': 0.749486883823326, 'sts-test_spearman_manhattan': 0.7226188211038197, 'sts-test_pearson_euclidean': 0.725871099029778, 'sts-test_spearman_euclidean': 0.7002108088046105, 'sts-test_pearson_dot': 0.32759836256312636, 'sts-test_spearman_dot': 0.34280069115519984, 'sts-test_pearson_max': 0.749486883823326, 'sts-test_spearman_max': 0.7226188211038197}

So, training runtime was about equivalent, and performance was slightly better even than CMNRL (likely because STS is a symmetric task, so it makes sense that it would benefit from the "backward" loss).

I'll further prepare this PR for release by fixing some formatting and adding some documentation.

  • Tom Aarsen

@madhavthaker1
Copy link
Contributor Author

Hey @tomaarsen -- Thanks for catching this bug and the thorough explanation. Your patch makes sense to me!

@tomaarsen tomaarsen merged commit 09fe766 into UKPLab:master Aug 29, 2024
11 checks passed
@tomaarsen
Copy link
Collaborator

Apologies for the delays. I believe this is all ready to go, so I've merged it into master. Much appreciated for leading this work!

  • Tom Aarsen

@madhavthaker1
Copy link
Contributor Author

Thanks @tomaarsen! I'm assuming this will get picked up and made available in the next Release? Any timelines for this?

@tomaarsen
Copy link
Collaborator

Indeed, it will be included in the next release, which is scheduled to drop some time this week.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants