-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MNSRL with GradCache #2879
Add MNSRL with GradCache #2879
Conversation
Hello! This looks awesome, thanks a bunch! As a heads up, the CI started having issues unrelated tot his PR, so we can ignore those failures.
I hope I was able to explain that somewhat clearly. I used this training script to verify: from collections import defaultdict
import datasets
from datasets import Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
losses,
evaluation,
SentenceTransformerTrainingArguments
)
from sentence_transformers.models import Transformer, Pooling
def to_triplets(dataset):
premises = defaultdict(dict)
for sample in dataset:
premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
queries = []
positives = []
negatives = []
for premise, sentences in premises.items():
if 0 in sentences and 2 in sentences:
queries.append(premise)
positives.append(sentences[0]) # <- entailment
negatives.append(sentences[2]) # <- contradiction
return Dataset.from_dict({
"anchor": queries,
"positive": positives,
"negative": negatives,
})
if __name__ == "__main__":
snli_ds = datasets.load_dataset("snli")
snli_ds = datasets.DatasetDict({
"train": to_triplets(snli_ds["train"]),
"validation": to_triplets(snli_ds["validation"]),
"test": to_triplets(snli_ds["test"]),
})
multi_nli_ds = datasets.load_dataset("multi_nli")
multi_nli_ds = datasets.DatasetDict({
"train": to_triplets(multi_nli_ds["train"]),
"validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
})
all_nli_ds = datasets.DatasetDict({
"train": datasets.concatenate_datasets([snli_ds["train"].select(range(10000)), snli_ds["train"].select(range(10000))]),
"validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),
"test": snli_ds["test"]
})
stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")
training_args = SentenceTransformerTrainingArguments(
output_dir="checkpoints",
report_to="none",
num_train_epochs=1,
seed=33,
per_device_train_batch_size=2048,
per_device_eval_batch_size=2048,
learning_rate=2e-5,
warmup_ratio=0.1,
bf16=True,
logging_steps=1,
evaluation_strategy="steps",
eval_steps=10,
save_steps=10,
save_total_limit=2,
metric_for_best_model="eval_sts-dev_spearman_cosine",
greater_is_better=True,
)
transformer = Transformer("microsoft/mpnet-base", max_seq_length=384)
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
model = SentenceTransformer(modules=[transformer, pooling])
tokenizer = model.tokenizer
# loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=64)
loss = losses.CachedMultipleNegativesSymmetricRankingLoss(model, mini_batch_size=64)
dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
stsb_dev["sentence1"],
stsb_dev["sentence2"],
[score / 5 for score in stsb_dev["score"]],
main_similarity=evaluation.SimilarityFunction.COSINE,
name="sts-dev",
)
trainer = SentenceTransformerTrainer(
model=model,
evaluator=dev_evaluator,
args=training_args,
train_dataset=all_nli_ds["train"],
eval_dataset=all_nli_ds["validation"],
loss=loss,
)
trainer.train()
# breakpoint()
test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
stsb_test["sentence1"],
stsb_test["sentence2"],
[score / 5 for score in stsb_test["score"]],
main_similarity=evaluation.SimilarityFunction.COSINE,
name="sts-test",
)
results = test_evaluator(model)
print(results) After this patch, these were my results for CMNRL and CMNSRL: Baseline (CMNRL):
CMNSRL:
So, training runtime was about equivalent, and performance was slightly better even than CMNRL (likely because STS is a symmetric task, so it makes sense that it would benefit from the "backward" loss). I'll further prepare this PR for release by fixing some formatting and adding some documentation.
|
sentence_transformers/losses/CachedMultipleNegativesSymmetricRankingLoss.py
Outdated
Show resolved
Hide resolved
Hey @tomaarsen -- Thanks for catching this bug and the thorough explanation. Your patch makes sense to me! |
Apologies for the delays. I believe this is all ready to go, so I've merged it into
|
Thanks @tomaarsen! I'm assuming this will get picked up and made available in the next Release? Any timelines for this? |
Indeed, it will be included in the next release, which is scheduled to drop some time this week.
|
Add Cached Multiple Negatives Symmetric Ranking Loss (MNSRL)
This PR introduces a new loss function, Cached Multiple Negatives Symmetric Ranking Loss (Cached MNSRL)
Motivation
I've seen promising value in the symmetric properties of MNSRL even with relatively small batch sizes. Like MNRL, memory constraints limit the size of our batches. This implementation aims to address that limitation.
Implementation
Notes
Tophatting