Skip to content

Commit

Permalink
Add ST annotation to evaluators (UKPLab#2586)
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey authored Apr 11, 2024
1 parent 99674c7 commit d105ec8
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sentence_transformers import SentenceTransformer
from contextlib import nullcontext
from . import SentenceEvaluator
import logging
Expand Down Expand Up @@ -111,7 +112,7 @@ def from_input_examples(cls, examples: List[InputExample], **kwargs):
scores.append(example.label)
return cls(sentences1, sentences2, scores, **kwargs)

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = f" after epoch {epoch}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from contextlib import nullcontext

from sentence_transformers import SentenceTransformer
from . import SentenceEvaluator, SimilarityFunction
import logging
import os
Expand Down Expand Up @@ -101,7 +103,7 @@ def from_input_examples(cls, examples: List[InputExample], **kwargs):
scores.append(example.label)
return cls(sentences1, sentences2, scores, **kwargs)

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = f" after epoch {epoch}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sentence_transformers import SentenceTransformer
from contextlib import nullcontext
from . import SentenceEvaluator
import torch
Expand Down Expand Up @@ -94,7 +95,9 @@ def __init__(
for k in map_at_k:
self.csv_headers.append("{}-MAP@{}".format(score_name, k))

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1, *args, **kwargs) -> float:
def __call__(
self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1, *args, **kwargs
) -> float:
if epoch != -1:
if steps == -1:
out_txt = f" after epoch {epoch}"
Expand Down Expand Up @@ -147,7 +150,9 @@ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int =
else:
return scores[self.main_score_function]["map@k"][max(self.map_at_k)]

def compute_metrices(self, model, corpus_model=None, corpus_embeddings: Tensor = None) -> Dict[str, float]:
def compute_metrices(
self, model: SentenceTransformer, corpus_model=None, corpus_embeddings: Tensor = None
) -> Dict[str, float]:
if corpus_model is None:
corpus_model = model

Expand Down
3 changes: 2 additions & 1 deletion sentence_transformers/evaluation/LabelAccuracyEvaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sentence_transformers import SentenceTransformer
from . import SentenceEvaluator
import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -37,7 +38,7 @@ def __init__(self, dataloader: DataLoader, name: str = "", softmax_model=None, w
self.csv_file = "accuracy_evaluation" + name + "_results.csv"
self.csv_headers = ["epoch", "steps", "accuracy"]

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
model.eval()
total = 0
correct = 0
Expand Down
3 changes: 2 additions & 1 deletion sentence_transformers/evaluation/MSEEvaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sentence_transformers import SentenceTransformer
from contextlib import nullcontext
from sentence_transformers.evaluation import SentenceEvaluator
import logging
Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(
self.csv_headers = ["epoch", "steps", "MSE"]
self.write_csv = write_csv

def __call__(self, model, output_path, epoch=-1, steps=-1):
def __call__(self, model: SentenceTransformer, output_path, epoch=-1, steps=-1):
if epoch != -1:
if steps == -1:
out_txt = f" after epoch {epoch}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
all_src_embeddings = teacher_model.encode(all_source_sentences, batch_size=self.batch_size)
self.teacher_embeddings = {sent: emb for sent, emb in zip(all_source_sentences, all_src_embeddings)}

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1):
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1):
model.eval()

mse_scores = []
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sentence_transformers import SentenceTransformer
from contextlib import nullcontext
from . import SentenceEvaluator
import logging
Expand Down Expand Up @@ -99,7 +100,7 @@ def __init__(
self.csv_headers = ["epoch", "steps", "precision", "recall", "f1", "threshold", "average_precision"]
self.write_csv = write_csv

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = f" after epoch {epoch}"
Expand Down
3 changes: 2 additions & 1 deletion sentence_transformers/evaluation/RerankingEvaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sentence_transformers import SentenceTransformer
from contextlib import nullcontext
from . import SentenceEvaluator
import logging
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(
]
self.write_csv = write_csv

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = f" after epoch {epoch}"
Expand Down
5 changes: 4 additions & 1 deletion sentence_transformers/evaluation/SentenceEvaluator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from sentence_transformers import SentenceTransformer


class SentenceEvaluator:
"""
Base class for all evaluators
Extend this class and implement __call__ for custom evaluators.
"""

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
"""
This is called during training to evaluate the model.
It returns a score for the evaluation with a higher score indicating a better result.
Expand Down
3 changes: 2 additions & 1 deletion sentence_transformers/evaluation/SequentialEvaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sentence_transformers import SentenceTransformer
from . import SentenceEvaluator
from typing import Iterable

Expand All @@ -14,7 +15,7 @@ def __init__(self, evaluators: Iterable[SentenceEvaluator], main_score_function=
self.evaluators = evaluators
self.main_score_function = main_score_function

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
scores = []
for evaluator in self.evaluators:
scores.append(evaluator(model, output_path, epoch, steps))
Expand Down
3 changes: 2 additions & 1 deletion sentence_transformers/evaluation/TranslationEvaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sentence_transformers import SentenceTransformer
from contextlib import nullcontext
from . import SentenceEvaluator
import logging
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
self.csv_headers = ["epoch", "steps", "src2trg", "trg2src"]
self.write_csv = write_csv

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = f" after epoch {epoch}"
Expand Down
3 changes: 2 additions & 1 deletion sentence_transformers/evaluation/TripletEvaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sentence_transformers import SentenceTransformer
from contextlib import nullcontext
from . import SentenceEvaluator, SimilarityFunction
import logging
Expand Down Expand Up @@ -75,7 +76,7 @@ def from_input_examples(cls, examples: List[InputExample], **kwargs):
negatives.append(example.texts[2])
return cls(anchors, positives, negatives, **kwargs)

def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
def __call__(self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = f" after epoch {epoch}"
Expand Down

0 comments on commit d105ec8

Please sign in to comment.