Skip to content

Commit

Permalink
Modified the test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
RashmikaReddy committed Jan 12, 2024
1 parent 368e73c commit 2376a6d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
3 changes: 1 addition & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ ipython
jupytext
jupyter
matplotlib
numpy
nltk
numpy
11 changes: 7 additions & 4 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@
format="%(asctime)s %(levelname)s %(module)s.%(funcName)s(): %(message)s",
)
logger = logging.getLogger(__name__)
nltk.download("wordnet")


def evaluate_documentation(predictions: List[List[str]], references: List[str]) -> Tuple[float, float]:
# Tokenize predictions and references
tokenized_predictions = [pred[0].split() if pred else [] for pred in predictions]
nltk.download("wordnet")

# Tokenize references
tokenized_references = [[ref.split()] for ref in references]
tokenized_predictions = [pred[0].split() if pred else [] for pred in predictions]

# Calculate BLEU score
# Calculate BLEU score with smoothing function
# SmoothingFunction().method1 is used to avoid zero scores for n-grams not found in the reference.
bleu = corpus_bleu(
tokenized_references, tokenized_predictions, smoothing_function=SmoothingFunction().method1
)
Expand Down Expand Up @@ -80,6 +82,7 @@ def eval(
timer_start = timer()
predictions = pred.predict(sys_prompt, instr_prompt, inputs, **param_dict)
bleu, meteor = evaluate_documentation(predictions, labels)

timer_end = timer()
pred_time = timer_end - timer_start
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
Expand Down
10 changes: 6 additions & 4 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import jsonlines
import pytest

from autora.doc.pipelines.main import eval, evaluate_documentation, generate
from autora.doc.runtime.prompts import InstructionPrompts, SystemPrompts
Expand All @@ -18,15 +19,16 @@ def test_predict() -> None:


def test_evaluation() -> None:
# Test Case: Valid Scores in the range of 0 and 1
# Test Case: Meteor and Bleu scores are close to 1
data = Path(__file__).parent.joinpath("../data/data.jsonl").resolve()
with jsonlines.open(data) as reader:
items = [item for item in reader]
labels = [item["output"] for item in items]
predictions = [[item["output"]] for item in items]

bleu, meteor = evaluate_documentation(labels, labels)
assert bleu >= 0 and bleu <= 1, "BLEU score should be between 0 and 1"
assert meteor >= 0 and meteor <= 1, "METEOR score should be between 0 and 1"
bleu, meteor = evaluate_documentation(predictions, labels)
assert bleu == pytest.approx(1, 0.01), f"BLEU Score is {bleu}"
assert meteor == pytest.approx(1, 0.01), f"METEOR Score is {meteor}"


def test_generate() -> None:
Expand Down

0 comments on commit 2376a6d

Please sign in to comment.