From 2376a6d7c3bca0874e4b5cce32092d00d04a35d3 Mon Sep 17 00:00:00 2001 From: RashmikaReddy Date: Fri, 12 Jan 2024 09:08:40 -0800 Subject: [PATCH] Modified the test cases --- docs/requirements.txt | 3 +-- src/autora/doc/pipelines/main.py | 11 +++++++---- tests/test_main.py | 10 ++++++---- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 25ac169..2b5c37d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,5 +6,4 @@ ipython jupytext jupyter matplotlib -numpy -nltk +numpy \ No newline at end of file diff --git a/src/autora/doc/pipelines/main.py b/src/autora/doc/pipelines/main.py index 2087fb3..9ffa311 100644 --- a/src/autora/doc/pipelines/main.py +++ b/src/autora/doc/pipelines/main.py @@ -18,15 +18,17 @@ format="%(asctime)s %(levelname)s %(module)s.%(funcName)s(): %(message)s", ) logger = logging.getLogger(__name__) -nltk.download("wordnet") def evaluate_documentation(predictions: List[List[str]], references: List[str]) -> Tuple[float, float]: - # Tokenize predictions and references - tokenized_predictions = [pred[0].split() if pred else [] for pred in predictions] + nltk.download("wordnet") + + # Tokenize references tokenized_references = [[ref.split()] for ref in references] + tokenized_predictions = [pred[0].split() if pred else [] for pred in predictions] - # Calculate BLEU score + # Calculate BLEU score with smoothing function + # SmoothingFunction().method1 is used to avoid zero scores for n-grams not found in the reference. bleu = corpus_bleu( tokenized_references, tokenized_predictions, smoothing_function=SmoothingFunction().method1 ) @@ -80,6 +82,7 @@ def eval( timer_start = timer() predictions = pred.predict(sys_prompt, instr_prompt, inputs, **param_dict) bleu, meteor = evaluate_documentation(predictions, labels) + timer_end = timer() pred_time = timer_end - timer_start mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs))) diff --git a/tests/test_main.py b/tests/test_main.py index 02aee63..f92acf9 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,7 @@ from pathlib import Path import jsonlines +import pytest from autora.doc.pipelines.main import eval, evaluate_documentation, generate from autora.doc.runtime.prompts import InstructionPrompts, SystemPrompts @@ -18,15 +19,16 @@ def test_predict() -> None: def test_evaluation() -> None: - # Test Case: Valid Scores in the range of 0 and 1 + # Test Case: Meteor and Bleu scores are close to 1 data = Path(__file__).parent.joinpath("../data/data.jsonl").resolve() with jsonlines.open(data) as reader: items = [item for item in reader] labels = [item["output"] for item in items] + predictions = [[item["output"]] for item in items] - bleu, meteor = evaluate_documentation(labels, labels) - assert bleu >= 0 and bleu <= 1, "BLEU score should be between 0 and 1" - assert meteor >= 0 and meteor <= 1, "METEOR score should be between 0 and 1" + bleu, meteor = evaluate_documentation(predictions, labels) + assert bleu == pytest.approx(1, 0.01), f"BLEU Score is {bleu}" + assert meteor == pytest.approx(1, 0.01), f"METEOR Score is {meteor}" def test_generate() -> None: