Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

See #1446: Adds huggingface trainer for sentence transformers #1733

Closed

Conversation

matthewfranglen
Copy link
Contributor

This implements a Huggingface Transformers compatible trainer that works for a task like the CosineSimilarityLoss example in Training. It should be easy to extend this to multi task training if desired.

You would use it as follows:

sick_ds = datasets.load_dataset("sick")

training_args = TrainingArguments(
    output_dir=...,
    num_train_epochs=10,
    seed=33,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_steps=100,
    optim="adamw_torch",
    #
    # checkpoint settings
    logging_dir=...,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="cosine_similarity",
    greater_is_better=True,
    #
    # needed to get sentence_A and sentence_B
    remove_unused_columns=False,
)

model = SentenceTransformer("nli-distilroberta-base-v2")
tokenizer = model.tokenizer
loss = losses.CosineSimilarityLoss(model)
evaluator = evaluation.EmbeddingSimilarityEvaluator(
    sick_ds["validation"]["sentence_A"],
    sick_ds["validation"]["sentence_B"],
    sick_ds["validation"]["label"],
    main_similarity=evaluation.SimilarityFunction.COSINE,
)
def compute_metrics(predictions: EvalPrediction) -> Dict[str, float]:
    return {
        "cosine_similarity": evaluator(model)
    }

data_collator = SentenceTransformersCollator(
    tokenizer=tokenizer,
    text_columns=["sentence_A", "sentence_B"],
)

trainer = SentenceTransformersTrainer(
    model=model,
    args=training_args,
    train_dataset=sick_ds["train"],
    eval_dataset=sick_ds["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    # custom arguments
    loss=loss,
    text_columns=["sentence_A", "sentence_B"],
)

trainer.train()

@matthewfranglen
Copy link
Contributor Author

Ooh I might've linked the wrong issue. Think that #1446 is more appropriate.

@matthewfranglen matthewfranglen changed the title See #1638: Adds huggingface trainer for sentence transformers See #1446: Adds huggingface trainer for sentence transformers Oct 26, 2022
vaibhavad added a commit to vaibhavad/sentence-transformers that referenced this pull request Jan 7, 2024
@tomaarsen
Copy link
Collaborator

Hello!

This has been fully extended and implemented in the v3.0 refactor via #2449. Thanks a bunch for starting this work.

  • Tom Aarsen

@tomaarsen tomaarsen closed this Jun 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants