Skip to content

Commit

Permalink
Allow passing a list of evaluators to the Trainer
Browse files Browse the repository at this point in the history
They'll be combined with a SequentialEvaluator internally
  • Loading branch information
tomaarsen committed Jun 5, 2024
1 parent a3e1b86 commit 20bd5c0
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformers.training_args import ParallelMode

from sentence_transformers.data_collator import SentenceTransformerDataCollator
from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.evaluation import SentenceEvaluator, SequentialEvaluator
from sentence_transformers.losses.CoSENTLoss import CoSENTLoss
from sentence_transformers.model_card import ModelCardCallback
from sentence_transformers.models.Transformer import Transformer
Expand Down Expand Up @@ -79,9 +79,12 @@ class SentenceTransformerTrainer(Trainer):
dataset names to functions that return a loss class instance given a model. In practice, the latter two
are primarily used for hyper-parameter optimization. Will default to
:class:`~sentence_transformers.losses.CoSENTLoss` if no ``loss`` is provided.
evaluator (:class:`~sentence_transformers.evaluation.SentenceEvaluator`, *optional*):
The evaluator class to use for evaluation alongside the evaluation dataset. An evaluator will display more
useful metrics than the loss function.
evaluator (Union[:class:`~sentence_transformers.evaluation.SentenceEvaluator`,\
List[:class:`~sentence_transformers.evaluation.SentenceEvaluator`]], *optional*):
The evaluator instance for useful evaluation metrics during training. You can use an ``evaluator`` with
or without an ``eval_dataset``, and vice versa. Generally, the metrics that an ``evaluator`` returns
are more useful than the loss value returned from the ``eval_dataset``. A list of evaluators will be
wrapped in a :class:`~sentence_transformers.evaluation.SequentialEvaluator` to run them sequentially.
callbacks (List of [:class:`transformers.TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](callback).
Expand Down Expand Up @@ -123,7 +126,7 @@ def __init__(
Dict[str, Callable[["SentenceTransformer"], torch.nn.Module]],
]
] = None,
evaluator: Optional[SentenceEvaluator] = None,
evaluator: Optional[Union[SentenceEvaluator, List[SentenceEvaluator]]] = None,
data_collator: Optional[DataCollator] = None,
tokenizer: Optional[Union[PreTrainedTokenizerBase, Callable]] = None,
model_init: Optional[Callable[[], "SentenceTransformer"]] = None,
Expand Down Expand Up @@ -218,6 +221,9 @@ def __init__(
)
else:
self.loss = self.prepare_loss(loss, model)
# If evaluator is a list, we wrap it in a SequentialEvaluator
if not isinstance(evaluator, SentenceEvaluator):
evaluator = SequentialEvaluator(evaluator)
self.evaluator = evaluator

# Add a callback responsible for automatically tracking data required for the automatic model card generation
Expand Down

0 comments on commit 20bd5c0

Please sign in to comment.