Skip to content

Commit

Permalink
perf: log predictions for prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
anujsinha3 committed Feb 15, 2024
1 parent 5f99fb1 commit 8db2bc0
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,20 @@ def eval_prompt(
pred_time = timer_end - timer_start
mlflow.log_metric("prediction_time/doc", pred_time / (len(inputs)))
for i in range(len(inputs)):
mlflow.log_text(labels[i], f"prompt_{prompt_index}_label_{i}.txt")
mlflow.log_text(inputs[i], f"prompt_{prompt_index}_input_{i}.py")
mlflow.log_text(predictions[i], f"prompt_{prompt_index}_prediction_{i}.txt")

# flatten predictions for counting tokens
predictions_flat = list(itertools.chain.from_iterable(predictions))
tokens = pred.tokenize(predictions_flat)["input_ids"]
total_tokens = sum([len(token) for token in tokens])
mlflow.log_metric(f"prompt_{prompt_index}_total_tokens", total_tokens)
mlflow.log_metric(f"prompt_{prompt_index}_tokens/sec", total_tokens / pred_time)
mlflow.log_metric(f"prompt_{prompt_index}_bleu_score", round(bleu, 5))
mlflow.log_metric(f"prompt_{prompt_index}_meteor_score", round(meteor, 5))
mlflow.log_metric(f"prompt_{prompt_index}_semscore", round(semscore, 5))
metrics_dict = {
f"prompt_{prompt_index}_total_tokens": total_tokens,
f"prompt_{prompt_index}_tokens/sec": total_tokens / pred_time,
f"prompt_{prompt_index}_bleu_score": round(bleu, 5),
f"prompt_{prompt_index}_meteor_score": round(meteor, 5),
f"prompt_{prompt_index}_semscore": round(semscore, 5),
}
mlflow.log_metrics(metrics_dict)
return EvalResult(predictions, prompt, bleu, meteor, semscore)


Expand Down

0 comments on commit 8db2bc0

Please sign in to comment.