-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: Move metrics to a different module, rename function (#37)
- Loading branch information
Showing
4 changed files
with
91 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import List, Tuple | ||
|
||
import nltk | ||
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu | ||
from nltk.translate.meteor_score import single_meteor_score | ||
|
||
|
||
def eval_bleu_meteor(predictions: List[str], references: List[str]) -> Tuple[float, float]: | ||
nltk.download("wordnet") | ||
|
||
# Tokenize references | ||
tokenized_references = [ref.split() for ref in references] | ||
# Currently there is only 1 prediction for 1 reference, need to avg in future | ||
tokenized_predictions = [pred.split() if pred else [] for pred in predictions] | ||
|
||
# Calculate BLEU score with smoothing function | ||
# SmoothingFunction().method1 is used to avoid zero scores for n-grams not found in the reference. | ||
bleu = corpus_bleu( | ||
# Wrap each reference list in another list | ||
[[tokenized_ref] for tokenized_ref in tokenized_references], | ||
tokenized_predictions, | ||
smoothing_function=SmoothingFunction().method1, | ||
) | ||
|
||
# Calculate METEOR scores | ||
meteor_scores = [ | ||
single_meteor_score(tokenized_ref, tokenized_pred) | ||
for tokenized_ref, tokenized_pred in zip(tokenized_references, tokenized_predictions) | ||
] | ||
meteor = sum(meteor_scores) / len(predictions) if predictions else 0 | ||
|
||
return (bleu, meteor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from pathlib import Path | ||
|
||
import jsonlines | ||
import pytest | ||
|
||
from autora.doc.pipelines.metrics import eval_bleu_meteor | ||
|
||
|
||
def test_evaluation() -> None: | ||
# Test Case: Meteor and Bleu scores are close to 1 | ||
data = Path(__file__).parent.joinpath("../data/sweetpea/data.jsonl").resolve() | ||
with jsonlines.open(data) as reader: | ||
items = [item for item in reader] | ||
labels = [item["output"] for item in items] | ||
predictions = [item["output"] for item in items] | ||
|
||
bleu, meteor = eval_bleu_meteor(predictions, labels) | ||
assert bleu == pytest.approx(1, 0.01), f"BLEU Score is {bleu}" | ||
assert meteor == pytest.approx(1, 0.01), f"METEOR Score is {meteor}" | ||
|
||
|
||
def test_extra_token_in_prediction() -> None: | ||
# Test Case bleu score should be less due to brevity penalty and meteor is robust to small mistakes | ||
labels = ["this is a test"] | ||
predictions = ["this is a test extra"] | ||
bleu, meteor = eval_bleu_meteor(predictions, labels) | ||
assert 0.6 <= bleu <= 0.8, f"BLEU Score is {bleu}" | ||
assert 0.8 <= meteor <= 1, f"METEOR Score is {meteor}" | ||
|
||
|
||
def test_missing_token_in_prediction() -> None: | ||
# bleu score is less, meteor is higher | ||
labels = ["this is a test"] | ||
predictions = ["this is a"] | ||
bleu, meteor = eval_bleu_meteor(predictions, labels) | ||
assert 0.4 <= bleu <= 0.6, f"BLEU Score is {bleu}" | ||
assert 0.6 <= meteor <= 0.8, f"METEOR Score is {meteor}" | ||
|
||
|
||
def test_completely_different_tokens() -> None: | ||
# both scores are less, as no common tokens | ||
labels = ["this is a test"] | ||
predictions = ["completely different sentence"] | ||
bleu, meteor = eval_bleu_meteor(predictions, labels) | ||
assert bleu <= 0.1, f"BLEU Score is {bleu}" | ||
assert meteor <= 0.1, f"METEOR Score is {meteor}" | ||
|
||
|
||
def test_partially_matching_tokens() -> None: | ||
# As ngrams arent matching because of extra token within, BLEU score is very less. Meteor gives a good score only. | ||
labels = ["this is a test"] | ||
predictions = ["this is a different test"] | ||
bleu, meteor = eval_bleu_meteor(predictions, labels) | ||
assert 0.25 <= bleu <= 0.4, f"BLEU Score is {bleu}" | ||
assert 0.8 <= meteor <= 0.95, f"METEOR Score is {meteor}" |