diff --git a/GANDLF/metrics/classification.py b/GANDLF/metrics/classification.py index dd5936ed9..5b0feb9e3 100644 --- a/GANDLF/metrics/classification.py +++ b/GANDLF/metrics/classification.py @@ -101,14 +101,17 @@ def __convert_tensor_to_int(input_tensor: torch.Tensor) -> torch.Tensor: ), } for metric_name, calculator in calculators.items(): + metric_prediction = prediction + metric_target = target if "auroc" in metric_name: - output_metrics[metric_name] = get_output_from_calculator( - predictions_prob, target_wrap, calculator - ) - else: - output_metrics[metric_name] = get_output_from_calculator( - prediction, target, calculator - ) + metric_prediction = predictions_prob + metric_target = target_wrap + if task == "binary": + metric_prediction = predictions_prob[:, 1] + + output_metrics[metric_name] = get_output_from_calculator( + metric_prediction, metric_target, calculator + ) # metrics that do not need the "average" parameter calculators = {