Skip to content

Commit

Permalink
Train via the Sentence Transformers Trainer from ST v3 (#554)
Browse files Browse the repository at this point in the history
* Train via the Sentence Transformers Trainer from ST v3

* Simplify some init code; docstring

* Prevent breaking changes by updating TrainerCallback

* Replace ST Training Args with SetFit Training Args

* Remove unused properties

* Require 'accelerate' when training SetFit models

* Remove log in docs as it is no longer used

* Fix docs issue

* Require installing sentence-transformers[train]

* Keep not having to override metric_for_best_model by default

It'll just keep using the loss of whatever trainer you're using.

* Ensure logs directory is made in Callbacks example

* Fix outdated docstring
  • Loading branch information
tomaarsen authored Sep 18, 2024
1 parent 72f4d1e commit fb91f67
Show file tree
Hide file tree
Showing 16 changed files with 280 additions and 441 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/how_to/callbacks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,15 @@ trainer.train()
SetFit supports custom callbacks in the same way that `transformers` does: by subclassing [`TrainerCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.TrainerCallback). This class implements a lot of `on_...` methods that can be overridden. For example, the following script shows a custom callback that saves plots of the tSNE of the training and evaluation embeddings during training.

```py
import os
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

class EmbeddingPlotCallback(TrainerCallback):
"""Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training."""
def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
os.makedirs("logs", exist_ok=True)

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs):
train_embeddings = model.encode(train_dataset["text"])
eval_embeddings = model.encode(eval_dataset["text"])
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/installation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Before you start, you'll need to setup your environment and install the appropri

## pip

The most straightforward way to install 🤗 Datasets is with pip:
The most straightforward way to install 🤗 SetFit is with pip:

```bash
pip install setfit
Expand Down
2 changes: 0 additions & 2 deletions docs/source/en/reference/trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
- apply_hyperparameters
- evaluate
- hyperparameter_search
- log
- pop_callback
- push_to_hub
- remove_callback
Expand All @@ -31,7 +30,6 @@
- apply_hyperparameters
- evaluate
- hyperparameter_search
- log
- pop_callback
- push_to_hub
- remove_callback
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
INTEGRATIONS_REQUIRE = ["optuna"]
REQUIRED_PKGS = [
"datasets>=2.15.0",
"sentence-transformers>=2.2.1",
"sentence-transformers[train]>=3",
"transformers>=4.41.0",
"evaluate>=0.3.0",
"huggingface_hub>=0.23.0",
Expand Down
15 changes: 5 additions & 10 deletions src/setfit/model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@ def __init__(self, trainer: "Trainer") -> None:
super().__init__()
self.trainer = trainer

callbacks = [
callback
for callback in self.trainer.callback_handler.callbacks
if isinstance(callback, CodeCarbonCallback)
]
if callbacks:
trainer.model.model_card_data.code_carbon_callback = callbacks[0]

def on_init_end(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: "SetFitModel", **kwargs
):
Expand Down Expand Up @@ -109,19 +101,22 @@ def on_evaluate(
metrics: Dict[str, float],
**kwargs,
) -> None:
keys = {"eval_embedding_loss", "eval_polarity_embedding_loss", "eval_aspect_embedding_loss"} & set(metrics)
if not keys:
return
if (
model.model_card_data.eval_lines_list
and model.model_card_data.eval_lines_list[-1]["Step"] == state.global_step
):
model.model_card_data.eval_lines_list[-1]["Validation Loss"] = metrics["eval_embedding_loss"]
model.model_card_data.eval_lines_list[-1]["Validation Loss"] = metrics[keys.pop()]
else:
model.model_card_data.eval_lines_list.append(
{
# "Training Loss": self.state.log_history[-1]["loss"] if "loss" in self.state.log_history[-1] else "-",
"Epoch": state.epoch,
"Step": state.global_step,
"Training Loss": "-",
"Validation Loss": metrics["eval_embedding_loss"],
"Validation Loss": metrics[keys.pop()],
}
)

Expand Down
34 changes: 19 additions & 15 deletions src/setfit/sampler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from itertools import zip_longest
from typing import Generator, Iterable, List, Optional
from typing import Dict, Generator, Iterable, List, Optional, Union

import numpy as np
import torch
from sentence_transformers import InputExample
from torch.utils.data import IterableDataset

from . import logging
Expand Down Expand Up @@ -35,7 +34,8 @@ def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Genera
class ContrastiveDataset(IterableDataset):
def __init__(
self,
examples: List[InputExample],
sentences: List[str],
labels: List[Union[int, float]],
multilabel: bool,
num_iterations: Optional[None] = None,
sampling_strategy: str = "oversampling",
Expand All @@ -44,7 +44,8 @@ def __init__(
"""Generates positive and negative text pairs for contrastive learning.
Args:
examples (InputExample): text and labels in a text transformer dataclass
sentences (List[str]): text sentences to generate pairs from
labels (List[Union[int, float]]): labels for each sentence
multilabel: set to process "multilabel" labels array
sampling_strategy: "unique", "oversampling", or "undersampling"
num_iterations: if provided explicitly sets the number of pairs to be generated
Expand All @@ -57,8 +58,8 @@ def __init__(
self.neg_index = 0
self.pos_pairs = []
self.neg_pairs = []
self.sentences = np.array([s.texts[0] for s in examples])
self.labels = np.array([s.label for s in examples])
self.sentences = sentences
self.labels = labels
self.sentence_labels = list(zip(self.sentences, self.labels))
self.max_pairs = max_pairs

Expand Down Expand Up @@ -89,23 +90,23 @@ def __init__(
def generate_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if _label == label:
self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
else:
self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))
self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0})
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
break

def generate_multilabel_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if any(np.logical_and(_label, label)):
# logical_and checks if labels are both set for each class
self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0))
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
else:
self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0))
self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0})
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs:
break

def get_positive_pairs(self) -> List[InputExample]:
def get_positive_pairs(self) -> List[Dict[str, Union[str, float]]]:
pairs = []
for _ in range(self.len_pos_pairs):
if self.pos_index >= len(self.pos_pairs):
Expand All @@ -114,7 +115,7 @@ def get_positive_pairs(self) -> List[InputExample]:
self.pos_index += 1
return pairs

def get_negative_pairs(self) -> List[InputExample]:
def get_negative_pairs(self) -> List[Dict[str, Union[str, float]]]:
pairs = []
for _ in range(self.len_neg_pairs):
if self.neg_index >= len(self.neg_pairs):
Expand All @@ -137,15 +138,16 @@ def __len__(self) -> int:
class ContrastiveDistillationDataset(ContrastiveDataset):
def __init__(
self,
examples: List[InputExample],
sentences: List[str],
cos_sim_matrix: torch.Tensor,
num_iterations: Optional[None] = None,
sampling_strategy: str = "oversampling",
max_pairs: int = -1,
) -> None:
self.cos_sim_matrix = cos_sim_matrix
super().__init__(
examples,
sentences,
[0] * len(sentences),
multilabel=False,
num_iterations=num_iterations,
sampling_strategy=sampling_strategy,
Expand All @@ -163,6 +165,8 @@ def __init__(

def generate_pairs(self) -> None:
for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels):
self.pos_pairs.append(InputExample(texts=[text_one, text_two], label=self.cos_sim_matrix[id_one][id_two]))
self.pos_pairs.append(
{"sentence_1": text_one, "sentence_2": text_two, "label": self.cos_sim_matrix[id_one][id_two]}
)
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs:
break
27 changes: 13 additions & 14 deletions src/setfit/span/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,35 @@ def __init__(
model.aspect_model, model.polarity_model, eval_dataset
)

# Set a better default value for the metric for best model for the aspect and polarity models
aspect_args = args if args is not None else TrainingArguments()
polarity_args = (polarity_args or args or TrainingArguments()).copy()
if aspect_args.metric_for_best_model == "embedding_loss":
aspect_args.metric_for_best_model = "aspect_embedding_loss"
if polarity_args.metric_for_best_model == "embedding_loss":
polarity_args.metric_for_best_model = "polarity_embedding_loss"

self.aspect_trainer = Trainer(
model.aspect_model,
args=args,
args=aspect_args,
train_dataset=aspect_train_dataset,
eval_dataset=aspect_eval_dataset,
metric=metric,
metric_kwargs=metric_kwargs,
callbacks=callbacks,
)
self.aspect_trainer._set_logs_mapper(
{
"eval_embedding_loss": "eval_aspect_embedding_loss",
"embedding_loss": "aspect_embedding_loss",
}
)
self.aspect_trainer._set_logs_prefix("aspect_embedding")

self.polarity_trainer = Trainer(
model.polarity_model,
args=polarity_args or args,
args=polarity_args,
train_dataset=polarity_train_dataset,
eval_dataset=polarity_eval_dataset,
metric=metric,
metric_kwargs=metric_kwargs,
callbacks=callbacks,
)
self.polarity_trainer._set_logs_mapper(
{
"eval_embedding_loss": "eval_polarity_embedding_loss",
"embedding_loss": "polarity_embedding_loss",
}
)
self.polarity_trainer._set_logs_prefix("polarity_embedding")

def preprocess_dataset(
self, aspect_model: AspectModel, polarity_model: PolarityModel, dataset: Dataset
Expand Down
Loading

0 comments on commit fb91f67

Please sign in to comment.