From 28c1e106934184f8f18ba52c2ff1d26878a6b8f0 Mon Sep 17 00:00:00 2001 From: Viacheslav Kukushkin Date: Sat, 10 Aug 2024 15:33:34 +0300 Subject: [PATCH] binary AUROC requires (N,) shape probas just of class 1 --- GANDLF/metrics/classification.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/GANDLF/metrics/classification.py b/GANDLF/metrics/classification.py index dd5936ed9..5b0feb9e3 100644 --- a/GANDLF/metrics/classification.py +++ b/GANDLF/metrics/classification.py @@ -101,14 +101,17 @@ def __convert_tensor_to_int(input_tensor: torch.Tensor) -> torch.Tensor: ), } for metric_name, calculator in calculators.items(): + metric_prediction = prediction + metric_target = target if "auroc" in metric_name: - output_metrics[metric_name] = get_output_from_calculator( - predictions_prob, target_wrap, calculator - ) - else: - output_metrics[metric_name] = get_output_from_calculator( - prediction, target, calculator - ) + metric_prediction = predictions_prob + metric_target = target_wrap + if task == "binary": + metric_prediction = predictions_prob[:, 1] + + output_metrics[metric_name] = get_output_from_calculator( + metric_prediction, metric_target, calculator + ) # metrics that do not need the "average" parameter calculators = {