From 7160460a31e702a1c8b5f2f7d3b3335297bdd755 Mon Sep 17 00:00:00 2001 From: Lars <77671944+hummus-love@users.noreply.github.com> Date: Tue, 9 Apr 2024 10:49:23 +0100 Subject: [PATCH 1/2] Add classification report per class to compute_metrics --- moralization/transformers_model_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/moralization/transformers_model_manager.py b/moralization/transformers_model_manager.py index ea4d51c..5a685de 100644 --- a/moralization/transformers_model_manager.py +++ b/moralization/transformers_model_manager.py @@ -19,6 +19,7 @@ import frontmatter from huggingface_hub import HfApi import shutil +from sklearn.metrics import classification_report IGNORED_LABEL = -100 @@ -318,6 +319,7 @@ def compute_metrics(self, eval_preds: tuple) -> Dict: "recall": all_metrics["overall_recall"], "f1": all_metrics["overall_f1"], "accuracy": all_metrics["overall_accuracy"], + "classification report": classification_report(y_true=true_labels, y_pred=true_predictions), } def _set_id2label(self) -> None: From 9ae052fa6cfaad20893858a233f1a0d3b322d3cb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Apr 2024 09:51:54 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- moralization/transformers_model_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/moralization/transformers_model_manager.py b/moralization/transformers_model_manager.py index 5a685de..04cb597 100644 --- a/moralization/transformers_model_manager.py +++ b/moralization/transformers_model_manager.py @@ -319,7 +319,9 @@ def compute_metrics(self, eval_preds: tuple) -> Dict: "recall": all_metrics["overall_recall"], "f1": all_metrics["overall_f1"], "accuracy": all_metrics["overall_accuracy"], - "classification report": classification_report(y_true=true_labels, y_pred=true_predictions), + "classification report": classification_report( + y_true=true_labels, y_pred=true_predictions + ), } def _set_id2label(self) -> None: