Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Allow passing a list of evaluators to the Trainer #2716

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformers.training_args import ParallelMode

from sentence_transformers.data_collator import SentenceTransformerDataCollator
from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.evaluation import SentenceEvaluator, SequentialEvaluator
from sentence_transformers.losses.CoSENTLoss import CoSENTLoss
from sentence_transformers.model_card import ModelCardCallback
from sentence_transformers.models.Transformer import Transformer
Expand Down Expand Up @@ -79,9 +79,12 @@ class SentenceTransformerTrainer(Trainer):
dataset names to functions that return a loss class instance given a model. In practice, the latter two
are primarily used for hyper-parameter optimization. Will default to
:class:`~sentence_transformers.losses.CoSENTLoss` if no ``loss`` is provided.
evaluator (:class:`~sentence_transformers.evaluation.SentenceEvaluator`, *optional*):
The evaluator class to use for evaluation alongside the evaluation dataset. An evaluator will display more
useful metrics than the loss function.
evaluator (Union[:class:`~sentence_transformers.evaluation.SentenceEvaluator`,\
List[:class:`~sentence_transformers.evaluation.SentenceEvaluator`]], *optional*):
The evaluator instance for useful evaluation metrics during training. You can use an ``evaluator`` with
or without an ``eval_dataset``, and vice versa. Generally, the metrics that an ``evaluator`` returns
are more useful than the loss value returned from the ``eval_dataset``. A list of evaluators will be
wrapped in a :class:`~sentence_transformers.evaluation.SequentialEvaluator` to run them sequentially.
callbacks (List of [:class:`transformers.TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](callback).
Expand Down Expand Up @@ -123,7 +126,7 @@ def __init__(
Dict[str, Callable[["SentenceTransformer"], torch.nn.Module]],
]
] = None,
evaluator: Optional[SentenceEvaluator] = None,
evaluator: Optional[Union[SentenceEvaluator, List[SentenceEvaluator]]] = None,
data_collator: Optional[DataCollator] = None,
tokenizer: Optional[Union[PreTrainedTokenizerBase, Callable]] = None,
model_init: Optional[Callable[[], "SentenceTransformer"]] = None,
Expand Down Expand Up @@ -218,6 +221,9 @@ def __init__(
)
else:
self.loss = self.prepare_loss(loss, model)
# If evaluator is a list, we wrap it in a SequentialEvaluator
if not isinstance(evaluator, SentenceEvaluator):
evaluator = SequentialEvaluator(evaluator)
self.evaluator = evaluator

# Add a callback responsible for automatically tracking data required for the automatic model card generation
Expand Down
Loading