Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Fix model loading inconsistency after Peft training by using PeftModel #2980

Merged
merged 20 commits into from
Nov 8, 2024

Conversation

pesuchin
Copy link
Contributor

@pesuchin pesuchin commented Oct 11, 2024

Resolves: #2465, #2979

Pull Request Overview

We have tried an implementation that uses PeftModel to load models trained by Peft.

Experiment

Evaluation Results of Experiment: Case of Using PeftModel

The following script was used to calculate the evaluation results.

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainingArguments, SentenceTransformerTrainer, losses
from datasets import load_dataset

train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train").select(range(100))
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev").select(range(100))

model_name = "sentence-transformers-testing/stsb-bert-tiny-safetensors" 
model = SentenceTransformer(model_name)

from peft import LoraConfig, TaskType, get_peft_model
peft_config = LoraConfig(
    target_modules=["dense"],
    task_type=TaskType.FEATURE_EXTRACTION,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

model._modules["0"].auto_model = get_peft_model(
    model._modules["0"].auto_model, peft_config
)

train_loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=1)
args = SentenceTransformerTrainingArguments("working_dir")
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=eval_dataset,
)
trainer.train()
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
test_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=test_dataset["sentence1"],
    sentences2=test_dataset["sentence2"],
    scores=test_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-test",
)
print(test_evaluator(model))
  • Evaluation results before saving the model:
{'sts-test_pearson_cosine': 0.7324155392944408, 'sts-test_spearman_cosine': 0.7308799021176352, 'sts-test_pearson_manhattan': 0.7278141028793592, 'sts-test_spearman_manhattan': 0.7103180463184993, 'sts-test_pearson_euclidean': 0.7296304300347718, 'sts-test_spearman_euclidean': 0.7114234949673607, 'sts-test_pearson_dot': 0.6415791460360187, 'sts-test_spearman_dot': 0.6223641892328629, 'sts-test_pearson_max': 0.7324155392944408, 'sts-test_spearman_max': 0.7308799021176352}
  • Evaluation results after loading the model:
model = SentenceTransformer("working_dir/checkpoint-39")

test_dataset = load_dataset("sentence-transformers/stsb", split="test")
test_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=test_dataset["sentence1"],
    sentences2=test_dataset["sentence2"],
    scores=test_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-test",
)
print(test_evaluator(model))
{'sts-test_pearson_cosine': 0.7324155392944408, 'sts-test_spearman_cosine': 0.7308799021176352, 'sts-test_pearson_manhattan': 0.7278141028793592, 'sts-test_spearman_manhattan': 0.7103180463184993, 'sts-test_pearson_euclidean': 0.7296304300347718, 'sts-test_spearman_euclidean': 0.7114234949673607, 'sts-test_pearson_dot': 0.6415791460360187, 'sts-test_spearman_dot': 0.6223641892328629, 'sts-test_pearson_max': 0.7324155392944408, 'sts-test_spearman_max': 0.7308799021176352}

@pesuchin pesuchin changed the title Fix model loading inconsistency after Peft training by using PeftModel for correct inference results [fix] Fix model loading inconsistency after Peft training by using PeftModel Oct 11, 2024
@tomaarsen tomaarsen merged commit 6baee57 into UKPLab:master Nov 8, 2024
9 checks passed
@tomaarsen
Copy link
Collaborator

Hello @pesuchin!

Apologies for the radio silence on my part - this PR came right as the new backends were being introduced, and I was worried that they might clash. With some minor changes (e.g. an extra error if backend != "torch" and using if is_peft_available() rather than try-except despite the EAFP principle, etc.) this is ready to be included as one of the 4 major features in the next update (scheduled for Monday!)

I also appreciate you answering some questions across the repository.

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to load lora model to sentencetransformer model?
2 participants